diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..77b7ba25dcb92f592bf0df5c8f500332d13187ab
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,30 @@
+*.pth
+*.pt
+*.onnx
+*.pk
+*.model
+*.zip
+*.tar
+*.pyc
+*.log
+*.o
+*.so
+*.a
+*.exe
+*.out
+.idea
+**.DS_Store**
+**/__pycache__/**
+**.swp
+.vscode/
+.env
+.log
+*.pid
+*.ipynb*
+*.mp4
+build/
+dist/
+.cache/
+server_cache/
+app/.gradio/
+*.pkl
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bdb356ab0608ea3a6145926c0d6906ad4217ee5c
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,22 @@
+# Follow https://verdantfox.com/blog/how-to-use-git-pre-commit-hooks-the-hard-way-and-the-easy-way
+repos:
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.11.0
+ hooks:
+ - id: ruff
+ args: [--fix, --respect-gitignore, --config=pyproject.toml]
+ - id: ruff-format
+ args: [--config=pyproject.toml]
+
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: check-yaml
+ - id: check-toml
+ - id: check-added-large-files
+ args: ['--maxkb=3000'] # Allow files up to 3MB
+ - id: check-case-conflict
+ - id: check-merge-conflict
+ - id: debug-statements
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 6b94e2481927b013b3776ef7cc011e9094341871..35408da19ac0a419ed5692305855e51124358fa7 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,295 @@
-# LightX2V
+
+
⚡️ LightX2V:
轻量级视频生成推理框架
+

+
+[](https://opensource.org/licenses/Apache-2.0)
+[](https://deepwiki.com/ModelTC/lightx2v)
+[](https://lightx2v-en.readthedocs.io/en/latest)
+[](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest)
+[](https://lightx2v-papers-zhcn.readthedocs.io/zh-cn/latest)
+[](https://hub.docker.com/r/lightx2v/lightx2v/tags)
+
+**\[ [English](README.md) | 中文 \]**
+
+
+
+--------------------------------------------------------------------------------
+
+**LightX2V** 是一个先进的轻量级视频生成推理框架,专为提供高效、高性能的视频合成解决方案而设计。该统一平台集成了多种前沿的视频生成技术,支持文本生成视频(T2V)和图像生成视频(I2V)等多样化生成任务。**X2V 表示将不同的输入模态(X,如文本或图像)转换为视频输出(V)**。
+
+> 🌐 **立即在线体验!** 无需安装即可体验 LightX2V:**[LightX2V 在线服务](https://x2v.light-ai.top/login)** - 免费、轻量、快速的AI数字人视频生成平台。
+
+## :fire: 最新动态
+
+- **2025年12月4日:** 🚀 支持 GGUF 格式模型推理,以及在寒武纪 MLU590、MetaX C500 硬件上的部署。
+
+- **2025年11月24日:** 🚀 我们发布了HunyuanVideo-1.5的4步蒸馏模型!这些模型支持**超快速4步推理**,无需CFG配置,相比标准50步推理可实现约**25倍加速**。现已提供基础版本和FP8量化版本:[Hy1.5-Distill-Models](https://huggingface.co/lightx2v/Hy1.5-Distill-Models)。
+
+- **2025年11月21日:** 🚀 我们Day0支持了[HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5)的视频生成模型,同样GPU数量,LightX2V可带来约2倍以上的速度提升,并支持更低显存GPU部署(如24G RTX4090)。支持CFG并行/Ulysses并行,高效Offload,TeaCache/MagCache等技术。同时支持沐曦,寒武纪等国产芯片部署。我们很快将在我们的[HuggingFace主页](https://huggingface.co/lightx2v)更新更多模型,包括步数蒸馏,VAE蒸馏等相关模型。量化模型和轻量VAE模型现已可用:[Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models)用于量化推理,[HunyuanVideo-1.5轻量TAE](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors)用于快速VAE解码。使用教程参考[这里](https://github.com/ModelTC/LightX2V/tree/main/scripts/hunyuan_video_15),或查看[示例目录](https://github.com/ModelTC/LightX2V/tree/main/examples)获取代码示例。
+
+
+## 🏆 性能测试数据 (更新于 2025.12.01)
+
+### 📊 推理框架之间性能对比 (H100)
+
+| Framework | GPUs | Step Time | Speedup |
+|-----------|---------|---------|---------|
+| Diffusers | 1 | 9.77s/it | 1x |
+| xDiT | 1 | 8.93s/it | 1.1x |
+| FastVideo | 1 | 7.35s/it | 1.3x |
+| SGL-Diffusion | 1 | 6.13s/it | 1.6x |
+| **LightX2V** | 1 | **5.18s/it** | **1.9x** 🚀 |
+| FastVideo | 8 | 2.94s/it | 1x |
+| xDiT | 8 | 2.70s/it | 1.1x |
+| SGL-Diffusion | 8 | 1.19s/it | 2.5x |
+| **LightX2V** | 8 | **0.75s/it** | **3.9x** 🚀 |
+
+### 📊 推理框架之间性能对比 (RTX 4090D)
+
+| Framework | GPUs | Step Time | Speedup |
+|-----------|---------|---------|---------|
+| Diffusers | 1 | 30.50s/it | 1x |
+| FastVideo | 1 | 22.66s/it | 1.3x |
+| xDiT | 1 | OOM | OOM |
+| SGL-Diffusion | 1 | OOM | OOM |
+| **LightX2V** | 1 | **20.26s/it** | **1.5x** 🚀 |
+| FastVideo | 8 | 15.48s/it | 1x |
+| xDiT | 8 | OOM | OOM |
+| SGL-Diffusion | 8 | OOM | OOM |
+| **LightX2V** | 8 | **4.75s/it** | **3.3x** 🚀 |
+
+### 📊 LightX2V不同配置之间性能对比
+
+| Framework | GPU | Configuration | Step Time | Speedup |
+|-----------|-----|---------------|-----------|---------------|
+| **LightX2V** | H100 | 8 GPUs + cfg | 0.75s/it | 1x |
+| **LightX2V** | H100 | 8 GPUs + no cfg | 0.39s/it | 1.9x |
+| **LightX2V** | H100 | **8 GPUs + no cfg + fp8** | **0.35s/it** | **2.1x** 🚀 |
+| **LightX2V** | 4090D | 8 GPUs + cfg | 4.75s/it | 1x |
+| **LightX2V** | 4090D | 8 GPUs + no cfg | 3.13s/it | 1.5x |
+| **LightX2V** | 4090D | **8 GPUs + no cfg + fp8** | **2.35s/it** | **2.0x** 🚀 |
+
+**注意**: 所有以上性能数据均在 Wan2.1-I2V-14B-480P(40 steps, 81 frames) 上测试。此外,我们[HuggingFace 主页](https://huggingface.co/lightx2v)还提供了4步蒸馏模型。
+
+
+## 💡 快速开始
+
+
+详细使用说明请参考我们的文档:**[英文文档](https://lightx2v-en.readthedocs.io/en/latest/) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/)**
+
+**我们强烈推荐使用 Docker 环境,这是最简单快捷的环境安装方式。具体参考:文档中的快速入门章节。**
+
+### 从 Git 安装
+```bash
+pip install -v git+https://github.com/ModelTC/LightX2V.git
+```
+
+### 从源码构建
+```bash
+git clone https://github.com/ModelTC/LightX2V.git
+cd LightX2V
+uv pip install -v . # pip install -v .
+```
+
+### (可选)安装注意力/量化算子
+注意力算子安装说明请参考我们的文档:**[英文文档](https://lightx2v-en.readthedocs.io/en/latest/getting_started/quickstart.html#step-4-install-attention-operators) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/quickstart.html#id9)**
+
+### 使用示例
+```python
+# examples/wan/wan_i2v.py
+"""
+Wan2.2 image-to-video generation example.
+This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for Wan2.2 I2V task
+# For wan2.1, use model_cls="wan2.1"
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.2-I2V-A14B",
+ model_cls="wan2.2_moe",
+ task="i2v",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(
+# config_json="configs/wan22/wan_moe_i2v.json"
+# )
+
+# Enable offloading to significantly reduce VRAM usage with minimal speed impact
+# Suitable for RTX 30/40/50 consumer GPUs
+pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block", # For Wan models, supports both "block" and "phase"
+ text_encoder_offload=True,
+ image_encoder_offload=False,
+ vae_offload=False,
+)
+
+# Create generator manually with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=40,
+ height=480, # Can be set to 720 for higher resolution
+ width=832, # Can be set to 1280 for higher resolution
+ num_frames=81,
+ guidance_scale=[3.5, 3.5], # For wan2.1, guidance_scale is a scalar (e.g., 5.0)
+ sample_shift=5.0,
+)
+
+# Generation parameters
+seed = 42
+prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+image_path="/path/to/img_0.jpg"
+save_result_path = "/path/to/save_results/output.mp4"
+
+# Generate video
+pipe.generate(
+ seed=seed,
+ image_path=image_path,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
+
+```
+
+> 💡 **更多示例**: 更多使用案例,包括量化、卸载、缓存等进阶配置,请参考 [examples 目录](https://github.com/ModelTC/LightX2V/tree/main/examples)。
+
+## 🤖 支持的模型生态
+
+### 官方开源模型
+- ✅ [HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5)
+- ✅ [Wan2.1 & Wan2.2](https://huggingface.co/Wan-AI/)
+- ✅ [Qwen-Image](https://huggingface.co/Qwen/Qwen-Image)
+- ✅ [Qwen-Image-Edit](https://huggingface.co/spaces/Qwen/Qwen-Image-Edit)
+- ✅ [Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509)
+
+### 量化模型和蒸馏模型/Lora (**🚀 推荐:4步推理**)
+- ✅ [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- ✅ [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+- ✅ [Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras)
+- ✅ [Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras)
+
+### 轻量级自编码器模型(**🚀 推荐:推理快速 + 内存占用低**)
+- ✅ [Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
+
+### 自回归模型
+- ✅ [Wan2.1-T2V-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
+- ✅ [Self-Forcing](https://github.com/guandeh17/Self-Forcing)
+- ✅ [Matrix-Game-2.0](https://huggingface.co/Skywork/Matrix-Game-2.0)
+
+🔔 可以关注我们的[HuggingFace主页](https://huggingface.co/lightx2v),及时获取我们团队的模型。
+
+💡 参考[模型结构文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/model_structure.html)快速上手 LightX2V
+
+## 🚀 前端展示
+
+我们提供了多种前端界面部署方式:
+
+- **🎨 Gradio界面**: 简洁易用的Web界面,适合快速体验和原型开发
+ - 📖 [Gradio部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html)
+- **🎯 ComfyUI界面**: 强大的节点式工作流界面,支持复杂的视频生成任务
+ - 📖 [ComfyUI部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_comfyui.html)
+- **🚀 Windows一键部署**: 专为Windows用户设计的便捷部署方案,支持自动环境配置和智能参数优化
+ - 📖 [Windows一键部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_local_windows.html)
+
+**💡 推荐方案**:
+- **首次使用**: 建议选择Windows一键部署方案
+- **高级用户**: 推荐使用ComfyUI界面获得更多自定义选项
+- **快速体验**: Gradio界面提供最直观的操作体验
+
+## 🚀 核心特性
+
+### 🎯 **极致性能优化**
+- **🔥 SOTA推理速度**: 通过步数蒸馏和系统优化实现**20倍**极速加速(单GPU)
+- **⚡️ 革命性4步蒸馏**: 将原始40-50步推理压缩至仅需4步,且无需CFG配置
+- **🛠️ 先进算子支持**: 集成顶尖算子,包括[Sage Attention](https://github.com/thu-ml/SageAttention)、[Flash Attention](https://github.com/Dao-AILab/flash-attention)、[Radial Attention](https://github.com/mit-han-lab/radial-attention)、[q8-kernel](https://github.com/KONAKONA666/q8_kernels)、[sgl-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)、[vllm](https://github.com/vllm-project/vllm)
+
+### 💾 **资源高效部署**
+- **💡 突破硬件限制**: **仅需8GB显存 + 16GB内存**即可运行14B模型生成480P/720P视频
+- **🔧 智能参数卸载**: 先进的磁盘-CPU-GPU三级卸载架构,支持阶段/块级别的精细化管理
+- **⚙️ 全面量化支持**: 支持`w8a8-int8`、`w8a8-fp8`、`w4a4-nvfp4`等多种量化策略
+
+### 🎨 **丰富功能生态**
+- **📈 智能特征缓存**: 智能缓存机制,消除冗余计算,提升效率
+- **🔄 并行推理加速**: 多GPU并行处理,显著提升性能表现
+- **📱 灵活部署选择**: 支持Gradio、服务化部署、ComfyUI等多种部署方式
+- **🎛️ 动态分辨率推理**: 自适应分辨率调整,优化生成质量
+- **🎞️ 视频帧插值**: 基于RIFE的帧插值技术,实现流畅的帧率提升
+
+
+## 📚 技术文档
+
+### 📖 **方法教程**
+- [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) - 量化策略全面指南
+- [特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html) - 智能缓存机制详解
+- [注意力机制](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html) - 前沿注意力算子
+- [参数卸载](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/offload.html) - 三级存储架构
+- [并行推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/parallel.html) - 多GPU加速策略
+- [变分辨率推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/changing_resolution.html) - U型分辨率策略
+- [步数蒸馏](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/step_distill.html) - 4步推理技术
+- [视频帧插值](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/video_frame_interpolation.html) - 基于RIFE的帧插值技术
+
+### 🛠️ **部署指南**
+- [低资源场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_resource.html) - 优化的8GB显存解决方案
+- [低延迟场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_latency.html) - 极速推理优化
+- [Gradio部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html) - Web界面搭建
+- [服务化部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_service.html) - 生产级API服务部署
+- [Lora模型部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/lora_deploy.html) - Lora灵活部署
+
+## 🧾 代码贡献指南
+
+我们通过自动化的预提交钩子来保证代码质量,确保项目代码格式的一致性。
+
+> [!TIP]
+> **安装说明:**
+>
+> 1. 安装必要的依赖:
+> ```shell
+> pip install ruff pre-commit
+> ```
+>
+> 2. 提交前运行:
+> ```shell
+> pre-commit run --all-files
+> ```
+
+感谢您为LightX2V的改进做出贡献!
+
+## 🤝 致谢
+
+我们向所有启发和促进LightX2V开发的模型仓库和研究社区表示诚挚的感谢。此框架基于开源社区的集体努力而构建。
+
+## 🌟 Star 历史
+
+[](https://star-history.com/#ModelTC/lightx2v&Timeline)
+
+## ✏️ 引用
+
+如果您发现LightX2V对您的研究有用,请考虑引用我们的工作:
+
+```bibtex
+@misc{lightx2v,
+ author = {LightX2V Contributors},
+ title = {LightX2V: Light Video Generation Inference Framework},
+ year = {2025},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/ModelTC/lightx2v}},
+}
+```
+
+## 📞 联系与支持
+
+如有任何问题、建议或需要支持,欢迎通过以下方式联系我们:
+- 🐛 [GitHub Issues](https://github.com/ModelTC/lightx2v/issues) - 错误报告和功能请求
+
+---
+
+
+由 LightX2V 团队用 ❤️ 构建
+
diff --git a/app/README.md b/app/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2be884b7e04244e2fb210d27ea7b7978abc072b7
--- /dev/null
+++ b/app/README.md
@@ -0,0 +1,13 @@
+# Gradio Demo
+
+Please refer our gradio deployment doc:
+
+[English doc: Gradio Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_gradio.html)
+
+[中文文档: Gradio 部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html)
+
+## 🚀 Quick Start (快速开始)
+
+For Windows users, we provide a convenient one-click deployment solution with automatic environment configuration and intelligent parameter optimization. Please refer to the [One-Click Gradio Launch](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_local_windows.html) section for detailed instructions.
+
+对于Windows用户,我们提供了便捷的一键部署方式,支持自动环境配置和智能参数优化。详细操作请参考[一键启动Gradio](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_local_windows.html)章节。
diff --git a/app/gradio_demo.py b/app/gradio_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..448ec3363692bae777afeccf3ef82c9450ffc53b
--- /dev/null
+++ b/app/gradio_demo.py
@@ -0,0 +1,1442 @@
+import argparse
+import gc
+import glob
+import importlib.util
+import json
+import os
+
+os.environ["PROFILING_DEBUG_LEVEL"] = "2"
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
+os.environ["DTYPE"] = "BF16"
+import random
+from datetime import datetime
+
+import gradio as gr
+import psutil
+import torch
+from loguru import logger
+
+from lightx2v.utils.input_info import set_input_info
+from lightx2v.utils.set_config import get_default_config
+
+try:
+ from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
+except ImportError:
+ apply_rope_with_cos_sin_cache_inplace = None
+
+
+logger.add(
+ "inference_logs.log",
+ rotation="100 MB",
+ encoding="utf-8",
+ enqueue=True,
+ backtrace=True,
+ diagnose=True,
+)
+
+MAX_NUMPY_SEED = 2**32 - 1
+
+
+def scan_model_path_contents(model_path):
+ """Scan model_path directory and return available files and subdirectories"""
+ if not model_path or not os.path.exists(model_path):
+ return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
+
+ dirs = []
+ files = []
+ safetensors_dirs = []
+ pth_files = []
+
+ try:
+ for item in os.listdir(model_path):
+ item_path = os.path.join(model_path, item)
+ if os.path.isdir(item_path):
+ dirs.append(item)
+ # Check if directory contains safetensors files
+ if glob.glob(os.path.join(item_path, "*.safetensors")):
+ safetensors_dirs.append(item)
+ elif os.path.isfile(item_path):
+ files.append(item)
+ if item.endswith(".pth"):
+ pth_files.append(item)
+ except Exception as e:
+ logger.warning(f"Failed to scan directory: {e}")
+
+ return {
+ "dirs": sorted(dirs),
+ "files": sorted(files),
+ "safetensors_dirs": sorted(safetensors_dirs),
+ "pth_files": sorted(pth_files),
+ }
+
+
+def get_dit_choices(model_path, model_type="wan2.1"):
+ """Get Diffusion model options (filtered by model type)"""
+ contents = scan_model_path_contents(model_path)
+ excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
+ fp8_supported = is_fp8_supported_gpu()
+
+ if model_type == "wan2.1":
+ # wan2.1: filter files/dirs containing wan2.1 or Wan2.1
+ def is_valid(name):
+ name_lower = name.lower()
+ if "wan2.1" not in name_lower:
+ return False
+ if not fp8_supported and "fp8" in name_lower:
+ return False
+ return not any(kw in name_lower for kw in excluded_keywords)
+ else:
+ # wan2.2: filter files/dirs containing wan2.2 or Wan2.2
+ def is_valid(name):
+ name_lower = name.lower()
+ if "wan2.2" not in name_lower:
+ return False
+ if not fp8_supported and "fp8" in name_lower:
+ return False
+ return not any(kw in name_lower for kw in excluded_keywords)
+
+ # Filter matching directories and files
+ dir_choices = [d for d in contents["dirs"] if is_valid(d)]
+ file_choices = [f for f in contents["files"] if is_valid(f)]
+ choices = dir_choices + file_choices
+ return choices if choices else [""]
+
+
+def get_high_noise_choices(model_path):
+ """Get high noise model options (files/dirs containing high_noise)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ def is_valid(name):
+ name_lower = name.lower()
+ if not fp8_supported and "fp8" in name_lower:
+ return False
+ return "high_noise" in name_lower or "high-noise" in name_lower
+
+ dir_choices = [d for d in contents["dirs"] if is_valid(d)]
+ file_choices = [f for f in contents["files"] if is_valid(f)]
+ choices = dir_choices + file_choices
+ return choices if choices else [""]
+
+
+def get_low_noise_choices(model_path):
+ """Get low noise model options (files/dirs containing low_noise)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ def is_valid(name):
+ name_lower = name.lower()
+ if not fp8_supported and "fp8" in name_lower:
+ return False
+ return "low_noise" in name_lower or "low-noise" in name_lower
+
+ dir_choices = [d for d in contents["dirs"] if is_valid(d)]
+ file_choices = [f for f in contents["files"] if is_valid(f)]
+ choices = dir_choices + file_choices
+ return choices if choices else [""]
+
+
+def get_t5_choices(model_path):
+ """Get T5 model options (.pth or .safetensors files containing t5 keyword)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ # Filter from .pth files
+ pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
+
+ # Filter from .safetensors files
+ safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
+
+ # Filter from directories containing safetensors
+ safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
+
+ choices = pth_choices + safetensors_choices + safetensors_dir_choices
+ return choices if choices else [""]
+
+
+def get_clip_choices(model_path):
+ """Get CLIP model options (.pth or .safetensors files containing clip keyword)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ # Filter from .pth files
+ pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
+
+ # Filter from .safetensors files
+ safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
+
+ # Filter from directories containing safetensors
+ safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
+
+ choices = pth_choices + safetensors_choices + safetensors_dir_choices
+ return choices if choices else [""]
+
+
+def get_vae_choices(model_path):
+ """Get VAE model options (.pth or .safetensors files containing vae/VAE/tae keyword)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ # Filter from .pth files
+ pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
+
+ # Filter from .safetensors files
+ safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
+
+ # Filter from directories containing safetensors
+ safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())]
+
+ choices = pth_choices + safetensors_choices + safetensors_dir_choices
+ return choices if choices else [""]
+
+
+def detect_quant_scheme(model_name):
+ """Automatically detect quantization scheme from model name
+ - If model name contains "int8" → "int8"
+ - If model name contains "fp8" and device supports → "fp8"
+ - Otherwise return None (no quantization)
+ """
+ if not model_name:
+ return None
+ name_lower = model_name.lower()
+ if "int8" in name_lower:
+ return "int8"
+ elif "fp8" in name_lower:
+ if is_fp8_supported_gpu():
+ return "fp8"
+ else:
+ # Device doesn't support fp8, return None (use default precision)
+ return None
+ return None
+
+
+def update_model_path_options(model_path, model_type="wan2.1"):
+ """Update all model path selectors when model_path or model_type changes"""
+ dit_choices = get_dit_choices(model_path, model_type)
+ high_noise_choices = get_high_noise_choices(model_path)
+ low_noise_choices = get_low_noise_choices(model_path)
+ t5_choices = get_t5_choices(model_path)
+ clip_choices = get_clip_choices(model_path)
+ vae_choices = get_vae_choices(model_path)
+
+ return (
+ gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
+ gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""),
+ gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""),
+ gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""),
+ gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""),
+ gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
+ )
+
+
+def generate_random_seed():
+ return random.randint(0, MAX_NUMPY_SEED)
+
+
+def is_module_installed(module_name):
+ try:
+ spec = importlib.util.find_spec(module_name)
+ return spec is not None
+ except ModuleNotFoundError:
+ return False
+
+
+def get_available_quant_ops():
+ available_ops = []
+
+ vllm_installed = is_module_installed("vllm")
+ if vllm_installed:
+ available_ops.append(("vllm", True))
+ else:
+ available_ops.append(("vllm", False))
+
+ sgl_installed = is_module_installed("sgl_kernel")
+ if sgl_installed:
+ available_ops.append(("sgl", True))
+ else:
+ available_ops.append(("sgl", False))
+
+ q8f_installed = is_module_installed("q8_kernels")
+ if q8f_installed:
+ available_ops.append(("q8f", True))
+ else:
+ available_ops.append(("q8f", False))
+
+ return available_ops
+
+
+def get_available_attn_ops():
+ available_ops = []
+
+ vllm_installed = is_module_installed("flash_attn")
+ if vllm_installed:
+ available_ops.append(("flash_attn2", True))
+ else:
+ available_ops.append(("flash_attn2", False))
+
+ sgl_installed = is_module_installed("flash_attn_interface")
+ if sgl_installed:
+ available_ops.append(("flash_attn3", True))
+ else:
+ available_ops.append(("flash_attn3", False))
+
+ sage_installed = is_module_installed("sageattention")
+ if sage_installed:
+ available_ops.append(("sage_attn2", True))
+ else:
+ available_ops.append(("sage_attn2", False))
+
+ sage3_installed = is_module_installed("sageattn3")
+ if sage3_installed:
+ available_ops.append(("sage_attn3", True))
+ else:
+ available_ops.append(("sage_attn3", False))
+
+ torch_installed = is_module_installed("torch")
+ if torch_installed:
+ available_ops.append(("torch_sdpa", True))
+ else:
+ available_ops.append(("torch_sdpa", False))
+
+ return available_ops
+
+
+def get_gpu_memory(gpu_idx=0):
+ if not torch.cuda.is_available():
+ return 0
+ try:
+ with torch.cuda.device(gpu_idx):
+ memory_info = torch.cuda.mem_get_info()
+ total_memory = memory_info[1] / (1024**3) # Convert bytes to GB
+ return total_memory
+ except Exception as e:
+ logger.warning(f"Failed to get GPU memory: {e}")
+ return 0
+
+
+def get_cpu_memory():
+ available_bytes = psutil.virtual_memory().available
+ return available_bytes / 1024**3
+
+
+def cleanup_memory():
+ gc.collect()
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ try:
+ import psutil
+
+ if hasattr(psutil, "virtual_memory"):
+ if os.name == "posix":
+ try:
+ os.system("sync")
+ except: # noqa
+ pass
+ except: # noqa
+ pass
+
+
+def generate_unique_filename(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ return os.path.join(output_dir, f"{timestamp}.mp4")
+
+
+def is_fp8_supported_gpu():
+ if not torch.cuda.is_available():
+ return False
+ compute_capability = torch.cuda.get_device_capability(0)
+ major, minor = compute_capability
+ return (major == 8 and minor == 9) or (major >= 9)
+
+
+def is_ada_architecture_gpu():
+ if not torch.cuda.is_available():
+ return False
+ try:
+ gpu_name = torch.cuda.get_device_name(0).upper()
+ ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"]
+ return any(keyword in gpu_name for keyword in ada_keywords)
+ except Exception as e:
+ logger.warning(f"Failed to get GPU name: {e}")
+ return False
+
+
+def get_quantization_options(model_path):
+ """Get quantization options dynamically based on model_path"""
+ import os
+
+ # Check subdirectories
+ subdirs = ["original", "fp8", "int8"]
+ has_subdirs = {subdir: os.path.exists(os.path.join(model_path, subdir)) for subdir in subdirs}
+
+ # Check original files in root directory
+ t5_bf16_exists = os.path.exists(os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth"))
+ clip_fp16_exists = os.path.exists(os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"))
+
+ # Generate options
+ def get_choices(has_subdirs, original_type, fp8_type, int8_type, fallback_type, has_original_file=False):
+ choices = []
+ if has_subdirs["original"]:
+ choices.append(original_type)
+ if has_subdirs["fp8"]:
+ choices.append(fp8_type)
+ if has_subdirs["int8"]:
+ choices.append(int8_type)
+
+ # If no subdirectories but original file exists, add original type
+ if has_original_file:
+ if not choices or "original" not in choices:
+ choices.append(original_type)
+
+ # If no options at all, use default value
+ if not choices:
+ choices = [fallback_type]
+
+ return choices, choices[0]
+
+ # DIT options
+ dit_choices, dit_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16")
+
+ # T5 options - check if original file exists
+ t5_choices, t5_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16", t5_bf16_exists)
+
+ # CLIP options - check if original file exists
+ clip_choices, clip_default = get_choices(has_subdirs, "fp16", "fp8", "int8", "fp16", clip_fp16_exists)
+
+ return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default}
+
+
+def determine_model_cls(model_type, dit_name, high_noise_name):
+ """Determine model_cls based on model type and file name"""
+ # Determine file name to check
+ if model_type == "wan2.1":
+ check_name = dit_name.lower() if dit_name else ""
+ is_distill = "4step" in check_name
+ return "wan2.1_distill" if is_distill else "wan2.1"
+ else:
+ # wan2.2
+ check_name = high_noise_name.lower() if high_noise_name else ""
+ is_distill = "4step" in check_name
+ return "wan2.2_moe_distill" if is_distill else "wan2.2_moe"
+
+
+global_runner = None
+current_config = None
+cur_dit_path = None
+cur_t5_path = None
+cur_clip_path = None
+
+available_quant_ops = get_available_quant_ops()
+quant_op_choices = []
+for op_name, is_installed in available_quant_ops:
+ status_text = "✅ Installed" if is_installed else "❌ Not Installed"
+ display_text = f"{op_name} ({status_text})"
+ quant_op_choices.append((op_name, display_text))
+
+available_attn_ops = get_available_attn_ops()
+# Priority order
+attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
+# Sort by priority, installed ones first, uninstalled ones last
+attn_op_choices = []
+attn_op_dict = dict(available_attn_ops)
+
+# Add installed ones first (by priority)
+for op_name in attn_priority:
+ if op_name in attn_op_dict and attn_op_dict[op_name]:
+ status_text = "✅ Installed"
+ display_text = f"{op_name} ({status_text})"
+ attn_op_choices.append((op_name, display_text))
+
+# Add uninstalled ones (by priority)
+for op_name in attn_priority:
+ if op_name in attn_op_dict and not attn_op_dict[op_name]:
+ status_text = "❌ Not Installed"
+ display_text = f"{op_name} ({status_text})"
+ attn_op_choices.append((op_name, display_text))
+
+# Add other operators not in priority list (installed ones first)
+other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority]
+for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # Installed ones first
+ status_text = "✅ Installed" if is_installed else "❌ Not Installed"
+ display_text = f"{op_name} ({status_text})"
+ attn_op_choices.append((op_name, display_text))
+
+
+def run_inference(
+ prompt,
+ negative_prompt,
+ save_result_path,
+ infer_steps,
+ num_frames,
+ resolution,
+ seed,
+ sample_shift,
+ enable_cfg,
+ cfg_scale,
+ fps,
+ use_tiling_vae,
+ lazy_load,
+ cpu_offload,
+ offload_granularity,
+ t5_cpu_offload,
+ clip_cpu_offload,
+ vae_cpu_offload,
+ unload_modules,
+ attention_type,
+ quant_op,
+ rope_chunk,
+ rope_chunk_size,
+ clean_cuda_cache,
+ model_path_input,
+ model_type_input,
+ task_type_input,
+ dit_path_input,
+ high_noise_path_input,
+ low_noise_path_input,
+ t5_path_input,
+ clip_path_input,
+ vae_path_input,
+ image_path=None,
+):
+ cleanup_memory()
+
+ quant_op = quant_op.split("(")[0].strip()
+ attention_type = attention_type.split("(")[0].strip()
+
+ global global_runner, current_config, model_path, model_cls
+ global cur_dit_path, cur_t5_path, cur_clip_path
+
+ task = task_type_input
+ model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input)
+ logger.info(f"Auto-determined model_cls: {model_cls} (Model type: {model_type_input})")
+
+ if model_type_input == "wan2.1":
+ dit_quant_detected = detect_quant_scheme(dit_path_input)
+ else:
+ dit_quant_detected = detect_quant_scheme(high_noise_path_input)
+ t5_quant_detected = detect_quant_scheme(t5_path_input)
+ clip_quant_detected = detect_quant_scheme(clip_path_input)
+ logger.info(f"Auto-detected quantization scheme - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}")
+
+ if model_path_input and model_path_input.strip():
+ model_path = model_path_input.strip()
+
+ if os.path.exists(os.path.join(model_path, "config.json")):
+ with open(os.path.join(model_path, "config.json"), "r") as f:
+ model_config = json.load(f)
+ else:
+ model_config = {}
+
+ save_result_path = generate_unique_filename(output_dir)
+
+ is_dit_quant = dit_quant_detected != "bf16"
+ is_t5_quant = t5_quant_detected != "bf16"
+ is_clip_quant = clip_quant_detected != "fp16"
+
+ dit_quantized_ckpt = None
+ dit_original_ckpt = None
+ high_noise_quantized_ckpt = None
+ low_noise_quantized_ckpt = None
+ high_noise_original_ckpt = None
+ low_noise_original_ckpt = None
+
+ if is_dit_quant:
+ dit_quant_scheme = f"{dit_quant_detected}-{quant_op}"
+ if "wan2.1" in model_cls:
+ dit_quantized_ckpt = os.path.join(model_path, dit_path_input)
+ else:
+ high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input)
+ low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input)
+ else:
+ dit_quantized_ckpt = "Default"
+ if "wan2.1" in model_cls:
+ dit_original_ckpt = os.path.join(model_path, dit_path_input)
+ else:
+ high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input)
+ low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input)
+
+ # Use frontend-selected T5 path
+ if is_t5_quant:
+ t5_quantized_ckpt = os.path.join(model_path, t5_path_input)
+ t5_quant_scheme = f"{t5_quant_detected}-{quant_op}"
+ t5_original_ckpt = None
+ else:
+ t5_quantized_ckpt = None
+ t5_quant_scheme = None
+ t5_original_ckpt = os.path.join(model_path, t5_path_input)
+
+ # Use frontend-selected CLIP path
+ if is_clip_quant:
+ clip_quantized_ckpt = os.path.join(model_path, clip_path_input)
+ clip_quant_scheme = f"{clip_quant_detected}-{quant_op}"
+ clip_original_ckpt = None
+ else:
+ clip_quantized_ckpt = None
+ clip_quant_scheme = None
+ clip_original_ckpt = os.path.join(model_path, clip_path_input)
+
+ if model_type_input == "wan2.1":
+ current_dit_path = dit_path_input
+ else:
+ current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
+
+ current_t5_path = t5_path_input
+ current_clip_path = clip_path_input
+
+ needs_reinit = (
+ lazy_load
+ or unload_modules
+ or global_runner is None
+ or current_config is None
+ or cur_dit_path is None
+ or cur_dit_path != current_dit_path
+ or cur_t5_path is None
+ or cur_t5_path != current_t5_path
+ or cur_clip_path is None
+ or cur_clip_path != current_clip_path
+ )
+
+ if cfg_scale == 1:
+ enable_cfg = False
+ else:
+ enable_cfg = True
+
+ vae_name_lower = vae_path_input.lower() if vae_path_input else ""
+ use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower
+ use_lightvae = "lightvae" in vae_name_lower
+ need_scaled = "lighttae" in vae_name_lower
+
+ logger.info(f"VAE configuration - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})")
+
+ config_graio = {
+ "infer_steps": infer_steps,
+ "target_video_length": num_frames,
+ "target_width": int(resolution.split("x")[0]),
+ "target_height": int(resolution.split("x")[1]),
+ "self_attn_1_type": attention_type,
+ "cross_attn_1_type": attention_type,
+ "cross_attn_2_type": attention_type,
+ "enable_cfg": enable_cfg,
+ "sample_guide_scale": cfg_scale,
+ "sample_shift": sample_shift,
+ "fps": fps,
+ "feature_caching": "NoCaching",
+ "do_mm_calib": False,
+ "parallel_attn_type": None,
+ "parallel_vae": False,
+ "max_area": False,
+ "vae_stride": (4, 8, 8),
+ "patch_size": (1, 2, 2),
+ "lora_path": None,
+ "strength_model": 1.0,
+ "use_prompt_enhancer": False,
+ "text_len": 512,
+ "denoising_step_list": [1000, 750, 500, 250],
+ "cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
+ "offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity,
+ "t5_cpu_offload": t5_cpu_offload,
+ "clip_cpu_offload": clip_cpu_offload,
+ "vae_cpu_offload": vae_cpu_offload,
+ "dit_quantized": is_dit_quant,
+ "dit_quant_scheme": dit_quant_scheme,
+ "dit_quantized_ckpt": dit_quantized_ckpt,
+ "dit_original_ckpt": dit_original_ckpt,
+ "high_noise_quantized_ckpt": high_noise_quantized_ckpt,
+ "low_noise_quantized_ckpt": low_noise_quantized_ckpt,
+ "high_noise_original_ckpt": high_noise_original_ckpt,
+ "low_noise_original_ckpt": low_noise_original_ckpt,
+ "t5_original_ckpt": t5_original_ckpt,
+ "t5_quantized": is_t5_quant,
+ "t5_quantized_ckpt": t5_quantized_ckpt,
+ "t5_quant_scheme": t5_quant_scheme,
+ "clip_original_ckpt": clip_original_ckpt,
+ "clip_quantized": is_clip_quant,
+ "clip_quantized_ckpt": clip_quantized_ckpt,
+ "clip_quant_scheme": clip_quant_scheme,
+ "vae_path": os.path.join(model_path, vae_path_input),
+ "use_tiling_vae": use_tiling_vae,
+ "use_tae": use_tae,
+ "use_lightvae": use_lightvae,
+ "need_scaled": need_scaled,
+ "lazy_load": lazy_load,
+ "rope_chunk": rope_chunk,
+ "rope_chunk_size": rope_chunk_size,
+ "clean_cuda_cache": clean_cuda_cache,
+ "unload_modules": unload_modules,
+ "seq_parallel": False,
+ "warm_up_cpu_buffers": False,
+ "boundary_step_index": 2,
+ "boundary": 0.900,
+ "use_image_encoder": False if "wan2.2" in model_cls else True,
+ "rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch",
+ }
+
+ args = argparse.Namespace(
+ model_cls=model_cls,
+ seed=seed,
+ task=task,
+ model_path=model_path,
+ prompt_enhancer=None,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image_path=image_path,
+ save_result_path=save_result_path,
+ return_result_tensor=False,
+ )
+
+ config = get_default_config()
+ config.update({k: v for k, v in vars(args).items()})
+ config.update(model_config)
+ config.update(config_graio)
+
+ logger.info(f"Using model: {model_path}")
+ logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
+
+ # Initialize or reuse the runner
+ runner = global_runner
+ if needs_reinit:
+ if runner is not None:
+ del runner
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ from lightx2v.infer import init_runner # noqa
+
+ runner = init_runner(config)
+ input_info = set_input_info(args)
+
+ current_config = config
+ cur_dit_path = current_dit_path
+ cur_t5_path = current_t5_path
+ cur_clip_path = current_clip_path
+
+ if not lazy_load:
+ global_runner = runner
+ else:
+ runner.config = config
+
+ runner.run_pipeline(input_info)
+ cleanup_memory()
+
+ return save_result_path
+
+
+def handle_lazy_load_change(lazy_load_enabled):
+ """Handle lazy_load checkbox change to automatically enable unload_modules"""
+ return gr.update(value=lazy_load_enabled)
+
+
+def auto_configure(resolution):
+ """Auto-configure inference options based on machine configuration and resolution"""
+ default_config = {
+ "lazy_load_val": False,
+ "rope_chunk_val": False,
+ "rope_chunk_size_val": 100,
+ "clean_cuda_cache_val": False,
+ "cpu_offload_val": False,
+ "offload_granularity_val": "block",
+ "t5_cpu_offload_val": False,
+ "clip_cpu_offload_val": False,
+ "vae_cpu_offload_val": False,
+ "unload_modules_val": False,
+ "attention_type_val": attn_op_choices[0][1],
+ "quant_op_val": quant_op_choices[0][1],
+ "use_tiling_vae_val": False,
+ }
+
+ gpu_memory = round(get_gpu_memory())
+ cpu_memory = round(get_cpu_memory())
+
+ attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
+
+ if is_ada_architecture_gpu():
+ quant_op_priority = ["q8f", "vllm", "sgl"]
+ else:
+ quant_op_priority = ["vllm", "sgl", "q8f"]
+
+ for op in attn_priority:
+ if dict(available_attn_ops).get(op):
+ default_config["attention_type_val"] = dict(attn_op_choices)[op]
+ break
+
+ for op in quant_op_priority:
+ if dict(available_quant_ops).get(op):
+ default_config["quant_op_val"] = dict(quant_op_choices)[op]
+ break
+
+ if resolution in [
+ "1280x720",
+ "720x1280",
+ "1280x544",
+ "544x1280",
+ "1104x832",
+ "832x1104",
+ "960x960",
+ ]:
+ res = "720p"
+ elif resolution in [
+ "960x544",
+ "544x960",
+ ]:
+ res = "540p"
+ else:
+ res = "480p"
+
+ if res == "720p":
+ gpu_rules = [
+ (80, {}),
+ (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
+ (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
+ (
+ 24,
+ {
+ "cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ },
+ ),
+ (
+ 16,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "offload_granularity_val": "phase",
+ "rope_chunk_val": True,
+ "rope_chunk_size_val": 100,
+ },
+ ),
+ (
+ 8,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "offload_granularity_val": "phase",
+ "rope_chunk_val": True,
+ "rope_chunk_size_val": 100,
+ "clean_cuda_cache_val": True,
+ },
+ ),
+ ]
+
+ else:
+ gpu_rules = [
+ (80, {}),
+ (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
+ (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
+ (
+ 24,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ },
+ ),
+ (
+ 16,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "offload_granularity_val": "phase",
+ },
+ ),
+ (
+ 8,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "offload_granularity_val": "phase",
+ },
+ ),
+ ]
+
+ cpu_rules = [
+ (128, {}),
+ (64, {}),
+ (32, {"unload_modules_val": True}),
+ (
+ 16,
+ {
+ "lazy_load_val": True,
+ "unload_modules_val": True,
+ },
+ ),
+ ]
+
+ for threshold, updates in gpu_rules:
+ if gpu_memory >= threshold:
+ default_config.update(updates)
+ break
+
+ for threshold, updates in cpu_rules:
+ if cpu_memory >= threshold:
+ default_config.update(updates)
+ break
+
+ return (
+ gr.update(value=default_config["lazy_load_val"]),
+ gr.update(value=default_config["rope_chunk_val"]),
+ gr.update(value=default_config["rope_chunk_size_val"]),
+ gr.update(value=default_config["clean_cuda_cache_val"]),
+ gr.update(value=default_config["cpu_offload_val"]),
+ gr.update(value=default_config["offload_granularity_val"]),
+ gr.update(value=default_config["t5_cpu_offload_val"]),
+ gr.update(value=default_config["clip_cpu_offload_val"]),
+ gr.update(value=default_config["vae_cpu_offload_val"]),
+ gr.update(value=default_config["unload_modules_val"]),
+ gr.update(value=default_config["attention_type_val"]),
+ gr.update(value=default_config["quant_op_val"]),
+ gr.update(value=default_config["use_tiling_vae_val"]),
+ )
+
+
+css = """
+ .main-content { max-width: 1600px; margin: auto; padding: 20px; }
+ .warning { color: #ff6b6b; font-weight: bold; }
+
+ /* Model configuration area styles */
+ .model-config {
+ margin-bottom: 20px !important;
+ border: 1px solid #e0e0e0;
+ border-radius: 12px;
+ padding: 15px;
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
+ }
+
+ /* Input parameters area styles */
+ .input-params {
+ margin-bottom: 20px !important;
+ border: 1px solid #e0e0e0;
+ border-radius: 12px;
+ padding: 15px;
+ background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
+ }
+
+ /* Output video area styles */
+ .output-video {
+ border: 1px solid #e0e0e0;
+ border-radius: 12px;
+ padding: 20px;
+ background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
+ min-height: 400px;
+ }
+
+ /* Generate button styles */
+ .generate-btn {
+ width: 100%;
+ margin-top: 20px;
+ padding: 15px 30px !important;
+ font-size: 18px !important;
+ font-weight: bold !important;
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
+ border: none !important;
+ border-radius: 10px !important;
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
+ transition: all 0.3s ease !important;
+ }
+ .generate-btn:hover {
+ transform: translateY(-2px);
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
+ }
+
+ /* Accordion header styles */
+ .model-config .gr-accordion-header,
+ .input-params .gr-accordion-header,
+ .output-video .gr-accordion-header {
+ font-size: 20px !important;
+ font-weight: bold !important;
+ padding: 15px !important;
+ }
+
+ /* Optimize spacing */
+ .gr-row {
+ margin-bottom: 15px;
+ }
+
+ /* Video player styles */
+ .output-video video {
+ border-radius: 10px;
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
+ }
+ """
+
+
+def main():
+ with gr.Blocks(title="Lightx2v (Lightweight Video Inference and Generation Engine)") as demo:
+ gr.Markdown(f"# 🎬 LightX2V Video Generator")
+ gr.HTML(f"")
+ # Main layout: left and right columns
+ with gr.Row():
+ # Left: configuration and input area
+ with gr.Column(scale=5):
+ # Model configuration area
+ with gr.Accordion("🗂️ Model Configuration", open=True, elem_classes=["model-config"]):
+ # FP8 support notice
+ if not is_fp8_supported_gpu():
+ gr.Markdown("⚠️ **Your device does not support FP8 inference**. Models containing FP8 have been automatically hidden.")
+
+ # Hidden state components
+ model_path_input = gr.Textbox(value=model_path, visible=False)
+
+ # Model type + Task type
+ with gr.Row():
+ model_type_input = gr.Radio(
+ label="Model Type",
+ choices=["wan2.1", "wan2.2"],
+ value="wan2.1",
+ info="wan2.2 requires separate high noise and low noise models",
+ )
+ task_type_input = gr.Radio(
+ label="Task Type",
+ choices=["i2v", "t2v"],
+ value="i2v",
+ info="i2v: Image-to-video, t2v: Text-to-video",
+ )
+
+ # wan2.1: Diffusion model (single row)
+ with gr.Row() as wan21_row:
+ dit_path_input = gr.Dropdown(
+ label="🎨 Diffusion Model",
+ choices=get_dit_choices(model_path, "wan2.1"),
+ value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "",
+ allow_custom_value=True,
+ visible=True,
+ )
+
+ # wan2.2 specific: high noise model + low noise model (hidden by default)
+ with gr.Row(visible=False) as wan22_row:
+ high_noise_path_input = gr.Dropdown(
+ label="🔊 High Noise Model",
+ choices=get_high_noise_choices(model_path),
+ value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+ low_noise_path_input = gr.Dropdown(
+ label="🔇 Low Noise Model",
+ choices=get_low_noise_choices(model_path),
+ value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+
+ # Text encoder (single row)
+ with gr.Row():
+ t5_path_input = gr.Dropdown(
+ label="📝 Text Encoder",
+ choices=get_t5_choices(model_path),
+ value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+
+ # Image encoder + VAE decoder
+ with gr.Row():
+ clip_path_input = gr.Dropdown(
+ label="🖼️ Image Encoder",
+ choices=get_clip_choices(model_path),
+ value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+ vae_path_input = gr.Dropdown(
+ label="🎞️ VAE Decoder",
+ choices=get_vae_choices(model_path),
+ value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+
+ # Attention operator and quantization matrix multiplication operator
+ with gr.Row():
+ attention_type = gr.Dropdown(
+ label="⚡ Attention Operator",
+ choices=[op[1] for op in attn_op_choices],
+ value=attn_op_choices[0][1] if attn_op_choices else "",
+ info="Use appropriate attention operators to accelerate inference",
+ )
+ quant_op = gr.Dropdown(
+ label="Quantization Matmul Operator",
+ choices=[op[1] for op in quant_op_choices],
+ value=quant_op_choices[0][1],
+ info="Select quantization matrix multiplication operator to accelerate inference",
+ interactive=True,
+ )
+
+ # Determine if model is distill version
+ def is_distill_model(model_type, dit_path, high_noise_path):
+ """Determine if model is distill version based on model type and path"""
+ if model_type == "wan2.1":
+ check_name = dit_path.lower() if dit_path else ""
+ else:
+ check_name = high_noise_path.lower() if high_noise_path else ""
+ return "4step" in check_name
+
+ # Model type change event
+ def on_model_type_change(model_type, model_path_val):
+ if model_type == "wan2.2":
+ return gr.update(visible=False), gr.update(visible=True), gr.update()
+ else:
+ # Update wan2.1 Diffusion model options
+ dit_choices = get_dit_choices(model_path_val, "wan2.1")
+ return (
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
+ )
+
+ model_type_input.change(
+ fn=on_model_type_change,
+ inputs=[model_type_input, model_path_input],
+ outputs=[wan21_row, wan22_row, dit_path_input],
+ )
+
+ # Input parameters area
+ with gr.Accordion("📥 Input Parameters", open=True, elem_classes=["input-params"]):
+ # Image input (shown for i2v)
+ with gr.Row(visible=True) as image_input_row:
+ image_path = gr.Image(
+ label="Input Image",
+ type="filepath",
+ height=300,
+ interactive=True,
+ )
+
+ # Task type change event
+ def on_task_type_change(task_type):
+ return gr.update(visible=(task_type == "i2v"))
+
+ task_type_input.change(
+ fn=on_task_type_change,
+ inputs=[task_type_input],
+ outputs=[image_input_row],
+ )
+
+ with gr.Row():
+ with gr.Column():
+ prompt = gr.Textbox(
+ label="Prompt",
+ lines=3,
+ placeholder="Describe the video content...",
+ max_lines=5,
+ )
+ with gr.Column():
+ negative_prompt = gr.Textbox(
+ label="Negative Prompt",
+ lines=3,
+ placeholder="What you don't want to appear in the video...",
+ max_lines=5,
+ value="Camera shake, bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
+ )
+ with gr.Column():
+ resolution = gr.Dropdown(
+ choices=[
+ # 720p
+ ("1280x720 (16:9, 720p)", "1280x720"),
+ ("720x1280 (9:16, 720p)", "720x1280"),
+ ("1280x544 (21:9, 720p)", "1280x544"),
+ ("544x1280 (9:21, 720p)", "544x1280"),
+ ("1104x832 (4:3, 720p)", "1104x832"),
+ ("832x1104 (3:4, 720p)", "832x1104"),
+ ("960x960 (1:1, 720p)", "960x960"),
+ # 480p
+ ("960x544 (16:9, 540p)", "960x544"),
+ ("544x960 (9:16, 540p)", "544x960"),
+ ("832x480 (16:9, 480p)", "832x480"),
+ ("480x832 (9:16, 480p)", "480x832"),
+ ("832x624 (4:3, 480p)", "832x624"),
+ ("624x832 (3:4, 480p)", "624x832"),
+ ("720x720 (1:1, 480p)", "720x720"),
+ ("512x512 (1:1, 480p)", "512x512"),
+ ],
+ value="832x480",
+ label="Maximum Resolution",
+ )
+
+ with gr.Column(scale=9):
+ seed = gr.Slider(
+ label="Random Seed",
+ minimum=0,
+ maximum=MAX_NUMPY_SEED,
+ step=1,
+ value=generate_random_seed(),
+ )
+ with gr.Column():
+ default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else ""
+ default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else ""
+ default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
+
+ if default_is_distill:
+ infer_steps = gr.Slider(
+ label="Inference Steps",
+ minimum=1,
+ maximum=100,
+ step=1,
+ value=4,
+ info="Distill model inference steps default to 4.",
+ )
+ else:
+ infer_steps = gr.Slider(
+ label="Inference Steps",
+ minimum=1,
+ maximum=100,
+ step=1,
+ value=40,
+ info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.",
+ )
+
+ # Dynamically update inference steps when model path changes
+ def update_infer_steps(model_type, dit_path, high_noise_path):
+ is_distill = is_distill_model(model_type, dit_path, high_noise_path)
+ if is_distill:
+ return gr.update(minimum=1, maximum=100, value=4, interactive=True)
+ else:
+ return gr.update(minimum=1, maximum=100, value=40, interactive=True)
+
+ # Listen to model path changes
+ dit_path_input.change(
+ fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[infer_steps],
+ )
+ high_noise_path_input.change(
+ fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[infer_steps],
+ )
+ model_type_input.change(
+ fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[infer_steps],
+ )
+
+ # Set default CFG based on model class
+ # CFG scale factor: default to 1 for distill, otherwise 5
+ default_cfg_scale = 1 if default_is_distill else 5
+ # enable_cfg is not exposed to frontend, automatically set based on cfg_scale
+ # If cfg_scale == 1, then enable_cfg = False, otherwise enable_cfg = True
+ default_enable_cfg = False if default_cfg_scale == 1 else True
+ enable_cfg = gr.Checkbox(
+ label="Enable Classifier-Free Guidance",
+ value=default_enable_cfg,
+ visible=False, # Hidden, not exposed to frontend
+ )
+
+ with gr.Row():
+ sample_shift = gr.Slider(
+ label="Distribution Shift",
+ value=5,
+ minimum=0,
+ maximum=10,
+ step=1,
+ info="Controls the degree of distribution shift for samples. Larger values indicate more significant shifts.",
+ )
+ cfg_scale = gr.Slider(
+ label="CFG Scale Factor",
+ minimum=1,
+ maximum=10,
+ step=1,
+ value=default_cfg_scale,
+ info="Controls the influence strength of the prompt. Higher values give more influence to the prompt. When value is 1, CFG is automatically disabled.",
+ )
+
+ # Update enable_cfg based on cfg_scale
+ def update_enable_cfg(cfg_scale_val):
+ """Automatically set enable_cfg based on cfg_scale value"""
+ if cfg_scale_val == 1:
+ return gr.update(value=False)
+ else:
+ return gr.update(value=True)
+
+ # Dynamically update CFG scale factor and enable_cfg when model path changes
+ def update_cfg_scale(model_type, dit_path, high_noise_path):
+ is_distill = is_distill_model(model_type, dit_path, high_noise_path)
+ if is_distill:
+ new_cfg_scale = 1
+ else:
+ new_cfg_scale = 5
+ new_enable_cfg = False if new_cfg_scale == 1 else True
+ return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg)
+
+ dit_path_input.change(
+ fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[cfg_scale, enable_cfg],
+ )
+ high_noise_path_input.change(
+ fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[cfg_scale, enable_cfg],
+ )
+ model_type_input.change(
+ fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[cfg_scale, enable_cfg],
+ )
+
+ cfg_scale.change(
+ fn=update_enable_cfg,
+ inputs=[cfg_scale],
+ outputs=[enable_cfg],
+ )
+
+ with gr.Row():
+ fps = gr.Slider(
+ label="Frames Per Second (FPS)",
+ minimum=8,
+ maximum=30,
+ step=1,
+ value=16,
+ info="Frames per second of the video. Higher FPS results in smoother videos.",
+ )
+ num_frames = gr.Slider(
+ label="Total Frames",
+ minimum=16,
+ maximum=120,
+ step=1,
+ value=81,
+ info="Total number of frames in the video. More frames result in longer videos.",
+ )
+
+ save_result_path = gr.Textbox(
+ label="Output Video Path",
+ value=generate_unique_filename(output_dir),
+ info="Must include .mp4 extension. If left blank or using the default value, a unique filename will be automatically generated.",
+ visible=False, # Hide output path, auto-generated
+ )
+
+ with gr.Column(scale=4):
+ with gr.Accordion("📤 Generated Video", open=True, elem_classes=["output-video"]):
+ output_video = gr.Video(
+ label="",
+ height=600,
+ autoplay=True,
+ show_label=False,
+ )
+
+ infer_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg", elem_classes=["generate-btn"])
+
+ rope_chunk = gr.Checkbox(label="Chunked Rotary Position Embedding", value=False, visible=False)
+ rope_chunk_size = gr.Slider(label="Rotary Embedding Chunk Size", value=100, minimum=100, maximum=10000, step=100, visible=False)
+ unload_modules = gr.Checkbox(label="Unload Modules", value=False, visible=False)
+ clean_cuda_cache = gr.Checkbox(label="Clean CUDA Memory Cache", value=False, visible=False)
+ cpu_offload = gr.Checkbox(label="CPU Offloading", value=False, visible=False)
+ lazy_load = gr.Checkbox(label="Enable Lazy Loading", value=False, visible=False)
+ offload_granularity = gr.Dropdown(label="Dit Offload Granularity", choices=["block", "phase"], value="phase", visible=False)
+ t5_cpu_offload = gr.Checkbox(label="T5 CPU Offloading", value=False, visible=False)
+ clip_cpu_offload = gr.Checkbox(label="CLIP CPU Offloading", value=False, visible=False)
+ vae_cpu_offload = gr.Checkbox(label="VAE CPU Offloading", value=False, visible=False)
+ use_tiling_vae = gr.Checkbox(label="VAE Tiling Inference", value=False, visible=False)
+
+ resolution.change(
+ fn=auto_configure,
+ inputs=[resolution],
+ outputs=[
+ lazy_load,
+ rope_chunk,
+ rope_chunk_size,
+ clean_cuda_cache,
+ cpu_offload,
+ offload_granularity,
+ t5_cpu_offload,
+ clip_cpu_offload,
+ vae_cpu_offload,
+ unload_modules,
+ attention_type,
+ quant_op,
+ use_tiling_vae,
+ ],
+ )
+
+ demo.load(
+ fn=lambda res: auto_configure(res),
+ inputs=[resolution],
+ outputs=[
+ lazy_load,
+ rope_chunk,
+ rope_chunk_size,
+ clean_cuda_cache,
+ cpu_offload,
+ offload_granularity,
+ t5_cpu_offload,
+ clip_cpu_offload,
+ vae_cpu_offload,
+ unload_modules,
+ attention_type,
+ quant_op,
+ use_tiling_vae,
+ ],
+ )
+
+ infer_btn.click(
+ fn=run_inference,
+ inputs=[
+ prompt,
+ negative_prompt,
+ save_result_path,
+ infer_steps,
+ num_frames,
+ resolution,
+ seed,
+ sample_shift,
+ enable_cfg,
+ cfg_scale,
+ fps,
+ use_tiling_vae,
+ lazy_load,
+ cpu_offload,
+ offload_granularity,
+ t5_cpu_offload,
+ clip_cpu_offload,
+ vae_cpu_offload,
+ unload_modules,
+ attention_type,
+ quant_op,
+ rope_chunk,
+ rope_chunk_size,
+ clean_cuda_cache,
+ model_path_input,
+ model_type_input,
+ task_type_input,
+ dit_path_input,
+ high_noise_path_input,
+ low_noise_path_input,
+ t5_path_input,
+ clip_path_input,
+ vae_path_input,
+ image_path,
+ ],
+ outputs=output_video,
+ )
+
+ demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Lightweight Video Generation")
+ parser.add_argument("--model_path", type=str, required=True, help="Model folder path")
+ parser.add_argument("--server_port", type=int, default=7862, help="Server port")
+ parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server IP")
+ parser.add_argument("--output_dir", type=str, default="./outputs", help="Output video save directory")
+ args = parser.parse_args()
+
+ global model_path, model_cls, output_dir
+ model_path = args.model_path
+ model_cls = "wan2.1"
+ output_dir = args.output_dir
+
+ main()
diff --git a/app/gradio_demo_zh.py b/app/gradio_demo_zh.py
new file mode 100644
index 0000000000000000000000000000000000000000..229409f7bf8df4cce3c056fbe334e680abd8fc7f
--- /dev/null
+++ b/app/gradio_demo_zh.py
@@ -0,0 +1,1442 @@
+import argparse
+import gc
+import glob
+import importlib.util
+import json
+import os
+
+os.environ["PROFILING_DEBUG_LEVEL"] = "2"
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
+os.environ["DTYPE"] = "BF16"
+import random
+from datetime import datetime
+
+import gradio as gr
+import psutil
+import torch
+from loguru import logger
+
+from lightx2v.utils.input_info import set_input_info
+from lightx2v.utils.set_config import get_default_config
+
+try:
+ from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
+except ImportError:
+ apply_rope_with_cos_sin_cache_inplace = None
+
+
+logger.add(
+ "inference_logs.log",
+ rotation="100 MB",
+ encoding="utf-8",
+ enqueue=True,
+ backtrace=True,
+ diagnose=True,
+)
+
+MAX_NUMPY_SEED = 2**32 - 1
+
+
+def scan_model_path_contents(model_path):
+ """扫描 model_path 目录,返回可用的文件和子目录"""
+ if not model_path or not os.path.exists(model_path):
+ return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
+
+ dirs = []
+ files = []
+ safetensors_dirs = []
+ pth_files = []
+
+ try:
+ for item in os.listdir(model_path):
+ item_path = os.path.join(model_path, item)
+ if os.path.isdir(item_path):
+ dirs.append(item)
+ # 检查目录是否包含 safetensors 文件
+ if glob.glob(os.path.join(item_path, "*.safetensors")):
+ safetensors_dirs.append(item)
+ elif os.path.isfile(item_path):
+ files.append(item)
+ if item.endswith(".pth"):
+ pth_files.append(item)
+ except Exception as e:
+ logger.warning(f"扫描目录失败: {e}")
+
+ return {
+ "dirs": sorted(dirs),
+ "files": sorted(files),
+ "safetensors_dirs": sorted(safetensors_dirs),
+ "pth_files": sorted(pth_files),
+ }
+
+
+def get_dit_choices(model_path, model_type="wan2.1"):
+ """获取 Diffusion 模型可选项(根据模型类型筛选)"""
+ contents = scan_model_path_contents(model_path)
+ excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
+ fp8_supported = is_fp8_supported_gpu()
+
+ if model_type == "wan2.1":
+ # wan2.1: 筛选包含 wan2.1 或 Wan2.1 的文件/目录
+ def is_valid(name):
+ name_lower = name.lower()
+ if "wan2.1" not in name_lower:
+ return False
+ if not fp8_supported and "fp8" in name_lower:
+ return False
+ return not any(kw in name_lower for kw in excluded_keywords)
+ else:
+ # wan2.2: 筛选包含 wan2.2 或 Wan2.2 的文件/目录
+ def is_valid(name):
+ name_lower = name.lower()
+ if "wan2.2" not in name_lower:
+ return False
+ if not fp8_supported and "fp8" in name_lower:
+ return False
+ return not any(kw in name_lower for kw in excluded_keywords)
+
+ # 筛选符合条件的目录和文件
+ dir_choices = [d for d in contents["dirs"] if is_valid(d)]
+ file_choices = [f for f in contents["files"] if is_valid(f)]
+ choices = dir_choices + file_choices
+ return choices if choices else [""]
+
+
+def get_high_noise_choices(model_path):
+ """获取高噪模型可选项(包含 high_noise 的文件/目录)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ def is_valid(name):
+ name_lower = name.lower()
+ if not fp8_supported and "fp8" in name_lower:
+ return False
+ return "high_noise" in name_lower or "high-noise" in name_lower
+
+ dir_choices = [d for d in contents["dirs"] if is_valid(d)]
+ file_choices = [f for f in contents["files"] if is_valid(f)]
+ choices = dir_choices + file_choices
+ return choices if choices else [""]
+
+
+def get_low_noise_choices(model_path):
+ """获取低噪模型可选项(包含 low_noise 的文件/目录)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ def is_valid(name):
+ name_lower = name.lower()
+ if not fp8_supported and "fp8" in name_lower:
+ return False
+ return "low_noise" in name_lower or "low-noise" in name_lower
+
+ dir_choices = [d for d in contents["dirs"] if is_valid(d)]
+ file_choices = [f for f in contents["files"] if is_valid(f)]
+ choices = dir_choices + file_choices
+ return choices if choices else [""]
+
+
+def get_t5_choices(model_path):
+ """获取 T5 模型可选项(.pth 或 .safetensors 文件,包含 t5 关键字)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ # 从 .pth 文件中筛选
+ pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
+
+ # 从 .safetensors 文件中筛选
+ safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
+
+ # 从包含 safetensors 的目录中筛选
+ safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
+
+ choices = pth_choices + safetensors_choices + safetensors_dir_choices
+ return choices if choices else [""]
+
+
+def get_clip_choices(model_path):
+ """获取 CLIP 模型可选项(.pth 或 .safetensors 文件,包含 clip 关键字)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ # 从 .pth 文件中筛选
+ pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
+
+ # 从 .safetensors 文件中筛选
+ safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
+
+ # 从包含 safetensors 的目录中筛选
+ safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
+
+ choices = pth_choices + safetensors_choices + safetensors_dir_choices
+ return choices if choices else [""]
+
+
+def get_vae_choices(model_path):
+ """获取 VAE 模型可选项(.pth 或 .safetensors 文件,包含 vae/VAE/tae 关键字)"""
+ contents = scan_model_path_contents(model_path)
+ fp8_supported = is_fp8_supported_gpu()
+
+ # 从 .pth 文件中筛选
+ pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
+
+ # 从 .safetensors 文件中筛选
+ safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
+
+ # 从包含 safetensors 的目录中筛选
+ safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())]
+
+ choices = pth_choices + safetensors_choices + safetensors_dir_choices
+ return choices if choices else [""]
+
+
+def detect_quant_scheme(model_name):
+ """根据模型名字自动检测量化精度
+ - 如果模型名字包含 "int8" → "int8"
+ - 如果模型名字包含 "fp8" 且设备支持 → "fp8"
+ - 否则返回 None(表示不使用量化)
+ """
+ if not model_name:
+ return None
+ name_lower = model_name.lower()
+ if "int8" in name_lower:
+ return "int8"
+ elif "fp8" in name_lower:
+ if is_fp8_supported_gpu():
+ return "fp8"
+ else:
+ # 设备不支持fp8,返回None(使用默认精度)
+ return None
+ return None
+
+
+def update_model_path_options(model_path, model_type="wan2.1"):
+ """当 model_path 或 model_type 改变时,更新所有模型路径选择器"""
+ dit_choices = get_dit_choices(model_path, model_type)
+ high_noise_choices = get_high_noise_choices(model_path)
+ low_noise_choices = get_low_noise_choices(model_path)
+ t5_choices = get_t5_choices(model_path)
+ clip_choices = get_clip_choices(model_path)
+ vae_choices = get_vae_choices(model_path)
+
+ return (
+ gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
+ gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""),
+ gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""),
+ gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""),
+ gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""),
+ gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
+ )
+
+
+def generate_random_seed():
+ return random.randint(0, MAX_NUMPY_SEED)
+
+
+def is_module_installed(module_name):
+ try:
+ spec = importlib.util.find_spec(module_name)
+ return spec is not None
+ except ModuleNotFoundError:
+ return False
+
+
+def get_available_quant_ops():
+ available_ops = []
+
+ vllm_installed = is_module_installed("vllm")
+ if vllm_installed:
+ available_ops.append(("vllm", True))
+ else:
+ available_ops.append(("vllm", False))
+
+ sgl_installed = is_module_installed("sgl_kernel")
+ if sgl_installed:
+ available_ops.append(("sgl", True))
+ else:
+ available_ops.append(("sgl", False))
+
+ q8f_installed = is_module_installed("q8_kernels")
+ if q8f_installed:
+ available_ops.append(("q8f", True))
+ else:
+ available_ops.append(("q8f", False))
+
+ return available_ops
+
+
+def get_available_attn_ops():
+ available_ops = []
+
+ vllm_installed = is_module_installed("flash_attn")
+ if vllm_installed:
+ available_ops.append(("flash_attn2", True))
+ else:
+ available_ops.append(("flash_attn2", False))
+
+ sgl_installed = is_module_installed("flash_attn_interface")
+ if sgl_installed:
+ available_ops.append(("flash_attn3", True))
+ else:
+ available_ops.append(("flash_attn3", False))
+
+ sage_installed = is_module_installed("sageattention")
+ if sage_installed:
+ available_ops.append(("sage_attn2", True))
+ else:
+ available_ops.append(("sage_attn2", False))
+
+ sage3_installed = is_module_installed("sageattn3")
+ if sage3_installed:
+ available_ops.append(("sage_attn3", True))
+ else:
+ available_ops.append(("sage_attn3", False))
+
+ torch_installed = is_module_installed("torch")
+ if torch_installed:
+ available_ops.append(("torch_sdpa", True))
+ else:
+ available_ops.append(("torch_sdpa", False))
+
+ return available_ops
+
+
+def get_gpu_memory(gpu_idx=0):
+ if not torch.cuda.is_available():
+ return 0
+ try:
+ with torch.cuda.device(gpu_idx):
+ memory_info = torch.cuda.mem_get_info()
+ total_memory = memory_info[1] / (1024**3) # Convert bytes to GB
+ return total_memory
+ except Exception as e:
+ logger.warning(f"获取GPU内存失败: {e}")
+ return 0
+
+
+def get_cpu_memory():
+ available_bytes = psutil.virtual_memory().available
+ return available_bytes / 1024**3
+
+
+def cleanup_memory():
+ gc.collect()
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ try:
+ import psutil
+
+ if hasattr(psutil, "virtual_memory"):
+ if os.name == "posix":
+ try:
+ os.system("sync")
+ except: # noqa
+ pass
+ except: # noqa
+ pass
+
+
+def generate_unique_filename(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ return os.path.join(output_dir, f"{timestamp}.mp4")
+
+
+def is_fp8_supported_gpu():
+ if not torch.cuda.is_available():
+ return False
+ compute_capability = torch.cuda.get_device_capability(0)
+ major, minor = compute_capability
+ return (major == 8 and minor == 9) or (major >= 9)
+
+
+def is_ada_architecture_gpu():
+ if not torch.cuda.is_available():
+ return False
+ try:
+ gpu_name = torch.cuda.get_device_name(0).upper()
+ ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"]
+ return any(keyword in gpu_name for keyword in ada_keywords)
+ except Exception as e:
+ logger.warning(f"Failed to get GPU name: {e}")
+ return False
+
+
+def get_quantization_options(model_path):
+ """根据model_path动态获取量化选项"""
+ import os
+
+ # 检查子目录
+ subdirs = ["original", "fp8", "int8"]
+ has_subdirs = {subdir: os.path.exists(os.path.join(model_path, subdir)) for subdir in subdirs}
+
+ # 检查根目录下的原始文件
+ t5_bf16_exists = os.path.exists(os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth"))
+ clip_fp16_exists = os.path.exists(os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"))
+
+ # 生成选项
+ def get_choices(has_subdirs, original_type, fp8_type, int8_type, fallback_type, has_original_file=False):
+ choices = []
+ if has_subdirs["original"]:
+ choices.append(original_type)
+ if has_subdirs["fp8"]:
+ choices.append(fp8_type)
+ if has_subdirs["int8"]:
+ choices.append(int8_type)
+
+ # 如果没有子目录但有原始文件,添加原始类型
+ if has_original_file:
+ if not choices or "original" not in choices:
+ choices.append(original_type)
+
+ # 如果没有任何选项,使用默认值
+ if not choices:
+ choices = [fallback_type]
+
+ return choices, choices[0]
+
+ # DIT选项
+ dit_choices, dit_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16")
+
+ # T5选项 - 检查是否有原始文件
+ t5_choices, t5_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16", t5_bf16_exists)
+
+ # CLIP选项 - 检查是否有原始文件
+ clip_choices, clip_default = get_choices(has_subdirs, "fp16", "fp8", "int8", "fp16", clip_fp16_exists)
+
+ return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default}
+
+
+def determine_model_cls(model_type, dit_name, high_noise_name):
+ """根据模型类型和文件名确定 model_cls"""
+ # 确定要检查的文件名
+ if model_type == "wan2.1":
+ check_name = dit_name.lower() if dit_name else ""
+ is_distill = "4step" in check_name
+ return "wan2.1_distill" if is_distill else "wan2.1"
+ else:
+ # wan2.2
+ check_name = high_noise_name.lower() if high_noise_name else ""
+ is_distill = "4step" in check_name
+ return "wan2.2_moe_distill" if is_distill else "wan2.2_moe"
+
+
+global_runner = None
+current_config = None
+cur_dit_path = None
+cur_t5_path = None
+cur_clip_path = None
+
+available_quant_ops = get_available_quant_ops()
+quant_op_choices = []
+for op_name, is_installed in available_quant_ops:
+ status_text = "✅ 已安装" if is_installed else "❌ 未安装"
+ display_text = f"{op_name} ({status_text})"
+ quant_op_choices.append((op_name, display_text))
+
+available_attn_ops = get_available_attn_ops()
+# 优先级顺序
+attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
+# 按优先级排序,已安装的在前,未安装的在后
+attn_op_choices = []
+attn_op_dict = dict(available_attn_ops)
+
+# 先添加已安装的(按优先级)
+for op_name in attn_priority:
+ if op_name in attn_op_dict and attn_op_dict[op_name]:
+ status_text = "✅ 已安装"
+ display_text = f"{op_name} ({status_text})"
+ attn_op_choices.append((op_name, display_text))
+
+# 再添加未安装的(按优先级)
+for op_name in attn_priority:
+ if op_name in attn_op_dict and not attn_op_dict[op_name]:
+ status_text = "❌ 未安装"
+ display_text = f"{op_name} ({status_text})"
+ attn_op_choices.append((op_name, display_text))
+
+# 添加其他不在优先级列表中的算子(已安装的在前)
+other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority]
+for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # 已安装的在前
+ status_text = "✅ 已安装" if is_installed else "❌ 未安装"
+ display_text = f"{op_name} ({status_text})"
+ attn_op_choices.append((op_name, display_text))
+
+
+def run_inference(
+ prompt,
+ negative_prompt,
+ save_result_path,
+ infer_steps,
+ num_frames,
+ resolution,
+ seed,
+ sample_shift,
+ enable_cfg,
+ cfg_scale,
+ fps,
+ use_tiling_vae,
+ lazy_load,
+ cpu_offload,
+ offload_granularity,
+ t5_cpu_offload,
+ clip_cpu_offload,
+ vae_cpu_offload,
+ unload_modules,
+ attention_type,
+ quant_op,
+ rope_chunk,
+ rope_chunk_size,
+ clean_cuda_cache,
+ model_path_input,
+ model_type_input,
+ task_type_input,
+ dit_path_input,
+ high_noise_path_input,
+ low_noise_path_input,
+ t5_path_input,
+ clip_path_input,
+ vae_path_input,
+ image_path=None,
+):
+ cleanup_memory()
+
+ quant_op = quant_op.split("(")[0].strip()
+ attention_type = attention_type.split("(")[0].strip()
+
+ global global_runner, current_config, model_path, model_cls
+ global cur_dit_path, cur_t5_path, cur_clip_path
+
+ task = task_type_input
+ model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input)
+ logger.info(f"自动确定 model_cls: {model_cls} (模型类型: {model_type_input})")
+
+ if model_type_input == "wan2.1":
+ dit_quant_detected = detect_quant_scheme(dit_path_input)
+ else:
+ dit_quant_detected = detect_quant_scheme(high_noise_path_input)
+ t5_quant_detected = detect_quant_scheme(t5_path_input)
+ clip_quant_detected = detect_quant_scheme(clip_path_input)
+ logger.info(f"自动检测量化精度 - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}")
+
+ if model_path_input and model_path_input.strip():
+ model_path = model_path_input.strip()
+
+ if os.path.exists(os.path.join(model_path, "config.json")):
+ with open(os.path.join(model_path, "config.json"), "r") as f:
+ model_config = json.load(f)
+ else:
+ model_config = {}
+
+ save_result_path = generate_unique_filename(output_dir)
+
+ is_dit_quant = dit_quant_detected != "bf16"
+ is_t5_quant = t5_quant_detected != "bf16"
+ is_clip_quant = clip_quant_detected != "fp16"
+
+ dit_quantized_ckpt = None
+ dit_original_ckpt = None
+ high_noise_quantized_ckpt = None
+ low_noise_quantized_ckpt = None
+ high_noise_original_ckpt = None
+ low_noise_original_ckpt = None
+
+ if is_dit_quant:
+ dit_quant_scheme = f"{dit_quant_detected}-{quant_op}"
+ if "wan2.1" in model_cls:
+ dit_quantized_ckpt = os.path.join(model_path, dit_path_input)
+ else:
+ high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input)
+ low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input)
+ else:
+ dit_quantized_ckpt = "Default"
+ if "wan2.1" in model_cls:
+ dit_original_ckpt = os.path.join(model_path, dit_path_input)
+ else:
+ high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input)
+ low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input)
+
+ # 使用前端选择的 T5 路径
+ if is_t5_quant:
+ t5_quantized_ckpt = os.path.join(model_path, t5_path_input)
+ t5_quant_scheme = f"{t5_quant_detected}-{quant_op}"
+ t5_original_ckpt = None
+ else:
+ t5_quantized_ckpt = None
+ t5_quant_scheme = None
+ t5_original_ckpt = os.path.join(model_path, t5_path_input)
+
+ # 使用前端选择的 CLIP 路径
+ if is_clip_quant:
+ clip_quantized_ckpt = os.path.join(model_path, clip_path_input)
+ clip_quant_scheme = f"{clip_quant_detected}-{quant_op}"
+ clip_original_ckpt = None
+ else:
+ clip_quantized_ckpt = None
+ clip_quant_scheme = None
+ clip_original_ckpt = os.path.join(model_path, clip_path_input)
+
+ if model_type_input == "wan2.1":
+ current_dit_path = dit_path_input
+ else:
+ current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
+
+ current_t5_path = t5_path_input
+ current_clip_path = clip_path_input
+
+ needs_reinit = (
+ lazy_load
+ or unload_modules
+ or global_runner is None
+ or current_config is None
+ or cur_dit_path is None
+ or cur_dit_path != current_dit_path
+ or cur_t5_path is None
+ or cur_t5_path != current_t5_path
+ or cur_clip_path is None
+ or cur_clip_path != current_clip_path
+ )
+
+ if cfg_scale == 1:
+ enable_cfg = False
+ else:
+ enable_cfg = True
+
+ vae_name_lower = vae_path_input.lower() if vae_path_input else ""
+ use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower
+ use_lightvae = "lightvae" in vae_name_lower
+ need_scaled = "lighttae" in vae_name_lower
+
+ logger.info(f"VAE 配置 - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})")
+
+ config_graio = {
+ "infer_steps": infer_steps,
+ "target_video_length": num_frames,
+ "target_width": int(resolution.split("x")[0]),
+ "target_height": int(resolution.split("x")[1]),
+ "self_attn_1_type": attention_type,
+ "cross_attn_1_type": attention_type,
+ "cross_attn_2_type": attention_type,
+ "enable_cfg": enable_cfg,
+ "sample_guide_scale": cfg_scale,
+ "sample_shift": sample_shift,
+ "fps": fps,
+ "feature_caching": "NoCaching",
+ "do_mm_calib": False,
+ "parallel_attn_type": None,
+ "parallel_vae": False,
+ "max_area": False,
+ "vae_stride": (4, 8, 8),
+ "patch_size": (1, 2, 2),
+ "lora_path": None,
+ "strength_model": 1.0,
+ "use_prompt_enhancer": False,
+ "text_len": 512,
+ "denoising_step_list": [1000, 750, 500, 250],
+ "cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
+ "offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity,
+ "t5_cpu_offload": t5_cpu_offload,
+ "clip_cpu_offload": clip_cpu_offload,
+ "vae_cpu_offload": vae_cpu_offload,
+ "dit_quantized": is_dit_quant,
+ "dit_quant_scheme": dit_quant_scheme,
+ "dit_quantized_ckpt": dit_quantized_ckpt,
+ "dit_original_ckpt": dit_original_ckpt,
+ "high_noise_quantized_ckpt": high_noise_quantized_ckpt,
+ "low_noise_quantized_ckpt": low_noise_quantized_ckpt,
+ "high_noise_original_ckpt": high_noise_original_ckpt,
+ "low_noise_original_ckpt": low_noise_original_ckpt,
+ "t5_original_ckpt": t5_original_ckpt,
+ "t5_quantized": is_t5_quant,
+ "t5_quantized_ckpt": t5_quantized_ckpt,
+ "t5_quant_scheme": t5_quant_scheme,
+ "clip_original_ckpt": clip_original_ckpt,
+ "clip_quantized": is_clip_quant,
+ "clip_quantized_ckpt": clip_quantized_ckpt,
+ "clip_quant_scheme": clip_quant_scheme,
+ "vae_path": os.path.join(model_path, vae_path_input),
+ "use_tiling_vae": use_tiling_vae,
+ "use_tae": use_tae,
+ "use_lightvae": use_lightvae,
+ "need_scaled": need_scaled,
+ "lazy_load": lazy_load,
+ "rope_chunk": rope_chunk,
+ "rope_chunk_size": rope_chunk_size,
+ "clean_cuda_cache": clean_cuda_cache,
+ "unload_modules": unload_modules,
+ "seq_parallel": False,
+ "warm_up_cpu_buffers": False,
+ "boundary_step_index": 2,
+ "boundary": 0.900,
+ "use_image_encoder": False if "wan2.2" in model_cls else True,
+ "rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch",
+ }
+
+ args = argparse.Namespace(
+ model_cls=model_cls,
+ seed=seed,
+ task=task,
+ model_path=model_path,
+ prompt_enhancer=None,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image_path=image_path,
+ save_result_path=save_result_path,
+ return_result_tensor=False,
+ )
+
+ config = get_default_config()
+ config.update({k: v for k, v in vars(args).items()})
+ config.update(model_config)
+ config.update(config_graio)
+
+ logger.info(f"使用模型: {model_path}")
+ logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
+
+ # Initialize or reuse the runner
+ runner = global_runner
+ if needs_reinit:
+ if runner is not None:
+ del runner
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ from lightx2v.infer import init_runner # noqa
+
+ runner = init_runner(config)
+ input_info = set_input_info(args)
+
+ current_config = config
+ cur_dit_path = current_dit_path
+ cur_t5_path = current_t5_path
+ cur_clip_path = current_clip_path
+
+ if not lazy_load:
+ global_runner = runner
+ else:
+ runner.config = config
+
+ runner.run_pipeline(input_info)
+ cleanup_memory()
+
+ return save_result_path
+
+
+def handle_lazy_load_change(lazy_load_enabled):
+ """Handle lazy_load checkbox change to automatically enable unload_modules"""
+ return gr.update(value=lazy_load_enabled)
+
+
+def auto_configure(resolution):
+ """根据机器配置和分辨率自动设置推理选项"""
+ default_config = {
+ "lazy_load_val": False,
+ "rope_chunk_val": False,
+ "rope_chunk_size_val": 100,
+ "clean_cuda_cache_val": False,
+ "cpu_offload_val": False,
+ "offload_granularity_val": "block",
+ "t5_cpu_offload_val": False,
+ "clip_cpu_offload_val": False,
+ "vae_cpu_offload_val": False,
+ "unload_modules_val": False,
+ "attention_type_val": attn_op_choices[0][1],
+ "quant_op_val": quant_op_choices[0][1],
+ "use_tiling_vae_val": False,
+ }
+
+ gpu_memory = round(get_gpu_memory())
+ cpu_memory = round(get_cpu_memory())
+
+ attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
+
+ if is_ada_architecture_gpu():
+ quant_op_priority = ["q8f", "vllm", "sgl"]
+ else:
+ quant_op_priority = ["vllm", "sgl", "q8f"]
+
+ for op in attn_priority:
+ if dict(available_attn_ops).get(op):
+ default_config["attention_type_val"] = dict(attn_op_choices)[op]
+ break
+
+ for op in quant_op_priority:
+ if dict(available_quant_ops).get(op):
+ default_config["quant_op_val"] = dict(quant_op_choices)[op]
+ break
+
+ if resolution in [
+ "1280x720",
+ "720x1280",
+ "1280x544",
+ "544x1280",
+ "1104x832",
+ "832x1104",
+ "960x960",
+ ]:
+ res = "720p"
+ elif resolution in [
+ "960x544",
+ "544x960",
+ ]:
+ res = "540p"
+ else:
+ res = "480p"
+
+ if res == "720p":
+ gpu_rules = [
+ (80, {}),
+ (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
+ (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
+ (
+ 24,
+ {
+ "cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ },
+ ),
+ (
+ 16,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "offload_granularity_val": "phase",
+ "rope_chunk_val": True,
+ "rope_chunk_size_val": 100,
+ },
+ ),
+ (
+ 8,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "offload_granularity_val": "phase",
+ "rope_chunk_val": True,
+ "rope_chunk_size_val": 100,
+ "clean_cuda_cache_val": True,
+ },
+ ),
+ ]
+
+ else:
+ gpu_rules = [
+ (80, {}),
+ (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
+ (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
+ (
+ 24,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ },
+ ),
+ (
+ 16,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "offload_granularity_val": "phase",
+ },
+ ),
+ (
+ 8,
+ {
+ "cpu_offload_val": True,
+ "t5_cpu_offload_val": True,
+ "vae_cpu_offload_val": True,
+ "clip_cpu_offload_val": True,
+ "use_tiling_vae_val": True,
+ "offload_granularity_val": "phase",
+ },
+ ),
+ ]
+
+ cpu_rules = [
+ (128, {}),
+ (64, {}),
+ (32, {"unload_modules_val": True}),
+ (
+ 16,
+ {
+ "lazy_load_val": True,
+ "unload_modules_val": True,
+ },
+ ),
+ ]
+
+ for threshold, updates in gpu_rules:
+ if gpu_memory >= threshold:
+ default_config.update(updates)
+ break
+
+ for threshold, updates in cpu_rules:
+ if cpu_memory >= threshold:
+ default_config.update(updates)
+ break
+
+ return (
+ gr.update(value=default_config["lazy_load_val"]),
+ gr.update(value=default_config["rope_chunk_val"]),
+ gr.update(value=default_config["rope_chunk_size_val"]),
+ gr.update(value=default_config["clean_cuda_cache_val"]),
+ gr.update(value=default_config["cpu_offload_val"]),
+ gr.update(value=default_config["offload_granularity_val"]),
+ gr.update(value=default_config["t5_cpu_offload_val"]),
+ gr.update(value=default_config["clip_cpu_offload_val"]),
+ gr.update(value=default_config["vae_cpu_offload_val"]),
+ gr.update(value=default_config["unload_modules_val"]),
+ gr.update(value=default_config["attention_type_val"]),
+ gr.update(value=default_config["quant_op_val"]),
+ gr.update(value=default_config["use_tiling_vae_val"]),
+ )
+
+
+css = """
+ .main-content { max-width: 1600px; margin: auto; padding: 20px; }
+ .warning { color: #ff6b6b; font-weight: bold; }
+
+ /* 模型配置区域样式 */
+ .model-config {
+ margin-bottom: 20px !important;
+ border: 1px solid #e0e0e0;
+ border-radius: 12px;
+ padding: 15px;
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
+ }
+
+ /* 输入参数区域样式 */
+ .input-params {
+ margin-bottom: 20px !important;
+ border: 1px solid #e0e0e0;
+ border-radius: 12px;
+ padding: 15px;
+ background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
+ }
+
+ /* 输出视频区域样式 */
+ .output-video {
+ border: 1px solid #e0e0e0;
+ border-radius: 12px;
+ padding: 20px;
+ background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
+ min-height: 400px;
+ }
+
+ /* 生成按钮样式 */
+ .generate-btn {
+ width: 100%;
+ margin-top: 20px;
+ padding: 15px 30px !important;
+ font-size: 18px !important;
+ font-weight: bold !important;
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
+ border: none !important;
+ border-radius: 10px !important;
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
+ transition: all 0.3s ease !important;
+ }
+ .generate-btn:hover {
+ transform: translateY(-2px);
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
+ }
+
+ /* Accordion 标题样式 */
+ .model-config .gr-accordion-header,
+ .input-params .gr-accordion-header,
+ .output-video .gr-accordion-header {
+ font-size: 20px !important;
+ font-weight: bold !important;
+ padding: 15px !important;
+ }
+
+ /* 优化间距 */
+ .gr-row {
+ margin-bottom: 15px;
+ }
+
+ /* 视频播放器样式 */
+ .output-video video {
+ border-radius: 10px;
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
+ }
+ """
+
+
+def main():
+ with gr.Blocks(title="Lightx2v (轻量级视频推理和生成引擎)") as demo:
+ gr.Markdown(f"# 🎬 LightX2V 视频生成器")
+ gr.HTML(f"")
+ # 主布局:左右分栏
+ with gr.Row():
+ # 左侧:配置和输入区域
+ with gr.Column(scale=5):
+ # 模型配置区域
+ with gr.Accordion("🗂️ 模型配置", open=True, elem_classes=["model-config"]):
+ # FP8 支持提示
+ if not is_fp8_supported_gpu():
+ gr.Markdown("⚠️ **您的设备不支持fp8推理**,已自动隐藏包含fp8的模型选项。")
+
+ # 隐藏的状态组件
+ model_path_input = gr.Textbox(value=model_path, visible=False)
+
+ # 模型类型 + 任务类型
+ with gr.Row():
+ model_type_input = gr.Radio(
+ label="模型类型",
+ choices=["wan2.1", "wan2.2"],
+ value="wan2.1",
+ info="wan2.2 需要分别指定高噪模型和低噪模型",
+ )
+ task_type_input = gr.Radio(
+ label="任务类型",
+ choices=["i2v", "t2v"],
+ value="i2v",
+ info="i2v: 图生视频, t2v: 文生视频",
+ )
+
+ # wan2.1:Diffusion模型(单独一行)
+ with gr.Row() as wan21_row:
+ dit_path_input = gr.Dropdown(
+ label="🎨 Diffusion模型",
+ choices=get_dit_choices(model_path, "wan2.1"),
+ value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "",
+ allow_custom_value=True,
+ visible=True,
+ )
+
+ # wan2.2 专用:高噪模型 + 低噪模型(默认隐藏)
+ with gr.Row(visible=False) as wan22_row:
+ high_noise_path_input = gr.Dropdown(
+ label="🔊 高噪模型",
+ choices=get_high_noise_choices(model_path),
+ value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+ low_noise_path_input = gr.Dropdown(
+ label="🔇 低噪模型",
+ choices=get_low_noise_choices(model_path),
+ value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+
+ # 文本编码器(单独一行)
+ with gr.Row():
+ t5_path_input = gr.Dropdown(
+ label="📝 文本编码器",
+ choices=get_t5_choices(model_path),
+ value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+
+ # 图像编码器 + VAE解码器
+ with gr.Row():
+ clip_path_input = gr.Dropdown(
+ label="🖼️ 图像编码器",
+ choices=get_clip_choices(model_path),
+ value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+ vae_path_input = gr.Dropdown(
+ label="🎞️ VAE解码器",
+ choices=get_vae_choices(model_path),
+ value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "",
+ allow_custom_value=True,
+ )
+
+ # 注意力算子和量化矩阵乘法算子
+ with gr.Row():
+ attention_type = gr.Dropdown(
+ label="⚡ 注意力算子",
+ choices=[op[1] for op in attn_op_choices],
+ value=attn_op_choices[0][1] if attn_op_choices else "",
+ info="使用适当的注意力算子加速推理",
+ )
+ quant_op = gr.Dropdown(
+ label="量化矩阵乘法算子",
+ choices=[op[1] for op in quant_op_choices],
+ value=quant_op_choices[0][1],
+ info="选择量化矩阵乘法算子以加速推理",
+ interactive=True,
+ )
+
+ # 判断模型是否是 distill 版本
+ def is_distill_model(model_type, dit_path, high_noise_path):
+ """根据模型类型和路径判断是否是 distill 版本"""
+ if model_type == "wan2.1":
+ check_name = dit_path.lower() if dit_path else ""
+ else:
+ check_name = high_noise_path.lower() if high_noise_path else ""
+ return "4step" in check_name
+
+ # 模型类型切换事件
+ def on_model_type_change(model_type, model_path_val):
+ if model_type == "wan2.2":
+ return gr.update(visible=False), gr.update(visible=True), gr.update()
+ else:
+ # 更新 wan2.1 的 Diffusion 模型选项
+ dit_choices = get_dit_choices(model_path_val, "wan2.1")
+ return (
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
+ )
+
+ model_type_input.change(
+ fn=on_model_type_change,
+ inputs=[model_type_input, model_path_input],
+ outputs=[wan21_row, wan22_row, dit_path_input],
+ )
+
+ # 输入参数区域
+ with gr.Accordion("📥 输入参数", open=True, elem_classes=["input-params"]):
+ # 图片输入(i2v 时显示)
+ with gr.Row(visible=True) as image_input_row:
+ image_path = gr.Image(
+ label="输入图像",
+ type="filepath",
+ height=300,
+ interactive=True,
+ )
+
+ # 任务类型切换事件
+ def on_task_type_change(task_type):
+ return gr.update(visible=(task_type == "i2v"))
+
+ task_type_input.change(
+ fn=on_task_type_change,
+ inputs=[task_type_input],
+ outputs=[image_input_row],
+ )
+
+ with gr.Row():
+ with gr.Column():
+ prompt = gr.Textbox(
+ label="提示词",
+ lines=3,
+ placeholder="描述视频内容...",
+ max_lines=5,
+ )
+ with gr.Column():
+ negative_prompt = gr.Textbox(
+ label="负向提示词",
+ lines=3,
+ placeholder="不希望出现在视频中的内容...",
+ max_lines=5,
+ value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ )
+ with gr.Column():
+ resolution = gr.Dropdown(
+ choices=[
+ # 720p
+ ("1280x720 (16:9, 720p)", "1280x720"),
+ ("720x1280 (9:16, 720p)", "720x1280"),
+ ("1280x544 (21:9, 720p)", "1280x544"),
+ ("544x1280 (9:21, 720p)", "544x1280"),
+ ("1104x832 (4:3, 720p)", "1104x832"),
+ ("832x1104 (3:4, 720p)", "832x1104"),
+ ("960x960 (1:1, 720p)", "960x960"),
+ # 480p
+ ("960x544 (16:9, 540p)", "960x544"),
+ ("544x960 (9:16, 540p)", "544x960"),
+ ("832x480 (16:9, 480p)", "832x480"),
+ ("480x832 (9:16, 480p)", "480x832"),
+ ("832x624 (4:3, 480p)", "832x624"),
+ ("624x832 (3:4, 480p)", "624x832"),
+ ("720x720 (1:1, 480p)", "720x720"),
+ ("512x512 (1:1, 480p)", "512x512"),
+ ],
+ value="832x480",
+ label="最大分辨率",
+ )
+
+ with gr.Column(scale=9):
+ seed = gr.Slider(
+ label="随机种子",
+ minimum=0,
+ maximum=MAX_NUMPY_SEED,
+ step=1,
+ value=generate_random_seed(),
+ )
+ with gr.Column():
+ default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else ""
+ default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else ""
+ default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
+
+ if default_is_distill:
+ infer_steps = gr.Slider(
+ label="推理步数",
+ minimum=1,
+ maximum=100,
+ step=1,
+ value=4,
+ info="蒸馏模型推理步数默认为4。",
+ )
+ else:
+ infer_steps = gr.Slider(
+ label="推理步数",
+ minimum=1,
+ maximum=100,
+ step=1,
+ value=40,
+ info="视频生成的推理步数。增加步数可能提高质量但降低速度。",
+ )
+
+ # 当模型路径改变时,动态更新推理步数
+ def update_infer_steps(model_type, dit_path, high_noise_path):
+ is_distill = is_distill_model(model_type, dit_path, high_noise_path)
+ if is_distill:
+ return gr.update(minimum=1, maximum=100, value=4, interactive=True)
+ else:
+ return gr.update(minimum=1, maximum=100, value=40, interactive=True)
+
+ # 监听模型路径变化
+ dit_path_input.change(
+ fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[infer_steps],
+ )
+ high_noise_path_input.change(
+ fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[infer_steps],
+ )
+ model_type_input.change(
+ fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[infer_steps],
+ )
+
+ # 根据模型类别设置默认CFG
+ # CFG缩放因子:distill 时默认为 1,否则默认为 5
+ default_cfg_scale = 1 if default_is_distill else 5
+ # enable_cfg 不暴露到前端,根据 cfg_scale 自动设置
+ # 如果 cfg_scale == 1,则 enable_cfg = False,否则 enable_cfg = True
+ default_enable_cfg = False if default_cfg_scale == 1 else True
+ enable_cfg = gr.Checkbox(
+ label="启用无分类器引导",
+ value=default_enable_cfg,
+ visible=False, # 隐藏,不暴露到前端
+ )
+
+ with gr.Row():
+ sample_shift = gr.Slider(
+ label="分布偏移",
+ value=5,
+ minimum=0,
+ maximum=10,
+ step=1,
+ info="控制样本分布偏移的程度。值越大表示偏移越明显。",
+ )
+ cfg_scale = gr.Slider(
+ label="CFG缩放因子",
+ minimum=1,
+ maximum=10,
+ step=1,
+ value=default_cfg_scale,
+ info="控制提示词的影响强度。值越高,提示词的影响越大。当值为1时,自动禁用CFG。",
+ )
+
+ # 根据 cfg_scale 更新 enable_cfg
+ def update_enable_cfg(cfg_scale_val):
+ """根据 cfg_scale 的值自动设置 enable_cfg"""
+ if cfg_scale_val == 1:
+ return gr.update(value=False)
+ else:
+ return gr.update(value=True)
+
+ # 当模型路径改变时,动态更新 CFG 缩放因子和 enable_cfg
+ def update_cfg_scale(model_type, dit_path, high_noise_path):
+ is_distill = is_distill_model(model_type, dit_path, high_noise_path)
+ if is_distill:
+ new_cfg_scale = 1
+ else:
+ new_cfg_scale = 5
+ new_enable_cfg = False if new_cfg_scale == 1 else True
+ return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg)
+
+ dit_path_input.change(
+ fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[cfg_scale, enable_cfg],
+ )
+ high_noise_path_input.change(
+ fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[cfg_scale, enable_cfg],
+ )
+ model_type_input.change(
+ fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
+ inputs=[model_type_input, dit_path_input, high_noise_path_input],
+ outputs=[cfg_scale, enable_cfg],
+ )
+
+ cfg_scale.change(
+ fn=update_enable_cfg,
+ inputs=[cfg_scale],
+ outputs=[enable_cfg],
+ )
+
+ with gr.Row():
+ fps = gr.Slider(
+ label="每秒帧数(FPS)",
+ minimum=8,
+ maximum=30,
+ step=1,
+ value=16,
+ info="视频的每秒帧数。较高的FPS会产生更流畅的视频。",
+ )
+ num_frames = gr.Slider(
+ label="总帧数",
+ minimum=16,
+ maximum=120,
+ step=1,
+ value=81,
+ info="视频中的总帧数。更多帧数会产生更长的视频。",
+ )
+
+ save_result_path = gr.Textbox(
+ label="输出视频路径",
+ value=generate_unique_filename(output_dir),
+ info="必须包含.mp4扩展名。如果留空或使用默认值,将自动生成唯一文件名。",
+ visible=False, # 隐藏输出路径,自动生成
+ )
+
+ with gr.Column(scale=4):
+ with gr.Accordion("📤 生成的视频", open=True, elem_classes=["output-video"]):
+ output_video = gr.Video(
+ label="",
+ height=600,
+ autoplay=True,
+ show_label=False,
+ )
+
+ infer_btn = gr.Button("🎬 生成视频", variant="primary", size="lg", elem_classes=["generate-btn"])
+
+ rope_chunk = gr.Checkbox(label="分块旋转位置编码", value=False, visible=False)
+ rope_chunk_size = gr.Slider(label="旋转编码块大小", value=100, minimum=100, maximum=10000, step=100, visible=False)
+ unload_modules = gr.Checkbox(label="卸载模块", value=False, visible=False)
+ clean_cuda_cache = gr.Checkbox(label="清理CUDA内存缓存", value=False, visible=False)
+ cpu_offload = gr.Checkbox(label="CPU卸载", value=False, visible=False)
+ lazy_load = gr.Checkbox(label="启用延迟加载", value=False, visible=False)
+ offload_granularity = gr.Dropdown(label="Dit卸载粒度", choices=["block", "phase"], value="phase", visible=False)
+ t5_cpu_offload = gr.Checkbox(label="T5 CPU卸载", value=False, visible=False)
+ clip_cpu_offload = gr.Checkbox(label="CLIP CPU卸载", value=False, visible=False)
+ vae_cpu_offload = gr.Checkbox(label="VAE CPU卸载", value=False, visible=False)
+ use_tiling_vae = gr.Checkbox(label="VAE分块推理", value=False, visible=False)
+
+ resolution.change(
+ fn=auto_configure,
+ inputs=[resolution],
+ outputs=[
+ lazy_load,
+ rope_chunk,
+ rope_chunk_size,
+ clean_cuda_cache,
+ cpu_offload,
+ offload_granularity,
+ t5_cpu_offload,
+ clip_cpu_offload,
+ vae_cpu_offload,
+ unload_modules,
+ attention_type,
+ quant_op,
+ use_tiling_vae,
+ ],
+ )
+
+ demo.load(
+ fn=lambda res: auto_configure(res),
+ inputs=[resolution],
+ outputs=[
+ lazy_load,
+ rope_chunk,
+ rope_chunk_size,
+ clean_cuda_cache,
+ cpu_offload,
+ offload_granularity,
+ t5_cpu_offload,
+ clip_cpu_offload,
+ vae_cpu_offload,
+ unload_modules,
+ attention_type,
+ quant_op,
+ use_tiling_vae,
+ ],
+ )
+
+ infer_btn.click(
+ fn=run_inference,
+ inputs=[
+ prompt,
+ negative_prompt,
+ save_result_path,
+ infer_steps,
+ num_frames,
+ resolution,
+ seed,
+ sample_shift,
+ enable_cfg,
+ cfg_scale,
+ fps,
+ use_tiling_vae,
+ lazy_load,
+ cpu_offload,
+ offload_granularity,
+ t5_cpu_offload,
+ clip_cpu_offload,
+ vae_cpu_offload,
+ unload_modules,
+ attention_type,
+ quant_op,
+ rope_chunk,
+ rope_chunk_size,
+ clean_cuda_cache,
+ model_path_input,
+ model_type_input,
+ task_type_input,
+ dit_path_input,
+ high_noise_path_input,
+ low_noise_path_input,
+ t5_path_input,
+ clip_path_input,
+ vae_path_input,
+ image_path,
+ ],
+ outputs=output_video,
+ )
+
+ demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="轻量级视频生成")
+ parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径")
+ parser.add_argument("--server_port", type=int, default=7862, help="服务器端口")
+ parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP")
+ parser.add_argument("--output_dir", type=str, default="./outputs", help="输出视频保存目录")
+ args = parser.parse_args()
+
+ global model_path, model_cls, output_dir
+ model_path = args.model_path
+ model_cls = "wan2.1"
+ output_dir = args.output_dir
+
+ main()
diff --git a/app/run_gradio.sh b/app/run_gradio.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f5dbcc55605f33c7d6504a7e3f7e1da372236eb2
--- /dev/null
+++ b/app/run_gradio.sh
@@ -0,0 +1,181 @@
+#!/bin/bash
+
+# Lightx2v Gradio Demo Startup Script
+# Supports both Image-to-Video (i2v) and Text-to-Video (t2v) modes
+
+# ==================== Configuration Area ====================
+# ⚠️ Important: Please modify the following paths according to your actual environment
+
+# 🚨 Storage Performance Tips 🚨
+# 💾 Strongly recommend storing model files on SSD solid-state drives!
+# 📈 SSD can significantly improve model loading speed and inference performance
+# 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience
+
+
+# Lightx2v project root directory path
+# Example: /home/user/lightx2v or /data/video_gen/lightx2v
+lightx2v_path=/data/video_gen/lightx2v_debug/LightX2V
+
+# Model path configuration
+# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
+model_path=/models/
+
+# Server configuration
+server_name="0.0.0.0"
+server_port=8033
+
+# Output directory configuration
+output_dir="./outputs"
+
+# GPU configuration
+gpu_id=0
+
+# ==================== Environment Variables Setup ====================
+export CUDA_VISIBLE_DEVICES=$gpu_id
+export CUDA_LAUNCH_BLOCKING=1
+export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
+export PROFILING_DEBUG_LEVEL=2
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+
+# ==================== Parameter Parsing ====================
+# Default interface language
+lang="zh"
+
+# 解析命令行参数
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --lang)
+ lang="$2"
+ shift 2
+ ;;
+ --port)
+ server_port="$2"
+ shift 2
+ ;;
+ --gpu)
+ gpu_id="$2"
+ export CUDA_VISIBLE_DEVICES=$gpu_id
+ shift 2
+ ;;
+ --output_dir)
+ output_dir="$2"
+ shift 2
+ ;;
+ --model_path)
+ model_path="$2"
+ shift 2
+ ;;
+ --help)
+ echo "🎬 Lightx2v Gradio Demo Startup Script"
+ echo "=========================================="
+ echo "Usage: $0 [options]"
+ echo ""
+ echo "📋 Available options:"
+ echo " --lang zh|en Interface language (default: zh)"
+ echo " zh: Chinese interface"
+ echo " en: English interface"
+ echo " --port PORT Server port (default: 8032)"
+ echo " --gpu GPU_ID GPU device ID (default: 0)"
+ echo " --model_path PATH Model path (default: configured in script)"
+ echo " --output_dir DIR Output video save directory (default: ./outputs)"
+ echo " --help Show this help message"
+ echo ""
+ echo "📝 Notes:"
+ echo " - Task type (i2v/t2v) and model type are selected in the web UI"
+ echo " - Model class is auto-detected based on selected diffusion model"
+ echo " - Edit script to configure model paths before first use"
+ echo " - Ensure required Python dependencies are installed"
+ echo " - Recommended to use GPU with 8GB+ VRAM"
+ echo " - 🚨 Strongly recommend storing models on SSD for better performance"
+ exit 0
+ ;;
+ *)
+ echo "Unknown parameter: $1"
+ echo "Use --help to see help information"
+ exit 1
+ ;;
+ esac
+done
+
+# ==================== Parameter Validation ====================
+if [[ "$lang" != "zh" && "$lang" != "en" ]]; then
+ echo "Error: Language must be 'zh' or 'en'"
+ exit 1
+fi
+
+# Check if model path exists
+if [[ ! -d "$model_path" ]]; then
+ echo "❌ Error: Model path does not exist"
+ echo "📁 Path: $model_path"
+ echo "🔧 Solutions:"
+ echo " 1. Check model path configuration in script"
+ echo " 2. Ensure model files are properly downloaded"
+ echo " 3. Verify path permissions are correct"
+ echo " 4. 💾 Recommend storing models on SSD for faster loading"
+ exit 1
+fi
+
+# Select demo file based on language
+if [[ "$lang" == "zh" ]]; then
+ demo_file="gradio_demo_zh.py"
+ echo "🌏 Using Chinese interface"
+else
+ demo_file="gradio_demo.py"
+ echo "🌏 Using English interface"
+fi
+
+# Check if demo file exists
+if [[ ! -f "$demo_file" ]]; then
+ echo "❌ Error: Demo file does not exist"
+ echo "📄 File: $demo_file"
+ echo "🔧 Solutions:"
+ echo " 1. Ensure script is run in the correct directory"
+ echo " 2. Check if file has been renamed or moved"
+ echo " 3. Re-clone or download project files"
+ exit 1
+fi
+
+# ==================== System Information Display ====================
+echo "=========================================="
+echo "🚀 Lightx2v Gradio Demo Starting..."
+echo "=========================================="
+echo "📁 Project path: $lightx2v_path"
+echo "🤖 Model path: $model_path"
+echo "🌏 Interface language: $lang"
+echo "🖥️ GPU device: $gpu_id"
+echo "🌐 Server address: $server_name:$server_port"
+echo "📁 Output directory: $output_dir"
+echo "📝 Note: Task type and model class are selected in web UI"
+echo "=========================================="
+
+# Display system resource information
+echo "💻 System resource information:"
+free -h | grep -E "Mem|Swap"
+echo ""
+
+# Display GPU information
+if command -v nvidia-smi &> /dev/null; then
+ echo "🎮 GPU information:"
+ nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits | head -1
+ echo ""
+fi
+
+# ==================== Start Demo ====================
+echo "🎬 Starting Gradio demo..."
+echo "📱 Please access in browser: http://$server_name:$server_port"
+echo "⏹️ Press Ctrl+C to stop service"
+echo "🔄 First startup may take several minutes to load resources..."
+echo "=========================================="
+
+# Start Python demo
+python $demo_file \
+ --model_path "$model_path" \
+ --server_name "$server_name" \
+ --server_port "$server_port" \
+ --output_dir "$output_dir"
+
+# Display final system resource usage
+echo ""
+echo "=========================================="
+echo "📊 Final system resource usage:"
+free -h | grep -E "Mem|Swap"
diff --git a/app/run_gradio_win.bat b/app/run_gradio_win.bat
new file mode 100644
index 0000000000000000000000000000000000000000..7ea43cd8ba0afa3e243731dd2c2d0acfc2c9e770
--- /dev/null
+++ b/app/run_gradio_win.bat
@@ -0,0 +1,196 @@
+@echo off
+chcp 65001 >nul
+echo 🎬 LightX2V Gradio Windows Startup Script
+echo ==========================================
+
+REM ==================== Configuration Area ====================
+REM ⚠️ Important: Please modify the following paths according to your actual environment
+
+REM 🚨 Storage Performance Tips 🚨
+REM 💾 Strongly recommend storing model files on SSD solid-state drives!
+REM 📈 SSD can significantly improve model loading speed and inference performance
+REM 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience
+
+REM LightX2V project root directory path
+REM Example: D:\LightX2V
+set lightx2v_path=/path/to/LightX2V
+
+REM Model path configuration
+REM Model root directory path
+REM Example: D:\models\LightX2V
+set model_path=/path/to/LightX2V
+
+REM Server configuration
+set server_name=127.0.0.1
+set server_port=8032
+
+REM Output directory configuration
+set output_dir=./outputs
+
+REM GPU configuration
+set gpu_id=0
+
+REM ==================== Environment Variables Setup ====================
+set CUDA_VISIBLE_DEVICES=%gpu_id%
+set PYTHONPATH=%lightx2v_path%;%PYTHONPATH%
+set PROFILING_DEBUG_LEVEL=2
+set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+
+REM ==================== Parameter Parsing ====================
+REM Default interface language
+set lang=zh
+
+REM Parse command line arguments
+:parse_args
+if "%1"=="" goto :end_parse
+if "%1"=="--lang" (
+ set lang=%2
+ shift
+ shift
+ goto :parse_args
+)
+if "%1"=="--port" (
+ set server_port=%2
+ shift
+ shift
+ goto :parse_args
+)
+if "%1"=="--gpu" (
+ set gpu_id=%2
+ set CUDA_VISIBLE_DEVICES=%gpu_id%
+ shift
+ shift
+ goto :parse_args
+)
+if "%1"=="--output_dir" (
+ set output_dir=%2
+ shift
+ shift
+ goto :parse_args
+)
+if "%1"=="--help" (
+ echo 🎬 LightX2V Gradio Windows Startup Script
+ echo ==========================================
+ echo Usage: %0 [options]
+ echo.
+ echo 📋 Available options:
+ echo --lang zh^|en Interface language (default: zh)
+ echo zh: Chinese interface
+ echo en: English interface
+ echo --port PORT Server port (default: 8032)
+ echo --gpu GPU_ID GPU device ID (default: 0)
+ echo --output_dir OUTPUT_DIR
+ echo Output video save directory (default: ./outputs)
+ echo --help Show this help message
+ echo.
+ echo 🚀 Usage examples:
+ echo %0 # Default startup
+ echo %0 --lang zh --port 8032 # Start with specified parameters
+ echo %0 --lang en --port 7860 # English interface
+ echo %0 --gpu 1 --port 8032 # Use GPU 1
+ echo %0 --output_dir ./custom_output # Use custom output directory
+ echo.
+ echo 📝 Notes:
+ echo - Edit script to configure model path before first use
+ echo - Ensure required Python dependencies are installed
+ echo - Recommended to use GPU with 8GB+ VRAM
+ echo - 🚨 Strongly recommend storing models on SSD for better performance
+ pause
+ exit /b 0
+)
+echo Unknown parameter: %1
+echo Use --help to see help information
+pause
+exit /b 1
+
+:end_parse
+
+REM ==================== Parameter Validation ====================
+if "%lang%"=="zh" goto :valid_lang
+if "%lang%"=="en" goto :valid_lang
+echo Error: Language must be 'zh' or 'en'
+pause
+exit /b 1
+
+:valid_lang
+
+REM Check if model path exists
+if not exist "%model_path%" (
+ echo ❌ Error: Model path does not exist
+ echo 📁 Path: %model_path%
+ echo 🔧 Solutions:
+ echo 1. Check model path configuration in script
+ echo 2. Ensure model files are properly downloaded
+ echo 3. Verify path permissions are correct
+ echo 4. 💾 Recommend storing models on SSD for faster loading
+ pause
+ exit /b 1
+)
+
+REM Select demo file based on language
+if "%lang%"=="zh" (
+ set demo_file=gradio_demo_zh.py
+ echo 🌏 Using Chinese interface
+) else (
+ set demo_file=gradio_demo.py
+ echo 🌏 Using English interface
+)
+
+REM Check if demo file exists
+if not exist "%demo_file%" (
+ echo ❌ Error: Demo file does not exist
+ echo 📄 File: %demo_file%
+ echo 🔧 Solutions:
+ echo 1. Ensure script is run in the correct directory
+ echo 2. Check if file has been renamed or moved
+ echo 3. Re-clone or download project files
+ pause
+ exit /b 1
+)
+
+REM ==================== System Information Display ====================
+echo ==========================================
+echo 🚀 LightX2V Gradio Starting...
+echo ==========================================
+echo 📁 Project path: %lightx2v_path%
+echo 🤖 Model path: %model_path%
+echo 🌏 Interface language: %lang%
+echo 🖥️ GPU device: %gpu_id%
+echo 🌐 Server address: %server_name%:%server_port%
+echo 📁 Output directory: %output_dir%
+echo ==========================================
+
+REM Display system resource information
+echo 💻 System resource information:
+wmic OS get TotalVisibleMemorySize,FreePhysicalMemory /format:table
+
+REM Display GPU information
+nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits 2>nul
+if errorlevel 1 (
+ echo 🎮 GPU information: Unable to get GPU info
+) else (
+ echo 🎮 GPU information:
+ nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits
+)
+
+REM ==================== Start Demo ====================
+echo 🎬 Starting Gradio demo...
+echo 📱 Please access in browser: http://%server_name%:%server_port%
+echo ⏹️ Press Ctrl+C to stop service
+echo 🔄 First startup may take several minutes to load resources...
+echo ==========================================
+
+REM Start Python demo
+python %demo_file% ^
+ --model_path "%model_path%" ^
+ --server_name %server_name% ^
+ --server_port %server_port% ^
+ --output_dir "%output_dir%"
+
+REM Display final system resource usage
+echo.
+echo ==========================================
+echo 📊 Final system resource usage:
+wmic OS get TotalVisibleMemorySize,FreePhysicalMemory /format:table
+
+pause
diff --git a/assets/figs/offload/fig1_en.png b/assets/figs/offload/fig1_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..acee7c1fa89bf9f8564dd47a844e8553480be9fe
Binary files /dev/null and b/assets/figs/offload/fig1_en.png differ
diff --git a/assets/figs/offload/fig1_zh.png b/assets/figs/offload/fig1_zh.png
new file mode 100644
index 0000000000000000000000000000000000000000..1e1bbe2d92787d27c34c8fe742e5076e70b19161
Binary files /dev/null and b/assets/figs/offload/fig1_zh.png differ
diff --git a/assets/figs/offload/fig2_en.png b/assets/figs/offload/fig2_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..606523e857d192c9c21245b149eae4b63062e4c0
Binary files /dev/null and b/assets/figs/offload/fig2_en.png differ
diff --git a/assets/figs/offload/fig2_zh.png b/assets/figs/offload/fig2_zh.png
new file mode 100644
index 0000000000000000000000000000000000000000..bdabf44507c69b6ff099734e4d8eadf70edcf1f2
Binary files /dev/null and b/assets/figs/offload/fig2_zh.png differ
diff --git a/assets/figs/offload/fig3_en.png b/assets/figs/offload/fig3_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..585c27930d499615d1809d5586c1f0f71554fc35
Binary files /dev/null and b/assets/figs/offload/fig3_en.png differ
diff --git a/assets/figs/offload/fig3_zh.png b/assets/figs/offload/fig3_zh.png
new file mode 100644
index 0000000000000000000000000000000000000000..b6147797b57b3a4c9fedd69dc1713a1a1e12e54e
Binary files /dev/null and b/assets/figs/offload/fig3_zh.png differ
diff --git a/assets/figs/offload/fig4_en.png b/assets/figs/offload/fig4_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..9cc2a32c3a132c7ae2c174dd9de5f1852c0d933e
Binary files /dev/null and b/assets/figs/offload/fig4_en.png differ
diff --git a/assets/figs/offload/fig4_zh.png b/assets/figs/offload/fig4_zh.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e3442295196f81522ab030e559f444e3288e1be
Binary files /dev/null and b/assets/figs/offload/fig4_zh.png differ
diff --git a/assets/figs/offload/fig5_en.png b/assets/figs/offload/fig5_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..489e20f5b6a9574689a125cf7fecf7dbd82336de
Binary files /dev/null and b/assets/figs/offload/fig5_en.png differ
diff --git a/assets/figs/offload/fig5_zh.png b/assets/figs/offload/fig5_zh.png
new file mode 100644
index 0000000000000000000000000000000000000000..3cd30e007d973774b799c47ed83c1bcfce28dc27
Binary files /dev/null and b/assets/figs/offload/fig5_zh.png differ
diff --git a/assets/figs/portabl_windows/pic1.png b/assets/figs/portabl_windows/pic1.png
new file mode 100644
index 0000000000000000000000000000000000000000..04721dafb8ccd1f5ce5041dbbcc5b2c22c83ffc4
Binary files /dev/null and b/assets/figs/portabl_windows/pic1.png differ
diff --git a/assets/figs/portabl_windows/pic_gradio_en.png b/assets/figs/portabl_windows/pic_gradio_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..f08fa4cf1b0797045ad24ae9e52a70577396f6f2
Binary files /dev/null and b/assets/figs/portabl_windows/pic_gradio_en.png differ
diff --git a/assets/figs/portabl_windows/pic_gradio_zh.png b/assets/figs/portabl_windows/pic_gradio_zh.png
new file mode 100644
index 0000000000000000000000000000000000000000..86441fa54ad0d12140a083439c608eb41d2b8a0b
Binary files /dev/null and b/assets/figs/portabl_windows/pic_gradio_zh.png differ
diff --git a/assets/figs/step_distill/fig_01.png b/assets/figs/step_distill/fig_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..cdbca0a24540f1829376f132dbb66f4332d948a0
Binary files /dev/null and b/assets/figs/step_distill/fig_01.png differ
diff --git a/assets/img_lightx2v.png b/assets/img_lightx2v.png
new file mode 100644
index 0000000000000000000000000000000000000000..1067b67b49899403b4cb44b74de4c00683f63e98
Binary files /dev/null and b/assets/img_lightx2v.png differ
diff --git a/assets/inputs/audio/multi_person/config.json b/assets/inputs/audio/multi_person/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..3e0cf1e4fbbccacfa099e4559fa6a31288d3108c
--- /dev/null
+++ b/assets/inputs/audio/multi_person/config.json
@@ -0,0 +1,12 @@
+{
+ "talk_objects": [
+ {
+ "audio": "p1.mp3",
+ "mask": "p1_mask.png"
+ },
+ {
+ "audio": "p2.mp3",
+ "mask": "p2_mask.png"
+ }
+ ]
+}
diff --git a/assets/inputs/audio/multi_person/config_multi_template.json b/assets/inputs/audio/multi_person/config_multi_template.json
new file mode 100644
index 0000000000000000000000000000000000000000..3e0cf1e4fbbccacfa099e4559fa6a31288d3108c
--- /dev/null
+++ b/assets/inputs/audio/multi_person/config_multi_template.json
@@ -0,0 +1,12 @@
+{
+ "talk_objects": [
+ {
+ "audio": "p1.mp3",
+ "mask": "p1_mask.png"
+ },
+ {
+ "audio": "p2.mp3",
+ "mask": "p2_mask.png"
+ }
+ ]
+}
diff --git a/assets/inputs/audio/multi_person/config_single_template.json b/assets/inputs/audio/multi_person/config_single_template.json
new file mode 100644
index 0000000000000000000000000000000000000000..306a61b465058f8265a3596acab24306622d0f30
--- /dev/null
+++ b/assets/inputs/audio/multi_person/config_single_template.json
@@ -0,0 +1,8 @@
+{
+ "talk_objects": [
+ {
+ "audio": "p1.mp3",
+ "mask": "p1_mask.png"
+ }
+ ]
+}
diff --git a/assets/inputs/audio/multi_person/p1.mp3 b/assets/inputs/audio/multi_person/p1.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..41c1042f79cdecc04d2e87493067ea43142b2cad
Binary files /dev/null and b/assets/inputs/audio/multi_person/p1.mp3 differ
diff --git a/assets/inputs/audio/multi_person/p1_mask.png b/assets/inputs/audio/multi_person/p1_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..7cb8d84541a08927a71556f004b2b077a3d5a27a
Binary files /dev/null and b/assets/inputs/audio/multi_person/p1_mask.png differ
diff --git a/assets/inputs/audio/multi_person/p2.mp3 b/assets/inputs/audio/multi_person/p2.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..a2ce0dd648bf43852ab9baaa1bb9dc45f095df98
Binary files /dev/null and b/assets/inputs/audio/multi_person/p2.mp3 differ
diff --git a/assets/inputs/audio/multi_person/p2_mask.png b/assets/inputs/audio/multi_person/p2_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..05c1265091bce025deb000e066263e7fe8e733f6
Binary files /dev/null and b/assets/inputs/audio/multi_person/p2_mask.png differ
diff --git a/assets/inputs/audio/multi_person/seko_input.png b/assets/inputs/audio/multi_person/seko_input.png
new file mode 100644
index 0000000000000000000000000000000000000000..db8bf104725b83ef60894f50b3211aa85847f373
Binary files /dev/null and b/assets/inputs/audio/multi_person/seko_input.png differ
diff --git a/assets/inputs/audio/seko_input.mp3 b/assets/inputs/audio/seko_input.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..a09e2c90448a64dc8501b3eca6a9ceda376ef2e1
Binary files /dev/null and b/assets/inputs/audio/seko_input.mp3 differ
diff --git a/assets/inputs/audio/seko_input.png b/assets/inputs/audio/seko_input.png
new file mode 100644
index 0000000000000000000000000000000000000000..cefda29965efd3e3a3c68e142b1ce8f82dfc93da
Binary files /dev/null and b/assets/inputs/audio/seko_input.png differ
diff --git a/assets/inputs/imgs/flf2v_input_first_frame-fs8.png b/assets/inputs/imgs/flf2v_input_first_frame-fs8.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c67ec443ea2c609ca9576be4168fd879c2d215b
Binary files /dev/null and b/assets/inputs/imgs/flf2v_input_first_frame-fs8.png differ
diff --git a/assets/inputs/imgs/flf2v_input_last_frame-fs8.png b/assets/inputs/imgs/flf2v_input_last_frame-fs8.png
new file mode 100644
index 0000000000000000000000000000000000000000..51f3fbc5233e4d4215cb4c536c6374a5058cd40a
Binary files /dev/null and b/assets/inputs/imgs/flf2v_input_last_frame-fs8.png differ
diff --git a/assets/inputs/imgs/girl.png b/assets/inputs/imgs/girl.png
new file mode 100644
index 0000000000000000000000000000000000000000..8174bf2122301340e7aa9ceb005e1964def965d2
Binary files /dev/null and b/assets/inputs/imgs/girl.png differ
diff --git a/assets/inputs/imgs/img_0.jpg b/assets/inputs/imgs/img_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6f6dcced590cbd18f5eda26bc2362b76ed714847
Binary files /dev/null and b/assets/inputs/imgs/img_0.jpg differ
diff --git a/assets/inputs/imgs/img_1.jpg b/assets/inputs/imgs/img_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2dee840820801bab1ebc05300952bbb9a4945d6a
Binary files /dev/null and b/assets/inputs/imgs/img_1.jpg differ
diff --git a/assets/inputs/imgs/img_2.jpg b/assets/inputs/imgs/img_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..944f15ab131c77613c11c668a0ced4d531c5f9f5
Binary files /dev/null and b/assets/inputs/imgs/img_2.jpg differ
diff --git a/assets/inputs/imgs/snake.png b/assets/inputs/imgs/snake.png
new file mode 100644
index 0000000000000000000000000000000000000000..19085802f5b04656357bf8de812181f3ed161152
Binary files /dev/null and b/assets/inputs/imgs/snake.png differ
diff --git a/configs/attentions/wan_i2v_flash.json b/configs/attentions/wan_i2v_flash.json
new file mode 100644
index 0000000000000000000000000000000000000000..24c7a57e9fb24deebb0b54632e99ab4b585321af
--- /dev/null
+++ b/configs/attentions/wan_i2v_flash.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/attentions/wan_i2v_nbhd_480p.json b/configs/attentions/wan_i2v_nbhd_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..2ca3bca552ae6df8085d78ced97b95c5c7540d9f
--- /dev/null
+++ b/configs/attentions/wan_i2v_nbhd_480p.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "nbhd_attn",
+ "nbhd_attn_setting": {
+ "coefficient": [1.0, 0.5, 0.25, 0.25],
+ "min_width": 2.0
+ },
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/attentions/wan_i2v_nbhd_720p.json b/configs/attentions/wan_i2v_nbhd_720p.json
new file mode 100644
index 0000000000000000000000000000000000000000..10a99cf101849807e53939898936d33d099cbf4f
--- /dev/null
+++ b/configs/attentions/wan_i2v_nbhd_720p.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "nbhd_attn",
+ "nbhd_attn_setting": {
+ "coefficient": [1.0, 0.5, 0.25, 0.25],
+ "min_width": 2.0
+ },
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/attentions/wan_i2v_radial.json b/configs/attentions/wan_i2v_radial.json
new file mode 100644
index 0000000000000000000000000000000000000000..428517a9a0a9aa7e107f81814a7dd86063d7d403
--- /dev/null
+++ b/configs/attentions/wan_i2v_radial.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "radial_attn",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/attentions/wan_i2v_sage.json b/configs/attentions/wan_i2v_sage.json
new file mode 100644
index 0000000000000000000000000000000000000000..9a278a507854d5bf33c79a8d601c5eaff9b848b7
--- /dev/null
+++ b/configs/attentions/wan_i2v_sage.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/attentions/wan_i2v_svg.json b/configs/attentions/wan_i2v_svg.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3d677d3a50c9c373ef13bb1b1d8aeac99019973
--- /dev/null
+++ b/configs/attentions/wan_i2v_svg.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "svg_attn",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/attentions/wan_i2v_svg2.json b/configs/attentions/wan_i2v_svg2.json
new file mode 100644
index 0000000000000000000000000000000000000000..4f707076b7a9fa88872a5a8be7c1c3a33d4d988b
--- /dev/null
+++ b/configs/attentions/wan_i2v_svg2.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "svg2_attn",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/attentions/wan_t2v_sparge.json b/configs/attentions/wan_t2v_sparge.json
new file mode 100644
index 0000000000000000000000000000000000000000..64cce21b769c62bbb77dc509c3bc1abdcfaf948d
--- /dev/null
+++ b/configs/attentions/wan_t2v_sparge.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "sparge": true,
+ "sparge_ckpt": "/path/to/sparge_wan2.1_t2v_1.3B.pt"
+}
diff --git a/configs/bench/lightx2v_1.json b/configs/bench/lightx2v_1.json
new file mode 100644
index 0000000000000000000000000000000000000000..9a278a507854d5bf33c79a8d601c5eaff9b848b7
--- /dev/null
+++ b/configs/bench/lightx2v_1.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/bench/lightx2v_2.json b/configs/bench/lightx2v_2.json
new file mode 100644
index 0000000000000000000000000000000000000000..9a278a507854d5bf33c79a8d601c5eaff9b848b7
--- /dev/null
+++ b/configs/bench/lightx2v_2.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/bench/lightx2v_3.json b/configs/bench/lightx2v_3.json
new file mode 100644
index 0000000000000000000000000000000000000000..2ef47231ff608ec6214e753accf7b9b04a980cfa
--- /dev/null
+++ b/configs/bench/lightx2v_3.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "use_tiling_vae": true
+}
diff --git a/configs/bench/lightx2v_3_distill.json b/configs/bench/lightx2v_3_distill.json
new file mode 100644
index 0000000000000000000000000000000000000000..7cc8d37b9d99646740ed67a872feac8ad67a720b
--- /dev/null
+++ b/configs/bench/lightx2v_3_distill.json
@@ -0,0 +1,22 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "use_tiling_vae": true
+}
diff --git a/configs/bench/lightx2v_4.json b/configs/bench/lightx2v_4.json
new file mode 100644
index 0000000000000000000000000000000000000000..7d83a749111684490a6ea14210ef93f5bc9fd226
--- /dev/null
+++ b/configs/bench/lightx2v_4.json
@@ -0,0 +1,35 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "feature_caching": "Tea",
+ "coefficients": [
+ [
+ 2.57151496e05,
+ -3.54229917e04,
+ 1.40286849e03,
+ -1.35890334e01,
+ 1.32517977e-01
+ ],
+ [
+ -3.02331670e02,
+ 2.23948934e02,
+ -5.25463970e01,
+ 5.87348440e00,
+ -2.01973289e-01
+ ]
+ ],
+ "use_ret_steps": false,
+ "teacache_thresh": 0.2,
+ "use_tiling_vae": true
+}
diff --git a/configs/bench/lightx2v_5.json b/configs/bench/lightx2v_5.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9a72814b490f761e7045f27d7fa062fbf454b8a
--- /dev/null
+++ b/configs/bench/lightx2v_5.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 0.8,
+ "t5_cpu_offload": true,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "use_tiling_vae": true
+}
diff --git a/configs/bench/lightx2v_5_distill.json b/configs/bench/lightx2v_5_distill.json
new file mode 100644
index 0000000000000000000000000000000000000000..8d215c9a9b2ca1a7b28f9d596ad041147afe7956
--- /dev/null
+++ b/configs/bench/lightx2v_5_distill.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 0.8,
+ "t5_cpu_offload": true,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "use_tiling_vae": true
+}
diff --git a/configs/bench/lightx2v_6.json b/configs/bench/lightx2v_6.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9a72814b490f761e7045f27d7fa062fbf454b8a
--- /dev/null
+++ b/configs/bench/lightx2v_6.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 0.8,
+ "t5_cpu_offload": true,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "use_tiling_vae": true
+}
diff --git a/configs/bench/lightx2v_6_distill.json b/configs/bench/lightx2v_6_distill.json
new file mode 100644
index 0000000000000000000000000000000000000000..8d215c9a9b2ca1a7b28f9d596ad041147afe7956
--- /dev/null
+++ b/configs/bench/lightx2v_6_distill.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 0.8,
+ "t5_cpu_offload": true,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "use_tiling_vae": true
+}
diff --git a/configs/caching/adacache/wan_i2v_ada.json b/configs/caching/adacache/wan_i2v_ada.json
new file mode 100644
index 0000000000000000000000000000000000000000..90e635c1eddf00aa0d1d922db8062222e04baa22
--- /dev/null
+++ b/configs/caching/adacache/wan_i2v_ada.json
@@ -0,0 +1,15 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "seed": 442,
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Ada"
+}
diff --git a/configs/caching/adacache/wan_t2v_ada.json b/configs/caching/adacache/wan_t2v_ada.json
new file mode 100644
index 0000000000000000000000000000000000000000..88ea1781cd26671dd3da2192a067dbd475fd99f0
--- /dev/null
+++ b/configs/caching/adacache/wan_t2v_ada.json
@@ -0,0 +1,15 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Ada"
+}
diff --git a/configs/caching/custom/wan_i2v_custom_480p.json b/configs/caching/custom/wan_i2v_custom_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..4980e8fdf84f4792948d647234d34448429e2660
--- /dev/null
+++ b/configs/caching/custom/wan_i2v_custom_480p.json
@@ -0,0 +1,21 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "seed": 442,
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Custom",
+ "coefficients": [
+ [2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01],
+ [-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01]
+ ],
+ "use_ret_steps": false,
+ "teacache_thresh": 0.26
+}
diff --git a/configs/caching/custom/wan_i2v_custom_720p.json b/configs/caching/custom/wan_i2v_custom_720p.json
new file mode 100644
index 0000000000000000000000000000000000000000..7a72be1a3eb0793a099dbedeb338ef0cefbc5594
--- /dev/null
+++ b/configs/caching/custom/wan_i2v_custom_720p.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 1280,
+ "target_width": 720,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Custom",
+ "coefficients": [
+ [
+ 8.10705460e03,
+ 2.13393892e03,
+ -3.72934672e02,
+ 1.66203073e01,
+ -4.17769401e-02
+ ],
+ [
+ -114.36346466,
+ 65.26524496,
+ -18.82220707,
+ 4.91518089,
+ -0.23412683
+ ]
+ ],
+ "use_ret_steps": false,
+ "teacache_thresh": 0.26
+}
diff --git a/configs/caching/custom/wan_t2v_custom_14b.json b/configs/caching/custom/wan_t2v_custom_14b.json
new file mode 100644
index 0000000000000000000000000000000000000000..cdd24283810ac1b9cf2b55bc3d1a50cd9924bea0
--- /dev/null
+++ b/configs/caching/custom/wan_t2v_custom_14b.json
@@ -0,0 +1,33 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Custom",
+ "coefficients": [
+ [
+ -3.03318725e05,
+ 4.90537029e04,
+ -2.65530556e03,
+ 5.87365115e01,
+ -3.15583525e-01
+ ],
+ [
+ -5784.54975374,
+ 5449.50911966,
+ -1811.16591783,
+ 256.27178429,
+ -13.02252404
+ ]
+ ],
+ "use_ret_steps": false,
+ "teacache_thresh": 0.26
+}
diff --git a/configs/caching/custom/wan_t2v_custom_1_3b.json b/configs/caching/custom/wan_t2v_custom_1_3b.json
new file mode 100644
index 0000000000000000000000000000000000000000..1412ba2544c25fb0659bc3bae7ebf6fa7b2674f5
--- /dev/null
+++ b/configs/caching/custom/wan_t2v_custom_1_3b.json
@@ -0,0 +1,33 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Custom",
+ "coefficients": [
+ [
+ -5.21862437e04,
+ 9.23041404e03,
+ -5.28275948e02,
+ 1.36987616e01,
+ -4.99875664e-02
+ ],
+ [
+ 2.39676752e03,
+ -1.31110545e03,
+ 2.01331979e02,
+ -8.29855975e00,
+ 1.37887774e-01
+ ]
+ ],
+ "use_ret_steps": false,
+ "teacache_thresh": 0.26
+}
diff --git a/configs/caching/dualblock/wan_t2v_1_3b.json b/configs/caching/dualblock/wan_t2v_1_3b.json
new file mode 100644
index 0000000000000000000000000000000000000000..f81f56ce85ef7fa3ac425e89abccce1ff9d52896
--- /dev/null
+++ b/configs/caching/dualblock/wan_t2v_1_3b.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "DualBlock",
+ "residual_diff_threshold": 0.03,
+ "downsample_factor": 2
+}
diff --git a/configs/caching/dynamicblock/wan_t2v_1_3b.json b/configs/caching/dynamicblock/wan_t2v_1_3b.json
new file mode 100644
index 0000000000000000000000000000000000000000..fa24ba9e8a66b303fff19ae49f0e1d93cd444ce2
--- /dev/null
+++ b/configs/caching/dynamicblock/wan_t2v_1_3b.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "DynamicBlock",
+ "residual_diff_threshold": 0.003,
+ "downsample_factor": 2
+}
diff --git a/configs/caching/firstblock/wan_t2v_1_3b.json b/configs/caching/firstblock/wan_t2v_1_3b.json
new file mode 100644
index 0000000000000000000000000000000000000000..7dd393d735831fcab6b0d96907a8f50c74739c10
--- /dev/null
+++ b/configs/caching/firstblock/wan_t2v_1_3b.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "FirstBlock",
+ "residual_diff_threshold": 0.02,
+ "downsample_factor": 2
+}
diff --git a/configs/caching/magcache/wan_i2v_dist_cfg_ulysses_mag_480p.json b/configs/caching/magcache/wan_i2v_dist_cfg_ulysses_mag_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..0c7329e19b9f2d3da1e3c14777fc2057f6726cea
--- /dev/null
+++ b/configs/caching/magcache/wan_i2v_dist_cfg_ulysses_mag_480p.json
@@ -0,0 +1,109 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ },
+ "feature_caching": "Mag",
+ "magcache_calibration": false,
+ "magcache_K": 6,
+ "magcache_thresh": 0.24,
+ "magcache_retention_ratio": 0.2,
+ "magcache_ratios": [
+ [
+ 1.0,
+ 0.98783,
+ 0.97559,
+ 0.98311,
+ 0.98202,
+ 0.9888,
+ 0.98762,
+ 0.98957,
+ 0.99052,
+ 0.99383,
+ 0.98857,
+ 0.99065,
+ 0.98845,
+ 0.99057,
+ 0.98957,
+ 0.98601,
+ 0.98823,
+ 0.98756,
+ 0.98808,
+ 0.98721,
+ 0.98571,
+ 0.98543,
+ 0.98157,
+ 0.98411,
+ 0.97952,
+ 0.98149,
+ 0.9774,
+ 0.97825,
+ 0.97355,
+ 0.97085,
+ 0.97056,
+ 0.96588,
+ 0.96113,
+ 0.9567,
+ 0.94961,
+ 0.93973,
+ 0.93217,
+ 0.91878,
+ 0.90955,
+ 0.92617
+ ],
+ [
+ 1.0,
+ 0.98993,
+ 0.97593,
+ 0.98319,
+ 0.98225,
+ 0.98878,
+ 0.98759,
+ 0.98971,
+ 0.99043,
+ 0.99384,
+ 0.9886,
+ 0.99068,
+ 0.98847,
+ 0.99057,
+ 0.98961,
+ 0.9861,
+ 0.98823,
+ 0.98759,
+ 0.98814,
+ 0.98724,
+ 0.98572,
+ 0.98544,
+ 0.98165,
+ 0.98413,
+ 0.97953,
+ 0.9815,
+ 0.97742,
+ 0.97826,
+ 0.97361,
+ 0.97087,
+ 0.97055,
+ 0.96587,
+ 0.96124,
+ 0.95681,
+ 0.94969,
+ 0.93988,
+ 0.93224,
+ 0.91896,
+ 0.90954,
+ 0.92616
+ ]
+ ]
+}
diff --git a/configs/caching/magcache/wan_i2v_mag_480p.json b/configs/caching/magcache/wan_i2v_mag_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..c8588ac2a8124f57691cf7658740bf06fe4416bf
--- /dev/null
+++ b/configs/caching/magcache/wan_i2v_mag_480p.json
@@ -0,0 +1,104 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Mag",
+ "magcache_calibration": false,
+ "magcache_K": 6,
+ "magcache_thresh": 0.24,
+ "magcache_retention_ratio": 0.2,
+ "magcache_ratios": [
+ [
+ 1.0,
+ 0.98783,
+ 0.97559,
+ 0.98311,
+ 0.98202,
+ 0.9888,
+ 0.98762,
+ 0.98957,
+ 0.99052,
+ 0.99383,
+ 0.98857,
+ 0.99065,
+ 0.98845,
+ 0.99057,
+ 0.98957,
+ 0.98601,
+ 0.98823,
+ 0.98756,
+ 0.98808,
+ 0.98721,
+ 0.98571,
+ 0.98543,
+ 0.98157,
+ 0.98411,
+ 0.97952,
+ 0.98149,
+ 0.9774,
+ 0.97825,
+ 0.97355,
+ 0.97085,
+ 0.97056,
+ 0.96588,
+ 0.96113,
+ 0.9567,
+ 0.94961,
+ 0.93973,
+ 0.93217,
+ 0.91878,
+ 0.90955,
+ 0.92617
+ ],
+ [
+ 1.0,
+ 0.98993,
+ 0.97593,
+ 0.98319,
+ 0.98225,
+ 0.98878,
+ 0.98759,
+ 0.98971,
+ 0.99043,
+ 0.99384,
+ 0.9886,
+ 0.99068,
+ 0.98847,
+ 0.99057,
+ 0.98961,
+ 0.9861,
+ 0.98823,
+ 0.98759,
+ 0.98814,
+ 0.98724,
+ 0.98572,
+ 0.98544,
+ 0.98165,
+ 0.98413,
+ 0.97953,
+ 0.9815,
+ 0.97742,
+ 0.97826,
+ 0.97361,
+ 0.97087,
+ 0.97055,
+ 0.96587,
+ 0.96124,
+ 0.95681,
+ 0.94969,
+ 0.93988,
+ 0.93224,
+ 0.91896,
+ 0.90954,
+ 0.92616
+ ]
+ ]
+}
diff --git a/configs/caching/magcache/wan_i2v_mag_calibration_480p.json b/configs/caching/magcache/wan_i2v_mag_calibration_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..d331d33abddb417661b0ef68112b73e3377ca5ae
--- /dev/null
+++ b/configs/caching/magcache/wan_i2v_mag_calibration_480p.json
@@ -0,0 +1,104 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Mag",
+ "magcache_calibration": true,
+ "magcache_K": 6,
+ "magcache_thresh": 0.24,
+ "magcache_retention_ratio": 0.2,
+ "magcache_ratios": [
+ [
+ 1.0,
+ 0.98783,
+ 0.97559,
+ 0.98311,
+ 0.98202,
+ 0.9888,
+ 0.98762,
+ 0.98957,
+ 0.99052,
+ 0.99383,
+ 0.98857,
+ 0.99065,
+ 0.98845,
+ 0.99057,
+ 0.98957,
+ 0.98601,
+ 0.98823,
+ 0.98756,
+ 0.98808,
+ 0.98721,
+ 0.98571,
+ 0.98543,
+ 0.98157,
+ 0.98411,
+ 0.97952,
+ 0.98149,
+ 0.9774,
+ 0.97825,
+ 0.97355,
+ 0.97085,
+ 0.97056,
+ 0.96588,
+ 0.96113,
+ 0.9567,
+ 0.94961,
+ 0.93973,
+ 0.93217,
+ 0.91878,
+ 0.90955,
+ 0.92617
+ ],
+ [
+ 1.0,
+ 0.98993,
+ 0.97593,
+ 0.98319,
+ 0.98225,
+ 0.98878,
+ 0.98759,
+ 0.98971,
+ 0.99043,
+ 0.99384,
+ 0.9886,
+ 0.99068,
+ 0.98847,
+ 0.99057,
+ 0.98961,
+ 0.9861,
+ 0.98823,
+ 0.98759,
+ 0.98814,
+ 0.98724,
+ 0.98572,
+ 0.98544,
+ 0.98165,
+ 0.98413,
+ 0.97953,
+ 0.9815,
+ 0.97742,
+ 0.97826,
+ 0.97361,
+ 0.97087,
+ 0.97055,
+ 0.96587,
+ 0.96124,
+ 0.95681,
+ 0.94969,
+ 0.93988,
+ 0.93224,
+ 0.91896,
+ 0.90954,
+ 0.92616
+ ]
+ ]
+}
diff --git a/configs/caching/magcache/wan_t2v_dist_cfg_ulysses_mag_1_3b.json b/configs/caching/magcache/wan_t2v_dist_cfg_ulysses_mag_1_3b.json
new file mode 100644
index 0000000000000000000000000000000000000000..3c5c5f8710b3f1c62cc7712b0357f08d4b4f3237
--- /dev/null
+++ b/configs/caching/magcache/wan_t2v_dist_cfg_ulysses_mag_1_3b.json
@@ -0,0 +1,130 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ },
+ "feature_caching": "Mag",
+ "magcache_calibration": false,
+ "magcache_K": 4,
+ "magcache_thresh": 0.12,
+ "magcache_retention_ratio": 0.2,
+ "magcache_ratios": [
+ [
+ 1.0,
+ 1.0124,
+ 1.00166,
+ 0.99791,
+ 0.99682,
+ 0.99634,
+ 0.99567,
+ 0.99416,
+ 0.99578,
+ 0.9957,
+ 0.99511,
+ 0.99535,
+ 0.99552,
+ 0.99541,
+ 0.9954,
+ 0.99489,
+ 0.99518,
+ 0.99484,
+ 0.99481,
+ 0.99415,
+ 0.99419,
+ 0.99396,
+ 0.99388,
+ 0.99349,
+ 0.99309,
+ 0.9927,
+ 0.99228,
+ 0.99171,
+ 0.99137,
+ 0.99068,
+ 0.99005,
+ 0.98944,
+ 0.98849,
+ 0.98758,
+ 0.98644,
+ 0.98504,
+ 0.9836,
+ 0.98202,
+ 0.97977,
+ 0.97717,
+ 0.9741,
+ 0.97003,
+ 0.96538,
+ 0.9593,
+ 0.95086,
+ 0.94013,
+ 0.92402,
+ 0.90241,
+ 0.86821,
+ 0.81838
+ ],
+ [
+ 1.0,
+ 1.02213,
+ 1.0041,
+ 1.00061,
+ 0.99762,
+ 0.99685,
+ 0.99586,
+ 0.99422,
+ 0.99575,
+ 0.99563,
+ 0.99506,
+ 0.99531,
+ 0.99549,
+ 0.99539,
+ 0.99536,
+ 0.99485,
+ 0.99514,
+ 0.99478,
+ 0.99479,
+ 0.99413,
+ 0.99416,
+ 0.99393,
+ 0.99386,
+ 0.99349,
+ 0.99304,
+ 0.9927,
+ 0.99226,
+ 0.9917,
+ 0.99135,
+ 0.99063,
+ 0.99003,
+ 0.98942,
+ 0.98849,
+ 0.98757,
+ 0.98643,
+ 0.98503,
+ 0.98359,
+ 0.98201,
+ 0.97978,
+ 0.97718,
+ 0.97411,
+ 0.97002,
+ 0.96541,
+ 0.95933,
+ 0.95089,
+ 0.94019,
+ 0.92414,
+ 0.9026,
+ 0.86868,
+ 0.81939
+ ]
+ ]
+}
diff --git a/configs/caching/magcache/wan_t2v_mag_1_3b.json b/configs/caching/magcache/wan_t2v_mag_1_3b.json
new file mode 100644
index 0000000000000000000000000000000000000000..aaec17e20fb7d59ce854bbc611efcbf72413cf9f
--- /dev/null
+++ b/configs/caching/magcache/wan_t2v_mag_1_3b.json
@@ -0,0 +1,125 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Mag",
+ "magcache_calibration": false,
+ "magcache_K": 4,
+ "magcache_thresh": 0.12,
+ "magcache_retention_ratio": 0.2,
+ "magcache_ratios": [
+ [
+ 1.0,
+ 1.0124,
+ 1.00166,
+ 0.99791,
+ 0.99682,
+ 0.99634,
+ 0.99567,
+ 0.99416,
+ 0.99578,
+ 0.9957,
+ 0.99511,
+ 0.99535,
+ 0.99552,
+ 0.99541,
+ 0.9954,
+ 0.99489,
+ 0.99518,
+ 0.99484,
+ 0.99481,
+ 0.99415,
+ 0.99419,
+ 0.99396,
+ 0.99388,
+ 0.99349,
+ 0.99309,
+ 0.9927,
+ 0.99228,
+ 0.99171,
+ 0.99137,
+ 0.99068,
+ 0.99005,
+ 0.98944,
+ 0.98849,
+ 0.98758,
+ 0.98644,
+ 0.98504,
+ 0.9836,
+ 0.98202,
+ 0.97977,
+ 0.97717,
+ 0.9741,
+ 0.97003,
+ 0.96538,
+ 0.9593,
+ 0.95086,
+ 0.94013,
+ 0.92402,
+ 0.90241,
+ 0.86821,
+ 0.81838
+ ],
+ [
+ 1.0,
+ 1.02213,
+ 1.0041,
+ 1.00061,
+ 0.99762,
+ 0.99685,
+ 0.99586,
+ 0.99422,
+ 0.99575,
+ 0.99563,
+ 0.99506,
+ 0.99531,
+ 0.99549,
+ 0.99539,
+ 0.99536,
+ 0.99485,
+ 0.99514,
+ 0.99478,
+ 0.99479,
+ 0.99413,
+ 0.99416,
+ 0.99393,
+ 0.99386,
+ 0.99349,
+ 0.99304,
+ 0.9927,
+ 0.99226,
+ 0.9917,
+ 0.99135,
+ 0.99063,
+ 0.99003,
+ 0.98942,
+ 0.98849,
+ 0.98757,
+ 0.98643,
+ 0.98503,
+ 0.98359,
+ 0.98201,
+ 0.97978,
+ 0.97718,
+ 0.97411,
+ 0.97002,
+ 0.96541,
+ 0.95933,
+ 0.95089,
+ 0.94019,
+ 0.92414,
+ 0.9026,
+ 0.86868,
+ 0.81939
+ ]
+ ]
+}
diff --git a/configs/caching/magcache/wan_t2v_mag_calibration_1_3b.json b/configs/caching/magcache/wan_t2v_mag_calibration_1_3b.json
new file mode 100644
index 0000000000000000000000000000000000000000..ac5b5fc2b0fe73751ac4fb617077dde43c4cf95c
--- /dev/null
+++ b/configs/caching/magcache/wan_t2v_mag_calibration_1_3b.json
@@ -0,0 +1,125 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Mag",
+ "magcache_calibration": true,
+ "magcache_K": 4,
+ "magcache_thresh": 0.12,
+ "magcache_retention_ratio": 0.2,
+ "magcache_ratios": [
+ [
+ 1.0,
+ 1.0124,
+ 1.00166,
+ 0.99791,
+ 0.99682,
+ 0.99634,
+ 0.99567,
+ 0.99416,
+ 0.99578,
+ 0.9957,
+ 0.99511,
+ 0.99535,
+ 0.99552,
+ 0.99541,
+ 0.9954,
+ 0.99489,
+ 0.99518,
+ 0.99484,
+ 0.99481,
+ 0.99415,
+ 0.99419,
+ 0.99396,
+ 0.99388,
+ 0.99349,
+ 0.99309,
+ 0.9927,
+ 0.99228,
+ 0.99171,
+ 0.99137,
+ 0.99068,
+ 0.99005,
+ 0.98944,
+ 0.98849,
+ 0.98758,
+ 0.98644,
+ 0.98504,
+ 0.9836,
+ 0.98202,
+ 0.97977,
+ 0.97717,
+ 0.9741,
+ 0.97003,
+ 0.96538,
+ 0.9593,
+ 0.95086,
+ 0.94013,
+ 0.92402,
+ 0.90241,
+ 0.86821,
+ 0.81838
+ ],
+ [
+ 1.0,
+ 1.02213,
+ 1.0041,
+ 1.00061,
+ 0.99762,
+ 0.99685,
+ 0.99586,
+ 0.99422,
+ 0.99575,
+ 0.99563,
+ 0.99506,
+ 0.99531,
+ 0.99549,
+ 0.99539,
+ 0.99536,
+ 0.99485,
+ 0.99514,
+ 0.99478,
+ 0.99479,
+ 0.99413,
+ 0.99416,
+ 0.99393,
+ 0.99386,
+ 0.99349,
+ 0.99304,
+ 0.9927,
+ 0.99226,
+ 0.9917,
+ 0.99135,
+ 0.99063,
+ 0.99003,
+ 0.98942,
+ 0.98849,
+ 0.98757,
+ 0.98643,
+ 0.98503,
+ 0.98359,
+ 0.98201,
+ 0.97978,
+ 0.97718,
+ 0.97411,
+ 0.97002,
+ 0.96541,
+ 0.95933,
+ 0.95089,
+ 0.94019,
+ 0.92414,
+ 0.9026,
+ 0.86868,
+ 0.81939
+ ]
+ ]
+}
diff --git a/configs/caching/taylorseer/wan_t2v_taylorseer.json b/configs/caching/taylorseer/wan_t2v_taylorseer.json
new file mode 100644
index 0000000000000000000000000000000000000000..bb4b4414eaab046ce4452b4df7bfb79d6a6e60df
--- /dev/null
+++ b/configs/caching/taylorseer/wan_t2v_taylorseer.json
@@ -0,0 +1,15 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "TaylorSeer"
+}
diff --git a/configs/caching/teacache/wan_i2v_tea_480p.json b/configs/caching/teacache/wan_i2v_tea_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..cde8fc6b465ac97743e6a0fd32c10b4fa6545f2a
--- /dev/null
+++ b/configs/caching/teacache/wan_i2v_tea_480p.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Tea",
+ "coefficients": [
+ [
+ 2.57151496e05,
+ -3.54229917e04,
+ 1.40286849e03,
+ -1.35890334e01,
+ 1.32517977e-01
+ ],
+ [
+ -3.02331670e02,
+ 2.23948934e02,
+ -5.25463970e01,
+ 5.87348440e00,
+ -2.01973289e-01
+ ]
+ ],
+ "use_ret_steps": true,
+ "teacache_thresh": 0.26
+}
diff --git a/configs/caching/teacache/wan_i2v_tea_720p.json b/configs/caching/teacache/wan_i2v_tea_720p.json
new file mode 100644
index 0000000000000000000000000000000000000000..2ff56d0135e7412c4cf0ec6372057e30bc5c6a3e
--- /dev/null
+++ b/configs/caching/teacache/wan_i2v_tea_720p.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 1280,
+ "target_width": 720,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Tea",
+ "coefficients": [
+ [
+ 8.10705460e03,
+ 2.13393892e03,
+ -3.72934672e02,
+ 1.66203073e01,
+ -4.17769401e-02
+ ],
+ [
+ -114.36346466,
+ 65.26524496,
+ -18.82220707,
+ 4.91518089,
+ -0.23412683
+ ]
+ ],
+ "use_ret_steps": true,
+ "teacache_thresh": 0.26
+}
diff --git a/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json b/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..cdb266ecd46a79f90519e690c8e37fa4511427bd
--- /dev/null
+++ b/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json
@@ -0,0 +1,33 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "feature_caching": "Tea",
+ "coefficients": [
+ [
+ -5.21862437e04,
+ 9.23041404e03,
+ -5.28275948e02,
+ 1.36987616e01,
+ -4.99875664e-02
+ ],
+ [
+ 2.39676752e03,
+ -1.31110545e03,
+ 2.01331979e02,
+ -8.29855975e00,
+ 1.37887774e-01
+ ]
+ ],
+ "use_ret_steps": true,
+ "teacache_thresh": 0.26
+}
diff --git a/configs/caching/teacache/wan_ti2v_tea_720p.json b/configs/caching/teacache/wan_ti2v_tea_720p.json
new file mode 100644
index 0000000000000000000000000000000000000000..1ea52413d71cb0f78be6a873690b95ee7c1a5659
--- /dev/null
+++ b/configs/caching/teacache/wan_ti2v_tea_720p.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [4, 16, 16],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "fps": 24,
+ "feature_caching": "Tea",
+ "coefficients": [
+ [],
+ [ 1.57472669e+05, -1.15702395e+05, 3.10761669e+04, -3.83116651e+03, 2.21608777e+02, -4.81179567e+00]
+ ],
+ "use_ret_steps": false,
+ "teacache_thresh": 0.26,
+ "use_image_encoder": false
+}
diff --git a/configs/causvid/wan_i2v_causvid.json b/configs/causvid/wan_i2v_causvid.json
new file mode 100644
index 0000000000000000000000000000000000000000..3a1acf127754861573e99d9b8e6d074abc56cd9f
--- /dev/null
+++ b/configs/causvid/wan_i2v_causvid.json
@@ -0,0 +1,29 @@
+{
+ "infer_steps": 20,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "num_fragments": 3,
+ "num_frames": 21,
+ "num_frame_per_block": 7,
+ "num_blocks": 3,
+ "frame_seq_length": 1560,
+ "denoising_step_list": [
+ 999,
+ 934,
+ 862,
+ 756,
+ 603,
+ 410,
+ 250,
+ 140,
+ 74
+ ]
+}
diff --git a/configs/causvid/wan_t2v_causvid.json b/configs/causvid/wan_t2v_causvid.json
new file mode 100644
index 0000000000000000000000000000000000000000..253f453b44c1046aa496c109c49204d71e013ca7
--- /dev/null
+++ b/configs/causvid/wan_t2v_causvid.json
@@ -0,0 +1,29 @@
+{
+ "infer_steps": 9,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "num_fragments": 3,
+ "num_frames": 21,
+ "num_frame_per_block": 3,
+ "num_blocks": 7,
+ "frame_seq_length": 1560,
+ "denoising_step_list": [
+ 999,
+ 934,
+ 862,
+ 756,
+ 603,
+ 410,
+ 250,
+ 140,
+ 74
+ ]
+}
diff --git a/configs/changing_resolution/wan_i2v.json b/configs/changing_resolution/wan_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..7b828d461e7f8e13b6883dc1bf3fd0a83ca847e4
--- /dev/null
+++ b/configs/changing_resolution/wan_i2v.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "changing_resolution": true,
+ "resolution_rate": [
+ 0.75
+ ],
+ "changing_resolution_steps": [
+ 20
+ ]
+}
diff --git a/configs/changing_resolution/wan_i2v_U.json b/configs/changing_resolution/wan_i2v_U.json
new file mode 100644
index 0000000000000000000000000000000000000000..0f8077504b18376115300d505e527e8e6ad2403c
--- /dev/null
+++ b/configs/changing_resolution/wan_i2v_U.json
@@ -0,0 +1,22 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "changing_resolution": true,
+ "resolution_rate": [
+ 1.0,
+ 0.75
+ ],
+ "changing_resolution_steps": [
+ 5,
+ 25
+ ]
+}
diff --git a/configs/changing_resolution/wan_t2v.json b/configs/changing_resolution/wan_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..78782775cedb9e340bf00ba5e17903583635ffa7
--- /dev/null
+++ b/configs/changing_resolution/wan_t2v.json
@@ -0,0 +1,21 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "changing_resolution": true,
+ "resolution_rate": [
+ 0.75
+ ],
+ "changing_resolution_steps": [
+ 25
+ ]
+}
diff --git a/configs/changing_resolution/wan_t2v_U.json b/configs/changing_resolution/wan_t2v_U.json
new file mode 100644
index 0000000000000000000000000000000000000000..68531e6639263ba7f32e90c0612b7f7836d7ceab
--- /dev/null
+++ b/configs/changing_resolution/wan_t2v_U.json
@@ -0,0 +1,23 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "changing_resolution": true,
+ "resolution_rate": [
+ 1.0,
+ 0.75
+ ],
+ "changing_resolution_steps": [
+ 10,
+ 35
+ ]
+}
diff --git a/configs/changing_resolution/wan_t2v_U_teacache.json b/configs/changing_resolution/wan_t2v_U_teacache.json
new file mode 100644
index 0000000000000000000000000000000000000000..60ccc67a30176fd66cfbdeae2e568941c35984c3
--- /dev/null
+++ b/configs/changing_resolution/wan_t2v_U_teacache.json
@@ -0,0 +1,42 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "changing_resolution": true,
+ "resolution_rate": [
+ 1.0,
+ 0.75
+ ],
+ "changing_resolution_steps": [
+ 10,
+ 35
+ ],
+ "feature_caching": "Tea",
+ "coefficients": [
+ [
+ -5.21862437e04,
+ 9.23041404e03,
+ -5.28275948e02,
+ 1.36987616e01,
+ -4.99875664e-02
+ ],
+ [
+ 2.39676752e03,
+ -1.31110545e03,
+ 2.01331979e02,
+ -8.29855975e00,
+ 1.37887774e-01
+ ]
+ ],
+ "use_ret_steps": false,
+ "teacache_thresh": 0.1
+}
diff --git a/configs/deploy/wan_i2v.json b/configs/deploy/wan_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..9f1fbde2bdc8507a2e14bb36eb45f62f282a5ce2
--- /dev/null
+++ b/configs/deploy/wan_i2v.json
@@ -0,0 +1,30 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "sub_servers": {
+ "dit": [
+ "http://localhost:9000"
+ ],
+ "prompt_enhancer": [
+ "http://localhost:9001"
+ ],
+ "text_encoders": [
+ "http://localhost:9002"
+ ],
+ "image_encoder": [
+ "http://localhost:9003"
+ ],
+ "vae_model": [
+ "http://localhost:9004"
+ ]
+ }
+}
diff --git a/configs/deploy/wan_t2v.json b/configs/deploy/wan_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..43656d7fa0a2c45125677ca2fb6e85ef37d9087a
--- /dev/null
+++ b/configs/deploy/wan_t2v.json
@@ -0,0 +1,31 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "sub_servers": {
+ "dit": [
+ "http://localhost:9000"
+ ],
+ "prompt_enhancer": [
+ "http://localhost:9001"
+ ],
+ "text_encoders": [
+ "http://localhost:9002"
+ ],
+ "image_encoder": [
+ "http://localhost:9003"
+ ],
+ "vae_model": [
+ "http://localhost:9004"
+ ]
+ }
+}
diff --git a/configs/dist_infer/wan22_moe_i2v_cfg.json b/configs/dist_infer/wan22_moe_i2v_cfg.json
new file mode 100644
index 0000000000000000000000000000000000000000..bafa409ea8077198c458db5a0496e7a726b75074
--- /dev/null
+++ b/configs/dist_infer/wan22_moe_i2v_cfg.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary": 0.900,
+ "use_image_encoder": false,
+ "parallel": {
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json b/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..91a15a0d6200a00e3f882079a8da2aac009b0fd8
--- /dev/null
+++ b/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "boundary": 0.900,
+ "use_image_encoder": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan22_moe_i2v_ulysses.json b/configs/dist_infer/wan22_moe_i2v_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..eb0759b39a65be5f02f6e94bcd055de5c94ce406
--- /dev/null
+++ b/configs/dist_infer/wan22_moe_i2v_ulysses.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary": 0.900,
+ "use_image_encoder": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/dist_infer/wan22_moe_t2v_cfg.json b/configs/dist_infer/wan22_moe_t2v_cfg.json
new file mode 100644
index 0000000000000000000000000000000000000000..f47aad82f81120969cb2b177dddb05280fda7968
--- /dev/null
+++ b/configs/dist_infer/wan22_moe_t2v_cfg.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 4.0,
+ 3.0
+ ],
+ "sample_shift": 12.0,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary": 0.875,
+ "parallel": {
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan22_moe_t2v_cfg_ulysses.json b/configs/dist_infer/wan22_moe_t2v_cfg_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..645e6d7f5216dc3033461427ad39f6a06f792c98
--- /dev/null
+++ b/configs/dist_infer/wan22_moe_t2v_cfg_ulysses.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 4.0,
+ 3.0
+ ],
+ "sample_shift": 12.0,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary": 0.875,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan22_moe_t2v_ulysses.json b/configs/dist_infer/wan22_moe_t2v_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..03aaf9246536c23330cb1f502c66f815abef8fe3
--- /dev/null
+++ b/configs/dist_infer/wan22_moe_t2v_ulysses.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 4.0,
+ 3.0
+ ],
+ "sample_shift": 12.0,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary": 0.875,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/dist_infer/wan22_ti2v_i2v_cfg.json b/configs/dist_infer/wan22_ti2v_i2v_cfg.json
new file mode 100644
index 0000000000000000000000000000000000000000..4754ff30d382e3dd0732bca8e3ba51f9613ac4a7
--- /dev/null
+++ b/configs/dist_infer/wan22_ti2v_i2v_cfg.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "fps": 24,
+ "use_image_encoder": false,
+ "parallel": {
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan22_ti2v_i2v_cfg_ulysses.json b/configs/dist_infer/wan22_ti2v_i2v_cfg_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..9e340f400f80ad10a579e07c3c294265cff18e94
--- /dev/null
+++ b/configs/dist_infer/wan22_ti2v_i2v_cfg_ulysses.json
@@ -0,0 +1,27 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "fps": 24,
+ "use_image_encoder": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan22_ti2v_i2v_ulysses.json b/configs/dist_infer/wan22_ti2v_i2v_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..56d0f65e5aa66d7c7b42b2c7a651a4ab6c1c1930
--- /dev/null
+++ b/configs/dist_infer/wan22_ti2v_i2v_ulysses.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "fps": 24,
+ "use_image_encoder": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/dist_infer/wan22_ti2v_t2v_cfg.json b/configs/dist_infer/wan22_ti2v_t2v_cfg.json
new file mode 100644
index 0000000000000000000000000000000000000000..c88cb2a10144a23cda3daebf5bc06d341db5a5e0
--- /dev/null
+++ b/configs/dist_infer/wan22_ti2v_t2v_cfg.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "fps": 24,
+ "parallel": {
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan22_ti2v_t2v_cfg_ulysses.json b/configs/dist_infer/wan22_ti2v_t2v_cfg_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..79f4f03f0d1f16bf0825aabd1975c3105ea19fb6
--- /dev/null
+++ b/configs/dist_infer/wan22_ti2v_t2v_cfg_ulysses.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "fps": 24,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan22_ti2v_t2v_ulysses.json b/configs/dist_infer/wan22_ti2v_t2v_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..12b18509e0d98a2bbe9c89b054d96bfa79d66f5e
--- /dev/null
+++ b/configs/dist_infer/wan22_ti2v_t2v_ulysses.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "fps": 24,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/dist_infer/wan_i2v_dist_cfg_ulysses.json b/configs/dist_infer/wan_i2v_dist_cfg_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..001f1f0c6e813e6c3d3a7fcd1014a973037b91f5
--- /dev/null
+++ b/configs/dist_infer/wan_i2v_dist_cfg_ulysses.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan_i2v_dist_ring.json b/configs/dist_infer/wan_i2v_dist_ring.json
new file mode 100644
index 0000000000000000000000000000000000000000..fe2608aae5ed52a1873a3df346b00bf59e54e404
--- /dev/null
+++ b/configs/dist_infer/wan_i2v_dist_ring.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ring"
+ }
+}
diff --git a/configs/dist_infer/wan_i2v_dist_ulysses.json b/configs/dist_infer/wan_i2v_dist_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..a8af87968ceb00ee70a3bd8a2608e52cb88ca738
--- /dev/null
+++ b/configs/dist_infer/wan_i2v_dist_ulysses.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/dist_infer/wan_t2v_dist_cfg.json b/configs/dist_infer/wan_t2v_dist_cfg.json
new file mode 100644
index 0000000000000000000000000000000000000000..49bc48161ab648cf7bdb9358b10e42b3376aba0e
--- /dev/null
+++ b/configs/dist_infer/wan_t2v_dist_cfg.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan_t2v_dist_cfg_ring.json b/configs/dist_infer/wan_t2v_dist_cfg_ring.json
new file mode 100644
index 0000000000000000000000000000000000000000..50450f43e56cb27770305903996a3ee0674f0096
--- /dev/null
+++ b/configs/dist_infer/wan_t2v_dist_cfg_ring.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ring",
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan_t2v_dist_cfg_ulysses.json b/configs/dist_infer/wan_t2v_dist_cfg_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..bd6808dfd7c8c270221a966dea79da3ad54848be
--- /dev/null
+++ b/configs/dist_infer/wan_t2v_dist_cfg_ulysses.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ }
+}
diff --git a/configs/dist_infer/wan_t2v_dist_ring.json b/configs/dist_infer/wan_t2v_dist_ring.json
new file mode 100644
index 0000000000000000000000000000000000000000..a28549bd6c930885adcc691e63d7024e041d7861
--- /dev/null
+++ b/configs/dist_infer/wan_t2v_dist_ring.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ring"
+ }
+}
diff --git a/configs/dist_infer/wan_t2v_dist_ulysses.json b/configs/dist_infer/wan_t2v_dist_ulysses.json
new file mode 100644
index 0000000000000000000000000000000000000000..c9b9bc2ec71a41ea5f468dc6cf09e072e764e248
--- /dev/null
+++ b/configs/dist_infer/wan_t2v_dist_ulysses.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/distill/wan_i2v_distill_4step_cfg.json b/configs/distill/wan_i2v_distill_4step_cfg.json
new file mode 100644
index 0000000000000000000000000000000000000000..519f58a4ed1778b4a8b979482a9f7b7b34ea1571
--- /dev/null
+++ b/configs/distill/wan_i2v_distill_4step_cfg.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ]
+}
diff --git a/configs/distill/wan_i2v_distill_4step_cfg_4090.json b/configs/distill/wan_i2v_distill_4step_cfg_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..c6863a558a5d1d9eafc8ce7825f64cfc2ee58be8
--- /dev/null
+++ b/configs/distill/wan_i2v_distill_4step_cfg_4090.json
@@ -0,0 +1,29 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f"
+}
diff --git a/configs/distill/wan_i2v_distill_4step_cfg_4090_lora.json b/configs/distill/wan_i2v_distill_4step_cfg_4090_lora.json
new file mode 100644
index 0000000000000000000000000000000000000000..100eb453da684c3c6c1e8dd751b68671b149af55
--- /dev/null
+++ b/configs/distill/wan_i2v_distill_4step_cfg_4090_lora.json
@@ -0,0 +1,35 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f",
+ "lora_configs": [
+ {
+ "path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/distill/wan_i2v_distill_4step_cfg_lora.json b/configs/distill/wan_i2v_distill_4step_cfg_lora.json
new file mode 100644
index 0000000000000000000000000000000000000000..51797dd123006864bf4dedc81cf1c76b21751eaf
--- /dev/null
+++ b/configs/distill/wan_i2v_distill_4step_cfg_lora.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "denoising_step_list": [1000, 750, 500, 250],
+ "lora_configs": [
+ {
+ "path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/distill/wan_t2v_distill_4step_cfg.json b/configs/distill/wan_t2v_distill_4step_cfg.json
new file mode 100644
index 0000000000000000000000000000000000000000..89ea0675604af80a6042cd0caa36a744b7d38098
--- /dev/null
+++ b/configs/distill/wan_t2v_distill_4step_cfg.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ]
+}
diff --git a/configs/distill/wan_t2v_distill_4step_cfg_4090.json b/configs/distill/wan_t2v_distill_4step_cfg_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..0255d0923a7f21a832b09d9943044e1c355d8efd
--- /dev/null
+++ b/configs/distill/wan_t2v_distill_4step_cfg_4090.json
@@ -0,0 +1,30 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 6,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f"
+}
diff --git a/configs/distill/wan_t2v_distill_4step_cfg_dynamic.json b/configs/distill/wan_t2v_distill_4step_cfg_dynamic.json
new file mode 100644
index 0000000000000000000000000000000000000000..f05e7283308161084ab5280d32186c7ee06d9460
--- /dev/null
+++ b/configs/distill/wan_t2v_distill_4step_cfg_dynamic.json
@@ -0,0 +1,22 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "enable_dynamic_cfg": true,
+ "cfg_scale": 4.0,
+ "cpu_offload": false,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ]
+}
diff --git a/configs/distill/wan_t2v_distill_4step_cfg_lora.json b/configs/distill/wan_t2v_distill_4step_cfg_lora.json
new file mode 100644
index 0000000000000000000000000000000000000000..0388af4dd0c11cc2e7bd0c3668608e3ac9a2e946
--- /dev/null
+++ b/configs/distill/wan_t2v_distill_4step_cfg_lora.json
@@ -0,0 +1,21 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "denoising_step_list": [1000, 750, 500, 250],
+ "lora_configs": [
+ {
+ "path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/distill/wan_t2v_distill_4step_cfg_lora_4090.json b/configs/distill/wan_t2v_distill_4step_cfg_lora_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..5e87e18fdb3303202cb1e943785b15f728f42f88
--- /dev/null
+++ b/configs/distill/wan_t2v_distill_4step_cfg_lora_4090.json
@@ -0,0 +1,36 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 6,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f",
+ "lora_configs": [
+ {
+ "path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16.json b/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..ce108bbdce5d95a17bc39768648ec2f106b61fd9
--- /dev/null
+++ b/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2",
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false
+}
diff --git a/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16_dist.json b/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16_dist.json
new file mode 100644
index 0000000000000000000000000000000000000000..8049eda8d56f7eaa6fd0c52b5509c14a1fda38c9
--- /dev/null
+++ b/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16_dist.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 7.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2",
+ "parallel": {
+ "seq_p_attn_type": "ulysses-4090",
+ "seq_p_size": 8
+ }
+}
diff --git a/configs/hunyuan_video_15/4090/hy15_t2v_480p_fp8.json b/configs/hunyuan_video_15/4090/hy15_t2v_480p_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..6096bf90a780e3f1eae249d4b5984071c2fbebf6
--- /dev/null
+++ b/configs/hunyuan_video_15/4090/hy15_t2v_480p_fp8.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2",
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false,
+ "dit_quantized_ckpt": "/path/to/480p_t2v_fp8.safetensors",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "qwen25vl_quantized_ckpt": "/path/to/qwen25vl_fp8.safetensors",
+ "qwen25vl_quantized": true,
+ "qwen25vl_quant_scheme": "fp8-q8f"
+}
diff --git a/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16.json b/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..2668b2b0b42c937837bbf8a2ce96924e450eae65
--- /dev/null
+++ b/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn3",
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false
+}
diff --git a/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16_dist.json b/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16_dist.json
new file mode 100644
index 0000000000000000000000000000000000000000..d39a77310efddbca9a3361e00cf1856cbe812108
--- /dev/null
+++ b/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16_dist.json
@@ -0,0 +1,22 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn3",
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false,
+ "parallel": {
+ "seq_p_attn_type": "ulysses",
+ "seq_p_size": 8
+ }
+}
diff --git a/configs/hunyuan_video_15/5090/hy15_t2v_480p_fp8.json b/configs/hunyuan_video_15/5090/hy15_t2v_480p_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..303adcd24821b3fd89bf2abd45df35b7498e5aca
--- /dev/null
+++ b/configs/hunyuan_video_15/5090/hy15_t2v_480p_fp8.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn3",
+ "qwen25vl_cpu_offload": true,
+ "dit_quantized_ckpt": "/path/to/480p_t2v_fp8.safetensors",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "qwen25vl_quantized_ckpt": "/path/to/qwen25vl_fp8.safetensors",
+ "qwen25vl_quantized": true,
+ "qwen25vl_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache.json b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache.json
new file mode 100644
index 0000000000000000000000000000000000000000..91fe4b27c668bc1f4512a793bb158394a8ff472d
--- /dev/null
+++ b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_i2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2",
+ "feature_caching": "Mag",
+ "magcache_calibration": false,
+ "magcache_K": 6,
+ "magcache_thresh": 0.24,
+ "magcache_retention_ratio": 0.2,
+ "magcache_ratios": [[1.0, 1.01562, 1.00781, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.99609, 1.0, 0.99609, 1.0, 0.99609, 1.0, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99219, 0.99219, 0.99219, 0.98828, 0.98828, 0.98828, 0.98828, 0.98438, 0.98438, 0.98047, 0.98047, 0.97656, 0.97266, 0.96484, 0.95703, 0.94922, 0.92969, 0.91016, 0.88672], [1.0, 1.02344, 1.00781, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.99609, 1.0, 0.99609, 1.0, 0.99609, 1.0, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99219, 0.99219, 0.99219, 0.99219, 0.98828, 0.98828, 0.98828, 0.98438, 0.98438, 0.98047, 0.98047, 0.97656, 0.97266, 0.96484, 0.95703, 0.94922, 0.93359, 0.91016, 0.88672]]
+}
diff --git a/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache_calibration.json b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache_calibration.json
new file mode 100644
index 0000000000000000000000000000000000000000..ba180071145dccc2617387780b9d091188e9dd4e
--- /dev/null
+++ b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache_calibration.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_i2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2",
+ "feature_caching": "Mag",
+ "magcache_calibration": true
+}
diff --git a/configs/hunyuan_video_15/cache/hy_15_i2v_480p_teacache.json b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_teacache.json
new file mode 100644
index 0000000000000000000000000000000000000000..6872b0b2048a12b6b975133ee59e7f9cdea284df
--- /dev/null
+++ b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_teacache.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_i2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "flash_attn3",
+ "feature_caching": "Tea",
+ "coefficients": [8.08528429e+03 ,-2.44607178e+03, 2.49489589e+02, -9.10697865e+00, 1.20261379e-01],
+ "teacache_thresh": 0.15,
+ "cpu_offload": false,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false
+}
diff --git a/configs/hunyuan_video_15/cache/hy_15_i2v_720p_teacache.json b/configs/hunyuan_video_15/cache/hy_15_i2v_720p_teacache.json
new file mode 100644
index 0000000000000000000000000000000000000000..252a80b9b3fdb57e4db219506631e56c27ffa98f
--- /dev/null
+++ b/configs/hunyuan_video_15/cache/hy_15_i2v_720p_teacache.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "720p_i2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 7.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "flash_attn3",
+ "feature_caching": "Tea",
+ "coefficients": [3.84300014e+03, -1.39247433e+03, 1.69167679e+02, -7.07679232e+00, 1.02419011e-01],
+ "teacache_thresh": 0.15,
+ "cpu_offload": false,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false
+}
diff --git a/configs/hunyuan_video_15/cache/hy_15_t2v_480p_teacache.json b/configs/hunyuan_video_15/cache/hy_15_t2v_480p_teacache.json
new file mode 100644
index 0000000000000000000000000000000000000000..d2925d37deb8b0c068ace1f1b3b3bb8f38760bff
--- /dev/null
+++ b/configs/hunyuan_video_15/cache/hy_15_t2v_480p_teacache.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "flash_attn3",
+ "feature_caching": "Tea",
+ "coefficients": [-2.97190924e+04, 2.22834983e+04, -4.37418360e+03, 3.39340251e+02, -1.01365855e+01, 1.29101768e-01],
+ "teacache_thresh": 0.15,
+ "cpu_offload": false,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false
+}
diff --git a/configs/hunyuan_video_15/cache/hy_15_t2v_720p_teacache.json b/configs/hunyuan_video_15/cache/hy_15_t2v_720p_teacache.json
new file mode 100644
index 0000000000000000000000000000000000000000..97b24eac3dd07ae1047b20ea60fcef36fd021a40
--- /dev/null
+++ b/configs/hunyuan_video_15/cache/hy_15_t2v_720p_teacache.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "729p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "flash_attn3",
+ "feature_caching": "Tea",
+ "coefficients": [-3.08907507e+04, 1.67786188e+04, -3.19178643e+03, 2.60740519e+02, -8.19205881e+00, 1.07913775e-01],
+ "teacache_thresh": 0.15,
+ "cpu_offload": false,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false
+}
diff --git a/configs/hunyuan_video_15/fp8comm/hy15_i2v_480p_int8_offload_dist_fp8_comm.json b/configs/hunyuan_video_15/fp8comm/hy15_i2v_480p_int8_offload_dist_fp8_comm.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9fc93364e7dbaaec0e8acef4804dcf13ed499c9
--- /dev/null
+++ b/configs/hunyuan_video_15/fp8comm/hy15_i2v_480p_int8_offload_dist_fp8_comm.json
@@ -0,0 +1,23 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_i2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": false,
+ "attn_type": "sage_attn3",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false,
+ "dit_quantized_ckpt": "/path/to/quant_model.safetensors",
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-q8f",
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_fp8_comm": true,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/hunyuan_video_15/hunyuan_video_i2v_480p.json b/configs/hunyuan_video_15/hunyuan_video_i2v_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..32624592b19679d3dc50e2d8baf236232cbe2f57
--- /dev/null
+++ b/configs/hunyuan_video_15/hunyuan_video_i2v_480p.json
@@ -0,0 +1,11 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_i2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2"
+}
diff --git a/configs/hunyuan_video_15/hunyuan_video_i2v_720p.json b/configs/hunyuan_video_15/hunyuan_video_i2v_720p.json
new file mode 100644
index 0000000000000000000000000000000000000000..f993a5caeeb347c5e612aac024fc374e9b18c817
--- /dev/null
+++ b/configs/hunyuan_video_15/hunyuan_video_i2v_720p.json
@@ -0,0 +1,11 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "720p_i2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 7.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2"
+}
diff --git a/configs/hunyuan_video_15/hunyuan_video_i2v_720p_cfg_distilled.json b/configs/hunyuan_video_15/hunyuan_video_i2v_720p_cfg_distilled.json
new file mode 100644
index 0000000000000000000000000000000000000000..a467d446b0071724c928c7583444b7da9614cfe0
--- /dev/null
+++ b/configs/hunyuan_video_15/hunyuan_video_i2v_720p_cfg_distilled.json
@@ -0,0 +1,11 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "720p_i2v_distilled",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 9.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": false,
+ "attn_type": "sage_attn2"
+}
diff --git a/configs/hunyuan_video_15/hunyuan_video_t2v_480p.json b/configs/hunyuan_video_15/hunyuan_video_t2v_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..28ca4b9d6cf1da8530a4a17333942342b913ec90
--- /dev/null
+++ b/configs/hunyuan_video_15/hunyuan_video_t2v_480p.json
@@ -0,0 +1,12 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2"
+}
diff --git a/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json b/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json
new file mode 100644
index 0000000000000000000000000000000000000000..f4d9ceb734956865d5a90a1cb30bdfe18a130f2a
--- /dev/null
+++ b/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 4,
+ "transformer_model_name": "480p_t2v",
+ "fps": 16,
+ "target_video_length": 81,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": -1.0,
+ "enable_cfg": false,
+ "attn_type": "sage_attn2",
+ "dit_original_ckpt": "hunyuanvideo-1.5/distill_models/480p_t2v/distill_model.safetensors",
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ]
+}
diff --git a/configs/hunyuan_video_15/hunyuan_video_t2v_720p.json b/configs/hunyuan_video_15/hunyuan_video_t2v_720p.json
new file mode 100644
index 0000000000000000000000000000000000000000..f8923c1520129693bbfcd0e92396014b357b2995
--- /dev/null
+++ b/configs/hunyuan_video_15/hunyuan_video_t2v_720p.json
@@ -0,0 +1,12 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "720p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 9.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2"
+}
diff --git a/configs/hunyuan_video_15/lightae/hy15_t2v_480p_bf16.json b/configs/hunyuan_video_15/lightae/hy15_t2v_480p_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..6fc26bdb0e4ce8009fb72d66e480d720cb00f138
--- /dev/null
+++ b/configs/hunyuan_video_15/lightae/hy15_t2v_480p_bf16.json
@@ -0,0 +1,14 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 7.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "flash_attn3",
+ "use_tae": true,
+ "tae_path": "/path/to/lighttae"
+}
diff --git a/configs/hunyuan_video_15/offload/hy15_t2v_480p_bf16.json b/configs/hunyuan_video_15/offload/hy15_t2v_480p_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..1055492b1922da21c4efb9c096f5d14e31c9c96e
--- /dev/null
+++ b/configs/hunyuan_video_15/offload/hy15_t2v_480p_bf16.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 7.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "sage_attn2",
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "vae_cpu_offload": false,
+ "byt5_cpu_offload": false,
+ "qwen25vl_cpu_offload": true,
+ "siglip_cpu_offload": false
+}
diff --git a/configs/hunyuan_video_15/quant/hy15_t2v_480p_fp8.json b/configs/hunyuan_video_15/quant/hy15_t2v_480p_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..196a5b7eccfa5536daf34abd0aff5935498c0a9a
--- /dev/null
+++ b/configs/hunyuan_video_15/quant/hy15_t2v_480p_fp8.json
@@ -0,0 +1,15 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_t2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "aspect_ratio": "16:9",
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 7.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "flash_attn3",
+ "dit_quantized_ckpt": "/path/to/quant_model.safetensors",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/hunyuan_video_15/vsr/hy15_i2v_480p.json b/configs/hunyuan_video_15/vsr/hy15_i2v_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..4f774fff6b1380d224e838b7403ee79e1428492e
--- /dev/null
+++ b/configs/hunyuan_video_15/vsr/hy15_i2v_480p.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 50,
+ "transformer_model_name": "480p_i2v",
+ "fps": 24,
+ "target_video_length": 121,
+ "vae_stride": [4, 16, 16],
+ "sample_shift": 5.0,
+ "sample_guide_scale": 6.0,
+ "enable_cfg": true,
+ "attn_type": "flash_attn3",
+ "video_super_resolution": {
+ "sr_version": "720p_sr_distilled",
+ "flow_shift": 2.0,
+ "base_resolution": "480p",
+ "guidance_scale": 1.0,
+ "num_inference_steps": 6,
+ "use_meanflow": true
+ }
+}
diff --git a/configs/matrix_game2/matrix_game2_gta_drive.json b/configs/matrix_game2/matrix_game2_gta_drive.json
new file mode 100644
index 0000000000000000000000000000000000000000..d1952d813a4f56557f55d588a7414f18599595fa
--- /dev/null
+++ b/configs/matrix_game2/matrix_game2_gta_drive.json
@@ -0,0 +1,72 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 150,
+ "num_output_frames": 150,
+ "text_len": 512,
+ "target_height": 352,
+ "target_width": 640,
+ "self_attn_1_type": "flash_attn2",
+ "cross_attn_1_type": "flash_attn2",
+ "cross_attn_2_type": "flash_attn2",
+ "seed": 0,
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "sf_config": {
+ "local_attn_size": 6,
+ "shift": 5.0,
+ "num_frame_per_block": 3,
+ "num_transformer_blocks": 30,
+ "frame_seq_length": 880,
+ "num_output_frames": 150,
+ "num_inference_steps": 1000,
+ "denoising_step_list": [1000.0000, 908.8427, 713.9794]
+ },
+ "sub_model_folder": "gta_distilled_model",
+ "sub_model_name": "gta_keyboard2dim.safetensors",
+ "mode": "gta_drive",
+ "streaming": false,
+ "action_config": {
+ "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+ "enable_keyboard": true,
+ "enable_mouse": true,
+ "heads_num": 16,
+ "hidden_size": 128,
+ "img_hidden_size": 1536,
+ "keyboard_dim_in": 4,
+ "keyboard_hidden_dim": 1024,
+ "mouse_dim_in": 2,
+ "mouse_hidden_dim": 1024,
+ "mouse_qk_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "patch_size": [
+ 1,
+ 2,
+ 2
+ ],
+ "qk_norm": true,
+ "qkv_bias": false,
+ "rope_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "rope_theta": 256,
+ "vae_time_compression_ratio": 4,
+ "windows_size": 3
+ },
+ "dim": 1536,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_dim": 36,
+ "inject_sample_info": false,
+ "model_type": "i2v",
+ "num_heads": 12,
+ "num_layers": 30,
+ "out_dim": 16
+}
diff --git a/configs/matrix_game2/matrix_game2_gta_drive_streaming.json b/configs/matrix_game2/matrix_game2_gta_drive_streaming.json
new file mode 100644
index 0000000000000000000000000000000000000000..584eab096d7c8c380d927c8518996c48f02c1119
--- /dev/null
+++ b/configs/matrix_game2/matrix_game2_gta_drive_streaming.json
@@ -0,0 +1,72 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 360,
+ "num_output_frames": 360,
+ "text_len": 512,
+ "target_height": 352,
+ "target_width": 640,
+ "self_attn_1_type": "flash_attn2",
+ "cross_attn_1_type": "flash_attn2",
+ "cross_attn_2_type": "flash_attn2",
+ "seed": 0,
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "sf_config": {
+ "local_attn_size": 6,
+ "shift": 5.0,
+ "num_frame_per_block": 3,
+ "num_transformer_blocks": 30,
+ "frame_seq_length": 880,
+ "num_output_frames": 360,
+ "num_inference_steps": 1000,
+ "denoising_step_list": [1000.0000, 908.8427, 713.9794]
+ },
+ "sub_model_folder": "gta_distilled_model",
+ "sub_model_name": "gta_keyboard2dim.safetensors",
+ "mode": "gta_drive",
+ "streaming": true,
+ "action_config": {
+ "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+ "enable_keyboard": true,
+ "enable_mouse": true,
+ "heads_num": 16,
+ "hidden_size": 128,
+ "img_hidden_size": 1536,
+ "keyboard_dim_in": 4,
+ "keyboard_hidden_dim": 1024,
+ "mouse_dim_in": 2,
+ "mouse_hidden_dim": 1024,
+ "mouse_qk_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "patch_size": [
+ 1,
+ 2,
+ 2
+ ],
+ "qk_norm": true,
+ "qkv_bias": false,
+ "rope_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "rope_theta": 256,
+ "vae_time_compression_ratio": 4,
+ "windows_size": 3
+ },
+ "dim": 1536,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_dim": 36,
+ "inject_sample_info": false,
+ "model_type": "i2v",
+ "num_heads": 12,
+ "num_layers": 30,
+ "out_dim": 16
+}
diff --git a/configs/matrix_game2/matrix_game2_templerun.json b/configs/matrix_game2/matrix_game2_templerun.json
new file mode 100644
index 0000000000000000000000000000000000000000..63f1aab5e4b5ff43535e6e2cf504beafd550f7a4
--- /dev/null
+++ b/configs/matrix_game2/matrix_game2_templerun.json
@@ -0,0 +1,65 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 150,
+ "num_output_frames": 150,
+ "text_len": 512,
+ "target_height": 352,
+ "target_width": 640,
+ "self_attn_1_type": "flash_attn2",
+ "cross_attn_1_type": "flash_attn2",
+ "cross_attn_2_type": "flash_attn2",
+ "seed": 0,
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "sf_config": {
+ "local_attn_size": 6,
+ "shift": 5.0,
+ "num_frame_per_block": 3,
+ "num_transformer_blocks": 30,
+ "frame_seq_length": 880,
+ "num_output_frames": 150,
+ "num_inference_steps": 1000,
+ "denoising_step_list": [1000.0000, 908.8427, 713.9794]
+ },
+ "sub_model_folder": "templerun_distilled_model",
+ "sub_model_name": "templerun_7dim_onlykey.safetensors",
+ "mode": "templerun",
+ "streaming": false,
+ "action_config": {
+ "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+ "enable_keyboard": true,
+ "enable_mouse": false,
+ "heads_num": 16,
+ "hidden_size": 128,
+ "img_hidden_size": 1536,
+ "keyboard_dim_in": 7,
+ "keyboard_hidden_dim": 1024,
+ "patch_size": [
+ 1,
+ 2,
+ 2
+ ],
+ "qk_norm": true,
+ "qkv_bias": false,
+ "rope_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "rope_theta": 256,
+ "vae_time_compression_ratio": 4,
+ "windows_size": 3
+ },
+ "dim": 1536,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_dim": 36,
+ "inject_sample_info": false,
+ "model_type": "i2v",
+ "num_heads": 12,
+ "num_layers": 30,
+ "out_dim": 16
+}
diff --git a/configs/matrix_game2/matrix_game2_templerun_streaming.json b/configs/matrix_game2/matrix_game2_templerun_streaming.json
new file mode 100644
index 0000000000000000000000000000000000000000..90fd2955856a5182c19deac209b9e87e789d8338
--- /dev/null
+++ b/configs/matrix_game2/matrix_game2_templerun_streaming.json
@@ -0,0 +1,72 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 360,
+ "num_output_frames": 360,
+ "text_len": 512,
+ "target_height": 352,
+ "target_width": 640,
+ "self_attn_1_type": "flash_attn2",
+ "cross_attn_1_type": "flash_attn2",
+ "cross_attn_2_type": "flash_attn2",
+ "seed": 0,
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "sf_config": {
+ "local_attn_size": 6,
+ "shift": 5.0,
+ "num_frame_per_block": 3,
+ "num_transformer_blocks": 30,
+ "frame_seq_length": 880,
+ "num_output_frames": 360,
+ "num_inference_steps": 1000,
+ "denoising_step_list": [1000.0000, 908.8427, 713.9794]
+ },
+ "sub_model_folder": "templerun_distilled_model",
+ "sub_model_name": "templerun_7dim_onlykey.safetensors",
+ "mode": "templerun",
+ "streaming": true,
+ "action_config": {
+ "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+ "enable_keyboard": true,
+ "enable_mouse": true,
+ "heads_num": 16,
+ "hidden_size": 128,
+ "img_hidden_size": 1536,
+ "keyboard_dim_in": 4,
+ "keyboard_hidden_dim": 1024,
+ "mouse_dim_in": 2,
+ "mouse_hidden_dim": 1024,
+ "mouse_qk_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "patch_size": [
+ 1,
+ 2,
+ 2
+ ],
+ "qk_norm": true,
+ "qkv_bias": false,
+ "rope_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "rope_theta": 256,
+ "vae_time_compression_ratio": 4,
+ "windows_size": 3
+ },
+ "dim": 1536,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_dim": 36,
+ "inject_sample_info": false,
+ "model_type": "i2v",
+ "num_heads": 12,
+ "num_layers": 30,
+ "out_dim": 16
+}
diff --git a/configs/matrix_game2/matrix_game2_universal.json b/configs/matrix_game2/matrix_game2_universal.json
new file mode 100644
index 0000000000000000000000000000000000000000..0b801d8492b19a9d7567534f775bb3804af53a1d
--- /dev/null
+++ b/configs/matrix_game2/matrix_game2_universal.json
@@ -0,0 +1,72 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 150,
+ "num_output_frames": 150,
+ "text_len": 512,
+ "target_height": 352,
+ "target_width": 640,
+ "self_attn_1_type": "flash_attn2",
+ "cross_attn_1_type": "flash_attn2",
+ "cross_attn_2_type": "flash_attn2",
+ "seed": 0,
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "sf_config": {
+ "local_attn_size": 6,
+ "shift": 5.0,
+ "num_frame_per_block": 3,
+ "num_transformer_blocks": 30,
+ "frame_seq_length": 880,
+ "num_output_frames": 150,
+ "num_inference_steps": 1000,
+ "denoising_step_list": [1000.0000, 908.8427, 713.9794]
+ },
+ "sub_model_folder": "base_distilled_model",
+ "sub_model_name": "base_distill.safetensors",
+ "mode": "universal",
+ "streaming": false,
+ "action_config": {
+ "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+ "enable_keyboard": true,
+ "enable_mouse": true,
+ "heads_num": 16,
+ "hidden_size": 128,
+ "img_hidden_size": 1536,
+ "keyboard_dim_in": 4,
+ "keyboard_hidden_dim": 1024,
+ "mouse_dim_in": 2,
+ "mouse_hidden_dim": 1024,
+ "mouse_qk_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "patch_size": [
+ 1,
+ 2,
+ 2
+ ],
+ "qk_norm": true,
+ "qkv_bias": false,
+ "rope_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "rope_theta": 256,
+ "vae_time_compression_ratio": 4,
+ "windows_size": 3
+ },
+ "dim": 1536,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_dim": 36,
+ "inject_sample_info": false,
+ "model_type": "i2v",
+ "num_heads": 12,
+ "num_layers": 30,
+ "out_dim": 16
+}
diff --git a/configs/matrix_game2/matrix_game2_universal_streaming.json b/configs/matrix_game2/matrix_game2_universal_streaming.json
new file mode 100644
index 0000000000000000000000000000000000000000..f2c4575a78dfe3fed7f02819ad6f70758b637866
--- /dev/null
+++ b/configs/matrix_game2/matrix_game2_universal_streaming.json
@@ -0,0 +1,72 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 360,
+ "num_output_frames": 360,
+ "text_len": 512,
+ "target_height": 352,
+ "target_width": 640,
+ "self_attn_1_type": "flash_attn2",
+ "cross_attn_1_type": "flash_attn2",
+ "cross_attn_2_type": "flash_attn2",
+ "seed": 0,
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "sf_config": {
+ "local_attn_size": 6,
+ "shift": 5.0,
+ "num_frame_per_block": 3,
+ "num_transformer_blocks": 30,
+ "frame_seq_length": 880,
+ "num_output_frames": 360,
+ "num_inference_steps": 1000,
+ "denoising_step_list": [1000.0000, 908.8427, 713.9794]
+ },
+ "sub_model_folder": "base_distilled_model",
+ "sub_model_name": "base_distill.safetensors",
+ "mode": "universal",
+ "streaming": true,
+ "action_config": {
+ "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+ "enable_keyboard": true,
+ "enable_mouse": true,
+ "heads_num": 16,
+ "hidden_size": 128,
+ "img_hidden_size": 1536,
+ "keyboard_dim_in": 4,
+ "keyboard_hidden_dim": 1024,
+ "mouse_dim_in": 2,
+ "mouse_hidden_dim": 1024,
+ "mouse_qk_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "patch_size": [
+ 1,
+ 2,
+ 2
+ ],
+ "qk_norm": true,
+ "qkv_bias": false,
+ "rope_dim_list": [
+ 8,
+ 28,
+ 28
+ ],
+ "rope_theta": 256,
+ "vae_time_compression_ratio": 4,
+ "windows_size": 3
+ },
+ "dim": 1536,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_dim": 36,
+ "inject_sample_info": false,
+ "model_type": "i2v",
+ "num_heads": 12,
+ "num_layers": 30,
+ "out_dim": 16
+}
diff --git a/configs/model_pipeline.json b/configs/model_pipeline.json
new file mode 100644
index 0000000000000000000000000000000000000000..7f2f62b0a4f9f360e03b7ce3969e46ec1d5c65df
--- /dev/null
+++ b/configs/model_pipeline.json
@@ -0,0 +1,179 @@
+{
+ "data":
+ {
+ "t2v": {
+ "wan2.1-1.3B": {
+ "single_stage": {
+ "pipeline": {
+ "inputs": [],
+ "outputs": ["output_video"]
+ }
+ },
+ "multi_stage": {
+ "text_encoder": {
+ "inputs": [],
+ "outputs": ["text_encoder_output"]
+ },
+ "dit": {
+ "inputs": ["text_encoder_output"],
+ "outputs": ["latents"]
+ },
+ "vae_decoder": {
+ "inputs": ["latents"],
+ "outputs": ["output_video"]
+ }
+ }
+ },
+ "self-forcing-dmd": {
+ "single_stage": {
+ "pipeline": {
+ "inputs": [],
+ "outputs": ["output_video"]
+ }
+ }
+ }
+ },
+ "i2v": {
+ "wan2.1-14B-480P": {
+ "single_stage": {
+ "pipeline": {
+ "inputs": ["input_image"],
+ "outputs": ["output_video"]
+ }
+ },
+ "multi_stage": {
+ "text_encoder": {
+ "inputs": ["input_image"],
+ "outputs": ["text_encoder_output"]
+ },
+ "image_encoder": {
+ "inputs": ["input_image"],
+ "outputs": ["clip_encoder_output"]
+ },
+ "vae_encoder": {
+ "inputs": ["input_image"],
+ "outputs": ["vae_encoder_output"]
+ },
+ "dit": {
+ "inputs": [
+ "clip_encoder_output",
+ "vae_encoder_output",
+ "text_encoder_output"
+ ],
+ "outputs": ["latents"]
+ },
+ "vae_decoder": {
+ "inputs": ["latents"],
+ "outputs": ["output_video"]
+ }
+ }
+ },
+ "matrix-game2-gta-drive": {
+ "single_stage": {
+ "pipeline": {
+ "inputs": ["input_image"],
+ "outputs": ["output_video"]
+ }
+ }
+ },
+ "matrix-game2-universal": {
+ "single_stage": {
+ "pipeline": {
+ "inputs": ["input_image"],
+ "outputs": ["output_video"]
+ }
+ }
+ },
+ "matrix-game2-templerun": {
+ "single_stage": {
+ "pipeline": {
+ "inputs": ["input_image"],
+ "outputs": ["output_video"]
+ }
+ }
+ }
+ },
+ "s2v": {
+ "SekoTalk": {
+ "single_stage": {
+ "pipeline": {
+ "inputs": ["input_image", "input_audio"],
+ "outputs": ["output_video"]
+ }
+ },
+ "multi_stage": {
+ "text_encoder": {
+ "inputs": ["input_image"],
+ "outputs": ["text_encoder_output"]
+ },
+ "image_encoder": {
+ "inputs": ["input_image"],
+ "outputs": ["clip_encoder_output"]
+ },
+ "vae_encoder": {
+ "inputs": ["input_image"],
+ "outputs": ["vae_encoder_output"]
+ },
+ "segment_dit": {
+ "inputs": [
+ "input_audio",
+ "clip_encoder_output",
+ "vae_encoder_output",
+ "text_encoder_output"
+ ],
+ "outputs": ["output_video"]
+ }
+ }
+ }
+ },
+ "animate": {
+ "wan2.2_animate": {
+ "single_stage": {
+ "pipeline": {
+ "inputs": ["input_image","input_video"],
+ "outputs": ["output_video"]
+ }
+ }
+ }
+ }
+
+ },
+ "meta": {
+ "special_types": {
+ "input_image": "IMAGE",
+ "input_audio": "AUDIO",
+ "input_video": "VIDEO",
+ "latents": "TENSOR",
+ "output_video": "VIDEO"
+ },
+ "model_name_inner_to_outer": {
+ "seko_talk": "SekoTalk"
+ },
+ "model_name_outer_to_inner": {},
+ "monitor": {
+ "subtask_created_timeout": 1800,
+ "subtask_pending_timeout": 1800,
+ "subtask_running_timeouts": {
+ "t2v-wan2.1-1.3B-multi_stage-dit": 300,
+ "t2v-wan2.1-1.3B-single_stage-pipeline": 300,
+ "t2v-self-forcing-dmd-single_stage-pipeline": 300,
+ "i2v-wan2.1-14B-480P-multi_stage-dit": 600,
+ "i2v-wan2.1-14B-480P-single_stage-pipeline": 600,
+ "i2v-SekoTalk-Distill-single_stage-pipeline": 3600,
+ "i2v-SekoTalk-Distill-multi_stage-segment_dit": 3600
+ },
+ "worker_avg_window": 20,
+ "worker_offline_timeout": 5,
+ "worker_min_capacity": 20,
+ "worker_min_cnt": 1,
+ "worker_max_cnt": 10,
+ "task_timeout": 3600,
+ "schedule_ratio_high": 0.25,
+ "schedule_ratio_low": 0.02,
+ "ping_timeout": 30,
+ "user_max_active_tasks": 3,
+ "user_max_daily_tasks": 100,
+ "user_visit_frequency": 0.05
+ }
+ }
+}
diff --git a/configs/offload/block/qwen_image_i2i_2509_block.json b/configs/offload/block/qwen_image_i2i_2509_block.json
new file mode 100644
index 0000000000000000000000000000000000000000..7859a0a71e918004a639dfcd4efc68eda4a85b67
--- /dev/null
+++ b/configs/offload/block/qwen_image_i2i_2509_block.json
@@ -0,0 +1,66 @@
+{
+ "batchsize": 1,
+ "num_channels_latents": 16,
+ "vae_scale_factor": 8,
+ "infer_steps": 40,
+ "guidance_embeds": false,
+ "num_images_per_prompt": 1,
+ "vae_latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "vae_latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "vae_z_dim": 16,
+ "feature_caching": "NoCaching",
+ "transformer_in_channels": 64,
+ "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "prompt_template_encode_start_idx": 64,
+ "_auto_resize": true,
+ "num_layers": 60,
+ "attention_out_dim": 3072,
+ "attention_dim_head": 128,
+ "axes_dims_rope": [
+ 16,
+ 56,
+ 56
+ ],
+ "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
+ "attn_type": "sage_attn2",
+ "do_true_cfg": true,
+ "true_cfg_scale": 4.0,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "CONDITION_IMAGE_SIZE": 147456,
+ "USE_IMAGE_ID_IN_PROMPT": true
+}
diff --git a/configs/offload/block/qwen_image_i2i_block.json b/configs/offload/block/qwen_image_i2i_block.json
new file mode 100644
index 0000000000000000000000000000000000000000..7153b4a37070914f5d73c4f3ab4a87cdcd988949
--- /dev/null
+++ b/configs/offload/block/qwen_image_i2i_block.json
@@ -0,0 +1,66 @@
+{
+ "batchsize": 1,
+ "num_channels_latents": 16,
+ "vae_scale_factor": 8,
+ "infer_steps": 50,
+ "guidance_embeds": false,
+ "num_images_per_prompt": 1,
+ "vae_latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "vae_latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "vae_z_dim": 16,
+ "feature_caching": "NoCaching",
+ "transformer_in_channels": 64,
+ "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "prompt_template_encode_start_idx": 64,
+ "_auto_resize": true,
+ "num_layers": 60,
+ "attention_out_dim": 3072,
+ "attention_dim_head": 128,
+ "axes_dims_rope": [
+ 16,
+ 56,
+ 56
+ ],
+ "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
+ "attn_type": "flash_attn3",
+ "do_true_cfg": true,
+ "true_cfg_scale": 4.0,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "CONDITION_IMAGE_SIZE": 1048576,
+ "USE_IMAGE_ID_IN_PROMPT": false
+}
diff --git a/configs/offload/block/qwen_image_t2i_block.json b/configs/offload/block/qwen_image_t2i_block.json
new file mode 100644
index 0000000000000000000000000000000000000000..5b6313fcad6338c255398ed94932b0de01489362
--- /dev/null
+++ b/configs/offload/block/qwen_image_t2i_block.json
@@ -0,0 +1,86 @@
+{
+ "batchsize": 1,
+ "_comment": "格式: '宽高比': [width, height]",
+ "aspect_ratios": {
+ "1:1": [
+ 1328,
+ 1328
+ ],
+ "16:9": [
+ 1664,
+ 928
+ ],
+ "9:16": [
+ 928,
+ 1664
+ ],
+ "4:3": [
+ 1472,
+ 1140
+ ],
+ "3:4": [
+ 142,
+ 184
+ ]
+ },
+ "aspect_ratio": "16:9",
+ "num_channels_latents": 16,
+ "vae_scale_factor": 8,
+ "infer_steps": 50,
+ "guidance_embeds": false,
+ "num_images_per_prompt": 1,
+ "vae_latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "vae_latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "vae_z_dim": 16,
+ "feature_caching": "NoCaching",
+ "prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "prompt_template_encode_start_idx": 34,
+ "_auto_resize": false,
+ "num_layers": 60,
+ "attention_out_dim": 3072,
+ "attention_dim_head": 128,
+ "axes_dims_rope": [
+ 16,
+ 56,
+ 56
+ ],
+ "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
+ "attn_type": "flash_attn3",
+ "do_true_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block"
+}
diff --git a/configs/offload/block/wan_i2v_block.json b/configs/offload/block/wan_i2v_block.json
new file mode 100644
index 0000000000000000000000000000000000000000..4363093b7e5270d4d7bafbce10fe3b2546900001
--- /dev/null
+++ b/configs/offload/block/wan_i2v_block.json
@@ -0,0 +1,23 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f",
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "clip_cpu_offload": false
+}
diff --git a/configs/offload/block/wan_t2v_1_3b.json b/configs/offload/block/wan_t2v_1_3b.json
new file mode 100644
index 0000000000000000000000000000000000000000..e03f8a7e776510bac40ae743004eb3b416a19125
--- /dev/null
+++ b/configs/offload/block/wan_t2v_1_3b.json
@@ -0,0 +1,17 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "t5_cpu_offload": true,
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl",
+ "unload_modules": false,
+ "use_tiling_vae": false
+}
diff --git a/configs/offload/block/wan_t2v_block.json b/configs/offload/block/wan_t2v_block.json
new file mode 100644
index 0000000000000000000000000000000000000000..2f926bce9b84b77afd81ae3df585e3a3563346f1
--- /dev/null
+++ b/configs/offload/block/wan_t2v_block.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f",
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "clip_cpu_offload": false
+}
diff --git a/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json b/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json
new file mode 100644
index 0000000000000000000000000000000000000000..fa0348eed97babb7b09e4f52ca1a205d628b32f6
--- /dev/null
+++ b/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json
@@ -0,0 +1,28 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "phase",
+ "dit_quantized_ckpt": "/path/to/dit_quant_model",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-vllm",
+ "t5_cpu_offload": true,
+ "t5_quantized": true,
+ "t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth",
+ "t5_quant_scheme": "fp8",
+ "clip_quantized": true,
+ "clip_quantized_ckpt": "/path/to/clip-fp8.pth",
+ "clip_quant_scheme": "fp8",
+ "use_tiling_vae": true,
+ "use_tae": true,
+ "tae_path": "/path/to/taew2_1.pth",
+ "lazy_load": true
+}
diff --git a/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json b/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json
new file mode 100644
index 0000000000000000000000000000000000000000..0516a9572af23f0dd91013fc9b011b5343e80731
--- /dev/null
+++ b/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json
@@ -0,0 +1,30 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 1280,
+ "target_width": 720,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "phase",
+ "dit_quantized_ckpt": "/path/to/dit_quant_model",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-vllm",
+ "t5_cpu_offload": true,
+ "t5_quantized": true,
+ "t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth",
+ "t5_quant_scheme": "fp8",
+ "clip_quantized": true,
+ "clip_quantized_ckpt": "/path/to/clip-fp8.pth",
+ "clip_quant_scheme": "fp8",
+ "use_tiling_vae": true,
+ "use_tae": true,
+ "tae_path": "/path/to/taew2_1.pth",
+ "lazy_load": true,
+ "rotary_chunk": true,
+ "clean_cuda_cache": true
+}
diff --git a/configs/offload/phase/wan_i2v_phase.json b/configs/offload/phase/wan_i2v_phase.json
new file mode 100644
index 0000000000000000000000000000000000000000..c5bd698bd4345f99ed4d9f721fff98fd394c8311
--- /dev/null
+++ b/configs/offload/phase/wan_i2v_phase.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "phase",
+ "t5_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_tiling_vae": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f"
+}
diff --git a/configs/offload/phase/wan_t2v_phase.json b/configs/offload/phase/wan_t2v_phase.json
new file mode 100644
index 0000000000000000000000000000000000000000..b12036aa844b36b349f17b14ad7d0129faabcc45
--- /dev/null
+++ b/configs/offload/phase/wan_t2v_phase.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "phase",
+ "t5_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_tiling_vae": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f"
+}
diff --git a/configs/quantization/gguf/wan_i2v_q4_k.json b/configs/quantization/gguf/wan_i2v_q4_k.json
new file mode 100644
index 0000000000000000000000000000000000000000..e92f6b216c83dafca6e4e404851265f552d5ae81
--- /dev/null
+++ b/configs/quantization/gguf/wan_i2v_q4_k.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "dit_quantized": true,
+ "dit_quant_scheme": "gguf-Q4_K_S"
+}
diff --git a/configs/quantization/wan_i2v.json b/configs/quantization/wan_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..7624495e62fc56b9824d09303a91668c3b959eed
--- /dev/null
+++ b/configs/quantization/wan_i2v.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "dit_quantized_ckpt": "/path/to/int8/model",
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-vllm"
+}
diff --git a/configs/quantization/wan_i2v_q8f.json b/configs/quantization/wan_i2v_q8f.json
new file mode 100644
index 0000000000000000000000000000000000000000..15c5b1245e88dc225f6c5e3e8f3d484e2ddb26ce
--- /dev/null
+++ b/configs/quantization/wan_i2v_q8f.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "int8-q8f"
+}
diff --git a/configs/quantization/wan_i2v_torchao.json b/configs/quantization/wan_i2v_torchao.json
new file mode 100644
index 0000000000000000000000000000000000000000..ca561b87cc079e6c43a7b2d11f38c8c92cc485e6
--- /dev/null
+++ b/configs/quantization/wan_i2v_torchao.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-torchao",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-torchao",
+ "clip_quantized": true,
+ "clip_quant_scheme": "int8-torchao"
+}
diff --git a/configs/qwen_image/qwen_image_i2i.json b/configs/qwen_image/qwen_image_i2i.json
new file mode 100644
index 0000000000000000000000000000000000000000..0e0fd3881b53caeab6abc17b903f7241fffbfee2
--- /dev/null
+++ b/configs/qwen_image/qwen_image_i2i.json
@@ -0,0 +1,64 @@
+{
+ "batchsize": 1,
+ "num_channels_latents": 16,
+ "vae_scale_factor": 8,
+ "infer_steps": 50,
+ "guidance_embeds": false,
+ "num_images_per_prompt": 1,
+ "vae_latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "vae_latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "vae_z_dim": 16,
+ "feature_caching": "NoCaching",
+ "transformer_in_channels": 64,
+ "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "prompt_template_encode_start_idx": 64,
+ "_auto_resize": true,
+ "num_layers": 60,
+ "attention_out_dim": 3072,
+ "attention_dim_head": 128,
+ "axes_dims_rope": [
+ 16,
+ 56,
+ 56
+ ],
+ "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
+ "attn_type": "flash_attn3",
+ "do_true_cfg": true,
+ "true_cfg_scale": 4.0,
+ "CONDITION_IMAGE_SIZE": 1048576,
+ "USE_IMAGE_ID_IN_PROMPT": false
+}
diff --git a/configs/qwen_image/qwen_image_i2i_2509.json b/configs/qwen_image/qwen_image_i2i_2509.json
new file mode 100644
index 0000000000000000000000000000000000000000..974988e428650ab9320852bd4edbb9f5551e092b
--- /dev/null
+++ b/configs/qwen_image/qwen_image_i2i_2509.json
@@ -0,0 +1,64 @@
+{
+ "batchsize": 1,
+ "num_channels_latents": 16,
+ "vae_scale_factor": 8,
+ "infer_steps": 40,
+ "guidance_embeds": false,
+ "num_images_per_prompt": 1,
+ "vae_latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "vae_latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "vae_z_dim": 16,
+ "feature_caching": "NoCaching",
+ "transformer_in_channels": 64,
+ "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "prompt_template_encode_start_idx": 64,
+ "_auto_resize": true,
+ "num_layers": 60,
+ "attention_out_dim": 3072,
+ "attention_dim_head": 128,
+ "axes_dims_rope": [
+ 16,
+ 56,
+ 56
+ ],
+ "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
+ "attn_type": "flash_attn3",
+ "do_true_cfg": true,
+ "true_cfg_scale": 4.0,
+ "CONDITION_IMAGE_SIZE": 147456,
+ "USE_IMAGE_ID_IN_PROMPT": true
+}
diff --git a/configs/qwen_image/qwen_image_i2i_2509_quant.json b/configs/qwen_image/qwen_image_i2i_2509_quant.json
new file mode 100644
index 0000000000000000000000000000000000000000..c6e73581288181269c5da3a917190d591db8986a
--- /dev/null
+++ b/configs/qwen_image/qwen_image_i2i_2509_quant.json
@@ -0,0 +1,67 @@
+{
+ "batchsize": 1,
+ "num_channels_latents": 16,
+ "vae_scale_factor": 8,
+ "infer_steps": 40,
+ "guidance_embeds": false,
+ "num_images_per_prompt": 1,
+ "vae_latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "vae_latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "vae_z_dim": 16,
+ "feature_caching": "NoCaching",
+ "transformer_in_channels": 64,
+ "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "prompt_template_encode_start_idx": 64,
+ "_auto_resize": true,
+ "num_layers": 60,
+ "attention_out_dim": 3072,
+ "attention_dim_head": 128,
+ "axes_dims_rope": [
+ 16,
+ 56,
+ 56
+ ],
+ "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
+ "attn_type": "flash_attn3",
+ "do_true_cfg": true,
+ "true_cfg_scale": 4.0,
+ "CONDITION_IMAGE_SIZE": 147456,
+ "USE_IMAGE_ID_IN_PROMPT": true,
+ "dit_quantized": true,
+ "dit_quantized_ckpt": "/path/to/qwen_2509_fp8.safetensors",
+ "dit_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/qwen_image/qwen_image_i2i_lora.json b/configs/qwen_image/qwen_image_i2i_lora.json
new file mode 100644
index 0000000000000000000000000000000000000000..b0bbc49282aef7b54fe32d09729b495ae697b5aa
--- /dev/null
+++ b/configs/qwen_image/qwen_image_i2i_lora.json
@@ -0,0 +1,70 @@
+{
+ "batchsize": 1,
+ "num_channels_latents": 16,
+ "vae_scale_factor": 8,
+ "infer_steps": 8,
+ "guidance_embeds": false,
+ "num_images_per_prompt": 1,
+ "vae_latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "vae_latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "vae_z_dim": 16,
+ "feature_caching": "NoCaching",
+ "transformer_in_channels": 64,
+ "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "prompt_template_encode_start_idx": 64,
+ "_auto_resize": true,
+ "num_layers": 60,
+ "attention_out_dim": 3072,
+ "attention_dim_head": 128,
+ "axes_dims_rope": [
+ 16,
+ 56,
+ 56
+ ],
+ "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
+ "attn_type": "flash_attn3",
+ "do_true_cfg": true,
+ "true_cfg_scale": 4.0,
+ "CONDITION_IMAGE_SIZE": 1048576,
+ "USE_IMAGE_ID_IN_PROMPT": false,
+ "lora_configs": [
+ {
+ "path": "/path/to/Qwen-Image-Edit-Lightning-4steps-V1.0.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/qwen_image/qwen_image_t2i.json b/configs/qwen_image/qwen_image_t2i.json
new file mode 100644
index 0000000000000000000000000000000000000000..5581d50a2d5637396485114385b5dc40e9568461
--- /dev/null
+++ b/configs/qwen_image/qwen_image_t2i.json
@@ -0,0 +1,85 @@
+{
+ "batchsize": 1,
+ "_comment": "格式: '宽高比': [width, height]",
+ "aspect_ratios": {
+ "1:1": [
+ 1328,
+ 1328
+ ],
+ "16:9": [
+ 1664,
+ 928
+ ],
+ "9:16": [
+ 928,
+ 1664
+ ],
+ "4:3": [
+ 1472,
+ 1140
+ ],
+ "3:4": [
+ 142,
+ 184
+ ]
+ },
+ "aspect_ratio": "16:9",
+ "num_channels_latents": 16,
+ "vae_scale_factor": 8,
+ "infer_steps": 50,
+ "guidance_embeds": false,
+ "num_images_per_prompt": 1,
+ "vae_latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "vae_latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "vae_z_dim": 16,
+ "feature_caching": "NoCaching",
+ "prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "prompt_template_encode_start_idx": 34,
+ "_auto_resize": false,
+ "num_layers": 60,
+ "attention_out_dim": 3072,
+ "attention_dim_head": 128,
+ "axes_dims_rope": [
+ 16,
+ 56,
+ 56
+ ],
+ "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
+ "attn_type": "flash_attn3",
+ "do_true_cfg": true,
+ "true_cfg_scale": 4.0
+}
diff --git a/configs/seko_talk/5090/seko_talk_5090_bf16.json b/configs/seko_talk/5090/seko_talk_5090_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..a6cab78438410e463f554427fb9c6653d19875be
--- /dev/null
+++ b/configs/seko_talk/5090/seko_talk_5090_bf16.json
@@ -0,0 +1,23 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn3",
+ "cross_attn_1_type": "sage_attn3",
+ "cross_attn_2_type": "sage_attn3",
+ "sample_guide_scale": 1,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 1,
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": false,
+ "audio_encoder_cpu_offload": false,
+ "audio_adapter_cpu_offload": false,
+ "vae_cpu_offload": false
+}
diff --git a/configs/seko_talk/5090/seko_talk_5090_int8.json b/configs/seko_talk/5090/seko_talk_5090_int8.json
new file mode 100644
index 0000000000000000000000000000000000000000..664983b908669400d34f1ff8bcb14683a6e1d1c9
--- /dev/null
+++ b/configs/seko_talk/5090/seko_talk_5090_int8.json
@@ -0,0 +1,29 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn3",
+ "cross_attn_1_type": "sage_attn3",
+ "cross_attn_2_type": "sage_attn3",
+ "sample_guide_scale": 1,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 1,
+ "t5_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "audio_encoder_cpu_offload": false,
+ "audio_adapter_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-q8f",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-q8f"
+}
diff --git a/configs/seko_talk/5090/seko_talk_5090_int8_8gpu.json b/configs/seko_talk/5090/seko_talk_5090_int8_8gpu.json
new file mode 100644
index 0000000000000000000000000000000000000000..3dfc9d533204841113ddd21db22f9095f34a7ae7
--- /dev/null
+++ b/configs/seko_talk/5090/seko_talk_5090_int8_8gpu.json
@@ -0,0 +1,34 @@
+
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn3",
+ "cross_attn_1_type": "sage_attn3",
+ "cross_attn_2_type": "sage_attn3",
+ "sample_guide_scale": 1,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 1,
+ "t5_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "audio_encoder_cpu_offload": false,
+ "audio_adapter_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-q8f",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-q8f",
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses-4090"
+ }
+}
diff --git a/configs/seko_talk/A800/seko_talk_A800_int8.json b/configs/seko_talk/A800/seko_talk_A800_int8.json
new file mode 100644
index 0000000000000000000000000000000000000000..a5da24c2a907ea377a3403d4fb14b826814d07cf
--- /dev/null
+++ b/configs/seko_talk/A800/seko_talk_A800_int8.json
@@ -0,0 +1,22 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-vllm",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-vllm",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-vllm"
+}
diff --git a/configs/seko_talk/A800/seko_talk_A800_int8_dist_2gpu.json b/configs/seko_talk/A800/seko_talk_A800_int8_dist_2gpu.json
new file mode 100644
index 0000000000000000000000000000000000000000..f99ba12190fcc7dbc58b4dd9d60dacfa7f66e970
--- /dev/null
+++ b/configs/seko_talk/A800/seko_talk_A800_int8_dist_2gpu.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-vllm",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-vllm",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-vllm",
+ "parallel": {
+ "seq_p_size": 2,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/A800/seko_talk_A800_int8_dist_4gpu.json b/configs/seko_talk/A800/seko_talk_A800_int8_dist_4gpu.json
new file mode 100644
index 0000000000000000000000000000000000000000..c0d0c45df8bf7061170987f1815047c38ec72ec0
--- /dev/null
+++ b/configs/seko_talk/A800/seko_talk_A800_int8_dist_4gpu.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-vllm",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-vllm",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-vllm",
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/A800/seko_talk_A800_int8_dist_8gpu.json b/configs/seko_talk/A800/seko_talk_A800_int8_dist_8gpu.json
new file mode 100644
index 0000000000000000000000000000000000000000..278d37664a0e79c0b76492c1e03982cdb291dd68
--- /dev/null
+++ b/configs/seko_talk/A800/seko_talk_A800_int8_dist_8gpu.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-vllm",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-vllm",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-vllm",
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/L40s/1gpu/seko_talk_bf16.json b/configs/seko_talk/L40s/1gpu/seko_talk_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..fd1afa4d23e8d77779d02481c081fb2ab6d16a87
--- /dev/null
+++ b/configs/seko_talk/L40s/1gpu/seko_talk_bf16.json
@@ -0,0 +1,23 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 0.8,
+ "t5_cpu_offload": false,
+ "clip_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "audio_encoder_cpu_offload": false,
+ "audio_adapter_cpu_offload": false
+}
diff --git a/configs/seko_talk/L40s/1gpu/seko_talk_fp8.json b/configs/seko_talk/L40s/1gpu/seko_talk_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..d88f4542282f4e1e2b3fc2b37c1f37d3cfbf8e60
--- /dev/null
+++ b/configs/seko_talk/L40s/1gpu/seko_talk_fp8.json
@@ -0,0 +1,27 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8",
+ "cpu_offload": false,
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": true,
+ "vae_cpu_offload": true,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": true
+}
diff --git a/configs/seko_talk/L40s/2gpu/seko_talk_bf16.json b/configs/seko_talk/L40s/2gpu/seko_talk_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..4cd8a3f099a5a887ad9d500224a22ef24c396740
--- /dev/null
+++ b/configs/seko_talk/L40s/2gpu/seko_talk_bf16.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": false,
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": true,
+ "vae_cpu_offload": true,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": true,
+ "parallel": {
+ "seq_p_size": 2,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/L40s/2gpu/seko_talk_fp8.json b/configs/seko_talk/L40s/2gpu/seko_talk_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..b45e87f0d8786fc3797dcf21c4e0fd7562f6fb08
--- /dev/null
+++ b/configs/seko_talk/L40s/2gpu/seko_talk_fp8.json
@@ -0,0 +1,31 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8",
+ "cpu_offload": false,
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": true,
+ "vae_cpu_offload": true,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": true,
+ "parallel": {
+ "seq_p_size": 2,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/L40s/4gpu/seko_talk_bf16.json b/configs/seko_talk/L40s/4gpu/seko_talk_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..a3f7b6fb3b38f135dc2f2ead5514f3eebfac2526
--- /dev/null
+++ b/configs/seko_talk/L40s/4gpu/seko_talk_bf16.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": false,
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": true,
+ "vae_cpu_offload": true,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": true,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/L40s/4gpu/seko_talk_fp8.json b/configs/seko_talk/L40s/4gpu/seko_talk_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..ec682a6e0a48c7be36c4769435d2d3a004cd8f5d
--- /dev/null
+++ b/configs/seko_talk/L40s/4gpu/seko_talk_fp8.json
@@ -0,0 +1,31 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8",
+ "cpu_offload": false,
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": true,
+ "vae_cpu_offload": true,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": true,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/L40s/8gpu/seko_talk_bf16.json b/configs/seko_talk/L40s/8gpu/seko_talk_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..079439fe1b3fee28435fecb837f3638145814422
--- /dev/null
+++ b/configs/seko_talk/L40s/8gpu/seko_talk_bf16.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": false,
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": true,
+ "vae_cpu_offload": true,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": true,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/L40s/8gpu/seko_talk_fp8.json b/configs/seko_talk/L40s/8gpu/seko_talk_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..d50f7ebb896e29d762830b20e67d0c14a30c318e
--- /dev/null
+++ b/configs/seko_talk/L40s/8gpu/seko_talk_fp8.json
@@ -0,0 +1,31 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8",
+ "cpu_offload": false,
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": true,
+ "vae_cpu_offload": true,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": true,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/mlu/seko_talk_bf16.json b/configs/seko_talk/mlu/seko_talk_bf16.json
new file mode 100644
index 0000000000000000000000000000000000000000..71741879328b0e9425bc00a9fc3c0c339791dc9e
--- /dev/null
+++ b/configs/seko_talk/mlu/seko_talk_bf16.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "mlu_sage_attn",
+ "cross_attn_1_type": "mlu_sage_attn",
+ "cross_attn_2_type": "mlu_sage_attn",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "rope_type": "torch",
+ "modulate_type": "torch"
+}
diff --git a/configs/seko_talk/mlu/seko_talk_int8.json b/configs/seko_talk/mlu/seko_talk_int8.json
new file mode 100644
index 0000000000000000000000000000000000000000..673acf8619199f07eb251a5f9a5a08296f07e2d3
--- /dev/null
+++ b/configs/seko_talk/mlu/seko_talk_int8.json
@@ -0,0 +1,29 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "mlu_sage_attn",
+ "cross_attn_1_type": "mlu_sage_attn",
+ "cross_attn_2_type": "mlu_sage_attn",
+ "seed": 42,
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "clip_quantized": false,
+ "clip_quant_scheme": "int8-tmo",
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-tmo",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-tmo",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-tmo",
+ "modulate_type": "torch",
+ "rope_type": "torch",
+ "ln_type": "Default",
+ "rms_type": "Default"
+}
diff --git a/configs/seko_talk/mlu/seko_talk_int8_dist.json b/configs/seko_talk/mlu/seko_talk_int8_dist.json
new file mode 100644
index 0000000000000000000000000000000000000000..fe5f515260e81d5f1d444be919afcd8f04b1493f
--- /dev/null
+++ b/configs/seko_talk/mlu/seko_talk_int8_dist.json
@@ -0,0 +1,33 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "mlu_sage_attn",
+ "cross_attn_1_type": "mlu_sage_attn",
+ "cross_attn_2_type": "mlu_sage_attn",
+ "seed": 42,
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "clip_quantized": false,
+ "clip_quant_scheme": "int8-tmo",
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-tmo",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-tmo",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-tmo",
+ "modulate_type": "torch",
+ "rope_type": "torch",
+ "ln_type": "Default",
+ "rms_type": "Default",
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/multi_person/01_base.json b/configs/seko_talk/multi_person/01_base.json
new file mode 100644
index 0000000000000000000000000000000000000000..417f62ce1caa0175ba79a13f76634a7224b0a621
--- /dev/null
+++ b/configs/seko_talk/multi_person/01_base.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false
+}
diff --git a/configs/seko_talk/multi_person/02_base_fp8.json b/configs/seko_talk/multi_person/02_base_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..232b3987ea72a1979dd1d3af77252006847a5644
--- /dev/null
+++ b/configs/seko_talk/multi_person/02_base_fp8.json
@@ -0,0 +1,22 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8"
+}
diff --git a/configs/seko_talk/multi_person/03_dist.json b/configs/seko_talk/multi_person/03_dist.json
new file mode 100644
index 0000000000000000000000000000000000000000..5cc4cbd6e57a42c61b4683ed63e3d20ebdc237eb
--- /dev/null
+++ b/configs/seko_talk/multi_person/03_dist.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/multi_person/04_dist_fp8.json b/configs/seko_talk/multi_person/04_dist_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..f91b3fd57af5084d443a207607f59d18044609b3
--- /dev/null
+++ b/configs/seko_talk/multi_person/04_dist_fp8.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8",
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/multi_person/15_base_compile.json b/configs/seko_talk/multi_person/15_base_compile.json
new file mode 100644
index 0000000000000000000000000000000000000000..73cb92aacc0ecf1556d67c49a2fc8342777f3b37
--- /dev/null
+++ b/configs/seko_talk/multi_person/15_base_compile.json
@@ -0,0 +1,27 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "compile": true,
+ "compile_shapes": [
+ [
+ 480,
+ 832
+ ],
+ [
+ 720,
+ 1280
+ ]
+ ]
+}
diff --git a/configs/seko_talk/seko_talk_01_base.json b/configs/seko_talk/seko_talk_01_base.json
new file mode 100644
index 0000000000000000000000000000000000000000..417f62ce1caa0175ba79a13f76634a7224b0a621
--- /dev/null
+++ b/configs/seko_talk/seko_talk_01_base.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false
+}
diff --git a/configs/seko_talk/seko_talk_02_fp8.json b/configs/seko_talk/seko_talk_02_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..15cdbc046e57090cb3265f0b9489806e55438d21
--- /dev/null
+++ b/configs/seko_talk/seko_talk_02_fp8.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-sgl",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_03_dist.json b/configs/seko_talk/seko_talk_03_dist.json
new file mode 100644
index 0000000000000000000000000000000000000000..f7543370eac394efa72ed618ff65bd4389c62ffe
--- /dev/null
+++ b/configs/seko_talk/seko_talk_03_dist.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ }
+}
diff --git a/configs/seko_talk/seko_talk_04_fp8_dist.json b/configs/seko_talk/seko_talk_04_fp8_dist.json
new file mode 100644
index 0000000000000000000000000000000000000000..443295169fec8be2aaaa07e89f898092db61bb6d
--- /dev/null
+++ b/configs/seko_talk/seko_talk_04_fp8_dist.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ },
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_05_offload_fp8_4090.json b/configs/seko_talk/seko_talk_05_offload_fp8_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..96f8629d1c3a2e7614ed54eb51b0a8171784743a
--- /dev/null
+++ b/configs/seko_talk/seko_talk_05_offload_fp8_4090.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 1,
+ "t5_cpu_offload": false,
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_cpu_offload": false,
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f",
+ "audio_encoder_cpu_offload": false,
+ "audio_adapter_cpu_offload": false,
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-q8f",
+ "vae_cpu_offload": false,
+ "use_tiling_vae": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f"
+}
diff --git a/configs/seko_talk/seko_talk_05_offload_fp8_4090_dist.json b/configs/seko_talk/seko_talk_05_offload_fp8_4090_dist.json
new file mode 100644
index 0000000000000000000000000000000000000000..a2d66f64d4c78b4b131960624780c83f8ed4fa22
--- /dev/null
+++ b/configs/seko_talk/seko_talk_05_offload_fp8_4090_dist.json
@@ -0,0 +1,36 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 1,
+ "t5_cpu_offload": false,
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_cpu_offload": false,
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f",
+ "audio_encoder_cpu_offload": false,
+ "audio_adapter_cpu_offload": false,
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-q8f",
+ "vae_cpu_offload": false,
+ "use_tiling_vae": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses-4090"
+ }
+}
diff --git a/configs/seko_talk/seko_talk_06_offload_fp8_H100.json b/configs/seko_talk/seko_talk_06_offload_fp8_H100.json
new file mode 100644
index 0000000000000000000000000000000000000000..d5a5ae9b3948234a362576a886ead22a2e0a43df
--- /dev/null
+++ b/configs/seko_talk/seko_talk_06_offload_fp8_H100.json
@@ -0,0 +1,30 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "offload_ratio": 1,
+ "t5_cpu_offload": true,
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl",
+ "clip_cpu_offload": false,
+ "audio_encoder_cpu_offload": false,
+ "audio_adapter_cpu_offload": false,
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "vae_cpu_offload": false,
+ "use_tiling_vae": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_07_dist_offload.json b/configs/seko_talk/seko_talk_07_dist_offload.json
new file mode 100644
index 0000000000000000000000000000000000000000..b6a64eaa5c67e2b1c83b6cb0287b85496a9816bf
--- /dev/null
+++ b/configs/seko_talk/seko_talk_07_dist_offload.json
@@ -0,0 +1,28 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ },
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": true,
+ "clip_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "offload_ratio": 1,
+ "use_tiling_vae": true,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": false
+}
diff --git a/configs/seko_talk/seko_talk_08_5B_base.json b/configs/seko_talk/seko_talk_08_5B_base.json
new file mode 100644
index 0000000000000000000000000000000000000000..855670fd5827e2ad0e39c8f557b28e2f444d01df
--- /dev/null
+++ b/configs/seko_talk/seko_talk_08_5B_base.json
@@ -0,0 +1,31 @@
+{
+ "infer_steps": 4,
+ "target_fps": 24,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 121,
+ "resize_mode": "adaptive",
+ "text_len": 512,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "fps": 24,
+ "use_image_encoder": false,
+ "use_31_block": false,
+ "lora_configs": [
+ {
+ "path": "/mnt/aigc/rtxiang/pretrain/qianhai_weights/lora_model.safetensors",
+ "strength": 0.125
+ }
+ ]
+}
diff --git a/configs/seko_talk/seko_talk_09_base_fixed_min_area.json b/configs/seko_talk/seko_talk_09_base_fixed_min_area.json
new file mode 100644
index 0000000000000000000000000000000000000000..2b04ae5ee3c1c4f6a0b54a3c8f9044e05cd2ea0e
--- /dev/null
+++ b/configs/seko_talk/seko_talk_09_base_fixed_min_area.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "fixed_min_area",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false
+}
diff --git a/configs/seko_talk/seko_talk_10_fp8_dist_fixed_min_area.json b/configs/seko_talk/seko_talk_10_fp8_dist_fixed_min_area.json
new file mode 100644
index 0000000000000000000000000000000000000000..27f8afe412df886de87aff1dbcd1a9bb465a1037
--- /dev/null
+++ b/configs/seko_talk/seko_talk_10_fp8_dist_fixed_min_area.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "fixed_min_area",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ },
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_11_fp8_dist_fixed_shape.json b/configs/seko_talk/seko_talk_11_fp8_dist_fixed_shape.json
new file mode 100644
index 0000000000000000000000000000000000000000..0ea9de6f852dd2b09cce33976b77181a9de949c6
--- /dev/null
+++ b/configs/seko_talk/seko_talk_11_fp8_dist_fixed_shape.json
@@ -0,0 +1,30 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "fixed_shape",
+ "fixed_shape": [
+ 240,
+ 320
+ ],
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ },
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.json b/configs/seko_talk/seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.json
new file mode 100644
index 0000000000000000000000000000000000000000..6952e86f1e592926ee0546af99efc82ffad678c0
--- /dev/null
+++ b/configs/seko_talk/seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.json
@@ -0,0 +1,31 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 17,
+ "prev_frame_length": 1,
+ "resize_mode": "fixed_shape",
+ "fixed_shape": [
+ 480,
+ 480
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ },
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.json b/configs/seko_talk/seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.json
new file mode 100644
index 0000000000000000000000000000000000000000..7c9abea232a2894375a0a8186d0881df8864c2ab
--- /dev/null
+++ b/configs/seko_talk/seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.json
@@ -0,0 +1,62 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "bucket_shape": {
+ "0.667": [
+ [
+ 480,
+ 832
+ ],
+ [
+ 544,
+ 960
+ ]
+ ],
+ "1.500": [
+ [
+ 832,
+ 480
+ ],
+ [
+ 960,
+ 544
+ ]
+ ],
+ "1.000": [
+ [
+ 480,
+ 480
+ ],
+ [
+ 576,
+ 576
+ ],
+ [
+ 704,
+ 704
+ ]
+ ]
+ },
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ },
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.json b/configs/seko_talk/seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.json
new file mode 100644
index 0000000000000000000000000000000000000000..9ad4fd7330a0f71d17c9100ca54cb96430fb0cd8
--- /dev/null
+++ b/configs/seko_talk/seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.json
@@ -0,0 +1,63 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 17,
+ "prev_frame_length": 1,
+ "resize_mode": "adaptive",
+ "bucket_shape": {
+ "0.667": [
+ [
+ 480,
+ 832
+ ],
+ [
+ 544,
+ 960
+ ]
+ ],
+ "1.500": [
+ [
+ 832,
+ 480
+ ],
+ [
+ 960,
+ 544
+ ]
+ ],
+ "1.000": [
+ [
+ 480,
+ 480
+ ],
+ [
+ 576,
+ 576
+ ],
+ [
+ 704,
+ 704
+ ]
+ ]
+ },
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ },
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_15_base_compile.json b/configs/seko_talk/seko_talk_15_base_compile.json
new file mode 100644
index 0000000000000000000000000000000000000000..19547bc498b76474a1904ef6a8c095143153c795
--- /dev/null
+++ b/configs/seko_talk/seko_talk_15_base_compile.json
@@ -0,0 +1,59 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "compile": true,
+ "compile_shapes": [
+ [
+ 480,
+ 832
+ ],
+ [
+ 544,
+ 960
+ ],
+ [
+ 720,
+ 1280
+ ],
+ [
+ 832,
+ 480
+ ],
+ [
+ 960,
+ 544
+ ],
+ [
+ 1280,
+ 720
+ ],
+ [
+ 480,
+ 480
+ ],
+ [
+ 576,
+ 576
+ ],
+ [
+ 704,
+ 704
+ ],
+ [
+ 960,
+ 960
+ ]
+ ]
+}
diff --git a/configs/seko_talk/seko_talk_16_fp8_dist_compile.json b/configs/seko_talk/seko_talk_16_fp8_dist_compile.json
new file mode 100644
index 0000000000000000000000000000000000000000..f2ef9e412c96358f7343ac3a938fe77a06658cfc
--- /dev/null
+++ b/configs/seko_talk/seko_talk_16_fp8_dist_compile.json
@@ -0,0 +1,71 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ },
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-sgl",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl",
+ "compile": true,
+ "compile_shapes": [
+ [
+ 480,
+ 832
+ ],
+ [
+ 544,
+ 960
+ ],
+ [
+ 720,
+ 1280
+ ],
+ [
+ 832,
+ 480
+ ],
+ [
+ 960,
+ 544
+ ],
+ [
+ 1280,
+ 720
+ ],
+ [
+ 480,
+ 480
+ ],
+ [
+ 576,
+ 576
+ ],
+ [
+ 704,
+ 704
+ ],
+ [
+ 960,
+ 960
+ ]
+ ]
+}
diff --git a/configs/seko_talk/seko_talk_17_base_vsr.json b/configs/seko_talk/seko_talk_17_base_vsr.json
new file mode 100644
index 0000000000000000000000000000000000000000..7f39cc56842a2d68120419afdb67cef7ad5df105
--- /dev/null
+++ b/configs/seko_talk/seko_talk_17_base_vsr.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 2,
+ "target_fps": 25,
+ "video_duration": 1,
+ "audio_sr": 16000,
+ "target_video_length": 25,
+ "resize_mode": "fixed_shape",
+ "fixed_shape": [
+ 192,
+ 320
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "video_super_resolution": {
+ "scale": 2.0,
+ "seed": 0,
+ "model_path": "/base_code/FlashVSR/examples/WanVSR/FlashVSR"
+ }
+}
diff --git a/configs/seko_talk/seko_talk_22_nbhd_attn.json b/configs/seko_talk/seko_talk_22_nbhd_attn.json
new file mode 100644
index 0000000000000000000000000000000000000000..f1ad262450f66f208c2b019c30ef6eeda79b8b79
--- /dev/null
+++ b/configs/seko_talk/seko_talk_22_nbhd_attn.json
@@ -0,0 +1,16 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "nbhd_attn",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false
+}
diff --git a/configs/seko_talk/seko_talk_23_fp8_dist_nbhd_attn.json b/configs/seko_talk/seko_talk_23_fp8_dist_nbhd_attn.json
new file mode 100644
index 0000000000000000000000000000000000000000..fba7ca4b8cf73f75501f2ac0476cc7afd045cdf4
--- /dev/null
+++ b/configs/seko_talk/seko_talk_23_fp8_dist_nbhd_attn.json
@@ -0,0 +1,28 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "nbhd_attn",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ },
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-sgl",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/seko_talk/seko_talk_24_fp8_dist_compile_nbhd_attn.json b/configs/seko_talk/seko_talk_24_fp8_dist_compile_nbhd_attn.json
new file mode 100644
index 0000000000000000000000000000000000000000..cc812be7e4d0a6343edb55d4e594716edaefd41b
--- /dev/null
+++ b/configs/seko_talk/seko_talk_24_fp8_dist_compile_nbhd_attn.json
@@ -0,0 +1,74 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 360,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "nbhd_attn",
+ "nbhd_attn_setting": {
+ "coefficient": [1.0, 0.25, 0.056]
+ },
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_attn_type": "ulysses"
+ },
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-sgl",
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl",
+ "compile": true,
+ "compile_shapes": [
+ [
+ 480,
+ 832
+ ],
+ [
+ 544,
+ 960
+ ],
+ [
+ 720,
+ 1280
+ ],
+ [
+ 832,
+ 480
+ ],
+ [
+ 960,
+ 544
+ ],
+ [
+ 1280,
+ 720
+ ],
+ [
+ 480,
+ 480
+ ],
+ [
+ 576,
+ 576
+ ],
+ [
+ 704,
+ 704
+ ],
+ [
+ 960,
+ 960
+ ]
+ ]
+}
diff --git a/configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json b/configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
new file mode 100644
index 0000000000000000000000000000000000000000..a509d5efbf96c91c46808b069ec7e59435b14493
--- /dev/null
+++ b/configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json
@@ -0,0 +1,40 @@
+{
+ "infer_steps": 2,
+ "target_fps": 16,
+ "video_duration": 5,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 1,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "use_31_block": false,
+ "cpu_offload": false,
+ "offload_granularity": "block",
+ "offload_ratio": 1,
+ "t5_cpu_offload": true,
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-q8f",
+ "clip_cpu_offload": true,
+ "clip_quantized": false,
+ "audio_encoder_cpu_offload": true,
+ "audio_adapter_cpu_offload": true,
+ "adapter_quantized": true,
+ "adapter_quant_scheme": "int8-q8f",
+ "vae_cpu_offload": true,
+ "use_tiling_vae": false,
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-q8f",
+ "resize_mode": "fixed_shape",
+ "fixed_shape": [
+ 832,
+ 480
+ ],
+ "parallel": {
+ "seq_p_size": 8,
+ "seq_p_fp8_comm": true,
+ "seq_p_attn_type": "ulysses-4090"
+ }
+}
diff --git a/configs/seko_talk/seko_talk_28_f2v.json b/configs/seko_talk/seko_talk_28_f2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..cf9bebdfa90eb0efd164da7924f2cd022446b03a
--- /dev/null
+++ b/configs/seko_talk/seko_talk_28_f2v.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 4,
+ "target_fps": 16,
+ "video_duration": 12,
+ "audio_sr": 16000,
+ "target_video_length": 81,
+ "prev_frame_length": 1,
+ "resize_mode": "adaptive",
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 1.0,
+ "sample_shift": 5,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "use_31_block": false,
+ "f2v_process": true,
+ "lora_configs": [
+ {
+ "path": "lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/self_forcing/wan_t2v_sf.json b/configs/self_forcing/wan_t2v_sf.json
new file mode 100644
index 0000000000000000000000000000000000000000..df541f92c8fbf80554fef6a7227648d59ded74f4
--- /dev/null
+++ b/configs/self_forcing/wan_t2v_sf.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn2",
+ "cross_attn_1_type": "flash_attn2",
+ "cross_attn_2_type": "flash_attn2",
+ "seed": 0,
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "sf_config": {
+ "sf_type": "dmd",
+ "local_attn_size": -1,
+ "shift": 5.0,
+ "num_frame_per_block": 3,
+ "num_transformer_blocks": 30,
+ "frame_seq_length": 1560,
+ "num_output_frames": 21,
+ "num_inference_steps": 1000,
+ "denoising_step_list": [1000.0000, 937.5000, 833.3333, 625.0000]
+ }
+}
diff --git a/configs/sparse_attn/spas_sage_attn/wan_i2v.json b/configs/sparse_attn/spas_sage_attn/wan_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..5fffe94bb689118c5123c486d2849559d2174b55
--- /dev/null
+++ b/configs/sparse_attn/spas_sage_attn/wan_i2v.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "sage_attn",
+ "cross_attn_1_type": "sage_attn",
+ "cross_attn_2_type": "sage_attn",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/sparse_attn/spas_sage_attn/wan_t2v.json b/configs/sparse_attn/spas_sage_attn/wan_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..d2d045647a86d2d3afef015936936c3867b66b43
--- /dev/null
+++ b/configs/sparse_attn/spas_sage_attn/wan_t2v.json
@@ -0,0 +1,14 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "spas_sage_attn",
+ "cross_attn_1_type": "spas_sage_attn",
+ "cross_attn_2_type": "spas_sage_attn",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/video_frame_interpolation/wan_t2v.json b/configs/video_frame_interpolation/wan_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..02cc5498142a3f1be298cf02cba5faf615fe7e1b
--- /dev/null
+++ b/configs/video_frame_interpolation/wan_t2v.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "video_frame_interpolation": {
+ "algo": "rife",
+ "target_fps": 24,
+ "model_path": "/path to flownet.pkl"
+ }
+}
diff --git a/configs/volcengine_voices_list.json b/configs/volcengine_voices_list.json
new file mode 100644
index 0000000000000000000000000000000000000000..51316103a23d05e99cfdfa6328104c94dcc9b13b
--- /dev/null
+++ b/configs/volcengine_voices_list.json
@@ -0,0 +1,4727 @@
+{
+ "voices": [
+ {
+ "name": "Vivi 2.0",
+ "voice_type": "zh_female_vv_uranus_bigtts",
+ "gender": "female",
+ "version": "2.0",
+ "resource_id": "seed-tts-2.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "可爱女生",
+ "voice_type": "ICL_zh_female_keainvsheng_tob",
+ "gender": "female",
+ "version": "2.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "大壹",
+ "voice_type": "zh_male_dayi_saturn_bigtts",
+ "gender": "male",
+ "version": "2.0",
+ "resource_id": "seed-tts-2.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "魅力女友",
+ "voice_type": "ICL_zh_female_tiaopigongzhu_tob",
+ "gender": "female",
+ "version": "2.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "黑猫侦探社咪仔",
+ "voice_type": "zh_female_mizai_saturn_bigtts",
+ "gender": "female",
+ "version": "2.0",
+ "resource_id": "seed-tts-2.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "爽朗少年",
+ "voice_type": "ICL_zh_male_shuanglangshaonian_tob",
+ "gender": "male",
+ "version": "2.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "鸡汤女",
+ "voice_type": "zh_female_jitangnv_saturn_bigtts",
+ "gender": "female",
+ "version": "2.0",
+ "resource_id": "seed-tts-2.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "天才同桌",
+ "voice_type": "ICL_zh_male_tiancaitongzhuo_tob",
+ "gender": "male",
+ "version": "2.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "魅力女友",
+ "voice_type": "zh_female_meilinvyou_saturn_bigtts",
+ "gender": "female",
+ "version": "2.0",
+ "resource_id": "seed-tts-2.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "流畅女声",
+ "voice_type": "zh_female_santongyongns_saturn_bigtts",
+ "gender": "female",
+ "version": "2.0",
+ "resource_id": "seed-tts-2.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "儒雅逸辰",
+ "voice_type": "zh_male_ruyayichen_saturn_bigtts",
+ "gender": "male",
+ "version": "2.0",
+ "resource_id": "seed-tts-2.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "知性灿灿",
+ "voice_type": "saturn_zh_female_cancan_tob",
+ "gender": "female",
+ "version": "2.0",
+ "resource_id": "seed-tts-2.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔女神",
+ "voice_type": "ICL_zh_female_wenrounvshen_239eff5e8ffa_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "沪普男",
+ "voice_type": "zh_male_hupunan_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "纯真少女",
+ "voice_type": "ICL_zh_female_chunzhenshaonv_e588402fb8ad_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷酷哥哥",
+ "voice_type": "zh_male_lengkugege_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "sad",
+ "angry",
+ "fear",
+ "hate",
+ "coldness",
+ "neutral",
+ "depressed"
+ ]
+ },
+ {
+ "name": "理性圆子",
+ "voice_type": "ICL_zh_female_lixingyuanzi_cs",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Tina老师",
+ "voice_type": "zh_female_yingyujiaoyu_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "教育场景",
+ "language": [
+ "chinese",
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Lauren",
+ "voice_type": "en_female_lauren_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "内敛才俊",
+ "voice_type": "ICL_zh_male_neiliancaijun_e991be511569_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "悠悠君子",
+ "voice_type": "zh_male_M100_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Vivi",
+ "voice_type": "zh_female_vv_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "粤语小溏",
+ "voice_type": "zh_female_yueyunv_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "奶气小生",
+ "voice_type": "ICL_zh_male_xiaonaigou_edf58cf28b8b_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "甜心小美",
+ "voice_type": "zh_female_tianxinxiaomei_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "sad",
+ "fear",
+ "hate",
+ "neutral"
+ ]
+ },
+ {
+ "name": "清甜桃桃",
+ "voice_type": "ICL_zh_female_qingtiantaotao_cs",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Energetic Male II",
+ "voice_type": "en_male_campaign_jamal_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温暖少年",
+ "voice_type": "ICL_zh_male_yangyang_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "文静毛毛",
+ "voice_type": "zh_female_maomao_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "亲切女声",
+ "voice_type": "zh_female_qinqienvsheng_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "鲁班七号",
+ "voice_type": "zh_male_lubanqihao_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "精灵向导",
+ "voice_type": "ICL_zh_female_jinglingxiangdao_1beb294a9e3e_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "高冷御姐",
+ "voice_type": "zh_female_gaolengyujie_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "sad",
+ "angry",
+ "surprised",
+ "fear",
+ "hate",
+ "excited",
+ "coldness",
+ "neutral"
+ ]
+ },
+ {
+ "name": "清晰小雪",
+ "voice_type": "ICL_zh_female_qingxixiaoxue_cs",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Gotham Hero",
+ "voice_type": "en_male_chris_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "儒雅公子",
+ "voice_type": "ICL_zh_male_flc_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "倾心少女",
+ "voice_type": "ICL_zh_female_qiuling_v1_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "机灵小伙",
+ "voice_type": "ICL_zh_male_shenmi_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "林潇",
+ "voice_type": "zh_female_yangmi_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "闷油瓶小哥",
+ "voice_type": "ICL_zh_male_menyoupingxiaoge_ffed9fc2fee7_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲娇霸总",
+ "voice_type": "zh_male_aojiaobazong_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "angry",
+ "hate",
+ "neutral"
+ ]
+ },
+ {
+ "name": "清甜莓莓",
+ "voice_type": "ICL_zh_female_qingtianmeimei_cs",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Flirty Female",
+ "voice_type": "en_female_product_darcie_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "悬疑解说",
+ "voice_type": "zh_male_changtianyi_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "醇厚低音",
+ "voice_type": "ICL_zh_male_buyan_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "元气甜妹",
+ "voice_type": "ICL_zh_female_wuxi_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "玲玲姐姐",
+ "voice_type": "zh_female_linzhiling_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "黯刃秦主",
+ "voice_type": "ICL_zh_male_anrenqinzhu_cd62e63dcdab_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "广州德哥",
+ "voice_type": "zh_male_guangzhoudege_emo_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "angry",
+ "fear",
+ "neutral"
+ ]
+ },
+ {
+ "name": "开朗婷婷",
+ "voice_type": "ICL_zh_female_kailangtingting_cs",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Peaceful Female",
+ "voice_type": "en_female_emotional_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "儒雅青年",
+ "voice_type": "zh_male_ruyaqingnian_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "咆哮小哥",
+ "voice_type": "ICL_zh_male_BV144_paoxiaoge_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "知心姐姐",
+ "voice_type": "ICL_zh_female_wenyinvsheng_v1_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "春日部姐姐",
+ "voice_type": "zh_female_jiyejizi2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "霸道总裁",
+ "voice_type": "ICL_zh_male_badaozongcai_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "京腔侃爷",
+ "voice_type": "zh_male_jingqiangkanye_emo_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "angry",
+ "surprised",
+ "hate",
+ "neutral"
+ ]
+ },
+ {
+ "name": "清新沐沐",
+ "voice_type": "ICL_zh_male_qingxinmumu_cs",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Nara",
+ "voice_type": "en_female_nara_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "霸气青叔",
+ "voice_type": "zh_male_baqiqingshu_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "和蔼奶奶",
+ "voice_type": "ICL_zh_female_heainainai_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "阳光阿辰",
+ "voice_type": "zh_male_qingyiyuxuan_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "唐僧",
+ "voice_type": "zh_male_tangseng_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "妩媚可人",
+ "voice_type": "ICL_zh_female_ganli_v1_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "邻居阿姨",
+ "voice_type": "zh_female_linjuayi_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "angry",
+ "surprised",
+ "coldness",
+ "neutral",
+ "depressed"
+ ]
+ },
+ {
+ "name": "爽朗小阳",
+ "voice_type": "ICL_zh_male_shuanglangxiaoyang_cs",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Bruce",
+ "voice_type": "en_male_bruce_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "擎苍",
+ "voice_type": "zh_male_qingcang_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "邻居阿姨",
+ "voice_type": "ICL_zh_female_linjuayi_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "快乐小东",
+ "voice_type": "zh_male_xudong_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "庄周",
+ "voice_type": "zh_male_zhuangzhou_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "邪魅御姐",
+ "voice_type": "ICL_zh_female_xiangliangya_v1_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "优柔公子",
+ "voice_type": "zh_male_yourougongzi_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "angry",
+ "fear",
+ "hate",
+ "excited",
+ "neutral",
+ "depressed"
+ ]
+ },
+ {
+ "name": "清新波波",
+ "voice_type": "ICL_zh_male_qingxinbobo_cs",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Michael",
+ "voice_type": "en_male_michael_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "活力小哥",
+ "voice_type": "zh_male_yangguangqingnian_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔小雅",
+ "voice_type": "zh_female_wenrouxiaoya_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷酷哥哥",
+ "voice_type": "ICL_zh_male_lengkugege_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "猪八戒",
+ "voice_type": "zh_male_zhubajie_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "嚣张小哥",
+ "voice_type": "ICL_zh_male_ms_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "儒雅男友",
+ "voice_type": "zh_male_ruyayichen_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "sad",
+ "angry",
+ "fear",
+ "excited",
+ "coldness",
+ "neutral"
+ ]
+ },
+ {
+ "name": "温婉珊珊",
+ "voice_type": "ICL_zh_female_wenwanshanshan_cs",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Cartoon Chef",
+ "voice_type": "ICL_en_male_cc_sha_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "古风少御",
+ "voice_type": "zh_female_gufengshaoyu_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "天才童声",
+ "voice_type": "zh_male_tiancaitongsheng_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "纯澈女生",
+ "voice_type": "ICL_zh_female_feicui_v1_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "感冒电音姐姐",
+ "voice_type": "zh_female_ganmaodianyin_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "油腻大叔",
+ "voice_type": "ICL_zh_male_you_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "俊朗男友",
+ "voice_type": "zh_male_junlangnanyou_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "sad",
+ "angry",
+ "surprised",
+ "fear",
+ "neutral"
+ ]
+ },
+ {
+ "name": "甜美小雨",
+ "voice_type": "ICL_zh_female_tianmeixiaoyu_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Lucas",
+ "voice_type": "zh_male_M100_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔淑女",
+ "voice_type": "zh_female_wenroushunv_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "猴哥",
+ "voice_type": "zh_male_sunwukong_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "初恋女友",
+ "voice_type": "ICL_zh_female_yuxin_v1_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "直率英子",
+ "voice_type": "zh_female_naying_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "孤傲公子",
+ "voice_type": "ICL_zh_male_guaogongzi_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "北京小爷",
+ "voice_type": "zh_male_beijingxiaoye_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "angry",
+ "surprised",
+ "fear",
+ "excited",
+ "coldness",
+ "neutral"
+ ]
+ },
+ {
+ "name": "热情艾娜",
+ "voice_type": "ICL_zh_female_reqingaina_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Sophie",
+ "voice_type": "zh_female_sophie_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "反卷青年",
+ "voice_type": "zh_male_fanjuanqingnian_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "audiobook",
+ "scene": "有声阅读",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "熊二",
+ "voice_type": "zh_male_xionger_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "贴心闺蜜",
+ "voice_type": "ICL_zh_female_xnx_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "女雷神",
+ "voice_type": "zh_female_leidian_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "胡子叔叔",
+ "voice_type": "ICL_zh_male_huzi_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "柔美女友",
+ "voice_type": "zh_female_roumeinvyou_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "sad",
+ "angry",
+ "surprised",
+ "fear",
+ "hate",
+ "excited",
+ "coldness",
+ "neutral"
+ ]
+ },
+ {
+ "name": "甜美小橘",
+ "voice_type": "ICL_zh_female_tianmeixiaoju_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Daisy",
+ "voice_type": "en_female_dacey_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "佩奇猪",
+ "voice_type": "zh_female_peiqi_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔白月光",
+ "voice_type": "ICL_zh_female_yry_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "豫州子轩",
+ "voice_type": "zh_male_yuzhouzixuan_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "性感魅惑",
+ "voice_type": "ICL_zh_female_luoqing_v1_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "阳光青年",
+ "voice_type": "zh_male_yangguangqingnian_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "happy",
+ "sad",
+ "angry",
+ "fear",
+ "excited",
+ "coldness",
+ "neutral"
+ ]
+ },
+ {
+ "name": "沉稳明仔",
+ "voice_type": "ICL_zh_male_chenwenmingzai_cs_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Owen",
+ "voice_type": "en_male_charlie_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "武则天",
+ "voice_type": "zh_female_wuzetian_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "开朗学长",
+ "voice_type": "en_male_jason_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "呆萌川妹",
+ "voice_type": "zh_female_daimengchuanmei_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病弱公子",
+ "voice_type": "ICL_zh_male_bingruogongzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "魅力女友",
+ "voice_type": "zh_female_meilinvyou_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese"
+ ],
+ "emotions": [
+ "sad",
+ "fear",
+ "neutral"
+ ]
+ },
+ {
+ "name": "亲切小卓",
+ "voice_type": "ICL_zh_male_qinqiexiaozhuo_cs_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Luna",
+ "voice_type": "en_female_sarah_new_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "顾姐",
+ "voice_type": "zh_female_gujie_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "魅力苏菲",
+ "voice_type": "zh_female_sophie_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "广西远舟",
+ "voice_type": "zh_male_guangxiyuanzhou_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "邪魅女王",
+ "voice_type": "ICL_zh_female_bingjiao3_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "爽快思思",
+ "voice_type": "zh_female_shuangkuaisisi_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "chinese",
+ "en_gb"
+ ],
+ "emotions": [
+ "happy",
+ "sad",
+ "angry",
+ "surprised",
+ "excited",
+ "coldness",
+ "neutral"
+ ]
+ },
+ {
+ "name": "灵动欣欣",
+ "voice_type": "ICL_zh_female_lingdongxinxin_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Michael",
+ "voice_type": "ICL_en_male_michael_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "樱桃丸子",
+ "voice_type": "zh_female_yingtaowanzi_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "贴心妹妹",
+ "voice_type": "ICL_zh_female_yilin_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "双节棍小哥",
+ "voice_type": "zh_male_zhoujielun_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲慢青年",
+ "voice_type": "ICL_zh_male_aomanqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Candice",
+ "voice_type": "en_female_candice_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "en_us"
+ ],
+ "emotions": [
+ "angry",
+ "neutral",
+ "ASMR",
+ "happy",
+ "chat",
+ "chat",
+ "warm",
+ "affectionate"
+ ]
+ },
+ {
+ "name": "乖巧可儿",
+ "voice_type": "ICL_zh_female_guaiqiaokeer_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Charlie",
+ "voice_type": "ICL_en_female_cc_cm_v1_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "广告解说",
+ "voice_type": "zh_male_chunhui_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "甜美桃子",
+ "voice_type": "zh_female_tianmeitaozi_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "湾湾小何",
+ "voice_type": "zh_female_wanwanxiaohe_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "醋精男生",
+ "voice_type": "ICL_zh_male_cujingnansheng_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Serena",
+ "voice_type": "en_female_skye_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "en_us"
+ ],
+ "emotions": [
+ "sad",
+ "angry",
+ "neutral",
+ "ASMR",
+ "happy",
+ "chat",
+ "chat",
+ "warm",
+ "affectionate"
+ ]
+ },
+ {
+ "name": "暖心茜茜",
+ "voice_type": "ICL_zh_female_nuanxinqianqian_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Big Boogie",
+ "voice_type": "ICL_en_male_oogie2_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "少儿故事",
+ "voice_type": "zh_female_shaoergushi_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "清新女声",
+ "voice_type": "zh_female_qingxinnvsheng_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "湾区大叔",
+ "voice_type": "zh_female_wanqudashu_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "撒娇男友",
+ "voice_type": "ICL_zh_male_sajiaonanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Glen",
+ "voice_type": "en_male_glen_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "en_us"
+ ],
+ "emotions": [
+ "sad",
+ "angry",
+ "neutral",
+ "ASMR",
+ "happy",
+ "chat",
+ "chat",
+ "warm",
+ "affectionate"
+ ]
+ },
+ {
+ "name": "软萌团子",
+ "voice_type": "ICL_zh_female_ruanmengtuanzi_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Frosty Man",
+ "voice_type": "ICL_en_male_frosty1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "四郎",
+ "voice_type": "zh_male_silang_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "知性女声",
+ "voice_type": "zh_female_zhixingnvsheng_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "广州德哥",
+ "voice_type": "zh_male_guozhoudege_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔男友",
+ "voice_type": "ICL_zh_male_wenrounanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Sylus",
+ "voice_type": "en_male_sylus_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "en_us"
+ ],
+ "emotions": [
+ "sad",
+ "angry",
+ "neutral",
+ "ASMR",
+ "happy",
+ "chat",
+ "chat",
+ "warm",
+ "affectionate",
+ "authoritative"
+ ]
+ },
+ {
+ "name": "阳光洋洋",
+ "voice_type": "ICL_zh_male_yangguangyangyang_cs_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "The Grinch",
+ "voice_type": "ICL_en_male_grinch2_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "俏皮女声",
+ "voice_type": "zh_female_qiaopinvsheng_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "清爽男大",
+ "voice_type": "zh_male_qingshuangnanda_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "浩宇小哥",
+ "voice_type": "zh_male_haoyuxiaoge_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温顺少年",
+ "voice_type": "ICL_zh_male_wenshunshaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Corey",
+ "voice_type": "en_male_corey_emo_v2_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": [
+ "sad",
+ "angry",
+ "neutral",
+ "ASMR",
+ "happy",
+ "chat",
+ "chat",
+ "warm",
+ "affectionate",
+ "authoritative"
+ ]
+ },
+ {
+ "name": "软萌糖糖",
+ "voice_type": "ICL_zh_female_ruanmengtangtang_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Zayne",
+ "voice_type": "ICL_en_male_zayne_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "懒音绵宝",
+ "voice_type": "zh_male_lanxiaoyang_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "邻家女孩",
+ "voice_type": "zh_female_linjianvhai_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "北京小爷",
+ "voice_type": "zh_male_beijingxiaoye_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "粘人男友",
+ "voice_type": "ICL_zh_male_naigounanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Nadia",
+ "voice_type": "en_female_nadia_tips_emo_v2_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_emotion",
+ "scene": "多情感",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": [
+ "sad",
+ "angry",
+ "neutral",
+ "ASMR",
+ "happy",
+ "chat",
+ "chat",
+ "warm",
+ "affectionate"
+ ]
+ },
+ {
+ "name": "秀丽倩倩",
+ "voice_type": "ICL_zh_female_xiuliqianqian_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Jigsaw",
+ "voice_type": "ICL_en_male_cc_jigsaw_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "亮嗓萌仔",
+ "voice_type": "zh_male_dongmanhaimian_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "渊博小叔",
+ "voice_type": "zh_male_yuanboxiaoshu_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "京腔侃爷/Harmony",
+ "voice_type": "zh_male_jingqiangkanye_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "撒娇男生",
+ "voice_type": "ICL_zh_male_sajiaonansheng_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "开心小鸿",
+ "voice_type": "ICL_zh_female_kaixinxiaohong_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Chucky",
+ "voice_type": "ICL_en_male_cc_chucky_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "磁性解说男声/Morgan",
+ "voice_type": "zh_male_jieshuonansheng_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "阳光青年",
+ "voice_type": "zh_male_yangguangqingnian_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "妹坨洁儿",
+ "voice_type": "zh_female_meituojieer_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "accent",
+ "scene": "趣味口音",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "活泼男友",
+ "voice_type": "ICL_zh_male_huoponanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "轻盈朵朵",
+ "voice_type": "ICL_zh_female_qingyingduoduo_cs_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Clown Man",
+ "voice_type": "ICL_en_male_cc_penny_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "鸡汤妹妹/Hope",
+ "voice_type": "zh_female_jitangmeimei_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "甜美小源",
+ "voice_type": "zh_female_tianmeixiaoyuan_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "甜系男友",
+ "voice_type": "ICL_zh_male_tianxinanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "暖阳女声",
+ "voice_type": "zh_female_kefunvsheng_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "customer_service",
+ "scene": "客服场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Kevin McCallister",
+ "voice_type": "ICL_en_male_kevin2_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "贴心女声/Candy",
+ "voice_type": "zh_female_tiexinnvsheng_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "清澈梓梓",
+ "voice_type": "zh_female_qingchezizi_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "活力青年",
+ "voice_type": "ICL_zh_male_huoliqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Xavier",
+ "voice_type": "ICL_en_male_xavier1_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "萌丫头/Cutey",
+ "voice_type": "zh_female_mengyatou_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "video_dubbing",
+ "scene": "视频配音",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "解说小明",
+ "voice_type": "zh_male_jieshuoxiaoming_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "开朗青年",
+ "voice_type": "ICL_zh_male_kailangqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Noah",
+ "voice_type": "ICL_en_male_cc_dracula_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "开朗姐姐",
+ "voice_type": "zh_female_kailangjiejie_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷漠兄长",
+ "voice_type": "ICL_zh_male_lengmoxiongzhang_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Adam",
+ "voice_type": "en_male_adam_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "邻家男孩",
+ "voice_type": "zh_male_linjiananhai_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "翩翩公子",
+ "voice_type": "ICL_zh_male_pianpiangongzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Amanda",
+ "voice_type": "en_female_amanda_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "甜美悦悦",
+ "voice_type": "zh_female_tianmeiyueyue_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "懵懂青年",
+ "voice_type": "ICL_zh_male_mengdongqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Jackson",
+ "voice_type": "en_male_jackson_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "心灵鸡汤",
+ "voice_type": "zh_female_xinlingjitang_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷脸兄长",
+ "voice_type": "ICL_zh_male_lenglianxiongzhang_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Delicate Girl",
+ "voice_type": "en_female_daisy_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "知性温婉",
+ "voice_type": "ICL_zh_female_zhixingwenwan_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病娇少年",
+ "voice_type": "ICL_zh_male_bingjiaoshaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Dave",
+ "voice_type": "en_male_dave_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "暖心体贴",
+ "voice_type": "ICL_zh_male_nuanxintitie_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病娇男友",
+ "voice_type": "ICL_zh_male_bingjiaonanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Hades",
+ "voice_type": "en_male_hades_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "开朗轻快",
+ "voice_type": "ICL_zh_male_kailangqingkuai_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病弱少年",
+ "voice_type": "ICL_zh_male_bingruoshaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Onez",
+ "voice_type": "en_female_onez_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "活泼爽朗",
+ "voice_type": "ICL_zh_male_huoposhuanglang_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "意气少年",
+ "voice_type": "ICL_zh_male_yiqishaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Emily",
+ "voice_type": "en_female_emily_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "率真小伙",
+ "voice_type": "ICL_zh_male_shuaizhenxiaohuo_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "干净少年",
+ "voice_type": "ICL_zh_male_ganjingshaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Daniel",
+ "voice_type": "zh_male_xudong_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔小哥",
+ "voice_type": "zh_male_wenrouxiaoge_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷漠男友",
+ "voice_type": "ICL_zh_male_lengmonanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Alastor",
+ "voice_type": "ICL_en_male_cc_alastor_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "灿灿/Shiny",
+ "voice_type": "zh_female_cancan_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "精英青年",
+ "voice_type": "ICL_zh_male_jingyingqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Smith",
+ "voice_type": "en_male_smith_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "爽快思思/Skye",
+ "voice_type": "zh_female_shuangkuaisisi_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "热血少年",
+ "voice_type": "ICL_zh_male_rexueshaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Anna",
+ "voice_type": "en_female_anna_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_gb"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温暖阿虎/Alvin",
+ "voice_type": "zh_male_wennuanahu_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "清爽少年",
+ "voice_type": "ICL_zh_male_qingshuangshaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Ethan",
+ "voice_type": "ICL_en_male_aussie_v1_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_au"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "少年梓辛/Brayan",
+ "voice_type": "zh_male_shaonianzixin_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese",
+ "en_us"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "中二青年",
+ "voice_type": "ICL_zh_male_zhongerqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Sarah",
+ "voice_type": "en_female_sarah_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_au"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔文雅",
+ "voice_type": "ICL_zh_female_wenrouwenya_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "general",
+ "scene": "通用场景",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "凌云青年",
+ "voice_type": "ICL_zh_male_lingyunqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Dryw",
+ "voice_type": "en_male_dryw_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "en_au"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "自负青年",
+ "voice_type": "ICL_zh_male_zifuqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Diana",
+ "voice_type": "multi_female_maomao_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "es"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "不羁青年",
+ "voice_type": "ICL_zh_male_bujiqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Lucía",
+ "voice_type": "multi_male_M100_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "es"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "儒雅君子",
+ "voice_type": "ICL_zh_male_ruyajunzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Sofía",
+ "voice_type": "multi_female_sophie_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "es"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "低音沉郁",
+ "voice_type": "ICL_zh_male_diyinchenyu_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "Daníel",
+ "voice_type": "multi_male_xudong_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "es"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷脸学霸",
+ "voice_type": "ICL_zh_male_lenglianxueba_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "ひかる(光)",
+ "voice_type": "multi_zh_male_youyoujunzi_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "ja"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "儒雅总裁",
+ "voice_type": "ICL_zh_male_ruyazongcai_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "さとみ(智美)",
+ "voice_type": "multi_female_sophie_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "ja"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "深沉总裁",
+ "voice_type": "ICL_zh_male_shenchenzongcai_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "まさお(正男)",
+ "voice_type": "multi_male_xudong_conversation_wvae_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "ja"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "小侯爷",
+ "voice_type": "ICL_zh_male_xiaohouye_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "つき(月)",
+ "voice_type": "multi_female_maomao_conversation_wvae_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "ja"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "孤高公子",
+ "voice_type": "ICL_zh_male_gugaogongzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "あけみ(朱美)",
+ "voice_type": "multi_female_gaolengyujie_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "ja"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "仗剑君子",
+ "voice_type": "ICL_zh_male_zhangjianjunzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "かずね(和音)/Javier or Álvaro",
+ "voice_type": "multi_male_jingqiangkanye_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "ja",
+ "es"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温润学者",
+ "voice_type": "ICL_zh_male_wenrunxuezhe_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "はるこ(晴子)/Esmeralda",
+ "voice_type": "multi_female_shuangkuaisisi_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "ja",
+ "es"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "亲切青年",
+ "voice_type": "ICL_zh_male_qinqieqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "ひろし(広志)/Roberto",
+ "voice_type": "multi_male_wanqudashu_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "multi_language",
+ "scene": "多语种",
+ "language": [
+ "ja",
+ "es"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔学长",
+ "voice_type": "ICL_zh_male_wenrouxuezhang_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "高冷总裁",
+ "voice_type": "ICL_zh_male_gaolengzongcai_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷峻高智",
+ "voice_type": "ICL_zh_male_lengjungaozhi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "孱弱少爷",
+ "voice_type": "ICL_zh_male_chanruoshaoye_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "自信青年",
+ "voice_type": "ICL_zh_male_zixinqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "青涩青年",
+ "voice_type": "ICL_zh_male_qingseqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "学霸同桌",
+ "voice_type": "ICL_zh_male_xuebatongzhuo_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷傲总裁",
+ "voice_type": "ICL_zh_male_lengaozongcai_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "元气少年",
+ "voice_type": "ICL_zh_male_yuanqishaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "洒脱青年",
+ "voice_type": "ICL_zh_male_satuoqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "直率青年",
+ "voice_type": "ICL_zh_male_zhishuaiqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "斯文青年",
+ "voice_type": "ICL_zh_male_siwenqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "俊逸公子",
+ "voice_type": "ICL_zh_male_junyigongzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "仗剑侠客",
+ "voice_type": "ICL_zh_male_zhangjianxiake_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "机甲智能",
+ "voice_type": "ICL_zh_male_jijiaozhineng_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "奶气萌娃",
+ "voice_type": "zh_male_naiqimengwa_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "婆婆",
+ "voice_type": "zh_female_popo_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "高冷御姐",
+ "voice_type": "zh_female_gaolengyujie_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲娇霸总",
+ "voice_type": "zh_male_aojiaobazong_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "魅力女友",
+ "voice_type": "zh_female_meilinvyou_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "深夜播客",
+ "voice_type": "zh_male_shenyeboke_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "柔美女友",
+ "voice_type": "zh_female_sajiaonvyou_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "撒娇学妹",
+ "voice_type": "zh_female_yuanqinvyou_moon_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病弱少女",
+ "voice_type": "ICL_zh_female_bingruoshaonv_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "活泼女孩",
+ "voice_type": "ICL_zh_female_huoponvhai_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "东方浩然",
+ "voice_type": "zh_male_dongfanghaoran_moon_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "绿茶小哥",
+ "voice_type": "ICL_zh_male_lvchaxiaoge_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "娇弱萝莉",
+ "voice_type": "ICL_zh_female_jiaoruoluoli_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷淡疏离",
+ "voice_type": "ICL_zh_male_lengdanshuli_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "憨厚敦实",
+ "voice_type": "ICL_zh_male_hanhoudunshi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "活泼刁蛮",
+ "voice_type": "ICL_zh_female_huopodiaoman_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "固执病娇",
+ "voice_type": "ICL_zh_male_guzhibingjiao_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "撒娇粘人",
+ "voice_type": "ICL_zh_male_sajiaonianren_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲慢娇声",
+ "voice_type": "ICL_zh_female_aomanjiaosheng_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "潇洒随性",
+ "voice_type": "ICL_zh_male_xiaosasuixing_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "诡异神秘",
+ "voice_type": "ICL_zh_male_guiyishenmi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "儒雅才俊",
+ "voice_type": "ICL_zh_male_ruyacaijun_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "正直青年",
+ "voice_type": "ICL_zh_male_zhengzhiqingnian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "娇憨女王",
+ "voice_type": "ICL_zh_female_jiaohannvwang_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病娇萌妹",
+ "voice_type": "ICL_zh_female_bingjiaomengmei_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "青涩小生",
+ "voice_type": "ICL_zh_male_qingsenaigou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "纯真学弟",
+ "voice_type": "ICL_zh_male_chunzhenxuedi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "优柔帮主",
+ "voice_type": "ICL_zh_male_youroubangzhu_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "优柔公子",
+ "voice_type": "ICL_zh_male_yourougongzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "贴心男友",
+ "voice_type": "ICL_zh_male_tiexinnanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "少年将军",
+ "voice_type": "ICL_zh_male_shaonianjiangjun_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病娇哥哥",
+ "voice_type": "ICL_zh_male_bingjiaogege_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "学霸男同桌",
+ "voice_type": "ICL_zh_male_xuebanantongzhuo_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "幽默叔叔",
+ "voice_type": "ICL_zh_male_youmoshushu_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "假小子",
+ "voice_type": "ICL_zh_female_jiaxiaozi_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "温柔男同桌",
+ "voice_type": "ICL_zh_male_wenrounantongzhuo_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "幽默大爷",
+ "voice_type": "ICL_zh_male_youmodaye_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "枕边低语",
+ "voice_type": "ICL_zh_male_asmryexiu_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "神秘法师",
+ "voice_type": "ICL_zh_male_shenmifashi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "娇喘女声",
+ "voice_type": "zh_female_jiaochuan_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "开朗弟弟",
+ "voice_type": "zh_male_livelybro_mars_bigtts",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "谄媚女声",
+ "voice_type": "zh_female_flattery_mars_bigtts",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "冷峻上司",
+ "voice_type": "ICL_zh_male_lengjunshangsi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "醋精男友",
+ "voice_type": "ICL_zh_male_cujingnanyou_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "风发少年",
+ "voice_type": "ICL_zh_male_fengfashaonian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "磁性男嗓",
+ "voice_type": "ICL_zh_male_cixingnansang_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "成熟总裁",
+ "voice_type": "ICL_zh_male_chengshuzongcai_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲娇精英",
+ "voice_type": "ICL_zh_male_aojiaojingying_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲娇公子",
+ "voice_type": "ICL_zh_male_aojiaogongzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "霸道少爷",
+ "voice_type": "ICL_zh_male_badaoshaoye_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "腹黑公子",
+ "voice_type": "ICL_zh_male_fuheigongzi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "暖心学姐",
+ "voice_type": "ICL_zh_female_nuanxinxuejie_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "成熟姐姐",
+ "voice_type": "ICL_zh_female_chengshujiejie_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病娇姐姐",
+ "voice_type": "ICL_zh_female_bingjiaojiejie_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "妩媚御姐",
+ "voice_type": "ICL_zh_female_wumeiyujie_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲娇女友",
+ "voice_type": "ICL_zh_female_aojiaonvyou_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "贴心女友",
+ "voice_type": "ICL_zh_female_tiexinnvyou_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "性感御姐",
+ "voice_type": "ICL_zh_female_xingganyujie_tob",
+ "gender": "female",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病娇弟弟",
+ "voice_type": "ICL_zh_male_bingjiaodidi_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲慢少爷",
+ "voice_type": "ICL_zh_male_aomanshaoye_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "傲气凌人",
+ "voice_type": "ICL_zh_male_aiqilingren_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ },
+ {
+ "name": "病娇白莲",
+ "voice_type": "ICL_zh_male_bingjiaobailian_tob",
+ "gender": "male",
+ "version": "1.0",
+ "resource_id": "seed-tts-1.0",
+ "voice_category": "roleplay",
+ "scene": "角色扮演",
+ "language": [
+ "chinese"
+ ],
+ "emotions": []
+ }
+ ],
+ "languages": [
+ {
+ "name": "chinese",
+ "zh": "中文",
+ "en": "Chinese"
+ },
+ {
+ "name": "en_au",
+ "zh": "澳洲英语",
+ "en": "Australian English"
+ },
+ {
+ "name": "en_gb",
+ "zh": "英式英语",
+ "en": "British English"
+ },
+ {
+ "name": "en_us",
+ "zh": "美式英语",
+ "en": "American English"
+ },
+ {
+ "name": "es",
+ "zh": "西语",
+ "en": "Spanish"
+ },
+ {
+ "name": "ja",
+ "zh": "日语",
+ "en": "Japanese"
+ }
+ ],
+ "emotions": [
+ {
+ "name": "ASMR",
+ "zh": "低语",
+ "en": "ASMR"
+ },
+ {
+ "name": "affectionate",
+ "zh": "深情",
+ "en": "affectionate"
+ },
+ {
+ "name": "angry",
+ "zh": "生气",
+ "en": "angry"
+ },
+ {
+ "name": "authoritative",
+ "zh": "权威",
+ "en": "authoritative"
+ },
+ {
+ "name": "chat",
+ "zh": "对话",
+ "en": "chat"
+ },
+ {
+ "name": "coldness",
+ "zh": "冷漠",
+ "en": "coldness"
+ },
+ {
+ "name": "depressed",
+ "zh": "沮丧",
+ "en": "depressed"
+ },
+ {
+ "name": "excited",
+ "zh": "激动",
+ "en": "excited"
+ },
+ {
+ "name": "fear",
+ "zh": "恐惧",
+ "en": "fear"
+ },
+ {
+ "name": "happy",
+ "zh": "开心",
+ "en": "happy"
+ },
+ {
+ "name": "hate",
+ "zh": "厌恶",
+ "en": "hate"
+ },
+ {
+ "name": "neutral",
+ "zh": "中性",
+ "en": "neutral"
+ },
+ {
+ "name": "sad",
+ "zh": "悲伤",
+ "en": "sad"
+ },
+ {
+ "name": "surprised",
+ "zh": "惊讶",
+ "en": "surprised"
+ },
+ {
+ "name": "warm",
+ "zh": "温暖",
+ "en": "warm"
+ }
+ ]
+}
diff --git a/configs/wan/wan_flf2v.json b/configs/wan/wan_flf2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..998f38f05a73f6ce0bd74116764c6ea317f5391d
--- /dev/null
+++ b/configs/wan/wan_flf2v.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": 5,
+ "sample_shift": 16,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/wan/wan_i2v.json b/configs/wan/wan_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..6c2107083c6b94f2dcd36e9eec30664292e3ba9a
--- /dev/null
+++ b/configs/wan/wan_i2v.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 3,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/wan/wan_t2v.json b/configs/wan/wan_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..2e2825047a1d4ad72ae661befd729f26162dbabe
--- /dev/null
+++ b/configs/wan/wan_t2v.json
@@ -0,0 +1,14 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/wan/wan_t2v_enhancer.json b/configs/wan/wan_t2v_enhancer.json
new file mode 100644
index 0000000000000000000000000000000000000000..b27afdfbf5f645eabc9c9e179358b02c5cacf8e9
--- /dev/null
+++ b/configs/wan/wan_t2v_enhancer.json
@@ -0,0 +1,19 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 6,
+ "sample_shift": 8,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "sub_servers": {
+ "prompt_enhancer": [
+ "http://localhost:9001"
+ ]
+ }
+}
diff --git a/configs/wan/wan_vace.json b/configs/wan/wan_vace.json
new file mode 100644
index 0000000000000000000000000000000000000000..23fbff233a7243bb032e46d3013c6dce5be5b0b8
--- /dev/null
+++ b/configs/wan/wan_vace.json
@@ -0,0 +1,13 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5,
+ "sample_shift": 16,
+ "enable_cfg": true,
+ "cpu_offload": false
+}
diff --git a/configs/wan22/wan_animate.json b/configs/wan22/wan_animate.json
new file mode 100644
index 0000000000000000000000000000000000000000..c4c3d59b32eb406c2c30fab6474fd23d74363b49
--- /dev/null
+++ b/configs/wan22/wan_animate.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 20,
+ "target_video_length": 77,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "adapter_attn_type": "flash_attn3",
+ "sample_shift": 5.0,
+ "sample_guide_scale": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "refert_num": 1,
+ "replace_flag": false,
+ "fps": 30
+}
diff --git a/configs/wan22/wan_animate_4090.json b/configs/wan22/wan_animate_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..e0f89657ddea8ca989d2e33d51076cc6e2e3af95
--- /dev/null
+++ b/configs/wan22/wan_animate_4090.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 20,
+ "target_video_length": 77,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "adapter_attn_type": "sage_attn2",
+ "sample_shift": 5.0,
+ "sample_guide_scale": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "phase",
+ "refert_num": 1,
+ "replace_flag": false,
+ "fps": 30,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8"
+}
diff --git a/configs/wan22/wan_animate_lora.json b/configs/wan22/wan_animate_lora.json
new file mode 100644
index 0000000000000000000000000000000000000000..0e1c8059cf0b716a7778d9665d8409574331a603
--- /dev/null
+++ b/configs/wan22/wan_animate_lora.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 77,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "adapter_attn_type": "flash_attn3",
+ "sample_shift": 5.0,
+ "sample_guide_scale": 1.0,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "refert_num": 1,
+ "replace_flag": false,
+ "fps": 30,
+ "lora_configs": [
+ {
+ "path": "lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/wan22/wan_animate_replace.json b/configs/wan22/wan_animate_replace.json
new file mode 100644
index 0000000000000000000000000000000000000000..c4805a471386291791fdca873c0fc34d9df257d4
--- /dev/null
+++ b/configs/wan22/wan_animate_replace.json
@@ -0,0 +1,18 @@
+{
+ "infer_steps": 20,
+ "target_video_length": 77,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "adapter_attn_type": "flash_attn3",
+ "sample_shift": 5.0,
+ "sample_guide_scale": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": false,
+ "refert_num": 1,
+ "fps": 30,
+ "replace_flag": true
+}
diff --git a/configs/wan22/wan_animate_replace_4090.json b/configs/wan22/wan_animate_replace_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..fab2908b784cfc657d35853ea16e65a7706727cd
--- /dev/null
+++ b/configs/wan22/wan_animate_replace_4090.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 20,
+ "target_video_length": 77,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "adapter_attn_type": "sage_attn2",
+ "sample_shift": 5.0,
+ "sample_guide_scale": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "phase",
+ "refert_num": 1,
+ "fps": 30,
+ "replace_flag": true,
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f",
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-q8f"
+}
diff --git a/configs/wan22/wan_distill_moe_flf2v.json b/configs/wan22/wan_distill_moe_flf2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..d3f90018fa5a2ada4174f18972be495ed4b9d39f
--- /dev/null
+++ b/configs/wan22/wan_distill_moe_flf2v.json
@@ -0,0 +1,28 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 16,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ]
+}
diff --git a/configs/wan22/wan_distill_moe_flf2v_fp8.json b/configs/wan22/wan_distill_moe_flf2v_fp8.json
new file mode 100644
index 0000000000000000000000000000000000000000..2e982f38b0bb5299a4c84b5b708be625f79fae26
--- /dev/null
+++ b/configs/wan22/wan_distill_moe_flf2v_fp8.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 16,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8"
+}
diff --git a/configs/wan22/wan_distill_moe_flf2v_int8.json b/configs/wan22/wan_distill_moe_flf2v_int8.json
new file mode 100644
index 0000000000000000000000000000000000000000..c3489b04395f9884582810be50778d3fa55f5d80
--- /dev/null
+++ b/configs/wan22/wan_distill_moe_flf2v_int8.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 16,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-vllm",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8"
+}
diff --git a/configs/wan22/wan_distill_moe_flf2v_with_lora.json b/configs/wan22/wan_distill_moe_flf2v_with_lora.json
new file mode 100644
index 0000000000000000000000000000000000000000..445ca8fe2dd7c356080a5dd08414faa6e5a096da
--- /dev/null
+++ b/configs/wan22/wan_distill_moe_flf2v_with_lora.json
@@ -0,0 +1,40 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 16,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "lora_configs": [
+ {
+ "name": "low_noise_model",
+ "path": "/path/to/low_noise_lora",
+ "strength": 1.0
+ },
+ {
+ "name": "high_noise_model",
+ "path": "/path/to/high_noise_lora",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/wan22/wan_moe_flf2v.json b/configs/wan22/wan_moe_flf2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..73879fc935900abc1a81b3dc840ca057a122a21e
--- /dev/null
+++ b/configs/wan22/wan_moe_flf2v.json
@@ -0,0 +1,21 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_shift": 16,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary": 0.900,
+ "use_image_encoder": false
+}
diff --git a/configs/wan22/wan_moe_i2v.json b/configs/wan22/wan_moe_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..eb29bfe8cefa414cedd0edd352028fd628f96dbb
--- /dev/null
+++ b/configs/wan22/wan_moe_i2v.json
@@ -0,0 +1,20 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "boundary": 0.900,
+ "use_image_encoder": false
+}
diff --git a/configs/wan22/wan_moe_i2v_4090.json b/configs/wan22/wan_moe_i2v_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..d37b9289257dece6a4f5cbcc7a2d49d44b4d1f49
--- /dev/null
+++ b/configs/wan22/wan_moe_i2v_4090.json
@@ -0,0 +1,27 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "phase",
+ "boundary": 0.900,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ]
+}
diff --git a/configs/wan22/wan_moe_i2v_audio.json b/configs/wan22/wan_moe_i2v_audio.json
new file mode 100644
index 0000000000000000000000000000000000000000..a0deb7ce7e08975da20a3e5e5e70d6eb21be8dec
--- /dev/null
+++ b/configs/wan22/wan_moe_i2v_audio.json
@@ -0,0 +1,38 @@
+{
+ "infer_steps": 6,
+ "target_fps": 16,
+ "video_duration": 16,
+ "audio_sr": 16000,
+ "text_len": 512,
+ "target_video_length": 81,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 1.0,
+ 1.0
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary": 0.900,
+ "use_image_encoder": false,
+ "use_31_block": false,
+ "lora_configs": [
+ {
+ "name": "high_noise_model",
+ "path": "/mnt/Text2Video/wuzhuguanyu/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors",
+ "strength": 1.0
+ },
+ {
+ "name": "low_noise_model",
+ "path": "/mnt/Text2Video/wuzhuguanyu/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/wan22/wan_moe_i2v_distill.json b/configs/wan22/wan_moe_i2v_distill.json
new file mode 100644
index 0000000000000000000000000000000000000000..92f3360d75a3801a359e1b403deba77a57d8c166
--- /dev/null
+++ b/configs/wan22/wan_moe_i2v_distill.json
@@ -0,0 +1,28 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ]
+}
diff --git a/configs/wan22/wan_moe_i2v_distill_4090.json b/configs/wan22/wan_moe_i2v_distill_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..279a0f25e88954426b732d70388d216b0c01a643
--- /dev/null
+++ b/configs/wan22/wan_moe_i2v_distill_4090.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-q8f"
+}
diff --git a/configs/wan22/wan_moe_i2v_distill_5090.json b/configs/wan22/wan_moe_i2v_distill_5090.json
new file mode 100644
index 0000000000000000000000000000000000000000..04494c2234874188f86c3aac7e2d92d9f311ad68
--- /dev/null
+++ b/configs/wan22/wan_moe_i2v_distill_5090.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn3",
+ "cross_attn_1_type": "sage_attn3",
+ "cross_attn_2_type": "sage_attn3",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "int8-q8f",
+ "t5_quantized": true,
+ "t5_quant_scheme": "int8-q8f"
+}
diff --git a/configs/wan22/wan_moe_i2v_distill_quant.json b/configs/wan22/wan_moe_i2v_distill_quant.json
new file mode 100644
index 0000000000000000000000000000000000000000..ce4e1c4aeb8524304b2e87c95909d713116cab84
--- /dev/null
+++ b/configs/wan22/wan_moe_i2v_distill_quant.json
@@ -0,0 +1,32 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl"
+}
diff --git a/configs/wan22/wan_moe_i2v_distill_with_lora.json b/configs/wan22/wan_moe_i2v_distill_with_lora.json
new file mode 100644
index 0000000000000000000000000000000000000000..f28ddc45126899e4ad8f5365546bca0bf3fb745d
--- /dev/null
+++ b/configs/wan22/wan_moe_i2v_distill_with_lora.json
@@ -0,0 +1,40 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "sage_attn2",
+ "cross_attn_1_type": "sage_attn2",
+ "cross_attn_2_type": "sage_attn2",
+ "sample_guide_scale": [
+ 3.5,
+ 3.5
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "block",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "use_image_encoder": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "lora_configs": [
+ {
+ "name": "high_noise_model",
+ "path": "lightx2v/Wan2.2-Distill-Loras/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors",
+ "strength": 1.0
+ },
+ {
+ "name": "low_noise_model",
+ "path": "lightx2v/Wan2.2-Distill-Loras/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/wan22/wan_moe_t2v.json b/configs/wan22/wan_moe_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..d8be0d1078ba45f9b915b3e44899ac04fe05c38c
--- /dev/null
+++ b/configs/wan22/wan_moe_t2v.json
@@ -0,0 +1,21 @@
+{
+ "infer_steps": 40,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 720,
+ "target_width": 1280,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 4.0,
+ 3.0
+ ],
+ "sample_shift": 12.0,
+ "enable_cfg": true,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary": 0.875
+}
diff --git a/configs/wan22/wan_moe_t2v_distill.json b/configs/wan22/wan_moe_t2v_distill.json
new file mode 100644
index 0000000000000000000000000000000000000000..eabdf11a0c2232a0bc43e1d7bb1bf76e4ae44455
--- /dev/null
+++ b/configs/wan22/wan_moe_t2v_distill.json
@@ -0,0 +1,34 @@
+{
+ "infer_steps": 4,
+ "target_video_length": 81,
+ "text_len": 512,
+ "target_height": 480,
+ "target_width": 832,
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": [
+ 4.0,
+ 3.0
+ ],
+ "sample_shift": 5.0,
+ "enable_cfg": false,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "boundary_step_index": 2,
+ "denoising_step_list": [
+ 1000,
+ 750,
+ 500,
+ 250
+ ],
+ "lora_configs": [
+ {
+ "name": "low_noise_model",
+ "path": "Wan2.1-T2V-14B/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
diff --git a/configs/wan22/wan_ti2v_i2v.json b/configs/wan22/wan_ti2v_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..44e9c4e334de58491bc6d4b6ce9e40acdc3628a6
--- /dev/null
+++ b/configs/wan22/wan_ti2v_i2v.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "fps": 24,
+ "use_image_encoder": false
+}
diff --git a/configs/wan22/wan_ti2v_i2v_4090.json b/configs/wan22/wan_ti2v_i2v_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..742f8ab16d3de6d5bb29abdd116bf84604e5c3a7
--- /dev/null
+++ b/configs/wan22/wan_ti2v_i2v_4090.json
@@ -0,0 +1,26 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "fps": 24,
+ "use_image_encoder": false,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "vae_offload_cache": true
+}
diff --git a/configs/wan22/wan_ti2v_t2v.json b/configs/wan22/wan_ti2v_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..a1541e808a1d54508efc83757613619cf16dc6d7
--- /dev/null
+++ b/configs/wan22/wan_ti2v_t2v.json
@@ -0,0 +1,24 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "cpu_offload": false,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "fps": 24
+}
diff --git a/configs/wan22/wan_ti2v_t2v_4090.json b/configs/wan22/wan_ti2v_t2v_4090.json
new file mode 100644
index 0000000000000000000000000000000000000000..c1e5c49aced9fc2d3426aa6f5cc21a0defdc1f32
--- /dev/null
+++ b/configs/wan22/wan_ti2v_t2v_4090.json
@@ -0,0 +1,25 @@
+{
+ "infer_steps": 50,
+ "target_video_length": 121,
+ "text_len": 512,
+ "target_height": 704,
+ "target_width": 1280,
+ "num_channels_latents": 48,
+ "vae_stride": [
+ 4,
+ 16,
+ 16
+ ],
+ "self_attn_1_type": "flash_attn3",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3",
+ "sample_guide_scale": 5.0,
+ "sample_shift": 5.0,
+ "enable_cfg": true,
+ "fps": 24,
+ "cpu_offload": true,
+ "offload_granularity": "model",
+ "t5_cpu_offload": false,
+ "vae_cpu_offload": false,
+ "vae_offload_cache": true
+}
diff --git a/dockerfiles/Dockerfile b/dockerfiles/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..ca883dcf14053ba7daeca5f6f65ff83a2c925607
--- /dev/null
+++ b/dockerfiles/Dockerfile
@@ -0,0 +1,105 @@
+FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel AS base
+
+WORKDIR /app
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV LANG=C.UTF-8
+ENV LC_ALL=C.UTF-8
+ENV LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
+
+RUN apt-get update && apt-get install -y vim tmux zip unzip bzip2 wget git git-lfs build-essential libibverbs-dev ca-certificates \
+ curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev libssl-dev flex bison libgtk-3-dev libpango1.0-dev \
+ libsoup2.4-dev libnice-dev libopus-dev libvpx-dev libx264-dev libsrtp2-dev libglib2.0-dev libdrm-dev libjpeg-dev libpng-dev \
+ && apt-get clean && rm -rf /var/lib/apt/lists/* && git lfs install
+
+RUN conda install conda-forge::ffmpeg=8.0.0 -y && conda clean -all -y
+
+RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv meson ruff pre-commit fastapi uvicorn requests -U
+
+RUN git clone https://github.com/vllm-project/vllm.git && cd vllm \
+ && python use_existing_torch.py && pip install --no-cache-dir -r requirements/build.txt \
+ && pip install --no-cache-dir --no-build-isolation -v -e .
+
+RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel \
+ && make build && make clean
+
+RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \
+ imageio-ffmpeg einops loguru qtorch ftfy av decord matplotlib debugpy
+
+RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive
+
+RUN cd flash-attention && python setup.py install && rm -rf build
+
+RUN cd flash-attention/hopper && python setup.py install && rm -rf build
+
+RUN git clone https://github.com/ModelTC/SageAttention.git --depth 1
+
+RUN cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0,12.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install --no-cache-dir -v -e .
+
+RUN git clone https://github.com/ModelTC/SageAttention-1104.git --depth 1
+
+RUN cd SageAttention-1104/sageattention3_blackwell && python setup.py install && rm -rf build
+
+RUN git clone https://github.com/SandAI-org/MagiAttention.git --recursive
+
+RUN cd MagiAttention && TORCH_CUDA_ARCH_LIST="9.0" pip install --no-cache-dir --no-build-isolation -v -e .
+
+RUN git clone https://github.com/ModelTC/FlashVSR.git --depth 1
+
+RUN cd FlashVSR && pip install --no-cache-dir -v -e .
+
+COPY lightx2v_kernel /app/lightx2v_kernel
+
+RUN git clone https://github.com/NVIDIA/cutlass.git --depth 1 && cd /app/lightx2v_kernel && MAX_JOBS=32 && CMAKE_BUILD_PARALLEL_LEVEL=4 \
+ uv build --wheel \
+ -Cbuild-dir=build . \
+ -Ccmake.define.CUTLASS_PATH=/app/cutlass \
+ --verbose \
+ --color=always \
+ --no-build-isolation \
+ && pip install dist/*whl --force-reinstall --no-deps \
+ && rm -rf /app/lightx2v_kernel && rm -rf /app/cutlass
+
+# cloud deploy
+RUN pip install --no-cache-dir aio-pika asyncpg>=0.27.0 aioboto3>=12.0.0 PyJWT alibabacloud_dypnsapi20170525==1.2.2 redis==6.4.0 tos
+
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
+ENV PATH=/root/.cargo/bin:$PATH
+
+RUN cd /opt \
+ && wget https://mirrors.tuna.tsinghua.edu.cn/gnu/libiconv/libiconv-1.15.tar.gz \
+ && tar zxvf libiconv-1.15.tar.gz \
+ && cd libiconv-1.15 \
+ && ./configure \
+ && make \
+ && make install \
+ && rm -rf /opt/libiconv-1.15
+
+RUN cd /opt \
+ && git clone https://github.com/GStreamer/gstreamer.git -b 1.27.2 --depth 1 \
+ && cd gstreamer \
+ && meson setup builddir \
+ && meson compile -C builddir \
+ && meson install -C builddir \
+ && ldconfig \
+ && rm -rf /opt/gstreamer
+
+RUN cd /opt \
+ && git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.27.2 --depth 1 \
+ && cd gst-plugins-rs \
+ && cargo build --package gst-plugin-webrtchttp --release \
+ && install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/ \
+ && rm -rf /opt/gst-plugins-rs
+
+RUN ldconfig
+
+
+# q8f for base docker
+RUN git clone https://github.com/KONAKONA666/q8_kernels.git --depth 1
+RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build
+
+# q8f for 5090 docker
+# RUN git clone https://github.com/ModelTC/LTX-Video-Q8-Kernels.git --depth 1
+# RUN cd LTX-Video-Q8-Kernels && git submodule init && git submodule update && python setup.py install && rm -rf build
+
+WORKDIR /workspace
diff --git a/dockerfiles/Dockerfile_5090 b/dockerfiles/Dockerfile_5090
new file mode 100644
index 0000000000000000000000000000000000000000..c8bce20aa53f5a010e98c1dbdbc8f7351382b164
--- /dev/null
+++ b/dockerfiles/Dockerfile_5090
@@ -0,0 +1,105 @@
+FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel AS base
+
+WORKDIR /app
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV LANG=C.UTF-8
+ENV LC_ALL=C.UTF-8
+ENV LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
+
+RUN apt-get update && apt-get install -y vim tmux zip unzip bzip2 wget git git-lfs build-essential libibverbs-dev ca-certificates \
+ curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev libssl-dev flex bison libgtk-3-dev libpango1.0-dev \
+ libsoup2.4-dev libnice-dev libopus-dev libvpx-dev libx264-dev libsrtp2-dev libglib2.0-dev libdrm-dev libjpeg-dev libpng-dev \
+ && apt-get clean && rm -rf /var/lib/apt/lists/* && git lfs install
+
+RUN conda install conda-forge::ffmpeg=8.0.0 -y && conda clean -all -y
+
+RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv meson ruff pre-commit fastapi uvicorn requests -U
+
+RUN git clone https://github.com/vllm-project/vllm.git && cd vllm \
+ && python use_existing_torch.py && pip install --no-cache-dir -r requirements/build.txt \
+ && pip install --no-cache-dir --no-build-isolation -v -e .
+
+RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel \
+ && make build && make clean
+
+RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \
+ imageio-ffmpeg einops loguru qtorch ftfy av decord matplotlib debugpy
+
+RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive
+
+RUN cd flash-attention && python setup.py install && rm -rf build
+
+RUN cd flash-attention/hopper && python setup.py install && rm -rf build
+
+RUN git clone https://github.com/ModelTC/SageAttention.git --depth 1
+
+RUN cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0,12.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install --no-cache-dir -v -e .
+
+RUN git clone https://github.com/ModelTC/SageAttention-1104.git --depth 1
+
+RUN cd SageAttention-1104/sageattention3_blackwell && python setup.py install && rm -rf build
+
+RUN git clone https://github.com/SandAI-org/MagiAttention.git --recursive
+
+RUN cd MagiAttention && TORCH_CUDA_ARCH_LIST="9.0" pip install --no-cache-dir --no-build-isolation -v -e .
+
+RUN git clone https://github.com/ModelTC/FlashVSR.git --depth 1
+
+RUN cd FlashVSR && pip install --no-cache-dir -v -e .
+
+COPY lightx2v_kernel /app/lightx2v_kernel
+
+RUN git clone https://github.com/NVIDIA/cutlass.git --depth 1 && cd /app/lightx2v_kernel && MAX_JOBS=32 && CMAKE_BUILD_PARALLEL_LEVEL=4 \
+ uv build --wheel \
+ -Cbuild-dir=build . \
+ -Ccmake.define.CUTLASS_PATH=/app/cutlass \
+ --verbose \
+ --color=always \
+ --no-build-isolation \
+ && pip install dist/*whl --force-reinstall --no-deps \
+ && rm -rf /app/lightx2v_kernel && rm -rf /app/cutlass
+
+# cloud deploy
+RUN pip install --no-cache-dir aio-pika asyncpg>=0.27.0 aioboto3>=12.0.0 PyJWT alibabacloud_dypnsapi20170525==1.2.2 redis==6.4.0 tos
+
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
+ENV PATH=/root/.cargo/bin:$PATH
+
+RUN cd /opt \
+ && wget https://mirrors.tuna.tsinghua.edu.cn/gnu/libiconv/libiconv-1.15.tar.gz \
+ && tar zxvf libiconv-1.15.tar.gz \
+ && cd libiconv-1.15 \
+ && ./configure \
+ && make \
+ && make install \
+ && rm -rf /opt/libiconv-1.15
+
+RUN cd /opt \
+ && git clone https://github.com/GStreamer/gstreamer.git -b 1.27.2 --depth 1 \
+ && cd gstreamer \
+ && meson setup builddir \
+ && meson compile -C builddir \
+ && meson install -C builddir \
+ && ldconfig \
+ && rm -rf /opt/gstreamer
+
+RUN cd /opt \
+ && git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.27.2 --depth 1 \
+ && cd gst-plugins-rs \
+ && cargo build --package gst-plugin-webrtchttp --release \
+ && install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/ \
+ && rm -rf /opt/gst-plugins-rs
+
+RUN ldconfig
+
+
+# q8f for base docker
+# RUN git clone https://github.com/KONAKONA666/q8_kernels.git --depth 1
+# RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build
+
+# q8f for 5090 docker
+RUN git clone https://github.com/ModelTC/LTX-Video-Q8-Kernels.git --depth 1
+RUN cd LTX-Video-Q8-Kernels && git submodule init && git submodule update && python setup.py install && rm -rf build
+
+WORKDIR /workspace
diff --git a/dockerfiles/Dockerfile_cambricon_mlu590 b/dockerfiles/Dockerfile_cambricon_mlu590
new file mode 100644
index 0000000000000000000000000000000000000000..3cfa0278ea2199a52d39a2c79b3a7dd412f00281
--- /dev/null
+++ b/dockerfiles/Dockerfile_cambricon_mlu590
@@ -0,0 +1,31 @@
+FROM cambricon-base/pytorch:v25.10.0-torch2.8.0-torchmlu1.29.1-ubuntu22.04-py310 AS base
+
+WORKDIR /workspace/LightX2V
+
+# Set envs
+ENV PYTHONPATH=/workspace/LightX2V
+ENV LD_LIBRARY_PATH=/usr/local/neuware/lib64:${LD_LIBRARY_PATH}
+
+# Install deps
+RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg && \
+ pip install --no-cache-dir \
+ ftfy \
+ imageio \
+ imageio-ffmpeg \
+ loguru \
+ aiohttp \
+ gguf \
+ diffusers \
+ peft==0.17.0 \
+ transformers==4.57.1 &&
+
+# Copy files
+COPY app app
+COPY assets assets
+COPY configs configs
+COPY lightx2v lightx2v
+COPY lightx2v_kernel lightx2v_kernel
+COPY lightx2v_platform lightx2v_platform
+COPY scripts scripts
+COPY test_cases test_cases
+COPY tools tools
diff --git a/dockerfiles/Dockerfile_cu124 b/dockerfiles/Dockerfile_cu124
new file mode 100644
index 0000000000000000000000000000000000000000..ada01a3b841603e1663229882635bb716a0996bc
--- /dev/null
+++ b/dockerfiles/Dockerfile_cu124
@@ -0,0 +1,76 @@
+FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel AS base
+
+WORKDIR /app
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV LANG=C.UTF-8
+ENV LC_ALL=C.UTF-8
+ENV LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
+
+RUN apt-get update && apt-get install -y vim tmux zip unzip wget git git-lfs build-essential libibverbs-dev ca-certificates \
+ curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev libssl-dev flex bison libgtk-3-dev libpango1.0-dev \
+ libsoup2.4-dev libnice-dev libopus-dev libvpx-dev libx264-dev libsrtp2-dev libglib2.0-dev libdrm-dev\
+ && apt-get clean && rm -rf /var/lib/apt/lists/* && git lfs install
+
+RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv meson ruff pre-commit fastapi uvicorn requests -U
+
+RUN git clone https://github.com/vllm-project/vllm.git -b v0.10.0 && cd vllm \
+ && python use_existing_torch.py && pip install -r requirements/build.txt \
+ && pip install --no-cache-dir --no-build-isolation -v -e .
+
+RUN git clone https://github.com/sgl-project/sglang.git -b v0.4.10 && cd sglang/sgl-kernel \
+ && make build && make clean
+
+RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \
+ imageio-ffmpeg einops loguru qtorch ftfy av decord
+
+RUN conda install conda-forge::ffmpeg=8.0.0 -y && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg && conda clean -all -y
+
+RUN git clone https://github.com/Dao-AILab/flash-attention.git -b v2.8.3 --recursive
+
+RUN cd flash-attention && python setup.py install && rm -rf build
+
+RUN cd flash-attention/hopper && python setup.py install && rm -rf build
+
+RUN git clone https://github.com/ModelTC/SageAttention.git
+
+RUN cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install --no-cache-dir -v -e .
+
+RUN git clone https://github.com/KONAKONA666/q8_kernels.git
+
+RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build
+
+# cloud deploy
+RUN pip install --no-cache-dir aio-pika asyncpg>=0.27.0 aioboto3>=12.0.0 PyJWT alibabacloud_dypnsapi20170525==1.2.2 redis==6.4.0 tos
+
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
+ENV PATH=/root/.cargo/bin:$PATH
+
+RUN cd /opt \
+ && wget https://mirrors.tuna.tsinghua.edu.cn/gnu//libiconv/libiconv-1.15.tar.gz \
+ && tar zxvf libiconv-1.15.tar.gz \
+ && cd libiconv-1.15 \
+ && ./configure \
+ && make \
+ && make install \
+ && rm -rf /opt/libiconv-1.15
+
+RUN cd /opt \
+ && git clone https://github.com/GStreamer/gstreamer.git -b 1.24.12 --depth 1 \
+ && cd gstreamer \
+ && meson setup builddir \
+ && meson compile -C builddir \
+ && meson install -C builddir \
+ && ldconfig \
+ && rm -rf /opt/gstreamer
+
+RUN cd /opt \
+ && git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.24.12 --depth 1 \
+ && cd gst-plugins-rs \
+ && cargo build --package gst-plugin-webrtchttp --release \
+ && install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/ \
+ && rm -rf /opt/gst-plugins-rs
+
+RUN ldconfig
+
+WORKDIR /workspace
diff --git a/dockerfiles/Dockerfile_deploy b/dockerfiles/Dockerfile_deploy
new file mode 100644
index 0000000000000000000000000000000000000000..ee4b1fa068483f8bcafc636b572b557e36a23250
--- /dev/null
+++ b/dockerfiles/Dockerfile_deploy
@@ -0,0 +1,32 @@
+FROM node:alpine3.21 AS frontend_builder
+COPY lightx2v /opt/lightx2v
+
+RUN cd /opt/lightx2v/deploy/server/frontend \
+ && npm install \
+ && npm run build
+
+FROM lightx2v/lightx2v:25111101-cu128 AS base
+
+RUN mkdir /workspace/LightX2V
+WORKDIR /workspace/LightX2V
+ENV PYTHONPATH=/workspace/LightX2V
+
+# for multi-person & animate
+RUN pip install ultralytics moviepy pydub pyannote.audio onnxruntime decord peft onnxruntime pandas matplotlib loguru sentencepiece
+
+RUN export COMMIT=0e78a118995e66bb27d78518c4bd9a3e95b4e266 \
+ && export TORCH_CUDA_ARCH_LIST="9.0" \
+ && git clone --depth 1 https://github.com/facebookresearch/sam2.git \
+ && cd sam2 \
+ && git fetch --depth 1 origin $COMMIT \
+ && git checkout $COMMIT \
+ && python setup.py install
+
+COPY tools tools
+COPY assets assets
+COPY configs configs
+COPY lightx2v lightx2v
+COPY lightx2v_kernel lightx2v_kernel
+COPY lightx2v_platform lightx2v_platform
+
+COPY --from=frontend_builder /opt/lightx2v/deploy/server/frontend/dist lightx2v/deploy/server/frontend/dist
diff --git a/docs/EN/.readthedocs.yaml b/docs/EN/.readthedocs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dda16a6b056a96ef6afe797884912090a612f567
--- /dev/null
+++ b/docs/EN/.readthedocs.yaml
@@ -0,0 +1,17 @@
+version: 2
+
+# Set the version of Python and other tools you might need
+build:
+ os: ubuntu-20.04
+ tools:
+ python: "3.10"
+
+formats:
+ - epub
+
+sphinx:
+ configuration: docs/EN/source/conf.py
+
+python:
+ install:
+ - requirements: requirements-docs.txt
diff --git a/docs/EN/Makefile b/docs/EN/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293
--- /dev/null
+++ b/docs/EN/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/EN/make.bat b/docs/EN/make.bat
new file mode 100644
index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a
--- /dev/null
+++ b/docs/EN/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.https://www.sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/EN/source/conf.py b/docs/EN/source/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..25f127e5b6e95951e941d510fce9cf6788bbaab9
--- /dev/null
+++ b/docs/EN/source/conf.py
@@ -0,0 +1,128 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import logging
+import os
+import sys
+from typing import List
+
+import sphinxcontrib.redoc
+from sphinx.ext import autodoc
+
+logger = logging.getLogger(__name__)
+sys.path.append(os.path.abspath("../.."))
+
+# -- Project information -----------------------------------------------------
+
+project = "Lightx2v"
+copyright = "2025, Lightx2v Team"
+author = "the Lightx2v Team"
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "sphinx.ext.napoleon",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.intersphinx",
+ "sphinx_copybutton",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.mathjax",
+ "myst_parser",
+ "sphinxarg.ext",
+ "sphinxcontrib.redoc",
+ "sphinxcontrib.openapi",
+]
+
+myst_enable_extensions = [
+ "dollarmath",
+ "amsmath",
+]
+
+html_static_path = ["_static"]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns: List[str] = ["**/*.template.rst"]
+
+# Exclude the prompt "$" when copying code
+copybutton_prompt_text = r"\$ "
+copybutton_prompt_is_regexp = True
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_title = project
+html_theme = "sphinx_book_theme"
+# html_theme = 'sphinx_rtd_theme'
+html_logo = "../../../assets/img_lightx2v.png"
+html_theme_options = {
+ "path_to_docs": "docs/EN/source",
+ "repository_url": "https://github.com/ModelTC/lightx2v",
+ "use_repository_button": True,
+}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+# html_static_path = ['_static']
+
+
+# Generate additional rst documentation here.
+def setup(app):
+ # from docs.source.generate_examples import generate_examples
+ # generate_examples()
+ pass
+
+
+# Mock out external dependencies here.
+autodoc_mock_imports = [
+ "cpuinfo",
+ "torch",
+ "transformers",
+ "psutil",
+ "prometheus_client",
+ "sentencepiece",
+ "lightllmnumpy",
+ "tqdm",
+ "tensorizer",
+]
+
+for mock_target in autodoc_mock_imports:
+ if mock_target in sys.modules:
+ logger.info(
+ "Potentially problematic mock target (%s) found; autodoc_mock_imports cannot mock modules that have already been loaded into sys.modules when the sphinx build starts.",
+ mock_target,
+ )
+
+
+class MockedClassDocumenter(autodoc.ClassDocumenter):
+ """Remove note about base class when a class is derived from object."""
+
+ def add_line(self, line: str, source: str, *lineno: int) -> None:
+ if line == " Bases: :py:class:`object`":
+ return
+ super().add_line(line, source, *lineno)
+
+
+autodoc.ClassDocumenter = MockedClassDocumenter
+
+navigation_with_keys = False
diff --git a/docs/EN/source/deploy_guides/deploy_comfyui.md b/docs/EN/source/deploy_guides/deploy_comfyui.md
new file mode 100644
index 0000000000000000000000000000000000000000..afde7bd5307eda6874432a204da335a4fa733909
--- /dev/null
+++ b/docs/EN/source/deploy_guides/deploy_comfyui.md
@@ -0,0 +1,25 @@
+# ComfyUI Deployment
+
+## ComfyUI-Lightx2vWrapper
+
+The official ComfyUI integration nodes for LightX2V are now available in a dedicated repository, providing a complete modular configuration system and optimization features.
+
+### Project Repository
+
+- GitHub: [https://github.com/ModelTC/ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper)
+
+### Key Features
+
+- Modular Configuration System: Separate nodes for each aspect of video generation
+- Support for both Text-to-Video (T2V) and Image-to-Video (I2V) generation modes
+- Advanced Optimizations:
+ - TeaCache acceleration (up to 3x speedup)
+ - Quantization support (int8, fp8)
+ - Memory optimization with CPU offloading
+ - Lightweight VAE options
+- LoRA Support: Chain multiple LoRA models for customization
+- Multiple Model Support: wan2.1, hunyuan architectures
+
+### Installation and Usage
+
+Please visit the GitHub repository above for detailed installation instructions, usage tutorials, and example workflows.
diff --git a/docs/EN/source/deploy_guides/deploy_gradio.md b/docs/EN/source/deploy_guides/deploy_gradio.md
new file mode 100644
index 0000000000000000000000000000000000000000..5be510feee143be743e73149024b456c40d91c96
--- /dev/null
+++ b/docs/EN/source/deploy_guides/deploy_gradio.md
@@ -0,0 +1,240 @@
+# Gradio Deployment Guide
+
+## 📖 Overview
+
+Lightx2v is a lightweight video inference and generation engine that provides a web interface based on Gradio, supporting both Image-to-Video and Text-to-Video generation modes.
+
+For Windows systems, we provide a convenient one-click deployment solution with automatic environment configuration and intelligent parameter optimization. Please refer to the [One-Click Gradio Startup (Recommended)](./deploy_local_windows.md/#one-click-gradio-startup-recommended) section for detailed instructions.
+
+
+
+## 📁 File Structure
+
+```
+LightX2V/app/
+├── gradio_demo.py # English interface demo
+├── gradio_demo_zh.py # Chinese interface demo
+├── run_gradio.sh # Startup script
+├── README.md # Documentation
+├── outputs/ # Generated video save directory
+└── inference_logs.log # Inference logs
+```
+
+This project contains two main demo files:
+- `gradio_demo.py` - English interface version
+- `gradio_demo_zh.py` - Chinese interface version
+
+## 🚀 Quick Start
+
+### Environment Requirements
+
+Follow the [Quick Start Guide](../getting_started/quickstart.md) to install the environment
+
+#### Recommended Optimization Library Configuration
+
+- ✅ [Flash attention](https://github.com/Dao-AILab/flash-attention)
+- ✅ [Sage attention](https://github.com/thu-ml/SageAttention)
+- ✅ [vllm-kernel](https://github.com/vllm-project/vllm)
+- ✅ [sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
+- ✅ [q8-kernel](https://github.com/KONAKONA666/q8_kernels) (only supports ADA architecture GPUs)
+
+Install according to the project homepage tutorials for each operator as needed.
+
+### 📥 Model Download
+
+Refer to the [Model Structure Documentation](../getting_started/model_structure.md) to download complete models (including quantized and non-quantized versions) or download only quantized/non-quantized versions.
+
+#### wan2.1 Model Directory Structure
+
+```
+models/
+├── wan2.1_i2v_720p_lightx2v_4step.safetensors # Original precision
+├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization
+├── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 quantization
+├── wan2.1_i2v_720p_int8_lightx2v_4step_split # INT8 quantization block storage directory
+├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split # FP8 quantization block storage directory
+├── Other weights (e.g., t2v)
+├── t5/clip/xlm-roberta-large/google # text and image encoder
+├── vae/lightvae/lighttae # vae
+└── config.json # Model configuration file
+```
+
+#### wan2.2 Model Directory Structure
+
+```
+models/
+├── wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors # high noise original precision
+├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step_1030.safetensors # high noise FP8 quantization
+├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030.safetensors # high noise INT8 quantization
+├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030_split # high noise INT8 quantization block storage directory
+├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # low noise original precision
+├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # low noise FP8 quantization
+├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # low noise INT8 quantization
+├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step_split # low noise INT8 quantization block storage directory
+├── t5/clip/xlm-roberta-large/google # text and image encoder
+├── vae/lightvae/lighttae # vae
+└── config.json # Model configuration file
+```
+
+**📝 Download Instructions**:
+
+- Model weights can be downloaded from HuggingFace:
+ - [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+ - [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+- Text and Image Encoders can be downloaded from [Encoders](https://huggingface.co/lightx2v/Encoders)
+- VAE can be downloaded from [Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
+- For `xxx_split` directories (e.g., `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split`), which store multiple safetensors by block, suitable for devices with insufficient memory. For example, devices with 16GB or less memory should download according to their own situation.
+
+### Startup Methods
+
+#### Method 1: Using Startup Script (Recommended)
+
+**Linux Environment:**
+```bash
+# 1. Edit the startup script to configure relevant paths
+cd app/
+vim run_gradio.sh
+
+# Configuration items that need to be modified:
+# - lightx2v_path: Lightx2v project root directory path
+# - model_path: Model root directory path (contains all model files)
+
+# 💾 Important note: Recommend pointing model paths to SSD storage locations
+# Example: /mnt/ssd/models/ or /data/ssd/models/
+
+# 2. Run the startup script
+bash run_gradio.sh
+
+# 3. Or start with parameters
+bash run_gradio.sh --lang en --port 8032
+bash run_gradio.sh --lang zh --port 7862
+```
+
+**Windows Environment:**
+```cmd
+# 1. Edit the startup script to configure relevant paths
+cd app\
+notepad run_gradio_win.bat
+
+# Configuration items that need to be modified:
+# - lightx2v_path: Lightx2v project root directory path
+# - model_path: Model root directory path (contains all model files)
+
+# 💾 Important note: Recommend pointing model paths to SSD storage locations
+# Example: D:\models\ or E:\models\
+
+# 2. Run the startup script
+run_gradio_win.bat
+
+# 3. Or start with parameters
+run_gradio_win.bat --lang en --port 8032
+run_gradio_win.bat --lang zh --port 7862
+```
+
+#### Method 2: Direct Command Line Startup
+
+```bash
+pip install -v git+https://github.com/ModelTC/LightX2V.git
+```
+
+**Linux Environment:**
+
+**English Interface Version:**
+```bash
+python gradio_demo.py \
+ --model_path /path/to/models \
+ --server_name 0.0.0.0 \
+ --server_port 7862
+```
+
+**Chinese Interface Version:**
+```bash
+python gradio_demo_zh.py \
+ --model_path /path/to/models \
+ --server_name 0.0.0.0 \
+ --server_port 7862
+```
+
+**Windows Environment:**
+
+**English Interface Version:**
+```cmd
+python gradio_demo.py ^
+ --model_path D:\models ^
+ --server_name 127.0.0.1 ^
+ --server_port 7862
+```
+
+**Chinese Interface Version:**
+```cmd
+python gradio_demo_zh.py ^
+ --model_path D:\models ^
+ --server_name 127.0.0.1 ^
+ --server_port 7862
+```
+
+**💡 Tip**: Model type (wan2.1/wan2.2), task type (i2v/t2v), and specific model file selection are all configured in the Web interface.
+
+## 📋 Command Line Parameters
+
+| Parameter | Type | Required | Default | Description |
+|-----------|------|----------|---------|-------------|
+| `--model_path` | str | ✅ | - | Model root directory path (directory containing all model files) |
+| `--server_port` | int | ❌ | 7862 | Server port |
+| `--server_name` | str | ❌ | 0.0.0.0 | Server IP address |
+| `--output_dir` | str | ❌ | ./outputs | Output video save directory |
+
+**💡 Note**: Model type (wan2.1/wan2.2), task type (i2v/t2v), and specific model file selection are all configured in the Web interface.
+
+## 🎯 Features
+
+### Model Configuration
+
+- **Model Type**: Supports wan2.1 and wan2.2 model architectures
+- **Task Type**: Supports Image-to-Video (i2v) and Text-to-Video (t2v) generation modes
+- **Model Selection**: Frontend automatically identifies and filters available model files, supports automatic quantization precision detection
+- **Encoder Configuration**: Supports selection of T5 text encoder, CLIP image encoder, and VAE decoder
+- **Operator Selection**: Supports multiple attention operators and quantization matrix multiplication operators, system automatically sorts by installation status
+
+### Input Parameters
+
+- **Prompt**: Describe the expected video content
+- **Negative Prompt**: Specify elements you don't want to appear
+- **Input Image**: Upload input image required in i2v mode
+- **Resolution**: Supports multiple preset resolutions (480p/540p/720p)
+- **Random Seed**: Controls the randomness of generation results
+- **Inference Steps**: Affects the balance between generation quality and speed (defaults to 4 steps for distilled models)
+
+### Video Parameters
+
+- **FPS**: Frames per second
+- **Total Frames**: Video length
+- **CFG Scale Factor**: Controls prompt influence strength (1-10, defaults to 1 for distilled models)
+- **Distribution Shift**: Controls generation style deviation degree (0-10)
+
+## 🔧 Auto-Configuration Feature
+
+The system automatically configures optimal inference options based on your hardware configuration (GPU VRAM and CPU memory) without manual adjustment. The best configuration is automatically applied on startup, including:
+
+- **GPU Memory Optimization**: Automatically enables CPU offloading, VAE tiling inference, etc. based on VRAM size
+- **CPU Memory Optimization**: Automatically enables lazy loading, module unloading, etc. based on system memory
+- **Operator Selection**: Automatically selects the best installed operators (sorted by priority)
+- **Quantization Configuration**: Automatically detects and applies quantization precision based on model file names
+
+
+### Log Viewing
+
+```bash
+# View inference logs
+tail -f inference_logs.log
+
+# View GPU usage
+nvidia-smi
+
+# View system resources
+htop
+```
+
+Welcome to submit Issues and Pull Requests to improve this project!
+
+**Note**: Please comply with relevant laws and regulations when using videos generated by this tool, and do not use them for illegal purposes.
diff --git a/docs/EN/source/deploy_guides/deploy_local_windows.md b/docs/EN/source/deploy_guides/deploy_local_windows.md
new file mode 100644
index 0000000000000000000000000000000000000000..e675f9c5af1bae7d2220747a85219a8ccbfab291
--- /dev/null
+++ b/docs/EN/source/deploy_guides/deploy_local_windows.md
@@ -0,0 +1,127 @@
+# Windows Local Deployment Guide
+
+## 📖 Overview
+
+This document provides detailed instructions for deploying LightX2V locally on Windows environments, including batch file inference, Gradio Web interface inference, and other usage methods.
+
+## 🚀 Quick Start
+
+### Environment Requirements
+
+#### Hardware Requirements
+- **GPU**: NVIDIA GPU, recommended 8GB+ VRAM
+- **Memory**: Recommended 16GB+ RAM
+- **Storage**: Strongly recommended to use SSD solid-state drives, mechanical hard drives will cause slow model loading
+
+
+## 🎯 Usage Methods
+
+### Method 1: Using Batch File Inference
+
+Refer to [Quick Start Guide](../getting_started/quickstart.md) to install environment, and use [batch files](https://github.com/ModelTC/LightX2V/tree/main/scripts/win) to run.
+
+### Method 2: Using Gradio Web Interface Inference
+
+#### Manual Gradio Configuration
+
+Refer to [Quick Start Guide](../getting_started/quickstart.md) to install environment, refer to [Gradio Deployment Guide](./deploy_gradio.md)
+
+#### One-Click Gradio Startup (Recommended)
+
+**📦 Download Software Package**
+- [Quark Cloud](https://pan.quark.cn/s/8af1162d7a15)
+
+**📁 Directory Structure**
+After extraction, ensure the directory structure is as follows:
+
+```
+├── env/ # LightX2V environment directory
+├── LightX2V/ # LightX2V project directory
+├── start_lightx2v.bat # One-click startup script
+├── lightx2v_config.txt # Configuration file
+├── LightX2V使用说明.txt # LightX2V usage instructions
+├── outputs/ # Generated video save directory
+└── models/ # Model storage directory
+```
+
+**📥 Model Download**:
+
+Refer to [Model Structure Documentation](../getting_started/model_structure.md) or [Gradio Deployment Guide](./deploy_gradio.md) to download complete models (including quantized and non-quantized versions) or download only quantized/non-quantized versions.
+
+
+**📋 Configuration Parameters**
+
+Edit the `lightx2v_config.txt` file and modify the following parameters as needed:
+
+```ini
+
+# Interface language (zh: Chinese, en: English)
+lang=en
+
+# Server port
+port=8032
+
+# GPU device ID (0, 1, 2...)
+gpu=0
+
+# Model path
+model_path=models/
+```
+
+**🚀 Start Service**
+
+Double-click to run the `start_lightx2v.bat` file, the script will:
+1. Automatically read configuration file
+2. Verify model paths and file integrity
+3. Start Gradio Web interface
+4. Automatically open browser to access service
+
+
+
+
+**⚠️ Important Notes**:
+- **Display Issues**: If the webpage opens blank or displays abnormally, please run `pip install --upgrade gradio` to upgrade the Gradio version.
+
+### Method 3: Using ComfyUI Inference
+
+This guide will instruct you on how to download and use the portable version of the Lightx2v-ComfyUI environment, so you can avoid manual environment configuration steps. This is suitable for users who want to quickly start experiencing accelerated video generation with Lightx2v on Windows systems.
+
+#### Download the Windows Portable Environment:
+
+- [Baidu Cloud Download](https://pan.baidu.com/s/1FVlicTXjmXJA1tAVvNCrBw?pwd=wfid), extraction code: wfid
+
+The portable environment already packages all Python runtime dependencies, including the code and dependencies for ComfyUI and LightX2V. After downloading, simply extract to use.
+
+After extraction, the directory structure is as follows:
+
+```shell
+lightx2v_env
+├──📂 ComfyUI # ComfyUI code
+├──📂 portable_python312_embed # Standalone Python environment
+└── run_nvidia_gpu.bat # Windows startup script (double-click to start)
+```
+
+#### Start ComfyUI
+
+Directly double-click the run_nvidia_gpu.bat file. The system will open a Command Prompt window and run the program. The first startup may take a while, please be patient. After startup is complete, the browser will automatically open and display the ComfyUI frontend interface.
+
+
+
+The plugin used by LightX2V-ComfyUI is [ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper). Example workflows can be obtained from this project.
+
+#### Tested Graphics Cards (offload mode)
+
+- Tested model: `Wan2.1-I2V-14B-480P`
+
+| GPU Model | Task Type | VRAM Capacity | Actual Max VRAM Usage | Actual Max RAM Usage |
+|:-----------|:------------|:--------------|:---------------------|:---------------------|
+| 3090Ti | I2V | 24G | 6.1G | 7.1G |
+| 3080Ti | I2V | 12G | 6.1G | 7.1G |
+| 3060Ti | I2V | 8G | 6.1G | 7.1G |
+| 4070Ti Super | I2V | 16G | 6.1G | 7.1G |
+| 4070 | I2V | 12G | 6.1G | 7.1G |
+| 4060 | I2V | 8G | 6.1G | 7.1G |
+
+#### Environment Packaging and Usage Reference
+- [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
+- [Portable-Windows-ComfyUI-Docs](https://docs.comfy.org/zh-CN/installation/comfyui_portable_windows#portable-%E5%8F%8A%E8%87%AA%E9%83%A8%E7%BD%B2)
diff --git a/docs/EN/source/deploy_guides/deploy_service.md b/docs/EN/source/deploy_guides/deploy_service.md
new file mode 100644
index 0000000000000000000000000000000000000000..b34b349622d77ba874b394404794976652547a01
--- /dev/null
+++ b/docs/EN/source/deploy_guides/deploy_service.md
@@ -0,0 +1,88 @@
+# Service Deployment
+
+lightx2v provides asynchronous service functionality. The code entry point is [here](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
+
+### Start the Service
+
+```shell
+# Modify the paths in the script
+bash scripts/start_server.sh
+```
+
+The `--port 8000` option means the service will bind to port `8000` on the local machine. You can change this as needed.
+
+### Client Sends Request
+
+```shell
+python scripts/post.py
+```
+
+The service endpoint is: `/v1/tasks/`
+
+The `message` parameter in `scripts/post.py` is as follows:
+
+```python
+message = {
+ "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
+ "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ "image_path": "",
+}
+```
+
+1. `prompt`, `negative_prompt`, and `image_path` are basic inputs for video generation. `image_path` can be an empty string, indicating no image input is needed.
+
+
+### Client Checks Server Status
+
+```shell
+python scripts/check_status.py
+```
+
+The service endpoints include:
+
+1. `/v1/service/status` is used to check the status of the service. It returns whether the service is `busy` or `idle`. The service only accepts new requests when it is `idle`.
+
+2. `/v1/tasks/` is used to get all tasks received and completed by the server.
+
+3. `/v1/tasks/{task_id}/status` is used to get the status of a specified `task_id`. It returns whether the task is `processing` or `completed`.
+
+### Client Stops the Current Task on the Server at Any Time
+
+```shell
+python scripts/stop_running_task.py
+```
+
+The service endpoint is: `/v1/tasks/running`
+
+After terminating the task, the server will not exit but will return to waiting for new requests.
+
+### Starting Multiple Services on a Single Node
+
+On a single node, you can start multiple services using `scripts/start_server.sh` (Note that the port numbers under the same IP must be different for each service), or you can start multiple services at once using `scripts/start_multi_servers.sh`:
+
+```shell
+num_gpus=8 bash scripts/start_multi_servers.sh
+```
+
+Where `num_gpus` indicates the number of services to start; the services will run on consecutive ports starting from `--start_port`.
+
+### Scheduling Between Multiple Services
+
+```shell
+python scripts/post_multi_servers.py
+```
+
+`post_multi_servers.py` will schedule multiple client requests based on the idle status of the services.
+
+### API Endpoints Summary
+
+| Endpoint | Method | Description |
+|----------|--------|-------------|
+| `/v1/tasks/` | POST | Create video generation task |
+| `/v1/tasks/form` | POST | Create video generation task via form |
+| `/v1/tasks/` | GET | Get all task list |
+| `/v1/tasks/{task_id}/status` | GET | Get status of specified task |
+| `/v1/tasks/{task_id}/result` | GET | Get result video file of specified task |
+| `/v1/tasks/running` | DELETE | Stop currently running task |
+| `/v1/files/download/{file_path}` | GET | Download file |
+| `/v1/service/status` | GET | Get service status |
diff --git a/docs/EN/source/deploy_guides/for_low_latency.md b/docs/EN/source/deploy_guides/for_low_latency.md
new file mode 100644
index 0000000000000000000000000000000000000000..5e8df8a467452ff0223c9f62295e8c74b73b0002
--- /dev/null
+++ b/docs/EN/source/deploy_guides/for_low_latency.md
@@ -0,0 +1,41 @@
+# Deployment for Low Latency Scenarios
+
+In low latency scenarios, we pursue faster speed, ignoring issues such as video memory and RAM overhead. We provide two solutions:
+
+## 💡 Solution 1: Inference with Step Distillation Model
+
+This solution can refer to the [Step Distillation Documentation](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/step_distill.html)
+
+🧠 **Step Distillation** is a very direct acceleration inference solution for video generation models. By distilling from 50 steps to 4 steps, the time consumption will be reduced to 4/50 of the original. At the same time, under this solution, it can still be combined with the following solutions:
+1. [Efficient Attention Mechanism Solution](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/attention.html)
+2. [Model Quantization](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/quantization.html)
+
+## 💡 Solution 2: Inference with Non-Step Distillation Model
+
+Step distillation requires relatively large training resources, and the model after step distillation may have degraded video dynamic range.
+
+For the original model without step distillation, we can use the following solutions or a combination of multiple solutions for acceleration:
+
+1. [Parallel Inference](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/parallel.html) for multi-GPU parallel acceleration.
+2. [Feature Caching](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/cache.html) to reduce the actual inference steps.
+3. [Efficient Attention Mechanism Solution](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/attention.html) to accelerate Attention inference.
+4. [Model Quantization](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/quantization.html) to accelerate Linear layer inference.
+5. [Variable Resolution Inference](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/changing_resolution.html) to reduce the resolution of intermediate inference steps.
+
+## 💡 Using Tiny VAE
+
+In some cases, the VAE component can be time-consuming. You can use a lightweight VAE for acceleration, which can also reduce some GPU memory usage.
+
+```python
+{
+ "use_tae": true,
+ "tae_path": "/path to taew2_1.pth"
+}
+```
+The taew2_1.pth weights can be downloaded from [here](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)
+
+## ⚠️ Note
+
+Some acceleration solutions currently cannot be used together, and we are working to resolve this issue.
+
+If you have any questions, feel free to report bugs or request features in [🐛 GitHub Issues](https://github.com/ModelTC/lightx2v/issues)
diff --git a/docs/EN/source/deploy_guides/for_low_resource.md b/docs/EN/source/deploy_guides/for_low_resource.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad11c7489ff6f5748dcfd2b482bfd93600d9437f
--- /dev/null
+++ b/docs/EN/source/deploy_guides/for_low_resource.md
@@ -0,0 +1,219 @@
+# Lightx2v Low-Resource Deployment Guide
+
+## 📋 Overview
+
+This guide is specifically designed for hardware resource-constrained environments, particularly configurations with **8GB VRAM + 16/32GB RAM**, providing detailed instructions on how to successfully run Lightx2v 14B models for 480p and 720p video generation.
+
+Lightx2v is a powerful video generation model, but it requires careful optimization to run smoothly in resource-constrained environments. This guide provides a complete solution from hardware selection to software configuration, ensuring you can achieve the best video generation experience under limited hardware conditions.
+
+## 🎯 Target Hardware Configuration
+
+### Recommended Hardware Specifications
+
+**GPU Requirements**:
+- **VRAM**: 8GB (RTX 3060/3070/4060/4060Ti, etc.)
+- **Architecture**: NVIDIA graphics cards with CUDA support
+
+**System Memory**:
+- **Minimum**: 16GB DDR4
+- **Recommended**: 32GB DDR4/DDR5
+- **Memory Speed**: 3200MHz or higher recommended
+
+**Storage Requirements**:
+- **Type**: NVMe SSD strongly recommended
+- **Capacity**: At least 50GB available space
+- **Speed**: Read speed of 3000MB/s or higher recommended
+
+**CPU Requirements**:
+- **Cores**: 8 cores or more recommended
+- **Frequency**: 3.0GHz or higher recommended
+- **Architecture**: Support for AVX2 instruction set
+
+## ⚙️ Core Optimization Strategies
+
+### 1. Environment Optimization
+
+Before running Lightx2v, it's recommended to set the following environment variables to optimize performance:
+
+```bash
+# CUDA memory allocation optimization
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+
+# Enable CUDA Graph mode to improve inference performance
+export ENABLE_GRAPH_MODE=true
+
+# Use BF16 precision for inference to reduce VRAM usage (default FP32 precision)
+export DTYPE=BF16
+```
+
+**Optimization Details**:
+- `expandable_segments:True`: Allows dynamic expansion of CUDA memory segments, reducing memory fragmentation
+- `ENABLE_GRAPH_MODE=true`: Enables CUDA Graph to reduce kernel launch overhead
+- `DTYPE=BF16`: Uses BF16 precision to reduce VRAM usage while maintaining quality
+
+### 2. Quantization Strategy
+
+Quantization is a key optimization technique in low-resource environments, reducing memory usage by lowering model precision.
+
+#### Quantization Scheme Comparison
+
+**FP8 Quantization** (Recommended for RTX 40 series):
+```python
+# Suitable for GPUs supporting FP8, providing better precision
+dit_quant_scheme = "fp8" # DIT model quantization
+t5_quant_scheme = "fp8" # T5 text encoder quantization
+clip_quant_scheme = "fp8" # CLIP visual encoder quantization
+```
+
+**INT8 Quantization** (Universal solution):
+```python
+# Suitable for all GPUs, minimal memory usage
+dit_quant_scheme = "int8" # 8-bit integer quantization
+t5_quant_scheme = "int8" # Text encoder quantization
+clip_quant_scheme = "int8" # Visual encoder quantization
+```
+
+### 3. Efficient Operator Selection Guide
+
+Choosing the right operators can significantly improve inference speed and reduce memory usage.
+
+#### Attention Operator Selection
+
+**Recommended Priority**:
+1. **[Sage Attention](https://github.com/thu-ml/SageAttention)** (Highest priority)
+
+2. **[Flash Attention](https://github.com/Dao-AILab/flash-attention)** (Universal solution)
+
+#### Matrix Multiplication Operator Selection
+
+**ADA Architecture GPUs** (RTX 40 series):
+
+Recommended priority:
+1. **[q8-kernel](https://github.com/KONAKONA666/q8_kernels)** (Highest performance, ADA architecture only)
+2. **[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)** (Balanced solution)
+3. **[vllm-kernel](https://github.com/vllm-project/vllm)** (Universal solution)
+
+**Other Architecture GPUs**:
+1. **[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)** (Recommended)
+2. **[vllm-kernel](https://github.com/vllm-project/vllm)** (Alternative)
+
+### 4. Parameter Offloading Strategy
+
+Parameter offloading technology allows models to dynamically schedule parameters between CPU and disk, breaking through VRAM limitations.
+
+#### Three-Level Offloading Architecture
+
+```python
+# Disk-CPU-GPU three-level offloading configuration
+cpu_offload=True # Enable CPU offloading
+t5_cpu_offload=True # Enable T5 encoder CPU offloading
+offload_granularity=phase # DIT model fine-grained offloading
+t5_offload_granularity=block # T5 encoder fine-grained offloading
+lazy_load = True # Enable lazy loading mechanism
+num_disk_workers = 2 # Disk I/O worker threads
+```
+
+#### Offloading Strategy Details
+
+**Lazy Loading Mechanism**:
+- Model parameters are loaded from disk to CPU on demand
+- Reduces runtime memory usage
+- Supports large models running with limited memory
+
+**Disk Storage Optimization**:
+- Use high-speed SSD to store model parameters
+- Store model files grouped by blocks
+- Refer to conversion script [documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme.md), specify `--save_by_block` parameter during conversion
+
+### 5. VRAM Optimization Techniques
+
+VRAM optimization strategies for 720p video generation.
+
+#### CUDA Memory Management
+
+```python
+# CUDA memory cleanup configuration
+clean_cuda_cache = True # Timely cleanup of GPU cache
+rotary_chunk = True # Rotary position encoding chunked computation
+rotary_chunk_size = 100 # Chunk size, adjustable based on VRAM
+```
+
+#### Chunked Computation Strategy
+
+**Rotary Position Encoding Chunking**:
+- Process long sequences in small chunks
+- Reduce peak VRAM usage
+- Maintain computational precision
+
+### 6. VAE Optimization
+
+VAE (Variational Autoencoder) is a key component in video generation, and optimizing VAE can significantly improve performance.
+
+#### VAE Chunked Inference
+
+```python
+# VAE optimization configuration
+use_tiling_vae = True # Enable VAE chunked inference
+```
+
+#### Lightweight VAE
+
+```python
+# VAE optimization configuration
+use_tae = True # Use lightweight VAE
+tae_path = "/path to taew2_1.pth"
+```
+You can download taew2_1.pth [here](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth)
+
+**VAE Optimization Effects**:
+- Standard VAE: Baseline performance, 100% quality retention
+- Standard VAE chunked: Reduces VRAM usage, increases inference time, 100% quality retention
+- Lightweight VAE: Extremely low VRAM usage, video quality loss
+
+### 7. Model Selection Strategy
+
+Choosing the right model version is crucial for low-resource environments.
+
+#### Recommended Model Comparison
+
+**Distilled Models** (Strongly recommended):
+- ✅ **[Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v)**
+
+- ✅ **[Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v)**
+
+#### Performance Optimization Suggestions
+
+When using the above distilled models, you can further optimize performance:
+- Disable CFG: `"enable_cfg": false`
+- Reduce inference steps: `infer_step: 4`
+- Reference configuration files: [config](https://github.com/ModelTC/LightX2V/tree/main/configs/distill)
+
+## 🚀 Complete Configuration Examples
+
+### Pre-configured Templates
+
+- **[14B Model 480p Video Generation Configuration](https://github.com/ModelTC/lightx2v/tree/main/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json)**
+
+- **[14B Model 720p Video Generation Configuration](https://github.com/ModelTC/lightx2v/tree/main/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json)**
+
+- **[1.3B Model 720p Video Generation Configuration](https://github.com/ModelTC/LightX2V/tree/main/configs/offload/block/wan_t2v_1_3b.json)**
+ - The inference bottleneck for 1.3B models is the T5 encoder, so the configuration file specifically optimizes for T5
+
+**[Launch Script](https://github.com/ModelTC/LightX2V/tree/main/scripts/wan/run_wan_i2v_lazy_load.sh)**
+
+## 📚 Reference Resources
+
+- [Parameter Offloading Mechanism Documentation](../method_tutorials/offload.md) - In-depth understanding of offloading technology principles
+- [Quantization Technology Guide](../method_tutorials/quantization.md) - Detailed explanation of quantization technology
+- [Gradio Deployment Guide](deploy_gradio.md) - Detailed Gradio deployment instructions
+
+## ⚠️ Important Notes
+
+1. **Hardware Requirements**: Ensure your hardware meets minimum configuration requirements
+2. **Driver Version**: Recommend using the latest NVIDIA drivers (535+)
+3. **CUDA Version**: Ensure CUDA version is compatible with PyTorch (recommend CUDA 11.8+)
+4. **Storage Space**: Reserve sufficient disk space for model caching (at least 50GB)
+5. **Network Environment**: Stable network connection required for initial model download
+6. **Environment Variables**: Be sure to set the recommended environment variables to optimize performance
+
+**Technical Support**: If you encounter issues, please submit an Issue to the project repository.
diff --git a/docs/EN/source/deploy_guides/lora_deploy.md b/docs/EN/source/deploy_guides/lora_deploy.md
new file mode 100644
index 0000000000000000000000000000000000000000..769d1d57b4f3c620f36e077fe35dbb4145196cc2
--- /dev/null
+++ b/docs/EN/source/deploy_guides/lora_deploy.md
@@ -0,0 +1,214 @@
+# LoRA Model Deployment and Related Tools
+
+LoRA (Low-Rank Adaptation) is an efficient model fine-tuning technique that significantly reduces the number of trainable parameters through low-rank matrix decomposition. LightX2V fully supports LoRA technology, including LoRA inference, LoRA extraction, and LoRA merging functions.
+
+## 🎯 LoRA Technical Features
+
+- **Efficient Fine-tuning**: Dramatically reduces training parameters through low-rank adaptation
+- **Flexible Deployment**: Supports dynamic loading and removal of LoRA weights
+- **Multiple Formats**: Supports various LoRA weight formats and naming conventions
+- **Comprehensive Tools**: Provides complete LoRA extraction and merging toolchain
+
+## 📜 LoRA Inference Deployment
+
+### Configuration File Method
+
+Specify LoRA path in configuration file:
+
+```json
+{
+ "lora_configs": [
+ {
+ "path": "/path/to/your/lora.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
+```
+
+**Configuration Parameter Description:**
+
+- `lora_path`: LoRA weight file path list, supports loading multiple LoRAs simultaneously
+- `strength_model`: LoRA strength coefficient (alpha), controls LoRA's influence on the original model
+
+### Command Line Method
+
+Specify LoRA path directly in command line (supports loading single LoRA only):
+
+```bash
+python -m lightx2v.infer \
+ --model_cls wan2.1 \
+ --task t2v \
+ --model_path /path/to/model \
+ --config_json /path/to/config.json \
+ --lora_path /path/to/your/lora.safetensors \
+ --lora_strength 0.8 \
+ --prompt "Your prompt here"
+```
+
+### Multiple LoRAs Configuration
+
+To use multiple LoRAs with different strengths, specify them in the config JSON file:
+
+```json
+{
+ "lora_configs": [
+ {
+ "path": "/path/to/first_lora.safetensors",
+ "strength": 0.8
+ },
+ {
+ "path": "/path/to/second_lora.safetensors",
+ "strength": 0.5
+ }
+ ]
+}
+```
+
+### Supported LoRA Formats
+
+LightX2V supports multiple LoRA weight naming conventions:
+
+| Format Type | Weight Naming | Description |
+|-------------|---------------|-------------|
+| **Standard LoRA** | `lora_A.weight`, `lora_B.weight` | Standard LoRA matrix decomposition format |
+| **Down/Up Format** | `lora_down.weight`, `lora_up.weight` | Another common naming convention |
+| **Diff Format** | `diff` | `weight` difference values |
+| **Bias Diff** | `diff_b` | `bias` weight difference values |
+| **Modulation Diff** | `diff_m` | `modulation` weight difference values |
+
+### Inference Script Examples
+
+**Step Distillation LoRA Inference:**
+
+```bash
+# T2V LoRA Inference
+bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh
+
+# I2V LoRA Inference
+bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh
+```
+
+**Audio-Driven LoRA Inference:**
+
+```bash
+bash scripts/wan/run_wan_i2v_audio.sh
+```
+
+### Using LoRA in API Service
+
+Specify through [config file](wan_t2v_distill_4step_cfg_lora.json), modify the startup command in [scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh):
+
+```bash
+python -m lightx2v.api_server \
+ --model_cls wan2.1_distill \
+ --task t2v \
+ --model_path $model_path \
+ --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora.json \
+ --port 8000 \
+ --nproc_per_node 1
+```
+
+## 🔧 LoRA Extraction Tool
+
+Use `tools/extract/lora_extractor.py` to extract LoRA weights from the difference between two models.
+
+### Basic Usage
+
+```bash
+python tools/extract/lora_extractor.py \
+ --source-model /path/to/base/model \
+ --target-model /path/to/finetuned/model \
+ --output /path/to/extracted/lora.safetensors \
+ --rank 32
+```
+
+### Parameter Description
+
+| Parameter | Type | Required | Default | Description |
+|-----------|------|----------|---------|-------------|
+| `--source-model` | str | ✅ | - | Base model path |
+| `--target-model` | str | ✅ | - | Fine-tuned model path |
+| `--output` | str | ✅ | - | Output LoRA file path |
+| `--source-type` | str | ❌ | `safetensors` | Base model format (`safetensors`/`pytorch`) |
+| `--target-type` | str | ❌ | `safetensors` | Fine-tuned model format (`safetensors`/`pytorch`) |
+| `--output-format` | str | ❌ | `safetensors` | Output format (`safetensors`/`pytorch`) |
+| `--rank` | int | ❌ | `32` | LoRA rank value |
+| `--output-dtype` | str | ❌ | `bf16` | Output data type |
+| `--diff-only` | bool | ❌ | `False` | Save weight differences only, without LoRA decomposition |
+
+### Advanced Usage Examples
+
+**Extract High-Rank LoRA:**
+
+```bash
+python tools/extract/lora_extractor.py \
+ --source-model /path/to/base/model \
+ --target-model /path/to/finetuned/model \
+ --output /path/to/high_rank_lora.safetensors \
+ --rank 64 \
+ --output-dtype fp16
+```
+
+**Save Weight Differences Only:**
+
+```bash
+python tools/extract/lora_extractor.py \
+ --source-model /path/to/base/model \
+ --target-model /path/to/finetuned/model \
+ --output /path/to/weight_diff.safetensors \
+ --diff-only
+```
+
+## 🔀 LoRA Merging Tool
+
+Use `tools/extract/lora_merger.py` to merge LoRA weights into the base model for subsequent quantization and other operations.
+
+### Basic Usage
+
+```bash
+python tools/extract/lora_merger.py \
+ --source-model /path/to/base/model \
+ --lora-model /path/to/lora.safetensors \
+ --output /path/to/merged/model.safetensors \
+ --alpha 1.0
+```
+
+### Parameter Description
+
+| Parameter | Type | Required | Default | Description |
+|-----------|------|----------|---------|-------------|
+| `--source-model` | str | ✅ | - | Base model path |
+| `--lora-model` | str | ✅ | - | LoRA weights path |
+| `--output` | str | ✅ | - | Output merged model path |
+| `--source-type` | str | ❌ | `safetensors` | Base model format |
+| `--lora-type` | str | ❌ | `safetensors` | LoRA weights format |
+| `--output-format` | str | ❌ | `safetensors` | Output format |
+| `--alpha` | float | ❌ | `1.0` | LoRA merge strength |
+| `--output-dtype` | str | ❌ | `bf16` | Output data type |
+
+### Advanced Usage Examples
+
+**Partial Strength Merging:**
+
+```bash
+python tools/extract/lora_merger.py \
+ --source-model /path/to/base/model \
+ --lora-model /path/to/lora.safetensors \
+ --output /path/to/merged_model.safetensors \
+ --alpha 0.7 \
+ --output-dtype fp32
+```
+
+**Multi-Format Support:**
+
+```bash
+python tools/extract/lora_merger.py \
+ --source-model /path/to/base/model.pt \
+ --source-type pytorch \
+ --lora-model /path/to/lora.safetensors \
+ --lora-type safetensors \
+ --output /path/to/merged_model.safetensors \
+ --output-format safetensors \
+ --alpha 1.0
+```
diff --git a/docs/EN/source/getting_started/benchmark.md b/docs/EN/source/getting_started/benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..ecc8dc95e801f4e6a6f77e5888e014c0e0e4b846
--- /dev/null
+++ b/docs/EN/source/getting_started/benchmark.md
@@ -0,0 +1,3 @@
+# Benchmark
+
+For a better display of video playback effects and detailed performance comparisons, you can get better presentation and corresponding documentation content on this [🔗 page](https://github.com/ModelTC/LightX2V/blob/main/docs/EN/source/getting_started/benchmark_source.md).
diff --git a/docs/EN/source/getting_started/benchmark_source.md b/docs/EN/source/getting_started/benchmark_source.md
new file mode 100644
index 0000000000000000000000000000000000000000..35bafdfef46fc56020302b701724199ea5c16117
--- /dev/null
+++ b/docs/EN/source/getting_started/benchmark_source.md
@@ -0,0 +1,149 @@
+# 🚀 Benchmark
+
+> This document showcases the performance test results of LightX2V across different hardware environments, including detailed comparison data for H200 and RTX 4090 platforms.
+
+---
+
+## 🖥️ H200 Environment (~140GB VRAM)
+
+### 📋 Software Environment Configuration
+
+| Component | Version |
+|:----------|:--------|
+| **Python** | 3.11 |
+| **PyTorch** | 2.7.1+cu128 |
+| **SageAttention** | 2.2.0 |
+| **vLLM** | 0.9.2 |
+| **sgl-kernel** | 0.1.8 |
+
+---
+
+### 🎬 480P 5s Video Test
+
+**Test Configuration:**
+- **Model**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v)
+- **Parameters**: `infer_steps=40`, `seed=42`, `enable_cfg=True`
+
+#### 📊 Performance Comparison Table
+
+| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect |
+|:-------------|:-----------------:|:--------------:|:-------:|:------------:|
+| **Wan2.1 Official** | 366 | 71 | 1.0x | |
+| **FastVideo** | 292 | 26 | **1.25x** | |
+| **LightX2V_1** | 250 | 53 | **1.46x** | |
+| **LightX2V_2** | 216 | 50 | **1.70x** | |
+| **LightX2V_3** | 191 | 35 | **1.92x** | |
+| **LightX2V_3-Distill** | 14 | 35 | **🏆 20.85x** | |
+| **LightX2V_4** | 107 | 35 | **3.41x** | |
+
+---
+
+### 🎬 720P 5s Video Test
+
+**Test Configuration:**
+- **Model**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v)
+- **Parameters**: `infer_steps=40`, `seed=1234`, `enable_cfg=True`
+
+#### 📊 Performance Comparison Table
+
+| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect |
+|:-------------|:-----------------:|:--------------:|:-------:|:------------:|
+| **Wan2.1 Official** | 974 | 81 | 1.0x | |
+| **FastVideo** | 914 | 40 | **1.07x** | |
+| **LightX2V_1** | 807 | 65 | **1.21x** | |
+| **LightX2V_2** | 751 | 57 | **1.30x** | |
+| **LightX2V_3** | 671 | 43 | **1.45x** | |
+| **LightX2V_3-Distill** | 44 | 43 | **🏆 22.14x** | |
+| **LightX2V_4** | 344 | 46 | **2.83x** | |
+
+---
+
+## 🖥️ RTX 4090 Environment (~24GB VRAM)
+
+### 📋 Software Environment Configuration
+
+| Component | Version |
+|:----------|:--------|
+| **Python** | 3.9.16 |
+| **PyTorch** | 2.5.1+cu124 |
+| **SageAttention** | 2.1.0 |
+| **vLLM** | 0.6.6 |
+| **sgl-kernel** | 0.0.5 |
+| **q8-kernels** | 0.0.0 |
+
+---
+
+### 🎬 480P 5s Video Test
+
+**Test Configuration:**
+- **Model**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v)
+- **Parameters**: `infer_steps=40`, `seed=42`, `enable_cfg=True`
+
+#### 📊 Performance Comparison Table
+
+| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect |
+|:-------------|:-----------------:|:--------------:|:-------:|:------------:|
+| **Wan2GP(profile=3)** | 779 | 20 | **1.0x** | |
+| **LightX2V_5** | 738 | 16 | **1.05x** | |
+| **LightX2V_5-Distill** | 68 | 16 | **11.45x** | |
+| **LightX2V_6** | 630 | 12 | **1.24x** | |
+| **LightX2V_6-Distill** | 63 | 12 | **🏆 12.36x** |
+
+---
+
+### 🎬 720P 5s Video Test
+
+**Test Configuration:**
+- **Model**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v)
+- **Parameters**: `infer_steps=40`, `seed=1234`, `enable_cfg=True`
+
+#### 📊 Performance Comparison Table
+
+| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect |
+|:-------------|:-----------------:|:--------------:|:-------:|:------------:|
+| **Wan2GP(profile=3)** | -- | OOM | -- | |
+| **LightX2V_5** | 2473 | 23 | -- | |
+| **LightX2V_5-Distill** | 183 | 23 | -- | |
+| **LightX2V_6** | 2169 | 18 | -- | |
+| **LightX2V_6-Distill** | 171 | 18 | -- | |
+
+---
+
+## 📖 Configuration Descriptions
+
+### 🖥️ H200 Environment Configuration Descriptions
+
+| Configuration | Technical Features |
+|:--------------|:------------------|
+| **Wan2.1 Official** | Based on [Wan2.1 official repository](https://github.com/Wan-Video/Wan2.1) original implementation |
+| **FastVideo** | Based on [FastVideo official repository](https://github.com/hao-ai-lab/FastVideo), using SageAttention2 backend optimization |
+| **LightX2V_1** | Uses SageAttention2 to replace native attention mechanism, adopts DIT BF16+FP32 (partial sensitive layers) mixed precision computation, improving computational efficiency while maintaining precision |
+| **LightX2V_2** | Unified BF16 precision computation, further reducing memory usage and computational overhead while maintaining generation quality |
+| **LightX2V_3** | Introduces FP8 quantization technology to significantly reduce computational precision requirements, combined with Tiling VAE technology to optimize memory usage |
+| **LightX2V_3-Distill** | Based on LightX2V_3 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality |
+| **LightX2V_4** | Based on LightX2V_3 with TeaCache(teacache_thresh=0.2) caching reuse technology, achieving acceleration through intelligent redundant computation skipping |
+
+### 🖥️ RTX 4090 Environment Configuration Descriptions
+
+| Configuration | Technical Features |
+|:--------------|:------------------|
+| **Wan2GP(profile=3)** | Implementation based on [Wan2GP repository](https://github.com/deepbeepmeep/Wan2GP), using MMGP optimization technology. Profile=3 configuration is suitable for RTX 3090/4090 environments with at least 32GB RAM and 24GB VRAM, adapting to limited memory resources by sacrificing VRAM. Uses quantized models: [480P model](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors) and [720P model](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors) |
+| **LightX2V_5** | Uses SageAttention2 to replace native attention mechanism, adopts DIT FP8+FP32 (partial sensitive layers) mixed precision computation, enables CPU offload technology, executes partial sensitive layers with FP32 precision, asynchronously offloads DIT inference process data to CPU, saves VRAM, with block-level offload granularity |
+| **LightX2V_5-Distill** | Based on LightX2V_5 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality |
+| **LightX2V_6** | Based on LightX2V_3 with CPU offload technology enabled, executes partial sensitive layers with FP32 precision, asynchronously offloads DIT inference process data to CPU, saves VRAM, with block-level offload granularity |
+| **LightX2V_6-Distill** | Based on LightX2V_6 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality |
+
+---
+
+## 📁 Configuration Files Reference
+
+Benchmark-related configuration files and execution scripts are available at:
+
+| Type | Link | Description |
+|:-----|:-----|:------------|
+| **Configuration Files** | [configs/bench](https://github.com/ModelTC/LightX2V/tree/main/configs/bench) | Contains JSON files with various optimization configurations |
+| **Execution Scripts** | [scripts/bench](https://github.com/ModelTC/LightX2V/tree/main/scripts/bench) | Contains benchmark execution scripts |
+
+---
+
+> 💡 **Tip**: It is recommended to choose the appropriate optimization solution based on your hardware configuration to achieve the best performance.
diff --git a/docs/EN/source/getting_started/model_structure.md b/docs/EN/source/getting_started/model_structure.md
new file mode 100644
index 0000000000000000000000000000000000000000..6edfc288bfc916392d91ee2205d3b4a3a71736b9
--- /dev/null
+++ b/docs/EN/source/getting_started/model_structure.md
@@ -0,0 +1,573 @@
+# Model Format and Loading Guide
+
+## 📖 Overview
+
+LightX2V is a flexible video generation inference framework that supports multiple model sources and formats, providing users with rich options:
+
+- ✅ **Wan Official Models**: Directly compatible with officially released complete models from Wan2.1 and Wan2.2
+- ✅ **Single-File Models**: Supports single-file format models released by LightX2V (including quantized versions)
+- ✅ **LoRA Models**: Supports loading distilled LoRAs released by LightX2V
+
+This document provides detailed instructions on how to use various model formats, configuration parameters, and best practices.
+
+---
+
+## 🗂️ Format 1: Wan Official Models
+
+### Model Repositories
+- [Wan2.1 Collection](https://huggingface.co/collections/Wan-AI/wan21-68ac4ba85372ae5a8e282a1b)
+- [Wan2.2 Collection](https://huggingface.co/collections/Wan-AI/wan22-68ac4ae80a8b477e79636fc8)
+
+### Model Features
+- **Official Guarantee**: Complete models officially released by Wan-AI with highest quality
+- **Complete Components**: Includes all necessary components (DIT, T5, CLIP, VAE)
+- **Original Precision**: Uses BF16/FP32 precision with no quantization loss
+- **Strong Compatibility**: Fully compatible with Wan official toolchain
+
+### Wan2.1 Official Models
+
+#### Directory Structure
+
+Using [Wan2.1-I2V-14B-720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) as an example:
+
+```
+Wan2.1-I2V-14B-720P/
+├── diffusion_pytorch_model-00001-of-00007.safetensors # DIT model shard 1
+├── diffusion_pytorch_model-00002-of-00007.safetensors # DIT model shard 2
+├── diffusion_pytorch_model-00003-of-00007.safetensors # DIT model shard 3
+├── diffusion_pytorch_model-00004-of-00007.safetensors # DIT model shard 4
+├── diffusion_pytorch_model-00005-of-00007.safetensors # DIT model shard 5
+├── diffusion_pytorch_model-00006-of-00007.safetensors # DIT model shard 6
+├── diffusion_pytorch_model-00007-of-00007.safetensors # DIT model shard 7
+├── diffusion_pytorch_model.safetensors.index.json # Shard index file
+├── models_t5_umt5-xxl-enc-bf16.pth # T5 text encoder
+├── models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth # CLIP encoder
+├── Wan2.1_VAE.pth # VAE encoder/decoder
+├── config.json # Model configuration
+├── xlm-roberta-large/ # CLIP tokenizer
+├── google/ # T5 tokenizer
+├── assets/
+└── examples/
+```
+
+#### Usage
+
+```bash
+# Download model
+huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P \
+ --local-dir ./models/Wan2.1-I2V-14B-720P
+
+# Configure launch script
+model_path=./models/Wan2.1-I2V-14B-720P
+lightx2v_path=/path/to/LightX2V
+
+# Run inference
+cd LightX2V/scripts
+bash wan/run_wan_i2v.sh
+```
+
+### Wan2.2 Official Models
+
+#### Directory Structure
+
+Using [Wan2.2-I2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) as an example:
+
+```
+Wan2.2-I2V-A14B/
+├── high_noise_model/ # High-noise model directory
+│ ├── diffusion_pytorch_model-00001-of-00009.safetensors
+│ ├── diffusion_pytorch_model-00002-of-00009.safetensors
+│ ├── ...
+│ ├── diffusion_pytorch_model-00009-of-00009.safetensors
+│ └── diffusion_pytorch_model.safetensors.index.json
+├── low_noise_model/ # Low-noise model directory
+│ ├── diffusion_pytorch_model-00001-of-00009.safetensors
+│ ├── diffusion_pytorch_model-00002-of-00009.safetensors
+│ ├── ...
+│ ├── diffusion_pytorch_model-00009-of-00009.safetensors
+│ └── diffusion_pytorch_model.safetensors.index.json
+├── models_t5_umt5-xxl-enc-bf16.pth # T5 text encoder
+├── Wan2.1_VAE.pth # VAE encoder/decoder
+├── configuration.json # Model configuration
+├── google/ # T5 tokenizer
+├── assets/ # Example assets (optional)
+└── examples/ # Example files (optional)
+```
+
+#### Usage
+
+```bash
+# Download model
+huggingface-cli download Wan-AI/Wan2.2-I2V-A14B \
+ --local-dir ./models/Wan2.2-I2V-A14B
+
+# Configure launch script
+model_path=./models/Wan2.2-I2V-A14B
+lightx2v_path=/path/to/LightX2V
+
+# Run inference
+cd LightX2V/scripts
+bash wan22/run_wan22_moe_i2v.sh
+```
+
+### Available Model List
+
+#### Wan2.1 Official Model List
+
+| Model Name | Download Link |
+|---------|----------|
+| Wan2.1-I2V-14B-720P | [Link](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) |
+| Wan2.1-I2V-14B-480P | [Link](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) |
+| Wan2.1-T2V-14B | [Link](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) |
+| Wan2.1-T2V-1.3B | [Link](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) |
+| Wan2.1-FLF2V-14B-720P | [Link](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) |
+| Wan2.1-VACE-14B | [Link](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) |
+| Wan2.1-VACE-1.3B | [Link](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) |
+
+#### Wan2.2 Official Model List
+
+| Model Name | Download Link |
+|---------|----------|
+| Wan2.2-I2V-A14B | [Link](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) |
+| Wan2.2-T2V-A14B | [Link](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B) |
+| Wan2.2-TI2V-5B | [Link](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B) |
+| Wan2.2-Animate-14B | [Link](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B) |
+
+### Usage Tips
+
+> 💡 **Quantized Model Usage**: To use quantized models, refer to the [Model Conversion Script](https://github.com/ModelTC/LightX2V/blob/main/tools/convert/readme_zh.md) for conversion, or directly use pre-converted quantized models in Format 2 below
+>
+> 💡 **Memory Optimization**: For devices with RTX 4090 24GB or smaller memory, it's recommended to combine quantization techniques with CPU offload features:
+> - Quantization Configuration: Refer to [Quantization Documentation](../method_tutorials/quantization.md)
+> - CPU Offload: Refer to [Parameter Offload Documentation](../method_tutorials/offload.md)
+> - Wan2.1 Configuration: Refer to [offload config files](https://github.com/ModelTC/LightX2V/tree/main/configs/offload)
+> - Wan2.2 Configuration: Refer to [wan22 config files](https://github.com/ModelTC/LightX2V/tree/main/configs/wan22) with `4090` suffix
+
+---
+
+## 🗂️ Format 2: LightX2V Single-File Models (Recommended)
+
+### Model Repositories
+- [Wan2.1-LightX2V](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- [Wan2.2-LightX2V](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+
+### Model Features
+- **Single-File Management**: Single safetensors file, easy to manage and deploy
+- **Multi-Precision Support**: Provides original precision, FP8, INT8, and other precision versions
+- **Distillation Acceleration**: Supports 4-step fast inference
+- **Tool Compatibility**: Compatible with ComfyUI and other tools
+
+**Examples**:
+- `wan2.1_i2v_720p_lightx2v_4step.safetensors` - 720P I2V original precision
+- `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors` - 720P I2V FP8 quantization
+- `wan2.1_i2v_480p_int8_lightx2v_4step.safetensors` - 480P I2V INT8 quantization
+- ...
+
+### Wan2.1 Single-File Models
+
+#### Scenario A: Download Single Model File
+
+**Step 1: Select and Download Model**
+
+```bash
+# Create model directory
+mkdir -p ./models/wan2.1_i2v_720p
+
+# Download 720P I2V FP8 quantized model
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models/wan2.1_i2v_720p \
+ --include "wan2.1_i2v_720p_lightx2v_4step.safetensors"
+```
+
+**Step 2: Manually Organize Other Components**
+
+Directory structure as follows:
+```
+wan2.1_i2v_720p/
+├── wan2.1_i2v_720p_lightx2v_4step.safetensors # Original precision
+└── t5/clip/vae/config.json/xlm-roberta-large/google and other components # Need manual organization
+```
+
+**Step 3: Configure Launch Script**
+
+```bash
+# Set in launch script (point to directory containing model file)
+model_path=./models/wan2.1_i2v_720p
+lightx2v_path=/path/to/LightX2V
+
+# Run script
+cd LightX2V/scripts
+bash wan/run_wan_i2v_distill_4step_cfg.sh
+```
+
+> 💡 **Tip**: When there's only one model file in the directory, LightX2V will automatically load it.
+
+#### Scenario B: Download Multiple Model Files
+
+When you download multiple models with different precisions to the same directory, you need to explicitly specify which model to use in the configuration file.
+
+**Step 1: Download Multiple Models**
+
+```bash
+# Create model directory
+mkdir -p ./models/wan2.1_i2v_720p_multi
+
+# Download original precision model
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models/wan2.1_i2v_720p_multi \
+ --include "wan2.1_i2v_720p_lightx2v_4step.safetensors"
+
+# Download FP8 quantized model
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models/wan2.1_i2v_720p_multi \
+ --include "wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors"
+
+# Download INT8 quantized model
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models/wan2.1_i2v_720p_multi \
+ --include "wan2.1_i2v_720p_int8_lightx2v_4step.safetensors"
+```
+
+**Step 2: Manually Organize Other Components**
+
+Directory structure as follows:
+
+```
+wan2.1_i2v_720p_multi/
+├── wan2.1_i2v_720p_lightx2v_4step.safetensors # Original precision
+├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization
+└── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 quantization
+└── t5/clip/vae/config.json/xlm-roberta-large/google and other components # Need manual organization
+```
+
+**Step 3: Specify Model in Configuration File**
+
+Edit configuration file (e.g., `configs/distill/wan_i2v_distill_4step_cfg.json`):
+
+```json
+{
+ // Use original precision model
+ "dit_original_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_lightx2v_4step.safetensors",
+
+ // Or use FP8 quantized model
+ // "dit_quantized_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors",
+ // "dit_quantized": true,
+ // "dit_quant_scheme": "fp8-vllm",
+
+ // Or use INT8 quantized model
+ // "dit_quantized_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_int8_lightx2v_4step.safetensors",
+ // "dit_quantized": true,
+ // "dit_quant_scheme": "int8-vllm",
+
+ // Other configurations...
+}
+```
+### Usage Tips
+
+> 💡 **Configuration Parameter Description**:
+> - **dit_original_ckpt**: Used to specify the path to original precision models (BF16/FP32/FP16)
+> - **dit_quantized_ckpt**: Used to specify the path to quantized models (FP8/INT8), must be used with `dit_quantized` and `dit_quant_scheme` parameters
+
+**Step 4: Start Inference**
+
+```bash
+cd LightX2V/scripts
+bash wan/run_wan_i2v_distill_4step_cfg.sh
+```
+
+> 💡 **Tip**: Other components (T5, CLIP, VAE, tokenizer, etc.) need to be manually organized into the model directory
+
+### Wan2.2 Single-File Models
+
+#### Directory Structure Requirements
+
+When using Wan2.2 single-file models, you need to manually create a specific directory structure:
+
+```
+wan2.2_models/
+├── high_noise_model/ # High-noise model directory (required)
+│ └── wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors
+├── low_noise_model/ # Low-noise model directory (required)
+│ └── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors
+└── t5/clip/vae/config.json/... # Other components (manually organized)
+```
+
+#### Scenario A: Only One Model File Per Directory
+
+```bash
+# Create required subdirectories
+mkdir -p ./models/wan2.2_models/high_noise_model
+mkdir -p ./models/wan2.2_models/low_noise_model
+
+# Download high-noise model to corresponding directory
+huggingface-cli download lightx2v/Wan2.2-Distill-Models \
+ --local-dir ./models/wan2.2_models/high_noise_model \
+ --include "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
+
+# Download low-noise model to corresponding directory
+huggingface-cli download lightx2v/Wan2.2-Distill-Models \
+ --local-dir ./models/wan2.2_models/low_noise_model \
+ --include "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
+
+# Configure launch script (point to parent directory)
+model_path=./models/wan2.2_models
+lightx2v_path=/path/to/LightX2V
+
+# Run script
+cd LightX2V/scripts
+bash wan22/run_wan22_moe_i2v_distill.sh
+```
+
+> 💡 **Tip**: When there's only one model file in each subdirectory, LightX2V will automatically load it.
+
+#### Scenario B: Multiple Model Files Per Directory
+
+When you place multiple models with different precisions in both `high_noise_model/` and `low_noise_model/` directories, you need to explicitly specify them in the configuration file.
+
+```bash
+# Create directories
+mkdir -p ./models/wan2.2_models_multi/high_noise_model
+mkdir -p ./models/wan2.2_models_multi/low_noise_model
+
+# Download multiple versions of high-noise model
+huggingface-cli download lightx2v/Wan2.2-Distill-Models \
+ --local-dir ./models/wan2.2_models_multi/high_noise_model \
+ --include "wan2.2_i2v_A14b_high_noise_*.safetensors"
+
+# Download multiple versions of low-noise model
+huggingface-cli download lightx2v/Wan2.2-Distill-Models \
+ --local-dir ./models/wan2.2_models_multi/low_noise_model \
+ --include "wan2.2_i2v_A14b_low_noise_*.safetensors"
+```
+
+**Directory Structure**:
+
+```
+wan2.2_models_multi/
+├── high_noise_model/
+│ ├── wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors # Original precision
+│ ├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization
+│ └── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors # INT8 quantization
+└── low_noise_model/
+│ ├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # Original precision
+│ ├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization
+│ └── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # INT8 quantization
+└── t5/vae/config.json/xlm-roberta-large/google and other components # Need manual organization
+```
+
+**Configuration File Settings**:
+
+```json
+{
+ // Use original precision model
+ "high_noise_original_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors",
+ "low_noise_original_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors",
+
+ // Or use FP8 quantized model
+ // "high_noise_quantized_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step.safetensors",
+ // "low_noise_quantized_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors",
+ // "dit_quantized": true,
+ // "dit_quant_scheme": "fp8-vllm"
+
+ // Or use INT8 quantized model
+ // "high_noise_quantized_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors",
+ // "low_noise_quantized_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors",
+ // "dit_quantized": true,
+ // "dit_quant_scheme": "int8-vllm"
+}
+```
+
+### Usage Tips
+
+> 💡 **Configuration Parameter Description**:
+> - **high_noise_original_ckpt** / **low_noise_original_ckpt**: Used to specify the path to original precision models (BF16/FP32/FP16)
+> - **high_noise_quantized_ckpt** / **low_noise_quantized_ckpt**: Used to specify the path to quantized models (FP8/INT8), must be used with `dit_quantized` and `dit_quant_scheme` parameters
+
+
+### Available Model List
+
+#### Wan2.1 Single-File Model List
+
+**Image-to-Video Models (I2V)**
+
+| Filename | Precision | Description |
+|--------|------|------|
+| `wan2.1_i2v_480p_lightx2v_4step.safetensors` | BF16 | 4-step model original precision |
+| `wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4-step model FP8 quantization |
+| `wan2.1_i2v_480p_int8_lightx2v_4step.safetensors` | INT8 | 4-step model INT8 quantization |
+| `wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4-step model ComfyUI format |
+| `wan2.1_i2v_720p_lightx2v_4step.safetensors` | BF16 | 4-step model original precision |
+| `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4-step model FP8 quantization |
+| `wan2.1_i2v_720p_int8_lightx2v_4step.safetensors` | INT8 | 4-step model INT8 quantization |
+| `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4-step model ComfyUI format |
+
+**Text-to-Video Models (T2V)**
+
+| Filename | Precision | Description |
+|--------|------|------|
+| `wan2.1_t2v_14b_lightx2v_4step.safetensors` | BF16 | 4-step model original precision |
+| `wan2.1_t2v_14b_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4-step model FP8 quantization |
+| `wan2.1_t2v_14b_int8_lightx2v_4step.safetensors` | INT8 | 4-step model INT8 quantization |
+| `wan2.1_t2v_14b_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4-step model ComfyUI format |
+
+#### Wan2.2 Single-File Model List
+
+**Image-to-Video Models (I2V) - A14B Series**
+
+| Filename | Precision | Description |
+|--------|------|------|
+| `wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors` | BF16 | High-noise model - 4-step original precision |
+| `wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | High-noise model - 4-step FP8 quantization |
+| `wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors` | INT8 | High-noise model - 4-step INT8 quantization |
+| `wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors` | BF16 | Low-noise model - 4-step original precision |
+| `wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | Low-noise model - 4-step FP8 quantization |
+| `wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors` | INT8 | Low-noise model - 4-step INT8 quantization |
+
+> 💡 **Usage Tips**:
+> - Wan2.2 models use a dual-noise architecture, requiring both high-noise and low-noise models to be downloaded
+> - Refer to the "Wan2.2 Single-File Models" section above for detailed directory organization
+
+---
+
+## 🗂️ Format 3: LightX2V LoRA Models
+
+LoRA (Low-Rank Adaptation) models provide a lightweight model fine-tuning solution that enables customization for specific effects without modifying the base model.
+
+### Model Repositories
+
+- **Wan2.1 LoRA Models**: [lightx2v/Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras)
+- **Wan2.2 LoRA Models**: [lightx2v/Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras)
+
+### Usage Methods
+
+#### Method 1: Offline Merging
+
+Merge LoRA weights offline into the base model to generate a new complete model file.
+
+**Steps**:
+
+Refer to the [Model Conversion Documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) for offline merging.
+
+**Advantages**:
+- ✅ No need to load LoRA during inference
+- ✅ Better performance
+
+**Disadvantages**:
+- ❌ Requires additional storage space
+- ❌ Switching different LoRAs requires re-merging
+
+#### Method 2: Online Loading
+
+Dynamically load LoRA weights during inference without modifying the base model.
+
+**LoRA Application Principle**:
+
+```python
+# LoRA weight application formula
+# lora_scale = (alpha / rank)
+# W' = W + lora_scale * B @ A
+# Where: B = up_proj (out_features, rank)
+# A = down_proj (rank, in_features)
+
+if weights_dict["alpha"] is not None:
+ lora_scale = weights_dict["alpha"] / lora_down.shape[0]
+elif alpha is not None:
+ lora_scale = alpha / lora_down.shape[0]
+else:
+ lora_scale = 1.0
+```
+
+**Configuration Method**:
+
+**Wan2.1 LoRA Configuration**:
+
+```json
+{
+ "lora_configs": [
+ {
+ "path": "wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0,
+ "alpha": null
+ }
+ ]
+}
+```
+
+**Wan2.2 LoRA Configuration**:
+
+Since Wan2.2 uses a dual-model architecture (high-noise/low-noise), LoRA needs to be configured separately for both models:
+
+```json
+{
+ "lora_configs": [
+ {
+ "name": "low_noise_model",
+ "path": "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0,
+ "alpha": null
+ },
+ {
+ "name": "high_noise_model",
+ "path": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0,
+ "alpha": null
+ }
+ ]
+}
+```
+
+**Parameter Description**:
+
+| Parameter | Description | Default |
+|------|------|--------|
+| `path` | LoRA model file path | Required |
+| `strength` | LoRA strength coefficient, range [0.0, 1.0] | 1.0 |
+| `alpha` | LoRA scaling factor, uses model's built-in value when `null` | null |
+| `name` | (Wan2.2 only) Specifies which model to apply to | Required |
+
+**Advantages**:
+- ✅ Flexible switching between different LoRAs
+- ✅ Saves storage space
+- ✅ Can dynamically adjust LoRA strength
+
+**Disadvantages**:
+- ❌ Additional loading time during inference
+- ❌ Slightly increases memory usage
+
+---
+
+## 📚 Related Resources
+
+### Official Repositories
+- [LightX2V GitHub](https://github.com/ModelTC/LightX2V)
+- [LightX2V Single-File Model Repository](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- [Wan-AI Official Model Repository](https://huggingface.co/Wan-AI)
+
+### Model Download Links
+
+**Wan2.1 Series**
+- [Wan2.1 Collection](https://huggingface.co/collections/Wan-AI/wan21-68ac4ba85372ae5a8e282a1b)
+
+**Wan2.2 Series**
+- [Wan2.2 Collection](https://huggingface.co/collections/Wan-AI/wan22-68ac4ae80a8b477e79636fc8)
+
+**LightX2V Single-File Models**
+- [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+
+### Documentation Links
+- [Quantization Documentation](../method_tutorials/quantization.md)
+- [Parameter Offload Documentation](../method_tutorials/offload.md)
+- [Configuration File Examples](https://github.com/ModelTC/LightX2V/tree/main/configs)
+
+---
+
+Through this document, you should be able to:
+
+✅ Understand all model formats supported by LightX2V
+✅ Select appropriate models and precisions based on your needs
+✅ Correctly download and organize model files
+✅ Configure launch parameters and successfully run inference
+✅ Resolve common model loading issues
+
+If you have other questions, feel free to ask in [GitHub Issues](https://github.com/ModelTC/LightX2V/issues).
diff --git a/docs/EN/source/getting_started/quickstart.md b/docs/EN/source/getting_started/quickstart.md
new file mode 100644
index 0000000000000000000000000000000000000000..799959e4ae2cd6dcbf0f8ad262c276a87370806b
--- /dev/null
+++ b/docs/EN/source/getting_started/quickstart.md
@@ -0,0 +1,349 @@
+# LightX2V Quick Start Guide
+
+Welcome to LightX2V! This guide will help you quickly set up the environment and start using LightX2V for video generation.
+
+## 📋 Table of Contents
+
+- [System Requirements](#system-requirements)
+- [Linux Environment Setup](#linux-environment-setup)
+ - [Docker Environment (Recommended)](#docker-environment-recommended)
+ - [Conda Environment Setup](#conda-environment-setup)
+- [Windows Environment Setup](#windows-environment-setup)
+- [Inference Usage](#inference-usage)
+
+## 🚀 System Requirements
+
+- **Operating System**: Linux (Ubuntu 18.04+) or Windows 10/11
+- **Python**: 3.10 or higher
+- **GPU**: NVIDIA GPU with CUDA support, at least 8GB VRAM
+- **Memory**: 16GB or more recommended
+- **Storage**: At least 50GB available space
+
+## 🐧 Linux Environment Setup
+
+### 🐳 Docker Environment (Recommended)
+
+We strongly recommend using the Docker environment, which is the simplest and fastest installation method.
+
+#### 1. Pull Image
+
+Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags), select a tag with the latest date, such as `25111101-cu128`:
+
+```bash
+docker pull lightx2v/lightx2v:25111101-cu128
+```
+
+We recommend using the `cuda128` environment for faster inference speed. If you need to use the `cuda124` environment, you can use image versions with the `-cu124` suffix:
+
+```bash
+docker pull lightx2v/lightx2v:25101501-cu124
+```
+
+#### 2. Run Container
+
+```bash
+docker run --gpus all -itd --ipc=host --name [container_name] -v [mount_settings] --entrypoint /bin/bash [image_id]
+```
+
+#### 3. China Mirror Source (Optional)
+
+For mainland China, if the network is unstable when pulling images, you can pull from Alibaba Cloud:
+
+```bash
+# cuda128
+docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25111101-cu128
+
+# cuda124
+docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25101501-cu124
+```
+
+### 🐍 Conda Environment Setup
+
+If you prefer to set up the environment yourself using Conda, please follow these steps:
+
+#### Step 1: Clone Repository
+
+```bash
+# Download project code
+git clone https://github.com/ModelTC/LightX2V.git
+cd LightX2V
+```
+
+#### Step 2: Create Conda Virtual Environment
+
+```bash
+# Create and activate conda environment
+conda create -n lightx2v python=3.11 -y
+conda activate lightx2v
+```
+
+#### Step 3: Install Dependencies
+
+```bash
+pip install -v -e .
+```
+
+#### Step 4: Install Attention Operators
+
+**Option A: Flash Attention 2**
+```bash
+git clone https://github.com/Dao-AILab/flash-attention.git --recursive
+cd flash-attention && python setup.py install
+```
+
+**Option B: Flash Attention 3 (for Hopper architecture GPUs)**
+```bash
+cd flash-attention/hopper && python setup.py install
+```
+
+**Option C: SageAttention 2 (Recommended)**
+```bash
+git clone https://github.com/thu-ml/SageAttention.git
+cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0,12.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install -v -e .
+```
+
+#### Step 4: Install Quantization Operators (Optional)
+
+Quantization operators are used to support model quantization, which can significantly reduce memory usage and accelerate inference. Choose the appropriate quantization operator based on your needs:
+
+**Option A: VLLM Kernels (Recommended)**
+Suitable for various quantization schemes, supports FP8 and other quantization formats.
+
+```bash
+pip install vllm
+```
+
+Or install from source for the latest features:
+
+```bash
+git clone https://github.com/vllm-project/vllm.git
+cd vllm
+uv pip install -e .
+```
+
+**Option B: SGL Kernels**
+Suitable for SGL quantization scheme, requires torch == 2.8.0.
+
+```bash
+pip install sgl-kernel --upgrade
+```
+
+**Option C: Q8 Kernels**
+Suitable for Ada architecture GPUs (such as RTX 4090, L40S, etc.).
+
+```bash
+git clone https://github.com/KONAKONA666/q8_kernels.git
+cd q8_kernels && git submodule init && git submodule update
+python setup.py install
+```
+
+> 💡 **Note**:
+> - You can skip this step if you don't need quantization functionality
+> - Quantized models can be downloaded from [LightX2V HuggingFace](https://huggingface.co/lightx2v)
+> - For more quantization information, please refer to the [Quantization Documentation](method_tutorials/quantization.html)
+
+#### Step 5: Verify Installation
+
+```python
+import lightx2v
+print(f"LightX2V Version: {lightx2v.__version__}")
+```
+
+## 🪟 Windows Environment Setup
+
+Windows systems only support Conda environment setup. Please follow these steps:
+
+### 🐍 Conda Environment Setup
+
+#### Step 1: Check CUDA Version
+
+First, confirm your GPU driver and CUDA version:
+
+```cmd
+nvidia-smi
+```
+
+Record the **CUDA Version** information in the output, which needs to be consistent in subsequent installations.
+
+#### Step 2: Create Python Environment
+
+```cmd
+# Create new environment (Python 3.12 recommended)
+conda create -n lightx2v python=3.12 -y
+
+# Activate environment
+conda activate lightx2v
+```
+
+> 💡 **Note**: Python 3.10 or higher is recommended for best compatibility.
+
+#### Step 3: Install PyTorch Framework
+
+**Method 1: Download Official Wheel Package (Recommended)**
+
+1. Visit the [PyTorch Official Download Page](https://download.pytorch.org/whl/torch/)
+2. Select the corresponding version wheel package, paying attention to matching:
+ - **Python Version**: Consistent with your environment
+ - **CUDA Version**: Matches your GPU driver
+ - **Platform**: Select Windows version
+
+**Example (Python 3.12 + PyTorch 2.6 + CUDA 12.4):**
+
+```cmd
+# Download and install PyTorch
+pip install torch-2.6.0+cu124-cp312-cp312-win_amd64.whl
+
+# Install supporting packages
+pip install torchvision==0.21.0 torchaudio==2.6.0
+```
+
+**Method 2: Direct Installation via pip**
+
+```cmd
+# CUDA 12.4 version example
+pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
+```
+
+#### Step 4: Install Windows Version vLLM
+
+Download the corresponding wheel package from [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases).
+
+**Version Matching Requirements:**
+- Python version matching
+- PyTorch version matching
+- CUDA version matching
+
+```cmd
+# Install vLLM (please adjust according to actual filename)
+pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl
+```
+
+#### Step 5: Install Attention Mechanism Operators
+
+**Option A: Flash Attention 2**
+
+```cmd
+pip install flash-attn==2.7.2.post1
+```
+
+**Option B: SageAttention 2 (Strongly Recommended)**
+
+**Download Sources:**
+- [Windows Special Version 1](https://github.com/woct0rdho/SageAttention/releases)
+- [Windows Special Version 2](https://github.com/sdbds/SageAttention-for-windows/releases)
+
+```cmd
+# Install SageAttention (please adjust according to actual filename)
+pip install sageattention-2.1.1+cu126torch2.6.0-cp312-cp312-win_amd64.whl
+```
+
+> ⚠️ **Note**: SageAttention's CUDA version doesn't need to be strictly aligned, but Python and PyTorch versions must match.
+
+#### Step 6: Clone Repository
+
+```cmd
+# Clone project code
+git clone https://github.com/ModelTC/LightX2V.git
+cd LightX2V
+
+# Install Windows-specific dependencies
+pip install -r requirements_win.txt
+pip install -v -e .
+```
+
+#### Step 7: Install Quantization Operators (Optional)
+
+Quantization operators are used to support model quantization, which can significantly reduce memory usage and accelerate inference.
+
+**Install VLLM (Recommended):**
+
+Download the corresponding wheel package from [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases) and install it.
+
+```cmd
+# Install vLLM (please adjust according to actual filename)
+pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl
+```
+
+> 💡 **Note**:
+> - You can skip this step if you don't need quantization functionality
+> - Quantized models can be downloaded from [LightX2V HuggingFace](https://huggingface.co/lightx2v)
+> - For more quantization information, please refer to the [Quantization Documentation](method_tutorials/quantization.html)
+
+## 🎯 Inference Usage
+
+### 📥 Model Preparation
+
+Before starting inference, you need to download the model files in advance. We recommend:
+
+- **Download Source**: Download models from [LightX2V Official Hugging Face](https://huggingface.co/lightx2v/) or other open-source model repositories
+- **Storage Location**: It's recommended to store models on SSD disks for better read performance
+- **Available Models**: Including Wan2.1-I2V, Wan2.1-T2V, and other models supporting different resolutions and functionalities
+
+### 📁 Configuration Files and Scripts
+
+The configuration files used for inference are available [here](https://github.com/ModelTC/LightX2V/tree/main/configs), and scripts are available [here](https://github.com/ModelTC/LightX2V/tree/main/scripts).
+
+You need to configure the downloaded model path in the run script. In addition to the input arguments in the script, there are also some necessary parameters in the configuration file specified by `--config_json`. You can modify them as needed.
+
+### 🚀 Start Inference
+
+#### Linux Environment
+
+```bash
+# Run after modifying the path in the script
+bash scripts/wan/run_wan_t2v.sh
+```
+
+#### Windows Environment
+
+```cmd
+# Use Windows batch script
+scripts\win\run_wan_t2v.bat
+```
+
+#### Python Script Launch
+
+```python
+from lightx2v import LightX2VPipeline
+
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.1-T2V-14B",
+ model_cls="wan2.1",
+ task="t2v",
+)
+
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=50,
+ height=480, # 720
+ width=832, # 1280
+ num_frames=81,
+ guidance_scale=5.0,
+ sample_shift=5.0,
+)
+
+seed = 42
+prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+save_result_path="/path/to/save_results/output.mp4"
+
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
+```
+
+> 💡 **More Examples**: For more usage examples including quantization, offloading, caching, and other advanced configurations, please refer to the [examples directory](https://github.com/ModelTC/LightX2V/tree/main/examples).
+
+## 📞 Get Help
+
+If you encounter problems during installation or usage, please:
+
+1. Search for related issues in [GitHub Issues](https://github.com/ModelTC/LightX2V/issues)
+2. Submit a new Issue describing your problem
+
+---
+
+🎉 **Congratulations!** You have successfully set up the LightX2V environment and can now start enjoying video generation!
diff --git a/docs/EN/source/index.rst b/docs/EN/source/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..3e244e62ab26025fc1d5ae328eaf29205a5cf254
--- /dev/null
+++ b/docs/EN/source/index.rst
@@ -0,0 +1,67 @@
+Welcome to Lightx2v!
+==================
+
+.. figure:: ../../../assets/img_lightx2v.png
+ :width: 80%
+ :align: center
+ :alt: Lightx2v
+ :class: no-scaled-link
+
+.. raw:: html
+
+
+
+
+ LightX2V: Light Video Generation Inference Framework
+
+
+LightX2V is a lightweight video generation inference framework designed to provide an inference tool that leverages multiple advanced video generation inference techniques. As a unified inference platform, this framework supports various generation tasks such as text-to-video (T2V) and image-to-video (I2V) across different models. X2V means transforming different input modalities (such as text or images) to video output.
+
+GitHub: https://github.com/ModelTC/lightx2v
+
+HuggingFace: https://huggingface.co/lightx2v
+
+Documentation
+-------------
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Quick Start
+
+ Quick Start
+ Model Structure
+ Benchmark
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Method Tutorials
+
+ Model Quantization
+ Feature Caching
+ Attention Module
+ Offload
+ Parallel Inference
+ Changing Resolution Inference
+ Step Distill
+ Autoregressive Distill
+ Video Frame Interpolation
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Deployment Guides
+
+ Low Latency Deployment
+ Low Resource Deployment
+ Lora Deployment
+ Service Deployment
+ Gradio Deployment
+ ComfyUI Deployment
+ Local Windows Deployment
diff --git a/docs/EN/source/method_tutorials/attention.md b/docs/EN/source/method_tutorials/attention.md
new file mode 100644
index 0000000000000000000000000000000000000000..1396140c74bd7af94b7dc6ccb6f5bb867ed05794
--- /dev/null
+++ b/docs/EN/source/method_tutorials/attention.md
@@ -0,0 +1,35 @@
+# Attention Mechanisms
+
+## Attention Mechanisms Supported by LightX2V
+
+| Name | Type Name | GitHub Link |
+|--------------------|------------------|-------------|
+| Flash Attention 2 | `flash_attn2` | [flash-attention v2](https://github.com/Dao-AILab/flash-attention) |
+| Flash Attention 3 | `flash_attn3` | [flash-attention v3](https://github.com/Dao-AILab/flash-attention) |
+| Sage Attention 2 | `sage_attn2` | [SageAttention](https://github.com/thu-ml/SageAttention) |
+| Radial Attention | `radial_attn` | [Radial Attention](https://github.com/mit-han-lab/radial-attention) |
+| Sparge Attention | `sparge_ckpt` | [Sparge Attention](https://github.com/thu-ml/SpargeAttn) |
+
+---
+
+## Configuration Examples
+
+The configuration files for attention mechanisms are located [here](https://github.com/ModelTC/lightx2v/tree/main/configs/attentions)
+
+By specifying --config_json to a specific config file, you can test different attention mechanisms.
+
+For example, for radial_attn, the configuration is as follows:
+
+```json
+{
+ "self_attn_1_type": "radial_attn",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3"
+}
+```
+
+To switch to other types, simply replace the corresponding values with the type names from the table above.
+
+Tips: radial_attn can only be used in self attention due to the limitations of its sparse algorithm principle.
+
+For further customization of attention mechanism behavior, please refer to the official documentation or implementation code of each attention library.
diff --git a/docs/EN/source/method_tutorials/autoregressive_distill.md b/docs/EN/source/method_tutorials/autoregressive_distill.md
new file mode 100644
index 0000000000000000000000000000000000000000..32b9ed50ad551f69e283ff062a825906e6b8a8fe
--- /dev/null
+++ b/docs/EN/source/method_tutorials/autoregressive_distill.md
@@ -0,0 +1,53 @@
+# Autoregressive Distillation
+
+Autoregressive distillation is a technical exploration in LightX2V. By training distilled models, it reduces inference steps from the original 40-50 steps to **8 steps**, achieving inference acceleration while enabling infinite-length video generation through KV Cache technology.
+
+> ⚠️ Warning: Currently, autoregressive distillation has mediocre effects and the acceleration improvement has not met expectations, but it can serve as a long-term research project. Currently, LightX2V only supports autoregressive models for T2V.
+
+## 🔍 Technical Principle
+
+Autoregressive distillation is implemented through [CausVid](https://github.com/tianweiy/CausVid) technology. CausVid performs step distillation and CFG distillation on 1.3B autoregressive models. LightX2V extends it with a series of enhancements:
+
+1. **Larger Models**: Supports autoregressive distillation training for 14B models;
+2. **More Complete Data Processing Pipeline**: Generates a training dataset of 50,000 prompt-video pairs;
+
+For detailed implementation, refer to [CausVid-Plus](https://github.com/GoatWu/CausVid-Plus).
+
+## 🛠️ Configuration Files
+
+### Configuration File
+
+Configuration options are provided in the [configs/causvid/](https://github.com/ModelTC/lightx2v/tree/main/configs/causvid) directory:
+
+| Configuration File | Model Address |
+|-------------------|---------------|
+| [wan_t2v_causvid.json](https://github.com/ModelTC/lightx2v/blob/main/configs/causvid/wan_t2v_causvid.json) | https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid |
+
+### Key Configuration Parameters
+
+```json
+{
+ "enable_cfg": false, // Disable CFG for speed improvement
+ "num_fragments": 3, // Number of video segments generated at once, 5s each
+ "num_frames": 21, // Frames per video segment, modify with caution!
+ "num_frame_per_block": 3, // Frames per autoregressive block, modify with caution!
+ "num_blocks": 7, // Autoregressive blocks per video segment, modify with caution!
+ "frame_seq_length": 1560, // Encoding length per frame, modify with caution!
+ "denoising_step_list": [ // Denoising timestep list
+ 999, 934, 862, 756, 603, 410, 250, 140, 74
+ ]
+}
+```
+
+## 📜 Usage
+
+### Model Preparation
+
+Place the downloaded model (`causal_model.pt` or `causal_model.safetensors`) in the `causvid_models/` folder under the Wan model root directory:
+- For T2V: `Wan2.1-T2V-14B/causvid_models/`
+
+### Inference Script
+
+```bash
+bash scripts/wan/run_wan_t2v_causvid.sh
+```
diff --git a/docs/EN/source/method_tutorials/cache.md b/docs/EN/source/method_tutorials/cache.md
new file mode 100644
index 0000000000000000000000000000000000000000..953c36af8a7ec604ecf98f12fb401ecb1803307e
--- /dev/null
+++ b/docs/EN/source/method_tutorials/cache.md
@@ -0,0 +1,3 @@
+# Feature Cache
+
+To demonstrate some video playback effects, you can get better display effects and corresponding documentation content on this [🔗 page](https://github.com/ModelTC/LightX2V/blob/main/docs/EN/source/method_tutorials/cache_source.md).
diff --git a/docs/EN/source/method_tutorials/cache_source.md b/docs/EN/source/method_tutorials/cache_source.md
new file mode 100644
index 0000000000000000000000000000000000000000..3fd14c010867058f64644e43a07a625a438840b2
--- /dev/null
+++ b/docs/EN/source/method_tutorials/cache_source.md
@@ -0,0 +1,139 @@
+# Feature Caching
+
+## Cache Acceleration Algorithm
+- In the inference process of diffusion models, cache reuse is an important acceleration algorithm.
+- The core idea is to skip redundant computations at certain time steps by reusing historical cache results to improve inference efficiency.
+- The key to the algorithm is how to decide which time steps to perform cache reuse, usually based on dynamic judgment of model state changes or error thresholds.
+- During inference, key content such as intermediate features, residuals, and attention outputs need to be cached. When entering reusable time steps, the cached content is directly utilized, and the current output is reconstructed through approximation methods like Taylor expansion, thereby reducing repeated calculations and achieving efficient inference.
+
+### TeaCache
+The core idea of `TeaCache` is to accumulate the **relative L1** distance between adjacent time step inputs. When the accumulated distance reaches a set threshold, it determines that the current time step should not use cache reuse; conversely, when the accumulated distance does not reach the set threshold, cache reuse is used to accelerate the inference process.
+- Specifically, the algorithm calculates the relative L1 distance between the current input and the previous step input at each inference step and accumulates it.
+- When the accumulated distance does not exceed the threshold, it indicates that the model state change is not obvious, so the most recently cached content is directly reused, skipping some redundant calculations. This can significantly reduce the number of forward computations of the model and improve inference speed.
+
+In practical effects, TeaCache achieves significant acceleration while ensuring generation quality. On a single H200 card, the time consumption and video comparison before and after acceleration are as follows:
+
+
+
+ |
+ Before acceleration: 58s
+ |
+
+ After acceleration: 17.9s
+ |
+
+
+ |
+
+ |
+
+
+ |
+
+
+
+
+- Acceleration ratio: **3.24**
+- Config: [wan_t2v_1_3b_tea_480p.json](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json)
+- Reference paper: [https://arxiv.org/abs/2411.19108](https://arxiv.org/abs/2411.19108)
+
+### TaylorSeer Cache
+The core of `TaylorSeer Cache` lies in using Taylor's formula to recalculate cached content as residual compensation for cache reuse time steps.
+- The specific approach is to not only simply reuse historical cache at cache reuse time steps, but also approximately reconstruct the current output through Taylor expansion. This can further improve output accuracy while reducing computational load.
+- Taylor expansion can effectively capture minor changes in model state, allowing errors caused by cache reuse to be compensated, thereby ensuring generation quality while accelerating.
+
+`TaylorSeer Cache` is suitable for scenarios with high output accuracy requirements and can further improve model inference performance based on cache reuse.
+
+
+
+ |
+ Before acceleration: 57.7s
+ |
+
+ After acceleration: 41.3s
+ |
+
+
+ |
+
+ |
+
+
+ |
+
+
+
+
+- Acceleration ratio: **1.39**
+- Config: [wan_t2v_taylorseer](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/taylorseer/wan_t2v_taylorseer.json)
+- Reference paper: [https://arxiv.org/abs/2503.06923](https://arxiv.org/abs/2503.06923)
+
+### AdaCache
+The core idea of `AdaCache` is to dynamically adjust the step size of cache reuse based on partial cached content in specified block chunks.
+- The algorithm analyzes feature differences between two adjacent time steps within specific blocks and adaptively determines the next cache reuse time step interval based on the difference magnitude.
+- When model state changes are small, the step size automatically increases, reducing cache update frequency; when state changes are large, the step size decreases to ensure output quality.
+
+This allows flexible adjustment of caching strategies based on dynamic changes in the actual inference process, achieving more efficient acceleration and better generation results. AdaCache is suitable for application scenarios that have high requirements for both inference speed and generation quality.
+
+
+
+ |
+ Before acceleration: 227s
+ |
+
+ After acceleration: 83s
+ |
+
+
+ |
+
+ |
+
+
+ |
+
+
+
+
+- Acceleration ratio: **2.73**
+- Config: [wan_i2v_ada](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/adacache/wan_i2v_ada.json)
+- Reference paper: [https://arxiv.org/abs/2411.02397](https://arxiv.org/abs/2411.02397)
+
+### CustomCache
+`CustomCache` combines the advantages of `TeaCache` and `TaylorSeer Cache`.
+- It combines the real-time and reasonable cache decision-making of `TeaCache`, determining when to perform cache reuse through dynamic thresholds.
+- At the same time, it utilizes `TaylorSeer`'s Taylor expansion method to make use of cached content.
+
+This not only efficiently determines the timing of cache reuse but also maximizes the utilization of cached content, improving output accuracy and generation quality. Actual testing shows that `CustomCache` produces video quality superior to using `TeaCache`, `TaylorSeer Cache`, or `AdaCache` alone across multiple content generation tasks, making it one of the currently optimal comprehensive cache acceleration algorithms.
+
+
+
+ |
+ Before acceleration: 57.9s
+ |
+
+ After acceleration: 16.6s
+ |
+
+
+ |
+
+ |
+
+
+ |
+
+
+
+
+- Acceleration ratio: **3.49**
+- Config: [wan_t2v_custom_1_3b](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/custom/wan_t2v_custom_1_3b.json)
+
+
+## Usage
+
+The config files for feature caching are located [here](https://github.com/ModelTC/lightx2v/tree/main/configs/caching)
+
+By specifying --config_json to the specific config file, you can test different cache algorithms.
+
+[Here](https://github.com/ModelTC/lightx2v/tree/main/scripts/cache) are some running scripts for use.
diff --git a/docs/EN/source/method_tutorials/changing_resolution.md b/docs/EN/source/method_tutorials/changing_resolution.md
new file mode 100644
index 0000000000000000000000000000000000000000..9c936949efc29464dfa313454844618b34d80ef3
--- /dev/null
+++ b/docs/EN/source/method_tutorials/changing_resolution.md
@@ -0,0 +1,66 @@
+# Variable Resolution Inference
+
+## Overview
+
+Variable resolution inference is a technical strategy for optimizing the denoising process. It improves computational efficiency while maintaining generation quality by using different resolutions at different stages of the denoising process. The core idea of this method is to use lower resolution for coarse denoising in the early stages and switch to normal resolution for fine processing in the later stages.
+
+## Technical Principles
+
+### Multi-stage Denoising Strategy
+
+Variable resolution inference is based on the following observations:
+
+- **Early-stage denoising**: Mainly handles coarse noise and overall structure, requiring less detailed information
+- **Late-stage denoising**: Focuses on detail optimization and high-frequency information recovery, requiring complete resolution information
+
+### Resolution Switching Mechanism
+
+1. **Low-resolution stage** (early stage)
+ - Downsample the input to a lower resolution (e.g., 0.75x of original size)
+ - Execute initial denoising steps
+ - Quickly remove most noise and establish basic structure
+
+2. **Normal resolution stage** (late stage)
+ - Upsample the denoising result from the first step back to original resolution
+ - Continue executing remaining denoising steps
+ - Restore detailed information and complete fine processing
+
+### U-shaped Resolution Strategy
+
+If resolution is reduced at the very beginning of the denoising steps, it may cause significant differences between the final generated video and the video generated through normal inference. Therefore, a U-shaped resolution strategy can be adopted, where the original resolution is maintained for the first few steps, then resolution is reduced for inference.
+
+## Usage
+
+The config files for variable resolution inference are located [here](https://github.com/ModelTC/LightX2V/tree/main/configs/changing_resolution)
+
+You can test variable resolution inference by specifying --config_json to the specific config file.
+
+You can refer to the scripts [here](https://github.com/ModelTC/LightX2V/blob/main/scripts/changing_resolution) to run.
+
+### Example 1:
+```
+{
+ "infer_steps": 50,
+ "changing_resolution": true,
+ "resolution_rate": [0.75],
+ "changing_resolution_steps": [25]
+}
+```
+
+This means a total of 50 steps, with resolution at 0.75x original resolution from step 1 to 25, and original resolution from step 26 to the final step.
+
+### Example 2:
+```
+{
+ "infer_steps": 50,
+ "changing_resolution": true,
+ "resolution_rate": [1.0, 0.75],
+ "changing_resolution_steps": [10, 35]
+}
+```
+
+This means a total of 50 steps, with original resolution from step 1 to 10, 0.75x original resolution from step 11 to 35, and original resolution from step 36 to the final step.
+
+Generally, if `changing_resolution_steps` is [A, B, C], the denoising starts at step 1, and the total number of steps is X, then the inference process will be divided into four segments.
+
+Specifically, these segments are (0, A], (A, B], (B, C], and (C, X], where each segment is a left-open, right-closed interval.
diff --git a/docs/EN/source/method_tutorials/offload.md b/docs/EN/source/method_tutorials/offload.md
new file mode 100644
index 0000000000000000000000000000000000000000..0fff2dc569d02bd6c669ab52112d4c70d5f9c03d
--- /dev/null
+++ b/docs/EN/source/method_tutorials/offload.md
@@ -0,0 +1,177 @@
+# Parameter Offload
+
+## 📖 Overview
+
+LightX2V implements an advanced parameter offload mechanism specifically designed for large model inference under limited hardware resources. The system provides an excellent speed-memory balance by intelligently managing model weights across different memory hierarchies.
+
+**Core Features:**
+- **Block/Phase-level Offload**: Efficiently manages model weights in block/phase units for optimal memory usage
+ - **Block**: The basic computational unit of Transformer models, containing complete Transformer layers (self-attention, cross-attention, feedforward networks, etc.), serving as a larger memory management unit
+ - **Phase**: Finer-grained computational stages within blocks, containing individual computational components (such as self-attention, cross-attention, feedforward networks, etc.), providing more precise memory control
+- **Multi-tier Storage Support**: GPU → CPU → Disk hierarchy with intelligent caching
+- **Asynchronous Operations**: Overlaps computation and data transfer using CUDA streams
+- **Disk/NVMe Serialization**: Supports secondary storage when memory is insufficient
+
+## 🎯 Offload Strategies
+
+### Strategy 1: GPU-CPU Block/Phase Offload
+
+**Use Case**: Insufficient GPU memory but sufficient system memory
+
+**How It Works**: Manages model weights in block or phase units between GPU and CPU memory, utilizing CUDA streams to overlap computation and data transfer. Blocks contain complete Transformer layers, while Phases are individual computational components within blocks.
+
+
+

+
+
+
+

+
+
+
+

+
+
+
+**Block vs Phase Explanation**:
+- **Block Granularity**: Larger memory management unit containing complete Transformer layers (self-attention, cross-attention, feedforward networks, etc.), suitable for sufficient memory scenarios with reduced management overhead
+- **Phase Granularity**: Finer-grained memory management containing individual computational components (such as self-attention, cross-attention, feedforward networks, etc.), suitable for memory-constrained scenarios with more flexible memory control
+
+**Key Features:**
+- **Asynchronous Transfer**: Uses three CUDA streams with different priorities for parallel computation and transfer
+ - Compute stream (priority=-1): High priority, handles current computation
+ - GPU load stream (priority=0): Medium priority, handles CPU to GPU prefetching
+ - CPU load stream (priority=0): Medium priority, handles GPU to CPU offloading
+- **Prefetch Mechanism**: Preloads the next block/phase to GPU in advance
+- **Intelligent Caching**: Maintains weight cache in CPU memory
+- **Stream Synchronization**: Ensures correctness of data transfer and computation
+- **Swap Operation**: Rotates block/phase positions after computation for continuous execution
+
+
+
+
+### Strategy 2: Disk-CPU-GPU Block/Phase Offload (Lazy Loading)
+
+**Use Case**: Both GPU memory and system memory are insufficient
+
+**How It Works**: Builds upon Strategy 1 by introducing disk storage, implementing a three-tier storage hierarchy (Disk → CPU → GPU). CPU continues to serve as a cache pool with configurable size, suitable for devices with limited CPU memory.
+
+
+
+

+
+
+
+
+

+
+
+**Key Features:**
+- **Lazy Loading**: Model weights are loaded from disk on-demand, avoiding loading the entire model at once
+- **Intelligent Caching**: CPU memory buffer uses FIFO strategy with configurable size
+- **Multi-threaded Prefetch**: Uses multiple disk worker threads for parallel loading
+- **Asynchronous Transfer**: Uses CUDA streams to overlap computation and data transfer
+- **Swap Rotation**: Achieves continuous computation through position rotation, avoiding repeated loading/offloading
+
+**Working Steps**:
+- **Disk Storage**: Model weights are stored on SSD/NVMe by block, one .safetensors file per block
+- **Task Scheduling**: When a block/phase is needed, priority task queue assigns disk worker threads
+- **Asynchronous Loading**: Multiple disk threads load weight files from disk to CPU memory buffer in parallel
+- **Intelligent Caching**: CPU memory buffer manages cache using FIFO strategy with configurable size
+- **Cache Hit**: If weights are already in cache, transfer directly to GPU without disk read
+- **Prefetch Transfer**: Weights in cache are asynchronously transferred to GPU memory (using GPU load stream)
+- **Compute Execution**: Weights on GPU perform computation (using compute stream) while background continues prefetching next block/phase
+- **Swap Rotation**: After computation completes, rotate block/phase positions for continuous computation
+- **Memory Management**: When CPU cache is full, automatically evict the least recently used weight block/phase
+
+
+
+## ⚙️ Configuration Parameters
+
+### GPU-CPU Offload Configuration
+
+```python
+config = {
+ "cpu_offload": True,
+ "offload_ratio": 1.0, # Offload ratio (0.0-1.0)
+ "offload_granularity": "block", # Offload granularity: "block" or "phase"
+ "lazy_load": False, # Disable lazy loading
+}
+```
+
+### Disk-CPU-GPU Offload Configuration
+
+```python
+config = {
+ "cpu_offload": True,
+ "lazy_load": True, # Enable lazy loading
+ "offload_ratio": 1.0, # Offload ratio
+ "offload_granularity": "phase", # Recommended to use phase granularity
+ "num_disk_workers": 2, # Number of disk worker threads
+ "offload_to_disk": True, # Enable disk offload
+}
+```
+
+**Intelligent Cache Key Parameters:**
+- `max_memory`: Controls CPU cache size, affects cache hit rate and memory usage
+- `num_disk_workers`: Controls number of disk loading threads, affects prefetch speed
+- `offload_granularity`: Controls cache granularity (block or phase), affects cache efficiency
+ - `"block"`: Cache management in complete Transformer layer units
+ - `"phase"`: Cache management in individual computational component units
+
+**Offload Configuration for Non-DIT Model Components (T5, CLIP, VAE):**
+
+The offload behavior of these components follows these rules:
+- **Default Behavior**: If not specified separately, T5, CLIP, VAE will follow the `cpu_offload` setting
+- **Independent Configuration**: Can set offload strategy separately for each component for fine-grained control
+
+**Configuration Example**:
+```json
+{
+ "cpu_offload": true, // DIT model offload switch
+ "t5_cpu_offload": false, // T5 encoder independent setting
+ "clip_cpu_offload": false, // CLIP encoder independent setting
+ "vae_cpu_offload": false // VAE encoder independent setting
+}
+```
+
+For memory-constrained devices, a progressive offload strategy is recommended:
+
+1. **Step 1**: Only enable `cpu_offload`, disable `t5_cpu_offload`, `clip_cpu_offload`, `vae_cpu_offload`
+2. **Step 2**: If memory is still insufficient, gradually enable CPU offload for T5, CLIP, VAE
+3. **Step 3**: If memory is still not enough, consider using quantization + CPU offload or enable `lazy_load`
+
+**Practical Experience**:
+- **RTX 4090 24GB + 14B Model**: Usually only need to enable `cpu_offload`, manually set other component offload to `false`, and use FP8 quantized version
+- **Smaller Memory GPUs**: Need to combine quantization, CPU offload, and lazy loading
+- **Quantization Schemes**: Refer to [Quantization Documentation](../method_tutorials/quantization.md) to select appropriate quantization strategy
+
+
+**Configuration File Reference**:
+- **Wan2.1 Series Models**: Refer to [offload config files](https://github.com/ModelTC/lightx2v/tree/main/configs/offload)
+- **Wan2.2 Series Models**: Refer to [wan22 config files](https://github.com/ModelTC/lightx2v/tree/main/configs/wan22) with `4090` suffix
+
+## 🎯 Usage Recommendations
+- 🔄 GPU-CPU Block/Phase Offload: Suitable for insufficient GPU memory (RTX 3090/4090 24G) but sufficient system memory (>64/128G)
+
+- 💾 Disk-CPU-GPU Block/Phase Offload: Suitable for both insufficient GPU memory (RTX 3060/4090 8G) and system memory (16/32G)
+
+- 🚫 No Offload: Suitable for high-end hardware configurations pursuing best performance
+
+
+## 🔍 Troubleshooting
+
+### Common Issues and Solutions
+
+1. **Disk I/O Bottleneck**
+ - Solution: Use NVMe SSD, increase num_disk_workers
+
+
+2. **Memory Buffer Overflow**
+ - Solution: Increase max_memory or reduce num_disk_workers
+
+3. **Loading Timeout**
+ - Solution: Check disk performance, optimize file system
+
+
+**Note**: This offload mechanism is specifically designed for LightX2V, fully utilizing the asynchronous computing capabilities of modern hardware, significantly lowering the hardware threshold for large model inference.
diff --git a/docs/EN/source/method_tutorials/parallel.md b/docs/EN/source/method_tutorials/parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..324ea7b74c28117cfebc5d6fa66baefad31bc32f
--- /dev/null
+++ b/docs/EN/source/method_tutorials/parallel.md
@@ -0,0 +1,53 @@
+# Parallel Inference
+
+LightX2V supports distributed parallel inference, enabling the utilization of multiple GPUs for inference. The DiT component supports two parallel attention mechanisms: **Ulysses** and **Ring**, while also supporting **Cfg parallel inference**. Parallel inference significantly reduces inference time and alleviates memory overhead on each GPU.
+
+## DiT Parallel Configuration
+
+### 1. Ulysses Parallel
+
+**Configuration method:**
+```json
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+```
+
+### 2. Ring Parallel
+
+**Configuration method:**
+```json
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ring"
+ }
+```
+
+## Cfg Parallel Configuration
+
+**Configuration method:**
+```json
+ "parallel": {
+ "cfg_p_size": 2
+ }
+```
+
+## Hybrid Parallel Configuration
+
+**Configuration method:**
+```json
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ }
+```
+
+## Usage
+
+Parallel inference configuration files are available [here](https://github.com/ModelTC/lightx2v/tree/main/configs/dist_infer)
+
+By specifying --config_json to a specific config file, you can test parallel inference.
+
+[Here](https://github.com/ModelTC/lightx2v/tree/main/scripts/dist_infer) are some run scripts for your use.
diff --git a/docs/EN/source/method_tutorials/quantization.md b/docs/EN/source/method_tutorials/quantization.md
new file mode 100644
index 0000000000000000000000000000000000000000..da355a92a08c9280cfb3fd0447d851b8a658c0f7
--- /dev/null
+++ b/docs/EN/source/method_tutorials/quantization.md
@@ -0,0 +1,158 @@
+# Model Quantization Techniques
+
+## 📖 Overview
+
+LightX2V supports quantized inference for DIT, T5, and CLIP models, reducing memory usage and improving inference speed by lowering model precision.
+
+---
+
+## 🔧 Quantization Modes
+
+| Quantization Mode | Weight Quantization | Activation Quantization | Compute Kernel | Supported Hardware |
+|--------------|----------|----------|----------|----------|
+| `fp8-vllm` | FP8 channel symmetric | FP8 channel dynamic symmetric | [VLLM](https://github.com/vllm-project/vllm) | H100/H200/H800, RTX 40 series, etc. |
+| `int8-vllm` | INT8 channel symmetric | INT8 channel dynamic symmetric | [VLLM](https://github.com/vllm-project/vllm) | A100/A800, RTX 30/40 series, etc. |
+| `fp8-sgl` | FP8 channel symmetric | FP8 channel dynamic symmetric | [SGL](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) | H100/H200/H800, RTX 40 series, etc. |
+| `int8-sgl` | INT8 channel symmetric | INT8 channel dynamic symmetric | [SGL](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) | A100/A800, RTX 30/40 series, etc. |
+| `fp8-q8f` | FP8 channel symmetric | FP8 channel dynamic symmetric | [Q8-Kernels](https://github.com/KONAKONA666/q8_kernels) | RTX 40 series, L40S, etc. |
+| `int8-q8f` | INT8 channel symmetric | INT8 channel dynamic symmetric | [Q8-Kernels](https://github.com/KONAKONA666/q8_kernels) | RTX 40 series, L40S, etc. |
+| `int8-torchao` | INT8 channel symmetric | INT8 channel dynamic symmetric | [TorchAO](https://github.com/pytorch/ao) | A100/A800, RTX 30/40 series, etc. |
+| `int4-g128-marlin` | INT4 group symmetric | FP16 | [Marlin](https://github.com/IST-DASLab/marlin) | H200/H800/A100/A800, RTX 30/40 series, etc. |
+| `fp8-b128-deepgemm` | FP8 block symmetric | FP8 group symmetric | [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) | H100/H200/H800, RTX 40 series, etc.|
+
+---
+
+## 🔧 Obtaining Quantized Models
+
+### Method 1: Download Pre-Quantized Models
+
+Download pre-quantized models from LightX2V model repositories:
+
+**DIT Models**
+
+Download pre-quantized DIT models from [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models):
+
+```bash
+# Download DIT FP8 quantized model
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models \
+ --include "wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors"
+```
+
+**Encoder Models**
+
+Download pre-quantized T5 and CLIP models from [Encoders-LightX2V](https://huggingface.co/lightx2v/Encoders-Lightx2v):
+
+```bash
+# Download T5 FP8 quantized model
+huggingface-cli download lightx2v/Encoders-Lightx2v \
+ --local-dir ./models \
+ --include "models_t5_umt5-xxl-enc-fp8.pth"
+
+# Download CLIP FP8 quantized model
+huggingface-cli download lightx2v/Encoders-Lightx2v \
+ --local-dir ./models \
+ --include "models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8.pth"
+```
+
+### Method 2: Self-Quantize Models
+
+For detailed quantization tool usage, refer to: [Model Conversion Documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md)
+
+---
+
+## 🚀 Using Quantized Models
+
+### DIT Model Quantization
+
+#### Supported Quantization Modes
+
+DIT quantization modes (`dit_quant_scheme`) support: `fp8-vllm`, `int8-vllm`, `fp8-sgl`, `int8-sgl`, `fp8-q8f`, `int8-q8f`, `int8-torchao`, `int4-g128-marlin`, `fp8-b128-deepgemm`
+
+#### Configuration Example
+
+```json
+{
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "dit_quantized_ckpt": "/path/to/dit_quantized_model" // Optional
+}
+```
+
+> 💡 **Tip**: When there's only one DIT model in the script's `model_path`, `dit_quantized_ckpt` doesn't need to be specified separately.
+
+### T5 Model Quantization
+
+#### Supported Quantization Modes
+
+T5 quantization modes (`t5_quant_scheme`) support: `int8-vllm`, `fp8-sgl`, `int8-q8f`, `fp8-q8f`, `int8-torchao`
+
+#### Configuration Example
+
+```json
+{
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl",
+ "t5_quantized_ckpt": "/path/to/t5_quantized_model" // Optional
+}
+```
+
+> 💡 **Tip**: When a T5 quantized model exists in the script's specified `model_path` (such as `models_t5_umt5-xxl-enc-fp8.pth` or `models_t5_umt5-xxl-enc-int8.pth`), `t5_quantized_ckpt` doesn't need to be specified separately.
+
+### CLIP Model Quantization
+
+#### Supported Quantization Modes
+
+CLIP quantization modes (`clip_quant_scheme`) support: `int8-vllm`, `fp8-sgl`, `int8-q8f`, `fp8-q8f`, `int8-torchao`
+
+#### Configuration Example
+
+```json
+{
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-sgl",
+ "clip_quantized_ckpt": "/path/to/clip_quantized_model" // Optional
+}
+```
+
+> 💡 **Tip**: When a CLIP quantized model exists in the script's specified `model_path` (such as `models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8.pth` or `models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8.pth`), `clip_quantized_ckpt` doesn't need to be specified separately.
+
+### Performance Optimization Strategy
+
+If memory is insufficient, you can combine parameter offloading to further reduce memory usage. Refer to [Parameter Offload Documentation](../method_tutorials/offload.md):
+
+> - **Wan2.1 Configuration**: Refer to [offload config files](https://github.com/ModelTC/LightX2V/tree/main/configs/offload)
+> - **Wan2.2 Configuration**: Refer to [wan22 config files](https://github.com/ModelTC/LightX2V/tree/main/configs/wan22) with `4090` suffix
+
+---
+
+## 📚 Related Resources
+
+### Configuration File Examples
+- [INT8 Quantization Config](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v.json)
+- [Q8F Quantization Config](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v_q8f.json)
+- [TorchAO Quantization Config](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v_torchao.json)
+
+### Run Scripts
+- [Quantization Inference Scripts](https://github.com/ModelTC/LightX2V/tree/main/scripts/quantization)
+
+### Tool Documentation
+- [Quantization Tool Documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md)
+- [LightCompress Quantization Documentation](https://github.com/ModelTC/llmc/blob/main/docs/zh_cn/source/backend/lightx2v.md)
+
+### Model Repositories
+- [Wan2.1-LightX2V Quantized Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- [Wan2.2-LightX2V Quantized Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+- [Encoders Quantized Models](https://huggingface.co/lightx2v/Encoders-Lightx2v)
+
+---
+
+Through this document, you should be able to:
+
+✅ Understand quantization schemes supported by LightX2V
+✅ Select appropriate quantization strategies based on hardware
+✅ Correctly configure quantization parameters
+✅ Obtain and use quantized models
+✅ Optimize inference performance and memory usage
+
+If you have other questions, feel free to ask in [GitHub Issues](https://github.com/ModelTC/LightX2V/issues).
diff --git a/docs/EN/source/method_tutorials/step_distill.md b/docs/EN/source/method_tutorials/step_distill.md
new file mode 100644
index 0000000000000000000000000000000000000000..023a30b0d21a518e2c0f3ef8a4cbc91c7249b8d5
--- /dev/null
+++ b/docs/EN/source/method_tutorials/step_distill.md
@@ -0,0 +1,183 @@
+# Step Distillation
+
+Step distillation is an important optimization technique in LightX2V. By training distilled models, it significantly reduces inference steps from the original 40-50 steps to **4 steps**, dramatically improving inference speed while maintaining video quality. LightX2V implements step distillation along with CFG distillation to further enhance inference speed.
+
+## 🔍 Technical Principle
+
+### DMD Distillation
+
+The core technology of step distillation is [DMD Distillation](https://arxiv.org/abs/2311.18828). The DMD distillation framework is shown in the following diagram:
+
+
+

+
+
+The core idea of DMD distillation is to minimize the KL divergence between the output distributions of the distilled model and the original model:
+
+$$
+\begin{aligned}
+D_{KL}\left(p_{\text{fake}} \; \| \; p_{\text{real}} \right) &= \mathbb{E}{x\sim p\text{fake}}\left(\log\left(\frac{p_\text{fake}(x)}{p_\text{real}(x)}\right)\right)\\
+&= \mathbb{E}{\substack{
+z \sim \mathcal{N}(0; \mathbf{I}) \\
+x = G_\theta(z)
+}}-\big(\log~p_\text{real}(x) - \log~p_\text{fake}(x)\big).
+\end{aligned}
+$$
+
+Since directly computing the probability density is nearly impossible, DMD distillation instead computes the gradient of this KL divergence:
+
+$$
+\begin{aligned}
+\nabla_\theta D_{KL}
+&= \mathbb{E}{\substack{
+z \sim \mathcal{N}(0; \mathbf{I}) \\
+x = G_\theta(z)
+} } \Big[-
+\big(
+s_\text{real}(x) - s_\text{fake}(x)\big)
+\hspace{.5mm} \frac{dG}{d\theta}
+\Big],
+\end{aligned}
+$$
+
+where $s_\text{real}(x) =\nabla_{x} \text{log}~p_\text{real}(x)$ and $s_\text{fake}(x) =\nabla_{x} \text{log}~p_\text{fake}(x)$ are score functions. Score functions can be computed by the model. Therefore, DMD distillation maintains three models in total:
+
+- `real_score`, computes the score of the real distribution; since the real distribution is fixed, DMD distillation uses the original model with fixed weights as its score function;
+- `fake_score`, computes the score of the fake distribution; since the fake distribution is constantly updated, DMD distillation initializes it with the original model and fine-tunes it to learn the output distribution of the generator;
+- `generator`, the student model, guided by computing the gradient of the KL divergence between `real_score` and `fake_score`.
+
+> References:
+> 1. [DMD (One-step Diffusion with Distribution Matching Distillation)](https://arxiv.org/abs/2311.18828)
+> 2. [DMD2 (Improved Distribution Matching Distillation for Fast Image Synthesis)](https://arxiv.org/abs/2405.14867)
+
+### Self-Forcing
+
+DMD distillation technology is designed for image generation. The step distillation in LightX2V is implemented based on [Self-Forcing](https://github.com/guandeh17/Self-Forcing) technology. The overall implementation of Self-Forcing is similar to DMD, but following DMD2, it removes the regression loss and uses ODE initialization instead. Additionally, Self-Forcing adds an important optimization for video generation tasks:
+
+Current DMD distillation-based methods struggle to generate videos in one step. Self-Forcing selects one timestep for optimization each time, with the generator computing gradients only at this step. This approach significantly improves Self-Forcing's training speed and enhances the denoising quality at intermediate timesteps, also improving its effectiveness.
+
+> References:
+> 1. [Self-Forcing (Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion)](https://arxiv.org/abs/2506.08009)
+
+### LightX2V
+
+Self-Forcing performs step distillation and CFG distillation on 1.3B autoregressive models. LightX2V extends it with a series of enhancements:
+
+1. **Larger Models**: Supports step distillation training for 14B models;
+2. **More Model Types**: Supports standard bidirectional models and I2V model step distillation training;
+3. **Better Results**: LightX2V uses high-quality prompts from approximately 50,000 data entries for training;
+
+For detailed implementation, refer to [Self-Forcing-Plus](https://github.com/GoatWu/Self-Forcing-Plus).
+
+## 🎯 Technical Features
+
+- **Inference Acceleration**: Reduces inference steps from 40-50 to 4 steps without CFG, achieving approximately **20-24x** speedup
+- **Quality Preservation**: Maintains original video generation quality through distillation techniques
+- **Strong Compatibility**: Supports both T2V and I2V tasks
+- **Flexible Usage**: Supports loading complete step distillation models or loading step distillation LoRA on top of native models; compatible with int8/fp8 model quantization
+
+## 🛠️ Configuration Files
+
+### Basic Configuration Files
+
+Multiple configuration options are provided in the [configs/distill/](https://github.com/ModelTC/lightx2v/tree/main/configs/distill) directory:
+
+| Configuration File | Purpose | Model Address |
+|-------------------|---------|---------------|
+| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | Load T2V 4-step distillation complete model | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) |
+| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | Load I2V 4-step distillation complete model | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) |
+| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | Load Wan-T2V model and step distillation LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) |
+| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | Load Wan-I2V model and step distillation LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) |
+
+### Key Configuration Parameters
+
+- Since DMD distillation only trains a few fixed timesteps, we recommend using `LCM Scheduler` for inference. In [WanStepDistillScheduler](https://github.com/ModelTC/LightX2V/blob/main/lightx2v/models/schedulers/wan/step_distill/scheduler.py), `LCM Scheduler` is already fixed in use, requiring no user configuration.
+- `infer_steps`, `denoising_step_list` and `sample_shift` are set to parameters matching those during training, and are generally not recommended for user modification.
+- `enable_cfg` must be set to `false` (equivalent to setting `sample_guide_scale = 1`), otherwise the video may become completely blurred.
+- `lora_configs` supports merging multiple LoRAs with different strengths. When `lora_configs` is not empty, the original `Wan2.1` model is loaded by default. Therefore, when using `lora_config` and wanting to use step distillation, please set the path and strength of the step distillation LoRA.
+
+```json
+{
+ "infer_steps": 4, // Inference steps
+ "denoising_step_list": [1000, 750, 500, 250], // Denoising timestep list
+ "sample_shift": 5, // Scheduler timestep shift
+ "enable_cfg": false, // Disable CFG for speed improvement
+ "lora_configs": [ // LoRA weights path (optional)
+ {
+ "path": "path/to/distill_lora.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
+```
+
+## 📜 Usage
+
+### Model Preparation
+
+**Complete Model:**
+Place the downloaded model (`distill_model.pt` or `distill_model.safetensors`) in the `distill_models/` folder under the Wan model root directory:
+
+- For T2V: `Wan2.1-T2V-14B/distill_models/`
+- For I2V-480P: `Wan2.1-I2V-14B-480P/distill_models/`
+
+**LoRA:**
+
+1. Place the downloaded LoRA in any location
+2. Modify the `lora_path` parameter in the configuration file to the LoRA storage path
+
+### Inference Scripts
+
+**T2V Complete Model:**
+
+```bash
+bash scripts/wan/run_wan_t2v_distill_4step_cfg.sh
+```
+
+**I2V Complete Model:**
+
+```bash
+bash scripts/wan/run_wan_i2v_distill_4step_cfg.sh
+```
+
+### Step Distillation LoRA Inference Scripts
+
+**T2V LoRA:**
+
+```bash
+bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh
+```
+
+**I2V LoRA:**
+
+```bash
+bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh
+```
+
+## 🔧 Service Deployment
+
+### Start Distillation Model Service
+
+Modify the startup command in [scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh):
+
+```bash
+python -m lightx2v.api_server \
+ --model_cls wan2.1_distill \
+ --task t2v \
+ --model_path $model_path \
+ --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg.json \
+ --port 8000 \
+ --nproc_per_node 1
+```
+
+Run the service startup script:
+
+```bash
+scripts/server/start_server.sh
+```
+
+For more details, see [Service Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_service.html).
+
+### Usage in Gradio Interface
+
+See [Gradio Documentation](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_gradio.html)
diff --git a/docs/EN/source/method_tutorials/video_frame_interpolation.md b/docs/EN/source/method_tutorials/video_frame_interpolation.md
new file mode 100644
index 0000000000000000000000000000000000000000..f262f52a1c8762a4c349507fad0a379a1dcc64cb
--- /dev/null
+++ b/docs/EN/source/method_tutorials/video_frame_interpolation.md
@@ -0,0 +1,246 @@
+# Video Frame Interpolation (VFI)
+
+> **Important Note**: Video frame interpolation is enabled through configuration files, not command-line parameters. Please add a `video_frame_interpolation` configuration block to your JSON config file to enable this feature.
+
+## Overview
+
+Video Frame Interpolation (VFI) is a technique that generates intermediate frames between existing frames to increase the frame rate and create smoother video playback. LightX2V integrates the RIFE (Real-Time Intermediate Flow Estimation) model to provide high-quality frame interpolation capabilities.
+
+## What is RIFE?
+
+RIFE is a state-of-the-art video frame interpolation method that uses optical flow estimation to generate intermediate frames. It can effectively:
+
+- Increase video frame rate (e.g., from 16 FPS to 32 FPS)
+- Create smooth motion transitions
+- Maintain high visual quality with minimal artifacts
+- Process videos in real-time
+
+## Installation and Setup
+
+### Download RIFE Model
+
+First, download the RIFE model weights using the provided script:
+
+```bash
+python tools/download_rife.py
+```
+
+For example, to download to the location:
+```bash
+python tools/download_rife.py /path/to/rife/train_log
+```
+
+This script will:
+- Download RIFEv4.26 model from HuggingFace
+- Extract and place the model files in the correct directory
+- Clean up temporary files
+
+## Usage
+
+### Configuration File Setup
+
+Video frame interpolation is enabled through configuration files. Add a `video_frame_interpolation` configuration block to your JSON config file:
+
+```json
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "fps": 16,
+ "video_frame_interpolation": {
+ "algo": "rife",
+ "target_fps": 32,
+ "model_path": "/path/to/rife/train_log"
+ }
+}
+```
+
+### Command Line Interface
+
+Run inference using a configuration file that includes VFI settings:
+
+```bash
+python lightx2v/infer.py \
+ --model_cls wan2.1 \
+ --task t2v \
+ --model_path /path/to/model \
+ --config_json ./configs/video_frame_interpolation/wan_t2v.json \
+ --prompt "A beautiful sunset over the ocean" \
+ --save_result_path ./output.mp4
+```
+
+### Configuration Parameters
+
+In the `video_frame_interpolation` configuration block:
+
+- `algo`: Frame interpolation algorithm, currently supports "rife"
+- `target_fps`: Target frame rate for the output video
+- `model_path`: RIFE model path, typically "/path/to/rife/train_log"
+
+Other related configurations:
+- `fps`: Source video frame rate (default 16)
+
+### Configuration Priority
+
+The system automatically handles video frame rate configuration with the following priority:
+1. `video_frame_interpolation.target_fps` - If video frame interpolation is enabled, this frame rate is used as the output frame rate
+2. `fps` (default 16) - If video frame interpolation is not enabled, this frame rate is used; it's always used as the source frame rate
+
+
+## How It Works
+
+### Frame Interpolation Process
+
+1. **Source Video Generation**: The base model generates video frames at the source FPS
+2. **Frame Analysis**: RIFE analyzes adjacent frames to estimate optical flow
+3. **Intermediate Frame Generation**: New frames are generated between existing frames
+4. **Temporal Smoothing**: The interpolated frames create smooth motion transitions
+
+### Technical Details
+
+- **Input Format**: ComfyUI Image tensors [N, H, W, C] in range [0, 1]
+- **Output Format**: Interpolated ComfyUI Image tensors [M, H, W, C] in range [0, 1]
+- **Processing**: Automatic padding and resolution handling
+- **Memory Optimization**: Efficient GPU memory management
+
+## Example Configurations
+
+### Basic Frame Rate Doubling
+
+Create configuration file `wan_t2v_vfi_32fps.json`:
+
+```json
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "seed": 42,
+ "sample_guide_scale": 6,
+ "enable_cfg": true,
+ "fps": 16,
+ "video_frame_interpolation": {
+ "algo": "rife",
+ "target_fps": 32,
+ "model_path": "/path/to/rife/train_log"
+ }
+}
+```
+
+Run command:
+```bash
+python lightx2v/infer.py \
+ --model_cls wan2.1 \
+ --task t2v \
+ --model_path ./models/wan2.1 \
+ --config_json ./wan_t2v_vfi_32fps.json \
+ --prompt "A cat playing in the garden" \
+ --save_result_path ./output_32fps.mp4
+```
+
+### Higher Frame Rate Enhancement
+
+Create configuration file `wan_i2v_vfi_60fps.json`:
+
+```json
+{
+ "infer_steps": 30,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "seed": 42,
+ "sample_guide_scale": 6,
+ "enable_cfg": true,
+ "fps": 16,
+ "video_frame_interpolation": {
+ "algo": "rife",
+ "target_fps": 60,
+ "model_path": "/path/to/rife/train_log"
+ }
+}
+```
+
+Run command:
+```bash
+python lightx2v/infer.py \
+ --model_cls wan2.1 \
+ --task i2v \
+ --model_path ./models/wan2.1 \
+ --config_json ./wan_i2v_vfi_60fps.json \
+ --image_path ./input.jpg \
+ --prompt "Smooth camera movement" \
+ --save_result_path ./output_60fps.mp4
+```
+
+## Performance Considerations
+
+### Memory Usage
+
+- RIFE processing requires additional GPU memory
+- Memory usage scales with video resolution and length
+- Consider using lower resolutions for longer videos
+
+### Processing Time
+
+- Frame interpolation adds processing overhead
+- Higher target frame rates require more computation
+- Processing time is roughly proportional to the number of interpolated frames
+
+### Quality vs Speed Trade-offs
+
+- Higher interpolation ratios may introduce artifacts
+- Optimal range: 2x to 4x frame rate increase
+- For extreme interpolation (>4x), consider multiple passes
+
+## Best Practices
+
+### Optimal Use Cases
+
+- **Motion-heavy videos**: Benefit most from frame interpolation
+- **Camera movements**: Smoother panning and zooming
+- **Action sequences**: Reduced motion blur perception
+- **Slow-motion effects**: Create fluid slow-motion videos
+
+### Recommended Settings
+
+- **Source FPS**: 16-24 FPS (generated by base model)
+- **Target FPS**: 32-60 FPS (2x to 4x increase)
+- **Resolution**: Up to 720p for best performance
+
+### Troubleshooting
+
+#### Common Issues
+
+1. **Out of Memory**: Reduce video resolution or target FPS
+2. **Artifacts in output**: Lower the interpolation ratio
+3. **Slow processing**: Check GPU memory and consider using CPU offloading
+
+#### Solutions
+
+Solve issues by modifying the configuration file:
+
+```json
+{
+ // For memory issues, use lower resolution
+ "target_height": 480,
+ "target_width": 832,
+
+ // For quality issues, use moderate interpolation
+ "video_frame_interpolation": {
+ "target_fps": 24 // instead of 60
+ },
+
+ // For performance issues, enable offloading
+ "cpu_offload": true
+}
+```
+
+## Technical Implementation
+
+The RIFE integration in LightX2V includes:
+
+- **RIFEWrapper**: ComfyUI-compatible wrapper for RIFE model
+- **Automatic Model Loading**: Seamless integration with the inference pipeline
+- **Memory Optimization**: Efficient tensor management and GPU memory usage
+- **Quality Preservation**: Maintains original video quality while adding frames
diff --git a/docs/PAPERS_ZH_CN/.readthedocs.yaml b/docs/PAPERS_ZH_CN/.readthedocs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..784bfbe697f4f39630a2beb948b20f8dc7c22a66
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/.readthedocs.yaml
@@ -0,0 +1,17 @@
+version: 2
+
+# Set the version of Python and other tools you might need
+build:
+ os: ubuntu-20.04
+ tools:
+ python: "3.10"
+
+formats:
+ - epub
+
+sphinx:
+ configuration: docs/PAPERS_ZH_CN/source/conf.py
+
+python:
+ install:
+ - requirements: requirements-docs.txt
diff --git a/docs/PAPERS_ZH_CN/Makefile b/docs/PAPERS_ZH_CN/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/PAPERS_ZH_CN/make.bat b/docs/PAPERS_ZH_CN/make.bat
new file mode 100644
index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.https://www.sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/PAPERS_ZH_CN/source/conf.py b/docs/PAPERS_ZH_CN/source/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9d499df2da21e3609d39882eda567372fdd544f
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/conf.py
@@ -0,0 +1,122 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import logging
+import os
+import sys
+from typing import List
+
+import sphinxcontrib.redoc
+from sphinx.ext import autodoc
+
+logger = logging.getLogger(__name__)
+sys.path.append(os.path.abspath("../.."))
+
+# -- Project information -----------------------------------------------------
+
+project = "Lightx2v"
+copyright = "2025, Lightx2v Team"
+author = "the Lightx2v Team"
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "sphinx.ext.napoleon",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.intersphinx",
+ "sphinx_copybutton",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "myst_parser",
+ "sphinxarg.ext",
+ "sphinxcontrib.redoc",
+ "sphinxcontrib.openapi",
+]
+
+html_static_path = ["_static"]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns: List[str] = ["**/*.template.rst"]
+
+# Exclude the prompt "$" when copying code
+copybutton_prompt_text = r"\$ "
+copybutton_prompt_is_regexp = True
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_title = project
+html_theme = "sphinx_book_theme"
+# html_theme = 'sphinx_rtd_theme'
+html_logo = "../../../assets/img_lightx2v.png"
+html_theme_options = {
+ "path_to_docs": "docs/ZH_CN/source",
+ "repository_url": "https://github.com/ModelTC/lightx2v",
+ "use_repository_button": True,
+}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+# html_static_path = ['_static']
+
+
+# Generate additional rst documentation here.
+def setup(app):
+ # from docs.source.generate_examples import generate_examples
+ # generate_examples()
+ pass
+
+
+# Mock out external dependencies here.
+autodoc_mock_imports = [
+ "cpuinfo",
+ "torch",
+ "transformers",
+ "psutil",
+ "prometheus_client",
+ "sentencepiece",
+ "lightllmnumpy",
+ "tqdm",
+ "tensorizer",
+]
+
+for mock_target in autodoc_mock_imports:
+ if mock_target in sys.modules:
+ logger.info(
+ "Potentially problematic mock target (%s) found; autodoc_mock_imports cannot mock modules that have already been loaded into sys.modules when the sphinx build starts.",
+ mock_target,
+ )
+
+
+class MockedClassDocumenter(autodoc.ClassDocumenter):
+ """Remove note about base class when a class is derived from object."""
+
+ def add_line(self, line: str, source: str, *lineno: int) -> None:
+ if line == " Bases: :py:class:`object`":
+ return
+ super().add_line(line, source, *lineno)
+
+
+autodoc.ClassDocumenter = MockedClassDocumenter
+
+navigation_with_keys = False
diff --git a/docs/PAPERS_ZH_CN/source/index.rst b/docs/PAPERS_ZH_CN/source/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..cb4414a9a3c80e31572fe99b052aa9e065122274
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/index.rst
@@ -0,0 +1,53 @@
+欢迎了解 Lightx2v 论文收藏集!
+==================
+
+.. figure:: ../../../assets/img_lightx2v.png
+ :width: 80%
+ :align: center
+ :alt: Lightx2v
+ :class: no-scaled-link
+
+.. raw:: html
+
+
+
+
+ LightX2V: 一个轻量级的视频生成推理框架
+
+
+
+LightX2V 是一个轻量级的视频生成推理框架。这里是我们维护的一个视频生成推理加速相关的论文收藏集,帮助你快速了解视频生成推理加速相关的经典方法和最新进展。
+
+GitHub: https://github.com/ModelTC/lightx2v
+
+HuggingFace: https://huggingface.co/lightx2v
+
+论文收藏集
+-------------
+
+.. toctree::
+ :maxdepth: 1
+ :caption: 论文分类
+
+ 图像视频生成基础
+ 开源模型
+ 模型量化
+ 特征缓存
+ 注意力机制
+ 参数卸载
+ 并行推理
+ 变分辨率推理
+ 步数蒸馏
+ 自回归模型
+ vae加速
+ prompt增强
+ 强化学习
diff --git a/docs/PAPERS_ZH_CN/source/papers/RL.md b/docs/PAPERS_ZH_CN/source/papers/RL.md
new file mode 100644
index 0000000000000000000000000000000000000000..0081f92328a9fe548654098b76adc5815aba6171
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/RL.md
@@ -0,0 +1,3 @@
+# 强化学习
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/attention.md b/docs/PAPERS_ZH_CN/source/papers/attention.md
new file mode 100644
index 0000000000000000000000000000000000000000..a17d295b51fb08e97f2a41c7c79bd5a1831cddfc
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/attention.md
@@ -0,0 +1,113 @@
+# 注意力机制
+
+### Sparse VideoGen: Accelerating Video Diffusion Transformers with Spatial-Temporal Sparsity
+
+[paper](https://arxiv.org/abs/2502.01776) | [code](https://github.com/svg-project/Sparse-VideoGen)
+
+### Sparse VideoGen2: Accelerate Video Generation with Sparse Attention via Semantic-Aware Permutation
+
+[paper](https://arxiv.org/abs/2505.18875)
+
+### Training-free and Adaptive Sparse Attention for Efficient Long Video Generation
+
+[paper](https://arxiv.org/abs/2502.21079)
+
+### DSV: Exploiting Dynamic Sparsity to Accelerate Large-Scale Video DiT Training
+
+[paper](https://arxiv.org/abs/2502.07590)
+
+### MMInference: Accelerating Pre-filling for Long-Context VLMs via Modality-Aware Permutation Sparse Attention
+
+[paper](https://github.com/microsoft/MInference)
+
+### FPSAttention: Training-Aware FP8 and Sparsity Co-Design for Fast Video Diffusion
+
+[paper](https://arxiv.org/abs/2506.04648)
+
+### VORTA: Efficient Video Diffusion via Routing Sparse Attention
+
+[paper](https://arxiv.org/abs/2505.18809)
+
+### Training-Free Efficient Video Generation via Dynamic Token Carving
+
+[paper](https://arxiv.org/abs/2505.16864)
+
+### RainFusion: Adaptive Video Generation Acceleration via Multi-Dimensional Visual Redundancy
+
+[paper](https://arxiv.org/abs/2505.21036)
+
+### Radial Attention: O(nlogn) Sparse Attention with Energy Decay for Long Video Generation
+
+[paper](https://arxiv.org/abs/2506.19852)
+
+### VMoBA: Mixture-of-Block Attention for Video Diffusion Models
+
+[paper](https://arxiv.org/abs/2506.23858)
+
+### SpargeAttention: Accurate and Training-free Sparse Attention Accelerating Any Model Inference
+
+[paper](https://arxiv.org/abs/2502.18137) | [code](https://github.com/thu-ml/SpargeAttn)
+
+### Fast Video Generation with Sliding Tile Attention
+
+[paper](https://arxiv.org/abs/2502.04507) | [code](https://github.com/hao-ai-lab/FastVideo)
+
+### PAROAttention: Pattern-Aware ReOrdering for Efficient Sparse and Quantized Attention in Visual Generation Models
+
+[paper](https://arxiv.org/abs/2506.16054)
+
+### Generalized Neighborhood Attention: Multi-dimensional Sparse Attention at the Speed of Light
+
+[paper](https://arxiv.org/abs/2504.16922)
+
+### Astraea: A GPU-Oriented Token-wise Acceleration Framework for Video Diffusion Transformers
+
+[paper](https://arxiv.org/abs/2506.05096)
+
+### ∇NABLA: Neighborhood Adaptive Block-Level Attention
+
+[paper](https://arxiv.org/abs/2507.13546v1) [code](https://github.com/gen-ai-team/Wan2.1-NABLA)
+
+### Compact Attention: Exploiting Structured Spatio-Temporal Sparsity for Fast Video Generation
+
+[paper](https://arxiv.org/abs/2508.12969)
+
+### A Survey of Efficient Attention Methods: Hardware-efficient, Sparse, Compact, and Linear Attention
+
+[paper](https://attention-survey.github.io/files/Attention_Survey.pdf)
+
+### Bidirectional Sparse Attention for Faster Video Diffusion Training
+
+[paper](https://arxiv.org/abs/2509.01085)
+
+### Mixture of Contexts for Long Video Generation
+
+[paper](https://arxiv.org/abs/2508.21058)
+
+### LoViC: Efficient Long Video Generation with Context Compression
+
+[paper](https://arxiv.org/abs/2507.12952)
+
+### MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training
+
+[paper](https://sandai-org.github.io/MagiAttention/blog/) [code](https://github.com/SandAI-org/MagiAttention)
+
+### DraftAttention: Fast Video Diffusion via Low-Resolution Attention Guidance
+
+[paper](https://arxiv.org/abs/2505.14708) [code](https://github.com/shawnricecake/draft-attention)
+
+### XAttention: Block Sparse Attention with Antidiagonal Scoring
+
+[paper](https://arxiv.org/abs/2503.16428) [code](https://github.com/mit-han-lab/x-attention)
+
+### VSA: Faster Video Diffusion with Trainable Sparse Attention
+
+[paper](https://arxiv.org/abs/2505.13389) [code](https://github.com/hao-ai-lab/FastVideo)
+
+### QuantSparse: Comprehensively Compressing Video Diffusion Transformer with Model Quantization and Attention Sparsification
+
+[paper](https://arxiv.org/abs/2509.23681)
+
+### SLA: Beyond Sparsity in Diffusion Transformers via Fine-Tunable Sparse-Linear Attention
+
+[paper](https://arxiv.org/abs/2509.24006)
diff --git a/docs/PAPERS_ZH_CN/source/papers/autoregressive.md b/docs/PAPERS_ZH_CN/source/papers/autoregressive.md
new file mode 100644
index 0000000000000000000000000000000000000000..27c2d1dad269960325ee64c86df155622f2410c8
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/autoregressive.md
@@ -0,0 +1,3 @@
+# 自回归模型
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/cache.md b/docs/PAPERS_ZH_CN/source/papers/cache.md
new file mode 100644
index 0000000000000000000000000000000000000000..661ef10213cbf5af62d0f3443a8ab774285f4be4
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/cache.md
@@ -0,0 +1,3 @@
+# 特征缓存
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/changing_resolution.md b/docs/PAPERS_ZH_CN/source/papers/changing_resolution.md
new file mode 100644
index 0000000000000000000000000000000000000000..175eea9e79dc2936404ad276ffe96f86ae8b7a35
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/changing_resolution.md
@@ -0,0 +1,3 @@
+# 变分辨率推理
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/generation_basics.md b/docs/PAPERS_ZH_CN/source/papers/generation_basics.md
new file mode 100644
index 0000000000000000000000000000000000000000..6d5b97f72edf792e6f670998849da9a2a3e7d96d
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/generation_basics.md
@@ -0,0 +1,3 @@
+# 图像视频生成基础
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/models.md b/docs/PAPERS_ZH_CN/source/papers/models.md
new file mode 100644
index 0000000000000000000000000000000000000000..de9e00c69d9b1d02e35766addcc9df106604b842
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/models.md
@@ -0,0 +1,276 @@
+
+
+
+# Open-Source Models
+
+📢: Collections of Awesome Open-Source Model Resources.
+
+
+
+
+## 📚 *Contents*
+
+- Open-Source Models
+ - [Foundation Models](#foundation-models)
+ - [World Models](#world-models)
+
+
+### Foundation Models:
+
+- **Stable Video Diffusion: Scaling Latent Video Diffusion Models to Large Datasets**, Technical Report 2023.
+
+ *Andreas Blattmann, Tim Dockhorn, Sumith Kulal, Daniel Mendelevitch, Maciej Kilian, et al.*
+
+ [[Paper](https://arxiv.org/abs/2311.15127)] [[Code](https://github.com/Stability-AI/generative-models)]   
+
+ BibTex
+
+ ```text
+ @article{blattmann2023stable,
+ title={Stable video diffusion: Scaling latent video diffusion models to large datasets},
+ author={Blattmann, Andreas and Dockhorn, Tim and Kulal, Sumith and Mendelevitch, Daniel and Kilian, Maciej and Lorenz, Dominik and Levi, Yam and English, Zion and Voleti, Vikram and Letts, Adam and others},
+ journal={arXiv preprint arXiv:2311.15127},
+ year={2023}
+ }
+ ```
+
+
+- **Wan: Open and Advanced Large-Scale Video Generative Models**, Technical Report 2025.
+
+ *Team Wan, Ang Wang, Baole Ai, Bin Wen, Chaojie Mao, Chen-Wei Xie, et al.*
+
+ [[Paper](https://arxiv.org/abs/2503.20314)] [[Code](https://github.com/Wan-Video/Wan2.1)]   
+
+ BibTex
+
+ ```text
+ @article{wan2025wan,
+ title={Wan: Open and advanced large-scale video generative models},
+ author={Wan, Team and Wang, Ang and Ai, Baole and Wen, Bin and Mao, Chaojie and Xie, Chen-Wei and Chen, Di and Yu, Feiwu and Zhao, Haiming and Yang, Jianxiao and others},
+ journal={arXiv preprint arXiv:2503.20314},
+ year={2025}
+ }
+ ```
+
+
+- **HunyuanVideo: A Systematic Framework For Large Video Generation Model**, Technical Report 2024.
+
+ *Weijie Kong, Qi Tian, Zijian Zhang, Rox Min, Zuozhuo Dai, Jin Zhou, et al.*
+
+ [[Paper](https://arxiv.org/abs/2412.03603)] [[Code](https://github.com/Tencent-Hunyuan/HunyuanVideo)]   
+
+ BibTex
+
+ ```text
+ @article{kong2024hunyuanvideo,
+ title={Hunyuanvideo: A systematic framework for large video generative models},
+ author={Kong, Weijie and Tian, Qi and Zhang, Zijian and Min, Rox and Dai, Zuozhuo and Zhou, Jin and Xiong, Jiangfeng and Li, Xin and Wu, Bo and Zhang, Jianwei and others},
+ journal={arXiv preprint arXiv:2412.03603},
+ year={2024}
+ }
+ ```
+
+
+- **CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer**, ICLR 2025.
+
+ *Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, et al.*
+
+ [[Paper](https://arxiv.org/abs/2408.06072)] [[Code](https://github.com/zai-org/CogVideo)]   
+
+ BibTex
+
+ ```text
+ @article{yang2024cogvideox,
+ title={Cogvideox: Text-to-video diffusion models with an expert transformer},
+ author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
+ journal={arXiv preprint arXiv:2408.06072},
+ year={2024}
+ }
+ ```
+
+
+
+- **SkyReels V2: Infinite-Length Film Generative Model**, Technical Report 2025.
+
+ *Guibin Chen, Dixuan Lin, Jiangping Yang, Chunze Lin, Junchen Zhu, Mingyuan Fan, et al.*
+
+ [[Paper](https://arxiv.org/abs/2504.13074)] [[Code](https://github.com/SkyworkAI/SkyReels-V2)]   
+
+ BibTex
+
+ ```text
+ @misc{chen2025skyreelsv2infinitelengthfilmgenerative,
+ title={SkyReels-V2: Infinite-length Film Generative Model},
+ author={Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou},
+ year={2025},
+ eprint={2504.13074},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2504.13074},
+ }
+ ```
+
+
+
+- **Open-Sora: Democratizing Efficient Video Production for All**, Technical Report 2025.
+
+ *Xiangyu Peng, Zangwei Zheng, Chenhui Shen, Tom Young, Xinying Guo, et al.*
+
+ [[Paper](https://arxiv.org/abs/2503.09642v2)] [[Code](https://github.com/hpcaitech/Open-Sora)]    
+
+ BibTex
+
+ ```text
+ @article{peng2025open,
+ title={Open-sora 2.0: Training a commercial-level video generation model in $200 k},
+ author={Peng, Xiangyu and Zheng, Zangwei and Shen, Chenhui and Young, Tom and Guo, Xinying and Wang, Binluo and Xu, Hang and Liu, Hongxin and Jiang, Mingyan and Li, Wenjun and others},
+ journal={arXiv preprint arXiv:2503.09642},
+ year={2025}
+ }
+ ```
+
+
+- **Pyramidal Flow Matching for Efficient Video Generative Modeling**, Technical Report 2024.
+
+ *Yang Jin, Zhicheng Sun, Ningyuan Li, Kun Xu, Kun Xu, et al.*
+
+ [[Paper](https://arxiv.org/abs/2410.05954)] [[Code](https://github.com/jy0205/Pyramid-Flow)]   
+ BibTex
+
+ ```text
+ @article{jin2024pyramidal,
+ title={Pyramidal flow matching for efficient video generative modeling},
+ author={Jin, Yang and Sun, Zhicheng and Li, Ningyuan and Xu, Kun and Jiang, Hao and Zhuang, Nan and Huang, Quzhe and Song, Yang and Mu, Yadong and Lin, Zhouchen},
+ journal={arXiv preprint arXiv:2410.05954},
+ year={2024}
+ }
+ ```
+
+
+- **MAGI-1: Autoregressive Video Generation at Scale**, Technical Report 2025.
+
+ *Sand.ai, Hansi Teng, Hongyu Jia, Lei Sun, Lingzhi Li, Maolin Li, Mingqiu Tang, et al.*
+
+ [[Paper](https://arxiv.org/pdf/2505.13211)] [[Code](https://github.com/SandAI-org/Magi-1)]    
+ BibTex
+
+ ```text
+ @article{teng2025magi,
+ title={MAGI-1: Autoregressive Video Generation at Scale},
+ author={Teng, Hansi and Jia, Hongyu and Sun, Lei and Li, Lingzhi and Li, Maolin and Tang, Mingqiu and Han, Shuai and Zhang, Tianning and Zhang, WQ and Luo, Weifeng and others},
+ journal={arXiv preprint arXiv:2505.13211},
+ year={2025}
+ }
+ ```
+
+
+- **From Slow Bidirectional to Fast Autoregressive Video Diffusion Models**, CVPR 2025.
+
+ *Tianwei Yin, Qiang Zhang, Richard Zhang, William T. Freeman, Fredo Durand, et al.*
+
+ [[Paper](http://arxiv.org/abs/2412.07772)] [[Code](https://github.com/tianweiy/CausVid)]   
+ BibTex
+
+ ```text
+ @inproceedings{yin2025slow,
+ title={From slow bidirectional to fast autoregressive video diffusion models},
+ author={Yin, Tianwei and Zhang, Qiang and Zhang, Richard and Freeman, William T and Durand, Fredo and Shechtman, Eli and Huang, Xun},
+ booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
+ pages={22963--22974},
+ year={2025}
+ }
+ ```
+
+
+- **Packing Input Frame Context in Next-Frame Prediction Models for Video Generation**, arxiv 2025.
+
+ *Lvmin Zhang, Maneesh Agrawala.*
+
+ [[Paper](https://arxiv.org/abs/2504.12626)] [[Code](https://github.com/lllyasviel/FramePack)]   
+ BibTex
+
+ ```text
+ @article{zhang2025packing,
+ title={Packing input frame context in next-frame prediction models for video generation},
+ author={Zhang, Lvmin and Agrawala, Maneesh},
+ journal={arXiv preprint arXiv:2504.12626},
+ year={2025}
+ }
+ ```
+
+
+### World Models:
+
+- **Matrix-Game 2.0: An Open-Source, Real-Time, and Streaming Interactive World Model**, Technical Report 2025.
+
+ *Xianglong He, Chunli Peng, Zexiang Liu, Boyang Wang, Yifan Zhang, et al.*
+
+ [[Paper](https://arxiv.org/abs/2508.13009)] [[Code](https://matrix-game-v2.github.io/)]   
+ BibTex
+
+ ```text
+ @article{he2025matrix,
+ title={Matrix-Game 2.0: An Open-Source, Real-Time, and Streaming Interactive World Model},
+ author={He, Xianglong and Peng, Chunli and Liu, Zexiang and Wang, Boyang and Zhang, Yifan and Cui, Qi and Kang, Fei and Jiang, Biao and An, Mengyin and Ren, Yangyang and others},
+ journal={arXiv preprint arXiv:2508.13009},
+ year={2025}
+ }
+ ```
+
+
+- **HunyuanWorld 1.0: Generating Immersive, Explorable, and Interactive 3D Worlds from Words or Pixels**, Technical Report 2025.
+
+ *HunyuanWorld Team, Zhenwei Wang, Yuhao Liu, Junta Wu, Zixiao Gu, Haoyuan Wang, et al.*
+
+ [[Paper](https://arxiv.org/abs/2507.21809)] [[Code](https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0)]   
+ BibTex
+
+ ```text
+ @article{team2025hunyuanworld,
+ title={HunyuanWorld 1.0: Generating Immersive, Explorable, and Interactive 3D Worlds from Words or Pixels},
+ author={Team, HunyuanWorld and Wang, Zhenwei and Liu, Yuhao and Wu, Junta and Gu, Zixiao and Wang, Haoyuan and Zuo, Xuhui and Huang, Tianyu and Li, Wenhuan and Zhang, Sheng and others},
+ journal={arXiv preprint arXiv:2507.21809},
+ year={2025}
+ }
+ ```
+
+
+- **Cosmos-Drive-Dreams: Scalable Synthetic Driving Data Generation with World Foundation Models**, Technical Report 2025.
+
+ *Xuanchi Ren, Yifan Lu, Tianshi Cao, Ruiyuan Gao, Shengyu Huang, Amirmojtaba Sabour, et al.*
+
+ [[Paper](https://arxiv.org/abs/2506.09042)] [[Code](https://research.nvidia.com/labs/toronto-ai/cosmos_drive_dreams)]  
+ BibTex
+
+ ```text
+ @article{ren2025cosmos,
+ title={Cosmos-Drive-Dreams: Scalable Synthetic Driving Data Generation with World Foundation Models},
+ author={Ren, Xuanchi and Lu, Yifan and Cao, Tianshi and Gao, Ruiyuan and Huang, Shengyu and Sabour, Amirmojtaba and Shen, Tianchang and Pfaff, Tobias and Wu, Jay Zhangjie and Chen, Runjian and others},
+ journal={arXiv preprint arXiv:2506.09042},
+ year={2025}
+ }
+ ```
+
+
+- **Genie 3: A new frontier for world models**, Blog 2025.
+
+ *Google DeepMind*
+
+ [[Blog](https://deepmind.google/discover/blog/genie-3-a-new-frontier-for-world-models/)]  
+
+- **GAIA-2: A Controllable Multi-View Generative World Model for Autonomous Driving.**, Technical Report 2025.
+
+ *Lloyd Russell, Anthony Hu, Lorenzo Bertoni, George Fedoseev, Jamie Shotton, et al.*
+
+ [[Paper](https://arxiv.org/abs/2503.20523)] [[Code](https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0)]  
+ BibTex
+
+ ```text
+ @article{russell2025gaia,
+ title={Gaia-2: A controllable multi-view generative world model for autonomous driving},
+ author={Russell, Lloyd and Hu, Anthony and Bertoni, Lorenzo and Fedoseev, George and Shotton, Jamie and Arani, Elahe and Corrado, Gianluca},
+ journal={arXiv preprint arXiv:2503.20523},
+ year={2025}
+ }
+ ```
+
diff --git a/docs/PAPERS_ZH_CN/source/papers/offload.md b/docs/PAPERS_ZH_CN/source/papers/offload.md
new file mode 100644
index 0000000000000000000000000000000000000000..302f0b6869df1d6263f2748a7dfd11bd7dae4f1e
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/offload.md
@@ -0,0 +1,3 @@
+# 参数卸载
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/parallel.md b/docs/PAPERS_ZH_CN/source/papers/parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..b68de1c0ebdc89c76e3155a1c22daed2ee3a7a8d
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/parallel.md
@@ -0,0 +1,3 @@
+# 并行推理
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/prompt_enhance.md b/docs/PAPERS_ZH_CN/source/papers/prompt_enhance.md
new file mode 100644
index 0000000000000000000000000000000000000000..d51d77bf8a8338c6eb04de88a09eb7cf65d01681
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/prompt_enhance.md
@@ -0,0 +1,3 @@
+# prompt增强
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/quantization.md b/docs/PAPERS_ZH_CN/source/papers/quantization.md
new file mode 100644
index 0000000000000000000000000000000000000000..d29d8f6af2e8e44f9a179d7a33150286ec500504
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/quantization.md
@@ -0,0 +1,3 @@
+# 模型量化
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/step_distill.md b/docs/PAPERS_ZH_CN/source/papers/step_distill.md
new file mode 100644
index 0000000000000000000000000000000000000000..e5e772f9a881bdce0e42f0a2862741da9c630830
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/step_distill.md
@@ -0,0 +1,3 @@
+# 步数蒸馏
+
+xxx
diff --git a/docs/PAPERS_ZH_CN/source/papers/vae.md b/docs/PAPERS_ZH_CN/source/papers/vae.md
new file mode 100644
index 0000000000000000000000000000000000000000..914d54db1b4b307acceba09a66d1eb5bcd9bbee3
--- /dev/null
+++ b/docs/PAPERS_ZH_CN/source/papers/vae.md
@@ -0,0 +1,3 @@
+# vae加速
+
+xxx
diff --git a/docs/ZH_CN/.readthedocs.yaml b/docs/ZH_CN/.readthedocs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c3677c6092fd6559cf83d4663cfcf32df98d891a
--- /dev/null
+++ b/docs/ZH_CN/.readthedocs.yaml
@@ -0,0 +1,17 @@
+version: 2
+
+# Set the version of Python and other tools you might need
+build:
+ os: ubuntu-20.04
+ tools:
+ python: "3.10"
+
+formats:
+ - epub
+
+sphinx:
+ configuration: docs/ZH_CN/source/conf.py
+
+python:
+ install:
+ - requirements: requirements-docs.txt
diff --git a/docs/ZH_CN/Makefile b/docs/ZH_CN/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293
--- /dev/null
+++ b/docs/ZH_CN/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/ZH_CN/make.bat b/docs/ZH_CN/make.bat
new file mode 100644
index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a
--- /dev/null
+++ b/docs/ZH_CN/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.https://www.sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/ZH_CN/source/conf.py b/docs/ZH_CN/source/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..659c3130dff364844704bfed3350f0ff797fd9f0
--- /dev/null
+++ b/docs/ZH_CN/source/conf.py
@@ -0,0 +1,128 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import logging
+import os
+import sys
+from typing import List
+
+import sphinxcontrib.redoc
+from sphinx.ext import autodoc
+
+logger = logging.getLogger(__name__)
+sys.path.append(os.path.abspath("../.."))
+
+# -- Project information -----------------------------------------------------
+
+project = "Lightx2v"
+copyright = "2025, Lightx2v Team"
+author = "the Lightx2v Team"
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "sphinx.ext.napoleon",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.intersphinx",
+ "sphinx_copybutton",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.mathjax",
+ "myst_parser",
+ "sphinxarg.ext",
+ "sphinxcontrib.redoc",
+ "sphinxcontrib.openapi",
+]
+
+myst_enable_extensions = [
+ "dollarmath",
+ "amsmath",
+]
+
+html_static_path = ["_static"]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns: List[str] = ["**/*.template.rst"]
+
+# Exclude the prompt "$" when copying code
+copybutton_prompt_text = r"\$ "
+copybutton_prompt_is_regexp = True
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_title = project
+html_theme = "sphinx_book_theme"
+# html_theme = 'sphinx_rtd_theme'
+html_logo = "../../../assets/img_lightx2v.png"
+html_theme_options = {
+ "path_to_docs": "docs/ZH_CN/source",
+ "repository_url": "https://github.com/ModelTC/lightx2v",
+ "use_repository_button": True,
+}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+# html_static_path = ['_static']
+
+
+# Generate additional rst documentation here.
+def setup(app):
+ # from docs.source.generate_examples import generate_examples
+ # generate_examples()
+ pass
+
+
+# Mock out external dependencies here.
+autodoc_mock_imports = [
+ "cpuinfo",
+ "torch",
+ "transformers",
+ "psutil",
+ "prometheus_client",
+ "sentencepiece",
+ "lightllmnumpy",
+ "tqdm",
+ "tensorizer",
+]
+
+for mock_target in autodoc_mock_imports:
+ if mock_target in sys.modules:
+ logger.info(
+ "Potentially problematic mock target (%s) found; autodoc_mock_imports cannot mock modules that have already been loaded into sys.modules when the sphinx build starts.",
+ mock_target,
+ )
+
+
+class MockedClassDocumenter(autodoc.ClassDocumenter):
+ """Remove note about base class when a class is derived from object."""
+
+ def add_line(self, line: str, source: str, *lineno: int) -> None:
+ if line == " Bases: :py:class:`object`":
+ return
+ super().add_line(line, source, *lineno)
+
+
+autodoc.ClassDocumenter = MockedClassDocumenter
+
+navigation_with_keys = False
diff --git a/docs/ZH_CN/source/deploy_guides/deploy_comfyui.md b/docs/ZH_CN/source/deploy_guides/deploy_comfyui.md
new file mode 100644
index 0000000000000000000000000000000000000000..9355731c754462092d18b36cc46dca1c0761f8c4
--- /dev/null
+++ b/docs/ZH_CN/source/deploy_guides/deploy_comfyui.md
@@ -0,0 +1,25 @@
+# ComfyUI 部署
+
+## ComfyUI-Lightx2vWrapper
+
+LightX2V 的官方 ComfyUI 集成节点已经发布在独立仓库中,提供了完整的模块化配置系统和优化功能。
+
+### 项目地址
+
+- GitHub: [https://github.com/ModelTC/ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper)
+
+### 主要特性
+
+- 模块化配置系统:为视频生成的各个方面提供独立节点
+- 支持文生视频(T2V)和图生视频(I2V)两种生成模式
+- 高级优化功能:
+ - TeaCache 加速(最高 3 倍加速)
+ - 量化支持(int8、fp8)
+ - CPU 卸载内存优化
+ - 轻量级 VAE 选项
+- LoRA 支持:可链式组合多个 LoRA 模型
+- 多模型支持:wan2.1、hunyuan 等架构
+
+### 安装和使用
+
+请访问上述 GitHub 仓库查看详细的安装说明、使用教程和示例工作流。
diff --git a/docs/ZH_CN/source/deploy_guides/deploy_gradio.md b/docs/ZH_CN/source/deploy_guides/deploy_gradio.md
new file mode 100644
index 0000000000000000000000000000000000000000..a2f16b57c8144ace18a01bd7001dddc2a1dc69a5
--- /dev/null
+++ b/docs/ZH_CN/source/deploy_guides/deploy_gradio.md
@@ -0,0 +1,241 @@
+# Gradio 部署指南
+
+## 📖 概述
+
+Lightx2v 是一个轻量级的视频推理和生成引擎,提供基于 Gradio 的 Web 界面,支持图像到视频(Image-to-Video)和文本到视频(Text-to-Video)两种生成模式。
+
+对于Windows系统,我们提供了便捷的一键部署方式,支持自动环境配置和智能参数优化。详细操作请参考[一键启动Gradio](./deploy_local_windows.md/#一键启动gradio推荐)章节。
+
+
+
+## 📁 文件结构
+
+```
+LightX2V/app/
+├── gradio_demo.py # 英文界面演示
+├── gradio_demo_zh.py # 中文界面演示
+├── run_gradio.sh # 启动脚本
+├── README.md # 说明文档
+├── outputs/ # 生成视频保存目录
+└── inference_logs.log # 推理日志
+```
+
+本项目包含两个主要演示文件:
+- `gradio_demo.py` - 英文界面版本
+- `gradio_demo_zh.py` - 中文界面版本
+
+## 🚀 快速开始
+
+### 环境要求
+
+按照[快速开始文档](../getting_started/quickstart.md)安装环境
+
+#### 推荐优化库配置
+
+- ✅ [Flash attention](https://github.com/Dao-AILab/flash-attention)
+- ✅ [Sage attention](https://github.com/thu-ml/SageAttention)
+- ✅ [vllm-kernel](https://github.com/vllm-project/vllm)
+- ✅ [sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
+- ✅ [q8-kernel](https://github.com/KONAKONA666/q8_kernels) (仅支持ADA架构的GPU)
+
+可根据需要,按照各算子的项目主页教程进行安装。
+
+### 📥 模型下载
+
+可参考[模型结构文档](../getting_started/model_structure.md)下载完整模型(包含量化和非量化版本)或仅下载量化/非量化版本。
+
+#### wan2.1 模型目录结构
+
+```
+models/
+├── wan2.1_i2v_720p_lightx2v_4step.safetensors # 原始精度
+├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化
+├── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 量化
+├── wan2.1_i2v_720p_int8_lightx2v_4step_split # INT8 量化分block存储目录
+├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split # FP8 量化分block存储目录
+├── 其他权重(例如t2v)
+├── t5/clip/xlm-roberta-large/google # text和image encoder
+├── vae/lightvae/lighttae # vae
+└── config.json # 模型配置文件
+```
+
+#### wan2.2 模型目录结构
+
+```
+models/
+├── wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors # high noise 原始精度
+├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step_1030.safetensors # high noise FP8 量化
+├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030.safetensors # high noise INT8 量化
+├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030_split # high noise INT8 量化分block存储目录
+├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # low noise 原始精度
+├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # low noise FP8 量化
+├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # low noise INT8 量化
+├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step_split # low noise INT8 量化分block存储目录
+├── t5/clip/xlm-roberta-large/google # text和image encoder
+├── vae/lightvae/lighttae # vae
+└── config.json # 模型配置文件
+```
+
+**📝 下载说明**:
+
+- 模型权重可从 HuggingFace 下载:
+ - [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+ - [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+- Text 和 Image Encoder 可从 [Encoders](https://huggingface.co/lightx2v/Encoders) 下载
+- VAE 可从 [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) 下载
+- 对于 `xxx_split` 目录(例如 `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split`),即按照 block 存储的多个 safetensors,适用于内存不足的设备。例如内存 16GB 以内,请根据自身情况下载
+
+
+### 启动方式
+
+#### 方式一:使用启动脚本(推荐)
+
+**Linux 环境:**
+```bash
+# 1. 编辑启动脚本,配置相关路径
+cd app/
+vim run_gradio.sh
+
+# 需要修改的配置项:
+# - lightx2v_path: Lightx2v项目根目录路径
+# - model_path: 模型根目录路径(包含所有模型文件)
+
+# 💾 重要提示:建议将模型路径指向SSD存储位置
+# 例如:/mnt/ssd/models/ 或 /data/ssd/models/
+
+# 2. 运行启动脚本
+bash run_gradio.sh
+
+# 3. 或使用参数启动
+bash run_gradio.sh --lang zh --port 8032
+bash run_gradio.sh --lang en --port 7862
+```
+
+**Windows 环境:**
+```cmd
+# 1. 编辑启动脚本,配置相关路径
+cd app\
+notepad run_gradio_win.bat
+
+# 需要修改的配置项:
+# - lightx2v_path: Lightx2v项目根目录路径
+# - model_path: 模型根目录路径(包含所有模型文件)
+
+# 💾 重要提示:建议将模型路径指向SSD存储位置
+# 例如:D:\models\ 或 E:\models\
+
+# 2. 运行启动脚本
+run_gradio_win.bat
+
+# 3. 或使用参数启动
+run_gradio_win.bat --lang zh --port 8032
+run_gradio_win.bat --lang en --port 7862
+```
+
+#### 方式二:直接命令行启动
+
+```bash
+pip install -v git+https://github.com/ModelTC/LightX2V.git
+```
+
+**Linux 环境:**
+
+**中文界面版本:**
+```bash
+python gradio_demo_zh.py \
+ --model_path /path/to/models \
+ --server_name 0.0.0.0 \
+ --server_port 7862
+```
+
+**英文界面版本:**
+```bash
+python gradio_demo.py \
+ --model_path /path/to/models \
+ --server_name 0.0.0.0 \
+ --server_port 7862
+```
+
+**Windows 环境:**
+
+**中文界面版本:**
+```cmd
+python gradio_demo_zh.py ^
+ --model_path D:\models ^
+ --server_name 127.0.0.1 ^
+ --server_port 7862
+```
+
+**英文界面版本:**
+```cmd
+python gradio_demo.py ^
+ --model_path D:\models ^
+ --server_name 127.0.0.1 ^
+ --server_port 7862
+```
+
+**💡 提示**:模型类型(wan2.1/wan2.2)、任务类型(i2v/t2v)以及具体的模型文件选择均在 Web 界面中进行配置。
+
+## 📋 命令行参数
+
+| 参数 | 类型 | 必需 | 默认值 | 说明 |
+|------|------|------|--------|------|
+| `--model_path` | str | ✅ | - | 模型根目录路径(包含所有模型文件的目录) |
+| `--server_port` | int | ❌ | 7862 | 服务器端口 |
+| `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 |
+| `--output_dir` | str | ❌ | ./outputs | 输出视频保存目录 |
+
+**💡 说明**:模型类型(wan2.1/wan2.2)、任务类型(i2v/t2v)以及具体的模型文件选择均在 Web 界面中进行配置。
+
+## 🎯 功能特性
+
+### 模型配置
+
+- **模型类型**: 支持 wan2.1 和 wan2.2 两种模型架构
+- **任务类型**: 支持图像到视频(i2v)和文本到视频(t2v)两种生成模式
+- **模型选择**: 前端自动识别并筛选可用的模型文件,支持自动检测量化精度
+- **编码器配置**: 支持选择 T5 文本编码器、CLIP 图像编码器和 VAE 解码器
+- **算子选择**: 支持多种注意力算子和量化矩阵乘法算子,系统会根据安装状态自动排序
+
+### 输入参数
+
+- **提示词 (Prompt)**: 描述期望的视频内容
+- **负向提示词 (Negative Prompt)**: 指定不希望出现的元素
+- **输入图像**: i2v 模式下需要上传输入图像
+- **分辨率**: 支持多种预设分辨率(480p/540p/720p)
+- **随机种子**: 控制生成结果的随机性
+- **推理步数**: 影响生成质量和速度的平衡(蒸馏模型默认为 4 步)
+
+### 视频参数
+
+- **FPS**: 每秒帧数
+- **总帧数**: 视频长度
+- **CFG缩放因子**: 控制提示词影响强度(1-10,蒸馏模型默认为 1)
+- **分布偏移**: 控制生成风格偏离程度(0-10)
+
+## 🔧 自动配置功能
+
+系统会根据您的硬件配置(GPU 显存和 CPU 内存)自动配置最优推理选项,无需手动调整。启动时会自动应用最佳配置,包括:
+
+- **GPU 内存优化**: 根据显存大小自动启用 CPU 卸载、VAE 分块推理等
+- **CPU 内存优化**: 根据系统内存自动启用延迟加载、模块卸载等
+- **算子选择**: 自动选择已安装的最优算子(按优先级排序)
+- **量化配置**: 根据模型文件名自动检测并应用量化精度
+
+
+### 日志查看
+
+```bash
+# 查看推理日志
+tail -f inference_logs.log
+
+# 查看GPU使用情况
+nvidia-smi
+
+# 查看系统资源
+htop
+```
+
+欢迎提交Issue和Pull Request来改进这个项目!
+
+**注意**: 使用本工具生成的视频内容请遵守相关法律法规,不得用于非法用途。
diff --git a/docs/ZH_CN/source/deploy_guides/deploy_local_windows.md b/docs/ZH_CN/source/deploy_guides/deploy_local_windows.md
new file mode 100644
index 0000000000000000000000000000000000000000..e2baabbe703a7df6fbe68e7f51c140248b7be527
--- /dev/null
+++ b/docs/ZH_CN/source/deploy_guides/deploy_local_windows.md
@@ -0,0 +1,129 @@
+# Windows 本地部署指南
+
+## 📖 概述
+
+本文档将详细指导您在Windows环境下完成LightX2V的本地部署配置,包括批处理文件推理、Gradio Web界面推理等多种使用方式。
+
+## 🚀 快速开始
+
+### 环境要求
+
+#### 硬件要求
+- **GPU**: NVIDIA GPU,建议 8GB+ VRAM
+- **内存**: 建议 16GB+ RAM
+- **存储**: 强烈建议使用 SSD 固态硬盘,机械硬盘会导致模型加载缓慢
+
+## 🎯 使用方式
+
+### 方式一:使用批处理文件推理
+
+参考[快速开始文档](../getting_started/quickstart.md)安装环境,并使用[批处理文件](https://github.com/ModelTC/LightX2V/tree/main/scripts/win)运行。
+
+### 方式二:使用Gradio Web界面推理
+
+#### 手动配置Gradio
+
+参考[快速开始文档](../getting_started/quickstart.md)安装环境,参考[Gradio部署指南](./deploy_gradio.md)
+
+#### 一键启动Gradio(推荐)
+
+**📦 下载软件包**
+- [夸克网盘](https://pan.quark.cn/s/8af1162d7a15)
+
+**📁 目录结构**
+解压后,确保目录结构如下:
+
+```
+├── env/ # LightX2V 环境目录
+├── LightX2V/ # LightX2V 项目目录
+├── start_lightx2v.bat # 一键启动脚本
+├── lightx2v_config.txt # 配置文件
+├── LightX2V使用说明.txt # LightX2V使用说明
+├── outputs/ # 生成的视频保存目录
+└── models/ # 模型存放目录
+```
+
+**📥 下载模型**:
+
+可参考[模型结构文档](../getting_started/model_structure.md)或者[gradio部署文档](./deploy_gradio.md)下载完整模型(包含量化和非量化版本)或仅下载量化/非量化版本。
+
+
+**📋 配置参数**
+
+编辑 `lightx2v_config.txt` 文件,根据需要修改以下参数:
+
+```ini
+
+# 界面语言 (zh: 中文, en: 英文)
+lang=zh
+
+# 服务器端口
+port=8032
+
+# GPU设备ID (0, 1, 2...)
+gpu=0
+
+# 模型路径
+model_path=models/
+```
+
+**🚀 启动服务**
+
+双击运行 `start_lightx2v.bat` 文件,脚本将:
+1. 自动读取配置文件
+2. 验证模型路径和文件完整性
+3. 启动 Gradio Web 界面
+4. 自动打开浏览器访问服务
+
+
+
+
+**⚠️ 重要提示**:
+- **页面显示问题**: 如果网页打开空白或显示异常,请运行 `pip install --upgrade gradio` 升级Gradio版本。
+
+
+### 方式三:使用ComfyUI推理
+
+此说明将指导您如何下载与使用便携版的Lightx2v-ComfyUI环境,如此可以免去手动配置环境的步骤,适用于想要在Windows系统下快速开始体验使用Lightx2v加速视频生成的用户。
+
+#### 下载Windows便携环境:
+
+- [百度网盘下载](https://pan.baidu.com/s/1FVlicTXjmXJA1tAVvNCrBw?pwd=wfid),提取码:wfid
+
+便携环境中已经打包了所有Python运行相关的依赖,也包括ComfyUI和LightX2V的代码及其相关依赖,下载后解压即可使用。
+
+解压后对应的文件目录说明如下:
+
+```shell
+lightx2v_env
+├──📂 ComfyUI # ComfyUI代码
+├──📂 portable_python312_embed # 独立的Python环境
+└── run_nvidia_gpu.bat # Windows启动脚本(双击启动)
+```
+
+#### 启动ComfyUI
+
+直接双击run_nvidia_gpu.bat文件,系统会打开一个Command Prompt窗口并运行程序,一般第一次启动时间会比较久,请耐心等待,启动完成后会自动打开浏览器并出现ComfyUI的前端界面。
+
+
+
+LightX2V-ComfyUI的插件使用的是,[ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper),示例工作流可以从此项目中获取。
+
+#### 已测试显卡(offload模式)
+
+- 测试模型`Wan2.1-I2V-14B-480P`
+
+| 显卡型号 | 任务类型 | 显存容量 | 实际最大显存占用 | 实际最大内存占用 |
+|:----------|:-----------|:-----------|:-------- |:---------- |
+| 3090Ti | I2V | 24G | 6.1G | 7.1G |
+| 3080Ti | I2V | 12G | 6.1G | 7.1G |
+| 3060Ti | I2V | 8G | 6.1G | 7.1G |
+| 4070Ti Super | I2V | 16G | 6.1G | 7.1G |
+| 4070 | I2V | 12G | 6.1G | 7.1G |
+| 4060 | I2V | 8G | 6.1G | 7.1G |
+
+
+
+#### 环境打包和使用参考
+- [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
+- [Portable-Windows-ComfyUI-Docs](https://docs.comfy.org/zh-CN/installation/comfyui_portable_windows#portable-%E5%8F%8A%E8%87%AA%E9%83%A8%E7%BD%B2)
diff --git a/docs/ZH_CN/source/deploy_guides/deploy_service.md b/docs/ZH_CN/source/deploy_guides/deploy_service.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0111b32e8186661f862ecd339f0a2e59649b058
--- /dev/null
+++ b/docs/ZH_CN/source/deploy_guides/deploy_service.md
@@ -0,0 +1,88 @@
+# 服务化部署
+
+lightx2v 提供异步服务功能。代码入口点在 [这里](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
+
+### 启动服务
+
+```shell
+# 修改脚本中的路径
+bash scripts/start_server.sh
+```
+
+`--port 8000` 选项表示服务将绑定到本地机器的 `8000` 端口。您可以根据需要更改此端口。
+
+### 客户端发送请求
+
+```shell
+python scripts/post.py
+```
+
+服务端点:`/v1/tasks/`
+
+`scripts/post.py` 中的 `message` 参数如下:
+
+```python
+message = {
+ "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
+ "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ "image_path": ""
+}
+```
+
+1. `prompt`、`negative_prompt` 和 `image_path` 是视频生成的基本输入。`image_path` 可以是空字符串,表示不需要图像输入。
+
+
+### 客户端检查服务器状态
+
+```shell
+python scripts/check_status.py
+```
+
+服务端点包括:
+
+1. `/v1/service/status` 用于检查服务状态。返回服务是 `busy` 还是 `idle`。服务只有在 `idle` 时才接受新请求。
+
+2. `/v1/tasks/` 用于获取服务器接收和完成的所有任务。
+
+3. `/v1/tasks/{task_id}/status` 用于获取指定 `task_id` 的任务状态。返回任务是 `processing` 还是 `completed`。
+
+### 客户端随时停止服务器上的当前任务
+
+```shell
+python scripts/stop_running_task.py
+```
+
+服务端点:`/v1/tasks/running`
+
+终止任务后,服务器不会退出,而是返回等待新请求的状态。
+
+### 在单个节点上启动多个服务
+
+在单个节点上,您可以使用 `scripts/start_server.sh` 启动多个服务(注意同一 IP 下的端口号必须不同),或者可以使用 `scripts/start_multi_servers.sh` 同时启动多个服务:
+
+```shell
+num_gpus=8 bash scripts/start_multi_servers.sh
+```
+
+其中 `num_gpus` 表示要启动的服务数量;服务将从 `--start_port` 开始在连续端口上运行。
+
+### 多个服务之间的调度
+
+```shell
+python scripts/post_multi_servers.py
+```
+
+`post_multi_servers.py` 将根据服务的空闲状态调度多个客户端请求。
+
+### API 端点总结
+
+| 端点 | 方法 | 描述 |
+|------|------|------|
+| `/v1/tasks/` | POST | 创建视频生成任务 |
+| `/v1/tasks/form` | POST | 通过表单创建视频生成任务 |
+| `/v1/tasks/` | GET | 获取所有任务列表 |
+| `/v1/tasks/{task_id}/status` | GET | 获取指定任务状态 |
+| `/v1/tasks/{task_id}/result` | GET | 获取指定任务的结果视频文件 |
+| `/v1/tasks/running` | DELETE | 停止当前运行的任务 |
+| `/v1/files/download/{file_path}` | GET | 下载文件 |
+| `/v1/service/status` | GET | 获取服务状态 |
diff --git a/docs/ZH_CN/source/deploy_guides/for_low_latency.md b/docs/ZH_CN/source/deploy_guides/for_low_latency.md
new file mode 100644
index 0000000000000000000000000000000000000000..b101f6494a579773cf80ba0e886e42f802357646
--- /dev/null
+++ b/docs/ZH_CN/source/deploy_guides/for_low_latency.md
@@ -0,0 +1,42 @@
+# 低延迟场景部署
+
+在低延迟的场景,我们会追求更快的速度,忽略显存和内存开销等问题。我们提供两套方案:
+
+## 💡 方案一:步数蒸馏模型的推理
+
+该方案可以参考[步数蒸馏文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/step_distill.html)
+
+🧠 **步数蒸馏**是非常直接的视频生成模型的加速推理方案。从50步蒸馏到4步,耗时将缩短到原来的4/50。同时,该方案下,仍然可以和以下方案结合使用:
+1. [高效注意力机制方案](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html)
+2. [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html)
+
+## 💡 方案二:非步数蒸馏模型的推理
+
+步数蒸馏需要比较大的训练资源,以及步数蒸馏后的模型,可能会出现视频动态范围变差的问题。
+
+对于非步数蒸馏的原始模型,我们可以使用以下方案或者多种方案结合的方式进行加速:
+
+1. [并行推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/parallel.html) 进行多卡并行加速。
+2. [特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html) 降低实际推理步数。
+3. [高效注意力机制方案](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html) 加速 Attention 的推理。
+4. [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) 加速 Linear 层的推理。
+5. [变分辨率推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/changing_resolution.html) 降低中间推理步的分辨率。
+
+## 💡 使用Tiny VAE
+
+在某些情况下,VAE部分耗时会比较大,可以使用轻量级VAE进行加速,同时也可以降低一部分显存。
+
+```python
+{
+ "use_tae": true,
+ "tae_path": "/path to taew2_1.pth"
+}
+```
+taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载
+
+
+## ⚠️ 注意
+
+有一部分的加速方案之间目前无法结合使用,我们目前正在致力于解决这一问题。
+
+如有问题,欢迎在 [🐛 GitHub Issues](https://github.com/ModelTC/lightx2v/issues) 中进行错误报告或者功能请求
diff --git a/docs/ZH_CN/source/deploy_guides/for_low_resource.md b/docs/ZH_CN/source/deploy_guides/for_low_resource.md
new file mode 100644
index 0000000000000000000000000000000000000000..c5e50dea8edddd5242bfc0c08bfbc1109e867438
--- /dev/null
+++ b/docs/ZH_CN/source/deploy_guides/for_low_resource.md
@@ -0,0 +1,223 @@
+# Lightx2v 低资源部署指南
+
+## 📋 概述
+
+本指南专门针对硬件资源受限的环境,特别是**8GB显存 + 16/32GB内存**的配置,详细说明如何成功运行Lightx2v 14B模型进行480p和720p视频生成。
+
+Lightx2v是一个强大的视频生成模型,但在资源受限的环境下需要精心优化才能流畅运行。本指南将为您提供从硬件选择到软件配置的完整解决方案,确保您能够在有限的硬件条件下获得最佳的视频生成体验。
+
+## 🎯 目标硬件配置详解
+
+### 推荐硬件规格
+
+**GPU要求**:
+- **显存**: 8GB (RTX 3060/3070/4060/4060Ti 等)
+- **架构**: 支持CUDA的NVIDIA显卡
+
+**系统内存**:
+- **最低要求**: 16GB DDR4
+- **推荐配置**: 32GB DDR4/DDR5
+- **内存速度**: 建议3200MHz及以上
+
+**存储要求**:
+- **类型**: 强烈推荐NVMe SSD
+- **容量**: 至少50GB可用空间
+- **速度**: 读取速度建议3000MB/s以上
+
+**CPU要求**:
+- **核心数**: 建议8核心及以上
+- **频率**: 建议3.0GHz及以上
+- **架构**: 支持AVX2指令集
+
+## ⚙️ 核心优化策略详解
+
+### 1. 环境优化
+
+在运行Lightx2v之前,建议设置以下环境变量以优化性能:
+
+```bash
+# CUDA内存分配优化
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+
+# 启用CUDA Graph模式,提升推理性能
+export ENABLE_GRAPH_MODE=true
+
+# 使用BF16精度推理,减少显存占用(默认FP32精度)
+export DTYPE=BF16
+```
+
+**优化说明**:
+- `expandable_segments:True`: 允许CUDA内存段动态扩展,减少内存碎片
+- `ENABLE_GRAPH_MODE=true`: 启用CUDA Graph,减少内核启动开销
+- `DTYPE=BF16`: 使用BF16精度,在保持质量的同时减少显存占用
+
+### 2. 量化策略
+
+量化是低资源环境下的关键优化技术,通过降低模型精度来减少内存占用。
+
+#### 量化方案对比
+
+**FP8量化** (推荐用于RTX 40系列):
+```python
+# 适用于支持FP8的GPU,提供更好的精度
+dit_quant_scheme = "fp8" # DIT模型量化
+t5_quant_scheme = "fp8" # T5文本编码器量化
+clip_quant_scheme = "fp8" # CLIP视觉编码器量化
+```
+
+**INT8量化** (通用方案):
+```python
+# 适用于所有GPU,内存占用最小
+dit_quant_scheme = "int8" # 8位整数量化
+t5_quant_scheme = "int8" # 文本编码器量化
+clip_quant_scheme = "int8" # 视觉编码器量化
+```
+### 3. 高效算子选择指南
+
+选择合适的算子可以显著提升推理速度和减少内存占用。
+
+#### 注意力算子选择
+
+**推荐优先级**:
+1. **[Sage Attention](https://github.com/thu-ml/SageAttention)** (最高优先级)
+
+2. **[Flash Attention](https://github.com/Dao-AILab/flash-attention)** (通用方案)
+
+
+#### 矩阵乘算子选择
+
+**ADA架构显卡** (RTX 40系列):
+
+推荐优先级:
+1. **[q8-kernel](https://github.com/KONAKONA666/q8_kernels)** (最高性能,仅支持ADA架构)
+2. **[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)** (平衡方案)
+3. **[vllm-kernel](https://github.com/vllm-project/vllm)** (通用方案)
+
+**其他架构显卡**:
+1. **[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)** (推荐)
+2. **[vllm-kernel](https://github.com/vllm-project/vllm)** (备选)
+
+### 4. 参数卸载策略详解
+
+参数卸载技术允许模型在CPU和磁盘之间动态调度参数,突破显存限制。
+
+#### 三级卸载架构
+
+```python
+# 磁盘-CPU-GPU三级卸载配置
+cpu_offload=True # 启用CPU卸载
+t5_cpu_offload=True # 启用T5编码器CPU卸载
+offload_granularity=phase # DIT模型细粒度卸载
+t5_offload_granularity=block # T5编码器细粒度卸载
+lazy_load = True # 启用延迟加载机制
+num_disk_workers = 2 # 磁盘I/O工作线程数
+```
+
+#### 卸载策略详解
+
+**延迟加载机制**:
+- 模型参数按需从磁盘加载到CPU
+- 减少运行时内存占用
+- 支持大模型在有限内存下运行
+
+**磁盘存储优化**:
+- 使用高速SSD存储模型参数
+- 按照block分组存储模型文件
+- 参考转换脚本[文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md),转换时指定`--save_by_block`参数
+
+### 5. 显存优化技术详解
+
+针对720p视频生成的显存优化策略。
+
+#### CUDA内存管理
+
+```python
+# CUDA内存清理配置
+clean_cuda_cache = True # 及时清理GPU缓存
+rotary_chunk = True # 旋转位置编码分块计算
+rotary_chunk_size = 100 # 分块大小,可根据显存调整
+```
+
+#### 分块计算策略
+
+**旋转位置编码分块**:
+- 将长序列分成小块处理
+- 减少峰值显存占用
+- 保持计算精度
+
+### 6. VAE优化详解
+
+VAE (变分自编码器) 是视频生成的关键组件,优化VAE可以显著提升性能。
+
+#### VAE分块推理
+
+```python
+# VAE优化配置
+use_tiling_vae = True # 启用VAE分块推理
+```
+
+#### 轻量级VAE
+
+```python
+# VAE优化配置
+use_tae = True
+tae_path = "/path to taew2_1.pth"
+```
+taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载
+
+**VAE优化效果**:
+- 标准VAE: 基准性能,100%质量保持
+- 标准VAE分块: 降低显存,增加推理时间,100%质量保持
+- 轻量VAE: 极低显存,视频质量有损
+
+
+### 7. 模型选择策略
+
+选择合适的模型版本对低资源环境至关重要。
+
+#### 推荐模型对比
+
+**蒸馏模型** (强烈推荐):
+- ✅ **[Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v)**
+
+- ✅ **[Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v)**
+
+
+#### 性能优化建议
+
+使用上述蒸馏模型时,可以进一步优化性能:
+- 关闭CFG: `"enable_cfg": false`
+- 减少推理步数: `infer_step: 4`
+- 参考配置文件: [config](https://github.com/ModelTC/LightX2V/tree/main/configs/distill)
+
+## 🚀 完整配置示例
+
+### 预配置模板
+
+- **[14B模型480p视频生成配置](https://github.com/ModelTC/lightx2v/tree/main/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json)**
+
+- **[14B模型720p视频生成配置](https://github.com/ModelTC/lightx2v/tree/main/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json)**
+
+- **[1.3B模型720p视频生成配置](https://github.com/ModelTC/LightX2V/tree/main/configs/offload/block/wan_t2v_1_3b.json)**
+ - 1.3B模型推理瓶颈是T5 encoder,配置文件专门针对T5进行优化
+
+**[启动脚本](https://github.com/ModelTC/LightX2V/tree/main/scripts/wan/run_wan_i2v_lazy_load.sh)**
+
+
+## 📚 参考资源
+
+- [参数卸载机制文档](../method_tutorials/offload.md) - 深入了解卸载技术原理
+- [量化技术指南](../method_tutorials/quantization.md) - 量化技术详细说明
+- [Gradio部署指南](deploy_gradio.md) - Gradio部署详细说明
+
+## ⚠️ 重要注意事项
+
+1. **硬件要求**: 确保您的硬件满足最低配置要求
+2. **驱动版本**: 建议使用最新的NVIDIA驱动 (535+)
+3. **CUDA版本**: 确保CUDA版本与PyTorch兼容 (建议CUDA 11.8+)
+4. **存储空间**: 预留足够的磁盘空间用于模型缓存 (至少50GB)
+5. **网络环境**: 首次下载模型需要稳定的网络连接
+6. **环境变量**: 务必设置推荐的环境变量以优化性能
+
+
+**技术支持**: 如遇到问题,请提交Issue到项目仓库。
diff --git a/docs/ZH_CN/source/deploy_guides/lora_deploy.md b/docs/ZH_CN/source/deploy_guides/lora_deploy.md
new file mode 100644
index 0000000000000000000000000000000000000000..c75b71dc88c0d2a7f042ad2d06d77de10295f455
--- /dev/null
+++ b/docs/ZH_CN/source/deploy_guides/lora_deploy.md
@@ -0,0 +1,213 @@
+# LoRA 模型部署与相关工具
+
+LoRA (Low-Rank Adaptation) 是一种高效的模型微调技术,通过低秩矩阵分解显著减少可训练参数数量。LightX2V 全面支持 LoRA 技术,包括 LoRA 推理、LoRA 提取和 LoRA 合并等功能。
+
+## 🎯 LoRA 技术特性
+
+- **灵活部署**:支持动态加载和移除 LoRA 权重
+- **多种格式**:支持多种 LoRA 权重格式和命名约定
+- **工具完善**:提供完整的 LoRA 提取、合并工具链
+
+## 📜 LoRA 推理部署
+
+### 配置文件方式
+
+在配置文件中指定 LoRA 路径:
+
+```json
+{
+ "lora_configs": [
+ {
+ "path": "/path/to/your/lora.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
+```
+
+**配置参数说明:**
+
+- `lora_path`: LoRA 权重文件路径列表,支持多个 LoRA 同时加载
+- `strength_model`: LoRA 强度系数 (alpha),控制 LoRA 对原模型的影响程度
+
+### 命令行方式
+
+直接在命令行中指定 LoRA 路径(仅支持加载单个 LoRA):
+
+```bash
+python -m lightx2v.infer \
+ --model_cls wan2.1 \
+ --task t2v \
+ --model_path /path/to/model \
+ --config_json /path/to/config.json \
+ --lora_path /path/to/your/lora.safetensors \
+ --lora_strength 0.8 \
+ --prompt "Your prompt here"
+```
+
+### 多LoRA配置
+
+要使用多个具有不同强度的LoRA,请在配置JSON文件中指定:
+
+```json
+{
+ "lora_configs": [
+ {
+ "path": "/path/to/first_lora.safetensors",
+ "strength": 0.8
+ },
+ {
+ "path": "/path/to/second_lora.safetensors",
+ "strength": 0.5
+ }
+ ]
+}
+```
+
+### 支持的 LoRA 格式
+
+LightX2V 支持多种 LoRA 权重命名约定:
+
+| 格式类型 | 权重命名 | 说明 |
+|----------|----------|------|
+| **标准 LoRA** | `lora_A.weight`, `lora_B.weight` | 标准的 LoRA 矩阵分解格式 |
+| **Down/Up 格式** | `lora_down.weight`, `lora_up.weight` | 另一种常见的命名约定 |
+| **差值格式** | `diff` | `weight` 权重差值 |
+| **偏置差值** | `diff_b` | `bias` 权重差值 |
+| **调制差值** | `diff_m` | `modulation` 权重差值 |
+
+### 推理脚本示例
+
+**步数蒸馏 LoRA 推理:**
+
+```bash
+# T2V LoRA 推理
+bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh
+
+# I2V LoRA 推理
+bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh
+```
+
+**音频驱动 LoRA 推理:**
+
+```bash
+bash scripts/wan/run_wan_i2v_audio.sh
+```
+
+### API 服务中使用 LoRA
+
+在 API 服务中通过 [config 文件](wan_t2v_distill_4step_cfg_lora.json) 指定,对 [scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh) 中的启动命令进行修改:
+
+```bash
+python -m lightx2v.api_server \
+ --model_cls wan2.1_distill \
+ --task t2v \
+ --model_path $model_path \
+ --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora.json \
+ --port 8000 \
+ --nproc_per_node 1
+```
+
+## 🔧 LoRA 提取工具
+
+使用 `tools/extract/lora_extractor.py` 从两个模型的差异中提取 LoRA 权重。
+
+### 基本用法
+
+```bash
+python tools/extract/lora_extractor.py \
+ --source-model /path/to/base/model \
+ --target-model /path/to/finetuned/model \
+ --output /path/to/extracted/lora.safetensors \
+ --rank 32
+```
+
+### 参数说明
+
+| 参数 | 类型 | 必需 | 默认值 | 说明 |
+|------|------|------|--------|------|
+| `--source-model` | str | ✅ | - | 基础模型路径 |
+| `--target-model` | str | ✅ | - | 微调后模型路径 |
+| `--output` | str | ✅ | - | 输出 LoRA 文件路径 |
+| `--source-type` | str | ❌ | `safetensors` | 基础模型格式 (`safetensors`/`pytorch`) |
+| `--target-type` | str | ❌ | `safetensors` | 微调模型格式 (`safetensors`/`pytorch`) |
+| `--output-format` | str | ❌ | `safetensors` | 输出格式 (`safetensors`/`pytorch`) |
+| `--rank` | int | ❌ | `32` | LoRA 秩值 |
+| `--output-dtype` | str | ❌ | `bf16` | 输出数据类型 |
+| `--diff-only` | bool | ❌ | `False` | 仅保存权重差值,不进行 LoRA 分解 |
+
+### 高级用法示例
+
+**提取高秩 LoRA:**
+
+```bash
+python tools/extract/lora_extractor.py \
+ --source-model /path/to/base/model \
+ --target-model /path/to/finetuned/model \
+ --output /path/to/high_rank_lora.safetensors \
+ --rank 64 \
+ --output-dtype fp16
+```
+
+**仅保存权重差值:**
+
+```bash
+python tools/extract/lora_extractor.py \
+ --source-model /path/to/base/model \
+ --target-model /path/to/finetuned/model \
+ --output /path/to/weight_diff.safetensors \
+ --diff-only
+```
+
+## 🔀 LoRA 合并工具
+
+使用 `tools/extract/lora_merger.py` 将 LoRA 权重合并到基础模型中,以进行后续量化等操作。
+
+### 基本用法
+
+```bash
+python tools/extract/lora_merger.py \
+ --source-model /path/to/base/model \
+ --lora-model /path/to/lora.safetensors \
+ --output /path/to/merged/model.safetensors \
+ --alpha 1.0
+```
+
+### 参数说明
+
+| 参数 | 类型 | 必需 | 默认值 | 说明 |
+|------|------|------|--------|------|
+| `--source-model` | str | ✅ | 无 | 基础模型路径 |
+| `--lora-model` | str | ✅ | 无 | LoRA 权重路径 |
+| `--output` | str | ✅ | 无 | 输出合并模型路径 |
+| `--source-type` | str | ❌ | `safetensors` | 基础模型格式 |
+| `--lora-type` | str | ❌ | `safetensors` | LoRA 权重格式 |
+| `--output-format` | str | ❌ | `safetensors` | 输出格式 |
+| `--alpha` | float | ❌ | `1.0` | LoRA 合并强度 |
+| `--output-dtype` | str | ❌ | `bf16` | 输出数据类型 |
+
+### 高级用法示例
+
+**部分强度合并:**
+
+```bash
+python tools/extract/lora_merger.py \
+ --source-model /path/to/base/model \
+ --lora-model /path/to/lora.safetensors \
+ --output /path/to/merged_model.safetensors \
+ --alpha 0.7 \
+ --output-dtype fp32
+```
+
+**多格式支持:**
+
+```bash
+python tools/extract/lora_merger.py \
+ --source-model /path/to/base/model.pt \
+ --source-type pytorch \
+ --lora-model /path/to/lora.safetensors \
+ --lora-type safetensors \
+ --output /path/to/merged_model.safetensors \
+ --output-format safetensors \
+ --alpha 1.0
+```
diff --git a/docs/ZH_CN/source/getting_started/benchmark.md b/docs/ZH_CN/source/getting_started/benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..6adc116c38878349b9175f2e23c39cc67606c5a8
--- /dev/null
+++ b/docs/ZH_CN/source/getting_started/benchmark.md
@@ -0,0 +1,3 @@
+# 基准测试
+
+由于要展示一些视频的播放效果和详细的性能对比,您可以在这个[🔗 页面](https://github.com/ModelTC/LightX2V/blob/main/docs/ZH_CN/source/getting_started/benchmark_source.md)获得更好的展示效果以及相对应的文档内容。
diff --git a/docs/ZH_CN/source/getting_started/benchmark_source.md b/docs/ZH_CN/source/getting_started/benchmark_source.md
new file mode 100644
index 0000000000000000000000000000000000000000..c399d1837cee6a435b042f30b2af5b870cb53503
--- /dev/null
+++ b/docs/ZH_CN/source/getting_started/benchmark_source.md
@@ -0,0 +1,149 @@
+# 🚀 基准测试
+
+> 本文档展示了LightX2V在不同硬件环境下的性能测试结果,包括H200和RTX 4090平台的详细对比数据。
+
+---
+
+## 🖥️ H200 环境 (~140GB显存)
+
+### 📋 软件环境配置
+
+| 组件 | 版本 |
+|:-----|:-----|
+| **Python** | 3.11 |
+| **PyTorch** | 2.7.1+cu128 |
+| **SageAttention** | 2.2.0 |
+| **vLLM** | 0.9.2 |
+| **sgl-kernel** | 0.1.8 |
+
+---
+
+### 🎬 480P 5s视频测试
+
+**测试配置:**
+- **模型**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v)
+- **参数**: `infer_steps=40`, `seed=42`, `enable_cfg=True`
+
+#### 📊 性能对比表
+
+| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 |
+|:-----|:----------:|:---------------:|:------:|:--------:|
+| **Wan2.1 Official** | 366 | 71 | 1.0x | |
+| **FastVideo** | 292 | 26 | **1.25x** | |
+| **LightX2V_1** | 250 | 53 | **1.46x** | |
+| **LightX2V_2** | 216 | 50 | **1.70x** | |
+| **LightX2V_3** | 191 | 35 | **1.92x** | |
+| **LightX2V_3-Distill** | 14 | 35 | **🏆 20.85x** | |
+| **LightX2V_4** | 107 | 35 | **3.41x** | |
+
+---
+
+### 🎬 720P 5s视频测试
+
+**测试配置:**
+- **模型**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v)
+- **参数**: `infer_steps=40`, `seed=1234`, `enable_cfg=True`
+
+#### 📊 性能对比表
+
+| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 |
+|:-----|:----------:|:---------------:|:------:|:--------:|
+| **Wan2.1 Official** | 974 | 81 | 1.0x | |
+| **FastVideo** | 914 | 40 | **1.07x** | |
+| **LightX2V_1** | 807 | 65 | **1.21x** | |
+| **LightX2V_2** | 751 | 57 | **1.30x** | |
+| **LightX2V_3** | 671 | 43 | **1.45x** | |
+| **LightX2V_3-Distill** | 44 | 43 | **🏆 22.14x** | |
+| **LightX2V_4** | 344 | 46 | **2.83x** | |
+
+---
+
+## 🖥️ RTX 4090 环境 (~24GB显存)
+
+### 📋 软件环境配置
+
+| 组件 | 版本 |
+|:-----|:-----|
+| **Python** | 3.9.16 |
+| **PyTorch** | 2.5.1+cu124 |
+| **SageAttention** | 2.1.0 |
+| **vLLM** | 0.6.6 |
+| **sgl-kernel** | 0.0.5 |
+| **q8-kernels** | 0.0.0 |
+
+---
+
+### 🎬 480P 5s视频测试
+
+**测试配置:**
+- **模型**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v)
+- **参数**: `infer_steps=40`, `seed=42`, `enable_cfg=True`
+
+#### 📊 性能对比表
+
+| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 |
+|:-----|:----------:|:---------------:|:------:|:--------:|
+| **Wan2GP(profile=3)** | 779 | 20 | **1.0x** | |
+| **LightX2V_5** | 738 | 16 | **1.05x** | |
+| **LightX2V_5-Distill** | 68 | 16 | **11.45x** | |
+| **LightX2V_6** | 630 | 12 | **1.24x** | |
+| **LightX2V_6-Distill** | 63 | 12 | **🏆 12.36x** | |
+
+---
+
+### 🎬 720P 5s视频测试
+
+**测试配置:**
+- **模型**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v)
+- **参数**: `infer_steps=40`, `seed=1234`, `enable_cfg=True`
+
+#### 📊 性能对比表
+
+| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 |
+|:-----|:----------:|:---------------:|:------:|:--------:|
+| **Wan2GP(profile=3)** | -- | OOM | -- | |
+| **LightX2V_5** | 2473 | 23 | -- | |
+| **LightX2V_5-Distill** | 183 | 23 | -- | |
+| **LightX2V_6** | 2169 | 18 | -- | |
+| **LightX2V_6-Distill** | 171 | 18 | -- | |
+
+---
+
+## 📖 配置说明
+
+### 🖥️ H200 环境配置说明
+
+| 配置 | 技术特点 |
+|:-----|:---------|
+| **Wan2.1 Official** | 基于[Wan2.1官方仓库](https://github.com/Wan-Video/Wan2.1)的原始实现 |
+| **FastVideo** | 基于[FastVideo官方仓库](https://github.com/hao-ai-lab/FastVideo),使用SageAttention2后端优化 |
+| **LightX2V_1** | 使用SageAttention2替换原生注意力机制,采用DIT BF16+FP32(部分敏感层)混合精度计算,在保持精度的同时提升计算效率 |
+| **LightX2V_2** | 统一使用BF16精度计算,进一步减少显存占用和计算开销,同时保持生成质量 |
+| **LightX2V_3** | 引入FP8量化技术显著减少计算精度要求,结合Tiling VAE技术优化显存使用 |
+| **LightX2V_3-Distill** | 在LightX2V_3基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 |
+| **LightX2V_4** | 在LightX2V_3基础上加入TeaCache(teacache_thresh=0.2)缓存复用技术,通过智能跳过冗余计算实现加速 |
+
+### 🖥️ RTX 4090 环境配置说明
+
+| 配置 | 技术特点 |
+|:-----|:---------|
+| **Wan2GP(profile=3)** | 基于[Wan2GP仓库](https://github.com/deepbeepmeep/Wan2GP)实现,使用MMGP优化技术。profile=3配置适用于至少32GB内存和24GB显存的RTX 3090/4090环境,通过牺牲显存来适应有限的内存资源。使用量化模型:[480P模型](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors)和[720P模型](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors) |
+| **LightX2V_5** | 使用SageAttention2替换原生注意力机制,采用DIT FP8+FP32(部分敏感层)混合精度计算,启用CPU offload技术,将部分敏感层执行FP32精度计算,将DIT推理过程中异步数据卸载到CPU上,节省显存,offload粒度为block级别 |
+| **LightX2V_5-Distill** | 在LightX2V_5基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 |
+| **LightX2V_6** | 在LightX2V_3基础上启用CPU offload技术,将部分敏感层执行FP32精度计算,将DIT推理过程中异步数据卸载到CPU上,节省显存,offload粒度为block级别 |
+| **LightX2V_6-Distill** | 在LightX2V_6基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 |
+
+---
+
+## 📁 配置文件参考
+
+基准测试相关的配置文件和运行脚本可在以下位置获取:
+
+| 类型 | 链接 | 说明 |
+|:-----|:-----|:-----|
+| **配置文件** | [configs/bench](https://github.com/ModelTC/LightX2V/tree/main/configs/bench) | 包含各种优化配置的JSON文件 |
+| **运行脚本** | [scripts/bench](https://github.com/ModelTC/LightX2V/tree/main/scripts/bench) | 包含基准测试的执行脚本 |
+
+---
+
+> 💡 **提示**: 建议根据您的硬件配置选择合适的优化方案,以获得最佳的性能表现。
diff --git a/docs/ZH_CN/source/getting_started/model_structure.md b/docs/ZH_CN/source/getting_started/model_structure.md
new file mode 100644
index 0000000000000000000000000000000000000000..8ceb48a107ef321c12829b487b2c5b1e9dcd1154
--- /dev/null
+++ b/docs/ZH_CN/source/getting_started/model_structure.md
@@ -0,0 +1,571 @@
+# 模型格式与加载指南
+
+## 📖 概述
+
+LightX2V 是一个灵活的视频生成推理框架,支持多种模型来源和格式,为用户提供丰富的选择:
+
+- ✅ **Wan 官方模型**:直接兼容 Wan2.1 和 Wan2.2 官方发布的完整模型
+- ✅ **单文件模型**:支持 LightX2V 发布的单文件格式模型(包含量化版本)
+- ✅ **LoRA 模型**:支持加载 LightX2V 发布的蒸馏 LoRA
+
+本文档将详细介绍各种模型格式的使用方法、配置参数和最佳实践。
+
+---
+
+## 🗂️ 格式一:Wan 官方模型
+
+### 模型仓库
+- [Wan2.1 Collection](https://huggingface.co/collections/Wan-AI/wan21-68ac4ba85372ae5a8e282a1b)
+- [Wan2.2 Collection](https://huggingface.co/collections/Wan-AI/wan22-68ac4ae80a8b477e79636fc8)
+
+### 模型特点
+- **官方保证**:Wan-AI 官方发布的完整模型,质量最高
+- **完整组件**:包含所有必需的组件(DIT、T5、CLIP、VAE)
+- **原始精度**:使用 BF16/FP32 精度,无量化损失
+- **兼容性强**:与 Wan 官方工具链完全兼容
+
+### Wan2.1 官方模型
+
+#### 目录结构
+
+以 [Wan2.1-I2V-14B-720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 为例:
+
+```
+Wan2.1-I2V-14B-720P/
+├── diffusion_pytorch_model-00001-of-00007.safetensors # DIT 模型分片 1
+├── diffusion_pytorch_model-00002-of-00007.safetensors # DIT 模型分片 2
+├── diffusion_pytorch_model-00003-of-00007.safetensors # DIT 模型分片 3
+├── diffusion_pytorch_model-00004-of-00007.safetensors # DIT 模型分片 4
+├── diffusion_pytorch_model-00005-of-00007.safetensors # DIT 模型分片 5
+├── diffusion_pytorch_model-00006-of-00007.safetensors # DIT 模型分片 6
+├── diffusion_pytorch_model-00007-of-00007.safetensors # DIT 模型分片 7
+├── diffusion_pytorch_model.safetensors.index.json # 分片索引文件
+├── models_t5_umt5-xxl-enc-bf16.pth # T5 文本编码器
+├── models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth # CLIP 编码器
+├── Wan2.1_VAE.pth # VAE 编解码器
+├── config.json # 模型配置
+├── xlm-roberta-large/ # CLIP tokenizer
+├── google/ # T5 tokenizer
+├── assets/
+└── examples/
+```
+
+#### 使用方法
+
+```bash
+# 下载模型
+huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P \
+ --local-dir ./models/Wan2.1-I2V-14B-720P
+
+# 配置启动脚本
+model_path=./models/Wan2.1-I2V-14B-720P
+lightx2v_path=/path/to/LightX2V
+
+# 运行推理
+cd LightX2V/scripts
+bash wan/run_wan_i2v.sh
+```
+
+### Wan2.2 官方模型
+
+#### 目录结构
+
+以 [Wan2.2-I2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) 为例:
+
+```
+Wan2.2-I2V-A14B/
+├── high_noise_model/ # 高噪声模型目录
+│ ├── diffusion_pytorch_model-00001-of-00009.safetensors
+│ ├── diffusion_pytorch_model-00002-of-00009.safetensors
+│ ├── ...
+│ ├── diffusion_pytorch_model-00009-of-00009.safetensors
+│ └── diffusion_pytorch_model.safetensors.index.json
+├── low_noise_model/ # 低噪声模型目录
+│ ├── diffusion_pytorch_model-00001-of-00009.safetensors
+│ ├── diffusion_pytorch_model-00002-of-00009.safetensors
+│ ├── ...
+│ ├── diffusion_pytorch_model-00009-of-00009.safetensors
+│ └── diffusion_pytorch_model.safetensors.index.json
+├── models_t5_umt5-xxl-enc-bf16.pth # T5 文本编码器
+├── Wan2.1_VAE.pth # VAE 编解码器
+├── configuration.json # 模型配置
+├── google/ # T5 tokenizer
+├── assets/ # 示例资源(可选)
+└── examples/ # 示例文件(可选)
+```
+
+#### 使用方法
+
+```bash
+# 下载模型
+huggingface-cli download Wan-AI/Wan2.2-I2V-A14B \
+ --local-dir ./models/Wan2.2-I2V-A14B
+
+# 配置启动脚本
+model_path=./models/Wan2.2-I2V-A14B
+lightx2v_path=/path/to/LightX2V
+
+# 运行推理
+cd LightX2V/scripts
+bash wan22/run_wan22_moe_i2v.sh
+```
+
+### 可用模型列表
+
+#### Wan2.1 官方模型列表
+
+| 模型名称 | 下载链接 |
+|---------|----------|
+| Wan2.1-I2V-14B-720P | [链接](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) |
+| Wan2.1-I2V-14B-480P | [链接](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) |
+| Wan2.1-T2V-14B | [链接](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) |
+| Wan2.1-T2V-1.3B | [链接](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) |
+| Wan2.1-FLF2V-14B-720P | [链接](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) |
+| Wan2.1-VACE-14B | [链接](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) |
+| Wan2.1-VACE-1.3B | [链接](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) |
+
+#### Wan2.2 官方模型列表
+
+| 模型名称 | 下载链接 |
+|---------|----------|
+| Wan2.2-I2V-A14B | [链接](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) |
+| Wan2.2-T2V-A14B | [链接](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B) |
+| Wan2.2-TI2V-5B | [链接](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B) |
+| Wan2.2-Animate-14B | [链接](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B) |
+
+### 使用提示
+
+> 💡 **量化模型使用**:如需使用量化模型,可参考[模型转换脚本](https://github.com/ModelTC/LightX2V/blob/main/tools/convert/readme_zh.md)进行转换,或直接使用下方格式二中的预转换量化模型
+>
+> 💡 **显存优化**:对于 RTX 4090 24GB 或更小显存的设备,建议结合量化技术和 CPU 卸载功能:
+> - 量化配置:参考[量化技术文档](../method_tutorials/quantization.md)
+> - CPU 卸载:参考[参数卸载文档](../method_tutorials/offload.md)
+> - Wan2.1 配置:参考 [offload 配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/offload)
+> - Wan2.2 配置:参考 [wan22 配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/wan22) 中以 `4090` 结尾的配置
+
+---
+
+## 🗂️ 格式二:LightX2V 单文件模型(推荐)
+
+### 模型仓库
+- [Wan2.1-LightX2V](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- [Wan2.2-LightX2V](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+
+### 模型特点
+- **单文件管理**:单个 safetensors 文件,易于管理和部署
+- **多精度支持**:提供原始精度、FP8、INT8 等多种精度版本
+- **蒸馏加速**:支持 4-step 快速推理
+- **工具兼容**:兼容 ComfyUI 等其他工具
+
+**示例**:
+- `wan2.1_i2v_720p_lightx2v_4step.safetensors` - 720P 图生视频原始精度
+- `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors` - 720P 图生视频 FP8 量化
+- `wan2.1_i2v_480p_int8_lightx2v_4step.safetensors` - 480P 图生视频 INT8 量化
+- ...
+
+### Wan2.1 单文件模型
+
+#### 场景 A:下载单个模型文件
+
+**步骤 1:选择并下载模型**
+
+```bash
+# 创建模型目录
+mkdir -p ./models/wan2.1_i2v_720p
+
+# 下载 720P 图生视频 FP8 量化模型
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models/wan2.1_i2v_720p \
+ --include "wan2.1_i2v_720p_lightx2v_4step.safetensors"
+```
+
+**步骤 2:手动组织其他模块**
+
+目录结构如下
+```
+wan2.1_i2v_720p/
+├── wan2.1_i2v_720p_lightx2v_4step.safetensors # 原始精度
+└── t5/clip/vae/config.json/xlm-roberta-large/google等其他组件 # 需要手动组织
+```
+
+**步骤 3:配置启动脚本**
+
+```bash
+# 在启动脚本中设置(指向包含模型文件的目录)
+model_path=./models/wan2.1_i2v_720p
+lightx2v_path=/path/to/LightX2V
+
+# 运行脚本
+cd LightX2V/scripts
+bash wan/run_wan_i2v_distill_4step_cfg.sh
+```
+
+> 💡 **提示**:当目录下只有一个模型文件时,LightX2V 会自动加载该文件。
+
+#### 场景 B:下载多个模型文件
+
+当您下载了多个不同精度的模型到同一目录时,需要在配置文件中明确指定使用哪个模型。
+
+**步骤 1:下载多个模型**
+
+```bash
+# 创建模型目录
+mkdir -p ./models/wan2.1_i2v_720p_multi
+
+# 下载原始精度模型
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models/wan2.1_i2v_720p_multi \
+ --include "wan2.1_i2v_720p_lightx2v_4step.safetensors"
+
+# 下载 FP8 量化模型
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models/wan2.1_i2v_720p_multi \
+ --include "wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors"
+
+# 下载 INT8 量化模型
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models/wan2.1_i2v_720p_multi \
+ --include "wan2.1_i2v_720p_int8_lightx2v_4step.safetensors"
+```
+
+**步骤 2:手动组织其他模块**
+
+目录结构如下:
+
+```
+wan2.1_i2v_720p_multi/
+├── wan2.1_i2v_720p_lightx2v_4step.safetensors # 原始精度
+├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化
+└── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 量化
+└── t5/clip/vae/config.json/xlm-roberta-large/google等其他组件 # 需要手动组织
+```
+
+**步骤 3:在配置文件中指定模型**
+
+编辑配置文件(如 `configs/distill/wan_i2v_distill_4step_cfg.json`):
+
+```json
+{
+ // 使用原始精度模型
+ "dit_original_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_lightx2v_4step.safetensors",
+
+ // 或使用 FP8 量化模型
+ // "dit_quantized_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors",
+ // "dit_quantized": true,
+ // "dit_quant_scheme": "fp8-vllm",
+
+ // 或使用 INT8 量化模型
+ // "dit_quantized_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_int8_lightx2v_4step.safetensors",
+ // "dit_quantized": true,
+ // "dit_quant_scheme": "int8-vllm",
+
+ // 其他配置...
+}
+```
+### 使用提示
+
+> 💡 **配置参数说明**:
+> - **dit_original_ckpt**:用于指定原始精度模型(BF16/FP32/FP16)的路径
+> - **dit_quantized_ckpt**:用于指定量化模型(FP8/INT8)的路径,需配合 `dit_quantized` 和 `dit_quant_scheme` 参数使用
+
+**步骤 4:启动推理**
+
+```bash
+cd LightX2V/scripts
+bash wan/run_wan_i2v_distill_4step_cfg.sh
+```
+
+### Wan2.2 单文件模型
+
+#### 目录结构要求
+
+使用 Wan2.2 单文件模型时,需要手动创建特定的目录结构:
+
+```
+wan2.2_models/
+├── high_noise_model/ # 高噪声模型目录(必须)
+│ └── wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors # 高噪声模型文件
+└── low_noise_model/ # 低噪声模型目录(必须)
+│ └── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # 低噪声模型文件
+└── t5/vae/config.json/xlm-roberta-large/google等其他组件 # 需要手动组织
+```
+
+#### 场景 A:每个目录下只有一个模型文件
+
+```bash
+# 创建必需的子目录
+mkdir -p ./models/wan2.2_models/high_noise_model
+mkdir -p ./models/wan2.2_models/low_noise_model
+
+# 下载高噪声模型到对应目录
+huggingface-cli download lightx2v/Wan2.2-Distill-Models \
+ --local-dir ./models/wan2.2_models/high_noise_model \
+ --include "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
+
+# 下载低噪声模型到对应目录
+huggingface-cli download lightx2v/Wan2.2-Distill-Models \
+ --local-dir ./models/wan2.2_models/low_noise_model \
+ --include "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
+
+# 配置启动脚本(指向父目录)
+model_path=./models/wan2.2_models
+lightx2v_path=/path/to/LightX2V
+
+# 运行脚本
+cd LightX2V/scripts
+bash wan22/run_wan22_moe_i2v_distill.sh
+```
+
+> 💡 **提示**:当每个子目录下只有一个模型文件时,LightX2V 会自动加载。
+
+#### 场景 B:每个目录下有多个模型文件
+
+当您在 `high_noise_model/` 和 `low_noise_model/` 目录下分别放置了多个不同精度的模型时,需要在配置文件中明确指定。
+
+```bash
+# 创建目录
+mkdir -p ./models/wan2.2_models_multi/high_noise_model
+mkdir -p ./models/wan2.2_models_multi/low_noise_model
+
+# 下载高噪声模型的多个版本
+huggingface-cli download lightx2v/Wan2.2-Distill-Models \
+ --local-dir ./models/wan2.2_models_multi/high_noise_model \
+ --include "wan2.2_i2v_A14b_high_noise_*.safetensors"
+
+# 下载低噪声模型的多个版本
+huggingface-cli download lightx2v/Wan2.2-Distill-Models \
+ --local-dir ./models/wan2.2_models_multi/low_noise_model \
+ --include "wan2.2_i2v_A14b_low_noise_*.safetensors"
+```
+
+**目录结构**:
+
+```
+wan2.2_models_multi/
+├── high_noise_model/
+│ ├── wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors # 原始精度
+│ ├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化
+│ └── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors # INT8 量化
+└── low_noise_model/
+│ ├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # 原始精度
+│ ├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化
+│ └── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # INT8 量化
+└── t5/vae/config.json/xlm-roberta-large/google等其他组件 # 需要手动组织
+```
+
+**配置文件设置**:
+
+```json
+{
+ // 使用原始精度模型
+ "high_noise_original_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors",
+ "low_noise_original_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors",
+
+ // 或使用 FP8 量化模型
+ // "high_noise_quantized_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step.safetensors",
+ // "low_noise_quantized_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors",
+ // "dit_quantized": true,
+ // "dit_quant_scheme": "fp8-vllm"
+
+ // 或使用 INT8 量化模型
+ // "high_noise_quantized_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors",
+ // "low_noise_quantized_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors",
+ // "dit_quantized": true,
+ // "dit_quant_scheme": "int8-vllm"
+}
+```
+
+### 使用提示
+
+> 💡 **配置参数说明**:
+> - **high_noise_original_ckpt** / **low_noise_original_ckpt**:用于指定原始精度模型(BF16/FP32/FP16)的路径
+> - **high_noise_quantized_ckpt** / **low_noise_quantized_ckpt**:用于指定量化模型(FP8/INT8)的路径,需配合 `dit_quantized` 和 `dit_quant_scheme` 参数使用
+
+
+### 可用模型列表
+
+#### Wan2.1 单文件模型列表
+
+**图生视频模型(I2V)**
+
+| 文件名 | 精度 | 说明 |
+|--------|------|------|
+| `wan2.1_i2v_480p_lightx2v_4step.safetensors` | BF16 | 4步模型原始精度 |
+| `wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4步模型FP8 量化 |
+| `wan2.1_i2v_480p_int8_lightx2v_4step.safetensors` | INT8 | 4步模型INT8 量化 |
+| `wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4步模型ComfyUI 格式 |
+| `wan2.1_i2v_720p_lightx2v_4step.safetensors` | BF16 | 4步模型原始精度 |
+| `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4步模型FP8 量化 |
+| `wan2.1_i2v_720p_int8_lightx2v_4step.safetensors` | INT8 | 4步模型INT8 量化 |
+| `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4步模型ComfyUI 格式 |
+
+**文生视频模型(T2V)**
+
+| 文件名 | 精度 | 说明 |
+|--------|------|------|
+| `wan2.1_t2v_14b_lightx2v_4step.safetensors` | BF16 | 4步模型原始精度 |
+| `wan2.1_t2v_14b_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4步模型FP8 量化 |
+| `wan2.1_t2v_14b_int8_lightx2v_4step.safetensors` | INT8 | 4步模型INT8 量化 |
+| `wan2.1_t2v_14b_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4步模型ComfyUI 格式 |
+
+#### Wan2.2 单文件模型列表
+
+**图生视频模型(I2V)- A14B 系列**
+
+| 文件名 | 精度 | 说明 |
+|--------|------|------|
+| `wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors` | BF16 | 高噪声模型-4步原始精度 |
+| `wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 高噪声模型-4步FP8量化 |
+| `wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors` | INT8 | 高噪声模型-4步INT8量化 |
+| `wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors` | BF16 | 低噪声模型-4步原始精度 |
+| `wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 低噪声模型-4步FP8量化 |
+| `wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors` | INT8 | 低噪声模型-4步INT8量化 |
+
+> 💡 **使用提示**:
+> - Wan2.2 模型采用双噪声架构,需要同时下载高噪声(high_noise)和低噪声(low_noise)模型
+> - 详细的目录组织方式请参考上方"Wan2.2 单文件模型"部分
+
+---
+
+## 🗂️ 格式三:LightX2V LoRA 模型
+
+LoRA(Low-Rank Adaptation)模型提供了一种轻量级的模型微调方案,可以在不修改基础模型的情况下实现特定效果的定制化。
+
+### 模型仓库
+
+- **Wan2.1 LoRA 模型**:[lightx2v/Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras)
+- **Wan2.2 LoRA 模型**:[lightx2v/Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras)
+
+### 使用方式
+
+#### 方式一:离线合并
+
+将 LoRA 权重离线合并到基础模型中,生成新的完整模型文件。
+
+**操作步骤**:
+
+参考 [模型转换文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) 进行离线合并。
+
+**优点**:
+- ✅ 推理时无需额外加载 LoRA
+- ✅ 性能更优
+
+**缺点**:
+- ❌ 需要额外存储空间
+- ❌ 切换不同 LoRA 需要重新合并
+
+#### 方式二:在线加载
+
+在推理时动态加载 LoRA 权重,无需修改基础模型。
+
+**LoRA 应用原理**:
+
+```python
+# LoRA 权重应用公式
+# lora_scale = (alpha / rank)
+# W' = W + lora_scale * B @ A
+# 其中:B = up_proj (out_features, rank)
+# A = down_proj (rank, in_features)
+
+if weights_dict["alpha"] is not None:
+ lora_scale = weights_dict["alpha"] / lora_down.shape[0]
+elif alpha is not None:
+ lora_scale = alpha / lora_down.shape[0]
+else:
+ lora_scale = 1.0
+```
+
+**配置方法**:
+
+**Wan2.1 LoRA 配置**:
+
+```json
+{
+ "lora_configs": [
+ {
+ "path": "wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0,
+ "alpha": null
+ }
+ ]
+}
+```
+
+**Wan2.2 LoRA 配置**:
+
+由于 Wan2.2 采用双模型架构(高噪声/低噪声),需要分别为两个模型配置 LoRA:
+
+```json
+{
+ "lora_configs": [
+ {
+ "name": "low_noise_model",
+ "path": "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0,
+ "alpha": null
+ },
+ {
+ "name": "high_noise_model",
+ "path": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step.safetensors",
+ "strength": 1.0,
+ "alpha": null
+ }
+ ]
+}
+```
+
+**参数说明**:
+
+| 参数 | 说明 | 默认值 |
+|------|------|--------|
+| `path` | LoRA 模型文件路径 | 必填 |
+| `strength` | LoRA 强度系数,范围 [0.0, 1.0] | 1.0 |
+| `alpha` | LoRA 缩放因子,`null` 时使用模型内置值 | null |
+| `name` | (仅 Wan2.2)指定应用到哪个模型 | 必填 |
+
+**优点**:
+- ✅ 灵活切换不同 LoRA
+- ✅ 节省存储空间
+- ✅ 可动态调整 LoRA 强度
+
+**缺点**:
+- ❌ 推理时需额外加载时间
+- ❌ 略微增加显存占用
+
+---
+
+## 📚 相关资源
+
+### 官方仓库
+- [LightX2V GitHub](https://github.com/ModelTC/LightX2V)
+- [LightX2V 单文件模型仓库](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- [Wan-AI 官方模型仓库](https://huggingface.co/Wan-AI)
+
+### 模型下载链接
+
+**Wan2.1 系列**
+- [Wan2.1 Collection](https://huggingface.co/collections/Wan-AI/wan21-68ac4ba85372ae5a8e282a1b)
+
+**Wan2.2 系列**
+- [Wan2.2 Collection](https://huggingface.co/collections/Wan-AI/wan22-68ac4ae80a8b477e79636fc8)
+
+**LightX2V 单文件模型**
+- [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+
+### 文档链接
+- [量化技术文档](../method_tutorials/quantization.md)
+- [参数卸载文档](../method_tutorials/offload.md)
+- [配置文件示例](https://github.com/ModelTC/LightX2V/tree/main/configs)
+
+---
+
+通过本文档,您应该能够:
+
+✅ 理解 LightX2V 支持的所有模型格式
+✅ 根据需求选择合适的模型和精度
+✅ 正确下载和组织模型文件
+✅ 配置启动参数并成功运行推理
+✅ 解决常见的模型加载问题
+
+如有其他问题,欢迎在 [GitHub Issues](https://github.com/ModelTC/LightX2V/issues) 中提问。
diff --git a/docs/ZH_CN/source/getting_started/quickstart.md b/docs/ZH_CN/source/getting_started/quickstart.md
new file mode 100644
index 0000000000000000000000000000000000000000..a62a1939cbc4a1f65169d4a2e11b49b2f70295d0
--- /dev/null
+++ b/docs/ZH_CN/source/getting_started/quickstart.md
@@ -0,0 +1,352 @@
+# LightX2V 快速入门指南
+
+欢迎使用 LightX2V!本指南将帮助您快速搭建环境并开始使用 LightX2V 进行视频生成。
+
+## 📋 目录
+
+- [系统要求](#系统要求)
+- [Linux 系统环境搭建](#linux-系统环境搭建)
+ - [Docker 环境(推荐)](#docker-环境推荐)
+ - [Conda 环境搭建](#conda-环境搭建)
+- [Windows 系统环境搭建](#windows-系统环境搭建)
+- [推理使用](#推理使用)
+
+## 🚀 系统要求
+
+- **操作系统**: Linux (Ubuntu 18.04+) 或 Windows 10/11
+- **Python**: 3.10 或更高版本
+- **GPU**: NVIDIA GPU,支持 CUDA,至少 8GB 显存
+- **内存**: 建议 16GB 或更多
+- **存储**: 至少 50GB 可用空间
+
+## 🐧 Linux 系统环境搭建
+
+### 🐳 Docker 环境(推荐)
+
+我们强烈推荐使用 Docker 环境,这是最简单快捷的安装方式。
+
+#### 1. 拉取镜像
+
+访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25111101-cu128`:
+
+```bash
+docker pull lightx2v/lightx2v:25111101-cu128
+```
+
+我们推荐使用`cuda128`环境,以获得更快的推理速度,若需要使用`cuda124`环境,可以使用带`-cu124`后缀的镜像版本:
+
+```bash
+docker pull lightx2v/lightx2v:25101501-cu124
+```
+
+#### 2. 运行容器
+
+```bash
+docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --entrypoint /bin/bash [镜像id]
+```
+
+#### 3. 中国镜像源(可选)
+
+对于中国大陆地区,如果拉取镜像时网络不稳定,可以从阿里云上拉取:
+
+```bash
+# cuda128
+docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25111101-cu128
+
+# cuda124
+docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25101501-cu124
+```
+
+### 🐍 Conda 环境搭建
+
+如果您希望使用 Conda 自行搭建环境,请按照以下步骤操作:
+
+#### 步骤 1: 克隆项目
+
+```bash
+# 下载项目代码
+git clone https://github.com/ModelTC/LightX2V.git
+cd LightX2V
+```
+
+#### 步骤 2: 创建 conda 虚拟环境
+
+```bash
+# 创建并激活 conda 环境
+conda create -n lightx2v python=3.11 -y
+conda activate lightx2v
+```
+
+#### 步骤 3: 安装依赖及代码
+
+```bash
+pip install -v -e .
+```
+
+#### 步骤 4: 安装注意力机制算子
+
+**选项 A: Flash Attention 2**
+```bash
+git clone https://github.com/Dao-AILab/flash-attention.git --recursive
+cd flash-attention && python setup.py install
+```
+
+**选项 B: Flash Attention 3(用于 Hopper 架构显卡)**
+```bash
+cd flash-attention/hopper && python setup.py install
+```
+
+**选项 C: SageAttention 2(推荐)**
+```bash
+git clone https://github.com/thu-ml/SageAttention.git
+cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0,12.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install -v -e .
+```
+
+#### 步骤 4: 安装量化算子(可选)
+
+量化算子用于支持模型量化功能,可以显著降低显存占用并加速推理。根据您的需求选择合适的量化算子:
+
+**选项 A: VLLM Kernels(推荐)**
+适用于多种量化方案,支持 FP8 等量化格式。
+
+```bash
+pip install vllm
+```
+
+或者从源码安装以获得最新功能:
+
+```bash
+git clone https://github.com/vllm-project/vllm.git
+cd vllm
+uv pip install -e .
+```
+
+**选项 B: SGL Kernels**
+适用于 SGL 量化方案,需要 torch == 2.8.0。
+
+```bash
+pip install sgl-kernel --upgrade
+```
+
+**选项 C: Q8 Kernels**
+适用于 Ada 架构显卡(如 RTX 4090、L40S 等)。
+
+```bash
+git clone https://github.com/KONAKONA666/q8_kernels.git
+cd q8_kernels && git submodule init && git submodule update
+python setup.py install
+```
+
+> 💡 **提示**:
+> - 如果不需要使用量化功能,可以跳过此步骤
+> - 量化模型可以从 [LightX2V HuggingFace](https://huggingface.co/lightx2v) 下载
+> - 更多量化相关信息请参考 [量化文档](method_tutorials/quantization.html)
+
+#### 步骤 5: 验证安装
+```python
+import lightx2v
+print(f"LightX2V 版本: {lightx2v.__version__}")
+```
+
+## 🪟 Windows 系统环境搭建
+
+Windows 系统仅支持 Conda 环境搭建方式,请按照以下步骤操作:
+
+### 🐍 Conda 环境搭建
+
+#### 步骤 1: 检查 CUDA 版本
+
+首先确认您的 GPU 驱动和 CUDA 版本:
+
+```cmd
+nvidia-smi
+```
+
+记录输出中的 **CUDA Version** 信息,后续安装时需要保持版本一致。
+
+#### 步骤 2: 创建 Python 环境
+
+```cmd
+# 创建新环境(推荐 Python 3.12)
+conda create -n lightx2v python=3.12 -y
+
+# 激活环境
+conda activate lightx2v
+```
+
+> 💡 **提示**: 建议使用 Python 3.10 或更高版本以获得最佳兼容性。
+
+#### 步骤 3: 安装 PyTorch 框架
+
+**方法一:下载官方 wheel 包(推荐)**
+
+1. 访问 [PyTorch 官方下载页面](https://download.pytorch.org/whl/torch/)
+2. 选择对应版本的 wheel 包,注意匹配:
+ - **Python 版本**: 与您的环境一致
+ - **CUDA 版本**: 与您的 GPU 驱动匹配
+ - **平台**: 选择 Windows 版本
+
+**示例(Python 3.12 + PyTorch 2.6 + CUDA 12.4):**
+
+```cmd
+# 下载并安装 PyTorch
+pip install torch-2.6.0+cu124-cp312-cp312-win_amd64.whl
+
+# 安装配套包
+pip install torchvision==0.21.0 torchaudio==2.6.0
+```
+
+**方法二:使用 pip 直接安装**
+
+```cmd
+# CUDA 12.4 版本示例
+pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
+```
+
+#### 步骤 4: 安装 Windows 版 vLLM
+
+从 [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases) 下载对应的 wheel 包。
+
+**版本匹配要求:**
+- Python 版本匹配
+- PyTorch 版本匹配
+- CUDA 版本匹配
+
+```cmd
+# 安装 vLLM(请根据实际文件名调整)
+pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl
+```
+
+#### 步骤 5: 安装注意力机制算子
+
+**选项 A: Flash Attention 2**
+
+```cmd
+pip install flash-attn==2.7.2.post1
+```
+
+**选项 B: SageAttention 2(强烈推荐)**
+
+**下载源:**
+- [Windows 专用版本 1](https://github.com/woct0rdho/SageAttention/releases)
+- [Windows 专用版本 2](https://github.com/sdbds/SageAttention-for-windows/releases)
+
+```cmd
+# 安装 SageAttention(请根据实际文件名调整)
+pip install sageattention-2.1.1+cu126torch2.6.0-cp312-cp312-win_amd64.whl
+```
+
+> ⚠️ **注意**: SageAttention 的 CUDA 版本可以不严格对齐,但 Python 和 PyTorch 版本必须匹配。
+
+#### 步骤 6: 克隆项目
+
+```cmd
+# 克隆项目代码
+git clone https://github.com/ModelTC/LightX2V.git
+cd LightX2V
+
+# 安装 Windows 专用依赖
+pip install -r requirements_win.txt
+pip install -v -e .
+```
+
+#### 步骤 7: 安装量化算子(可选)
+
+量化算子用于支持模型量化功能,可以显著降低显存占用并加速推理。
+
+**安装 VLLM(推荐):**
+
+从 [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases) 下载对应的 wheel 包并安装。
+
+```cmd
+# 安装 vLLM(请根据实际文件名调整)
+pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl
+```
+
+> 💡 **提示**:
+> - 如果不需要使用量化功能,可以跳过此步骤
+> - 量化模型可以从 [LightX2V HuggingFace](https://huggingface.co/lightx2v) 下载
+> - 更多量化相关信息请参考 [量化文档](method_tutorials/quantization.html)
+
+#### 步骤 8: 验证安装
+```python
+import lightx2v
+print(f"LightX2V 版本: {lightx2v.__version__}")
+```
+
+## 🎯 推理使用
+
+### 📥 模型准备
+
+在开始推理之前,您需要提前下载好模型文件。我们推荐:
+
+- **下载源**: 从 [LightX2V 官方 Hugging Face](https://huggingface.co/lightx2v/)或者其他开源模型库下载模型
+- **存储位置**: 建议将模型存储在 SSD 磁盘上以获得更好的读取性能
+- **可用模型**: 包括 Wan2.1-I2V、Wan2.1-T2V 等多种模型,支持不同分辨率和功能
+
+### 📁 配置文件与脚本
+
+推理会用到的配置文件都在[这里](https://github.com/ModelTC/LightX2V/tree/main/configs),脚本都在[这里](https://github.com/ModelTC/LightX2V/tree/main/scripts)。
+
+需要将下载的模型路径配置到运行脚本中。除了脚本中的输入参数,`--config_json` 指向的配置文件中也会包含一些必要参数,您可以根据需要自行修改。
+
+### 🚀 开始推理
+
+#### Linux 环境
+
+```bash
+# 修改脚本中的路径后运行
+bash scripts/wan/run_wan_t2v.sh
+```
+
+#### Windows 环境
+
+```cmd
+# 使用 Windows 批处理脚本
+scripts\win\run_wan_t2v.bat
+```
+#### Python脚本启动
+
+```python
+from lightx2v import LightX2VPipeline
+
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.1-T2V-14B",
+ model_cls="wan2.1",
+ task="t2v",
+)
+
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=50,
+ height=480, # 720
+ width=832, # 1280
+ num_frames=81,
+ guidance_scale=5.0,
+ sample_shift=5.0,
+)
+
+seed = 42
+prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+save_result_path="/path/to/save_results/output.mp4"
+
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
+```
+
+
+## 📞 获取帮助
+
+如果您在安装或使用过程中遇到问题,请:
+
+1. 在 [GitHub Issues](https://github.com/ModelTC/LightX2V/issues) 中搜索相关问题
+2. 提交新的 Issue 描述您的问题
+
+---
+
+🎉 **恭喜!** 现在您已经成功搭建了 LightX2V 环境,可以开始享受视频生成的乐趣了!
diff --git a/docs/ZH_CN/source/index.rst b/docs/ZH_CN/source/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..3ea4ebd747337c67518da537f3833879d0e4187a
--- /dev/null
+++ b/docs/ZH_CN/source/index.rst
@@ -0,0 +1,68 @@
+欢迎了解 Lightx2v!
+==================
+
+.. figure:: ../../../assets/img_lightx2v.png
+ :width: 80%
+ :align: center
+ :alt: Lightx2v
+ :class: no-scaled-link
+
+.. raw:: html
+
+
+
+
+ LightX2V: 一个轻量级的视频生成推理框架
+
+
+
+LightX2V 是一个轻量级的视频生成推理框架,集成多种先进的视频生成推理技术,统一支持 文本生成视频 (T2V)、图像生成视频 (I2V) 等多种生成任务及模型。X2V 表示将不同的输入模态(X,如文本或图像)转换(to)为视频输出(V)。
+
+GitHub: https://github.com/ModelTC/lightx2v
+
+HuggingFace: https://huggingface.co/lightx2v
+
+文档列表
+-------------
+
+.. toctree::
+ :maxdepth: 1
+ :caption: 快速入门
+
+ 快速入门
+ 模型结构
+ 基准测试
+
+.. toctree::
+ :maxdepth: 1
+ :caption: 方法教程
+
+ 模型量化
+ 特征缓存
+ 注意力机制
+ 参数卸载
+ 并行推理
+ 变分辨率推理
+ 步数蒸馏
+ 自回归蒸馏
+ 视频帧插值
+
+.. toctree::
+ :maxdepth: 1
+ :caption: 部署指南
+
+ 低延迟场景部署
+ 低资源场景部署
+ Lora模型部署
+ 服务化部署
+ Gradio部署
+ ComfyUI部署
+ 本地windows电脑部署
diff --git a/docs/ZH_CN/source/method_tutorials/attention.md b/docs/ZH_CN/source/method_tutorials/attention.md
new file mode 100644
index 0000000000000000000000000000000000000000..06f9570579cdfb4042e3cfdcf6f63329bf4a9554
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/attention.md
@@ -0,0 +1,35 @@
+# 注意力机制
+
+## LightX2V支持的注意力机制
+
+| 名称 | 类型名称 | GitHub 链接 |
+|--------------------|------------------|-------------|
+| Flash Attention 2 | `flash_attn2` | [flash-attention v2](https://github.com/Dao-AILab/flash-attention) |
+| Flash Attention 3 | `flash_attn3` | [flash-attention v3](https://github.com/Dao-AILab/flash-attention) |
+| Sage Attention 2 | `sage_attn2` | [SageAttention](https://github.com/thu-ml/SageAttention) |
+| Radial Attention | `radial_attn` | [Radial Attention](https://github.com/mit-han-lab/radial-attention) |
+| Sparge Attention | `sparge_ckpt` | [Sparge Attention](https://github.com/thu-ml/SpargeAttn) |
+
+---
+
+## 配置示例
+
+注意力机制的config文件在[这里](https://github.com/ModelTC/lightx2v/tree/main/configs/attentions)
+
+通过指定--config_json到具体的config文件,即可以测试不同的注意力机制
+
+比如对于radial_attn,配置如下:
+
+```json
+{
+ "self_attn_1_type": "radial_attn",
+ "cross_attn_1_type": "flash_attn3",
+ "cross_attn_2_type": "flash_attn3"
+}
+```
+
+如需更换为其他类型,只需将对应值替换为上述表格中的类型名称即可。
+
+tips: radial_attn因为稀疏算法原理的限制只能用在self attention
+
+如需进一步定制注意力机制的行为,请参考各注意力库的官方文档或实现代码。
diff --git a/docs/ZH_CN/source/method_tutorials/autoregressive_distill.md b/docs/ZH_CN/source/method_tutorials/autoregressive_distill.md
new file mode 100644
index 0000000000000000000000000000000000000000..19845ce07f44e7b6ea3d15e7255637cc4d13bea4
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/autoregressive_distill.md
@@ -0,0 +1,53 @@
+# 自回归蒸馏
+
+自回归蒸馏是 LightX2V 中的一个技术探索,通过训练蒸馏模型将推理步数从原始的 40-50 步减少到 **8 步**,在实现推理加速的同时能够通过 KV Cache 技术生成无限长视频。
+
+> ⚠️ 警告:目前自回归蒸馏的效果一般,加速效果也没有达到预期,但是可以作为一个长期的研究项目。目前 LightX2V 仅支持 T2V 的自回归模型。
+
+## 🔍 技术原理
+
+自回归蒸馏通过 [CausVid](https://github.com/tianweiy/CausVid) 技术实现。CausVid 针对 1.3B 的自回归模型进行步数蒸馏、CFG蒸馏。LightX2V 在其基础上,进行了一系列扩展:
+
+1. **更大的模型**:支持 14B 模型的自回归蒸馏训练;
+2. **更完整的数据处理流程**:生成 50,000 个提示词-视频对的训练数据集;
+
+具体实现可参考 [CausVid-Plus](https://github.com/GoatWu/CausVid-Plus)。
+
+## 🛠️ 配置文件说明
+
+### 配置文件
+
+在 [configs/causvid/](https://github.com/ModelTC/lightx2v/tree/main/configs/causvid) 目录下提供了配置选项:
+
+| 配置文件 | 模型地址 |
+|----------|------------|
+| [wan_t2v_causvid.json](https://github.com/ModelTC/lightx2v/blob/main/configs/causvid/wan_t2v_causvid.json) | https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid |
+
+### 关键配置参数
+
+```json
+{
+ "enable_cfg": false, // 关闭CFG以提升速度
+ "num_fragments": 3, // 一次生成视频的段数,每段5s
+ "num_frames": 21, // 每段视频的帧数,谨慎修改!
+ "num_frame_per_block": 3, // 每个自回归块的帧数,谨慎修改!
+ "num_blocks": 7, // 每段视频的自回归块数,谨慎修改!
+ "frame_seq_length": 1560, // 每帧的编码长度,谨慎修改!
+ "denoising_step_list": [ // 去噪时间步列表
+ 999, 934, 862, 756, 603, 410, 250, 140, 74
+ ]
+}
+```
+
+## 📜 使用方法
+
+### 模型准备
+
+将下载好的模型(`causal_model.pt` 或者 `causal_model.safetensors`)放到 Wan 模型根目录的 `causvid_models/` 文件夹下即可
+- 对于 T2V:`Wan2.1-T2V-14B/causvid_models/`
+
+### 推理脚本
+
+```bash
+bash scripts/wan/run_wan_t2v_causvid.sh
+```
diff --git a/docs/ZH_CN/source/method_tutorials/cache.md b/docs/ZH_CN/source/method_tutorials/cache.md
new file mode 100644
index 0000000000000000000000000000000000000000..a867119c81d3f779a66105feed3be7df695a2ea5
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/cache.md
@@ -0,0 +1,3 @@
+# 特征缓存
+
+由于要展示一些视频的播放效果,你可以在这个[🔗 页面](https://github.com/ModelTC/LightX2V/blob/main/docs/ZH_CN/source/method_tutorials/cache_source.md)获得更好的展示效果以及相对应的文档内容。
diff --git a/docs/ZH_CN/source/method_tutorials/cache_source.md b/docs/ZH_CN/source/method_tutorials/cache_source.md
new file mode 100644
index 0000000000000000000000000000000000000000..37624af05f9d2ad63a65d071891a62a9911b28d8
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/cache_source.md
@@ -0,0 +1,139 @@
+# 特征缓存
+
+## 缓存加速算法
+- 在扩散模型的推理过程中,缓存复用是一种重要的加速算法。
+- 其核心思想是在部分时间步跳过冗余计算,通过复用历史缓存结果提升推理效率。
+- 算法的关键在于如何决策在哪些时间步进行缓存复用,通常基于模型状态变化或误差阈值动态判断。
+- 在推理过程中,需要缓存如中间特征、残差、注意力输出等关键内容。当进入可复用时间步时,直接利用已缓存的内容,通过泰勒展开等近似方法重构当前输出,从而减少重复计算,实现高效推理。
+
+### TeaCache
+`TeaCache`的核心思想是通过对相邻时间步输入的**相对L1**距离进行累加,当累计距离达到设定阈值时,判定当前时间步不使用缓存复用;相反,当累计距离未达到设定阈值时则使用缓存复用加速推理过程。
+- 具体来说,算法在每一步推理时计算当前输入与上一步输入的相对L1距离,并将其累加。
+- 当累计距离未超过阈值,说明模型状态变化不明显,则直接复用最近一次缓存的内容,跳过部分冗余计算。这样可以显著减少模型的前向计算次数,提高推理速度。
+
+实际效果上,TeaCache 在保证生成质量的前提下,实现了明显的加速。在单卡H200上,加速前后的用时与视频对比如下:
+
+
+
+ |
+ 加速前:58s
+ |
+
+ 加速后:17.9s
+ |
+
+
+ |
+
+ |
+
+
+ |
+
+
+
+
+- 加速比为:**3.24**
+- config:[wan_t2v_1_3b_tea_480p.json](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json)
+- 参考论文:[https://arxiv.org/abs/2411.19108](https://arxiv.org/abs/2411.19108)
+
+### TaylorSeer Cache
+`TaylorSeer Cache`的核心在于利用泰勒公式对缓存内容进行再次计算,作为缓存复用时间步的残差补偿。
+- 具体做法是在缓存复用的时间步,不仅简单地复用历史缓存,还通过泰勒展开对当前输出进行近似重构。这样可以在减少计算量的同时,进一步提升输出的准确性。
+- 泰勒展开能够有效捕捉模型状态的微小变化,使得缓存复用带来的误差得到补偿,从而在加速的同时保证生成质量。
+
+`TaylorSeer Cache`适用于对输出精度要求较高的场景,能够在缓存复用的基础上进一步提升模型推理的表现。
+
+
+
+ |
+ 加速前:57.7s
+ |
+
+ 加速后:41.3s
+ |
+
+
+ |
+
+ |
+
+
+ |
+
+
+
+
+- 加速比为:**1.39**
+- config:[wan_t2v_taylorseer](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/taylorseer/wan_t2v_taylorseer.json)
+- 参考论文:[https://arxiv.org/abs/2503.06923](https://arxiv.org/abs/2503.06923)
+
+### AdaCache
+`AdaCache`的核心思想是根据指定block块中的部分缓存内容,动态调整缓存复用的步长。
+- 算法会分析相邻两个时间步在特定 block 内的特征差异,根据差异大小自适应地决定下一个缓存复用的时间步间隔。
+- 当模型状态变化较小时,步长自动加大,减少缓存更新频率;当状态变化较大时,步长缩小,保证输出质量。
+
+这样可以根据实际推理过程中的动态变化,灵活调整缓存策略,实现更高效的加速和更优的生成效果。AdaCache 适合对推理速度和生成质量都有较高要求的应用场景。
+
+
+
+ |
+ 加速前:227s
+ |
+
+ 加速后:83s
+ |
+
+
+ |
+
+ |
+
+
+ |
+
+
+
+
+- 加速比为:**2.73**
+- config:[wan_i2v_ada](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/adacache/wan_i2v_ada.json)
+- 参考论文:[https://arxiv.org/abs/2411.02397](https://arxiv.org/abs/2411.02397)
+
+### CustomCache
+`CustomCache`综合了`TeaCache`和`TaylorSeer Cache`的优势。
+- 它结合了`TeaCache`在缓存决策上的实时性和合理性,通过动态阈值判断何时进行缓存复用.
+- 同时利用`TaylorSeer`的泰勒展开方法对已缓存内容进行利用。
+
+这样不仅能够高效地决定缓存复用的时机,还能最大程度地利用缓存内容,提升输出的准确性和生成质量。实际测试表明,`CustomCache`在多个内容生成任务上,生成的视频质量优于单独使用`TeaCache、TaylorSeer Cache`或`AdaCache`的方案,是目前综合性能最优的缓存加速算法之一。
+
+
+
+ |
+ 加速前:57.9s
+ |
+
+ 加速后:16.6s
+ |
+
+
+ |
+
+ |
+
+
+ |
+
+
+
+
+- 加速比为:**3.49**
+- config:[wan_t2v_custom_1_3b](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/custom/wan_t2v_custom_1_3b.json)
+
+
+## 使用方式
+
+特征缓存的config文件在[这里](https://github.com/ModelTC/lightx2v/tree/main/configs/caching)
+
+通过指定--config_json到具体的config文件,即可以测试不同的cache算法
+
+[这里](https://github.com/ModelTC/lightx2v/tree/main/scripts/cache)有一些运行脚本供使用。
diff --git a/docs/ZH_CN/source/method_tutorials/changing_resolution.md b/docs/ZH_CN/source/method_tutorials/changing_resolution.md
new file mode 100644
index 0000000000000000000000000000000000000000..e4a51490d032bcf556d3ca1cae9f06b7648bbc9d
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/changing_resolution.md
@@ -0,0 +1,68 @@
+# 变分辨率推理
+
+## 概述
+
+变分辨率推理是一种优化去噪过程的技术策略,通过在不同的去噪阶段采用不同的分辨率来提升计算效率并保持生成质量。该方法的核心思想是:在去噪过程的前期使用较低分辨率进行粗略去噪,在后期切换到正常分辨率进行精细化处理。
+
+## 技术原理
+
+### 分阶段去噪策略
+
+变分辨率推理基于以下观察:
+
+- **前期去噪**:主要处理粗糙的噪声和整体结构,不需要过多的细节信息
+- **后期去噪**:专注于细节优化和高频信息恢复,需要完整的分辨率信息
+
+### 分辨率切换机制
+
+1. **低分辨率阶段**(前期)
+ - 将输入降采样到较低分辨率(如原尺寸的0.75)
+ - 执行初始的去噪步骤
+ - 快速移除大部分噪声,建立基本结构
+
+2. **正常分辨率阶段**(后期)
+ - 将第一步的去噪结果上采样回原始分辨率
+ - 继续执行剩余的去噪步骤
+ - 恢复细节信息,完成精细化处理
+
+
+### U型分辨率策略
+
+如果在刚开始的去噪步,降低分辨率,可能会导致最后的生成的视频和正常推理的生成的视频,整体差异偏大。因此可以采用U型的分辨率策略,即最一开始保持几步的原始分辨率,再降低分辨率推理。
+
+## 使用方式
+
+变分辨率推理的config文件在[这里](https://github.com/ModelTC/LightX2V/tree/main/configs/changing_resolution)
+
+通过指定--config_json到具体的config文件,即可以测试变分辨率推理。
+
+可以参考[这里](https://github.com/ModelTC/LightX2V/blob/main/scripts/changing_resolution)的脚本运行。
+
+
+### 举例1:
+```
+{
+ "infer_steps": 50,
+ "changing_resolution": true,
+ "resolution_rate": [0.75],
+ "changing_resolution_steps": [25]
+}
+```
+
+表示总步数为50,1到25步的分辨率为0.75倍原始分辨率,26到最后一步的分辨率为原始分辨率。
+
+### 举例2:
+```
+{
+ "infer_steps": 50,
+ "changing_resolution": true,
+ "resolution_rate": [1.0, 0.75],
+ "changing_resolution_steps": [10, 35]
+}
+```
+
+表示总步数为50,1到10步的分辨率为原始分辨率,11到35步的分辨率为0.75倍原始分辨率,36到最后一步的分辨率为原始分辨率。
+
+通常来说,假设`changing_resolution_steps`为[A, B, C],去噪的起始步为1,总步数为X,则推理会被分成4个部分。
+
+分别是,(0, A], (A, B]. (B, C], (C, X],每个部分是左开右闭集合。
diff --git a/docs/ZH_CN/source/method_tutorials/offload.md b/docs/ZH_CN/source/method_tutorials/offload.md
new file mode 100644
index 0000000000000000000000000000000000000000..2f30f47d8496efb08581bc1bf5fc80dab8c3657b
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/offload.md
@@ -0,0 +1,177 @@
+# 参数卸载
+
+## 📖 概述
+
+Lightx2v 实现了先进的参数卸载机制,专为在有限硬件资源下处理大型模型推理而设计。该系统通过智能管理不同内存层次中的模型权重,提供了优秀的速度-内存平衡。
+
+**核心特性:**
+- **分block/phase卸载**:高效地以block/phase为单位管理模型权重,实现最优内存使用
+ - **Block**:Transformer模型的基本计算单元,包含完整的Transformer层(自注意力、交叉注意力、前馈网络等),是较大的内存管理单位
+ - **Phase**:Block内部的更细粒度计算阶段,包含单个计算组件(如自注意力、交叉注意力、前馈网络等),提供更精细的内存控制
+- **多级存储支持**:GPU → CPU → 磁盘层次结构,配合智能缓存
+- **异步操作**:使用 CUDA 流实现计算和数据传输的重叠
+- **磁盘/NVMe 序列化**:当内存不足时支持二级存储
+
+## 🎯 卸载策略
+
+### 策略一:GPU-CPU 分block/phase卸载
+
+**适用场景**:GPU 显存不足但系统内存充足
+
+**工作原理**:在 GPU 和 CPU 内存之间以block或phase为单位管理模型权重,利用 CUDA 流实现计算和数据传输的重叠。Block包含完整的Transformer层,而Phase则是Block内部的单个计算组件。
+
+
+

+
+
+
+

+
+
+
+

+
+
+
+**Block vs Phase 说明**:
+- **Block粒度**:较大的内存管理单位,包含完整的Transformer层(自注意力、交叉注意力、前馈网络等),适合内存充足的情况,减少管理开销
+- **Phase粒度**:更细粒度的内存管理,包含单个计算组件(如自注意力、交叉注意力、前馈网络等),适合内存受限的情况,提供更灵活的内存控制
+
+**关键特性:**
+- **异步传输**:使用三个不同优先级的CUDA流实现计算和传输的并行
+ - 计算流(priority=-1):高优先级,负责当前计算
+ - GPU加载流(priority=0):中优先级,负责从CPU到GPU的预取
+ - CPU加载流(priority=0):中优先级,负责从GPU到CPU的卸载
+- **预取机制**:提前将下一个block/phase加载到 GPU
+- **智能缓存**:在 CPU 内存中维护权重缓存
+- **流同步**:确保数据传输和计算的正确性
+- **Swap操作**:计算完成后轮换block/phase位置,实现连续计算
+
+
+
+
+### 策略二:磁盘-CPU-GPU 分block/phase卸载(延迟加载)
+
+**适用场景**:GPU 显存和系统内存都不足
+
+**工作原理**:在策略一的基础上引入磁盘存储,实现三级存储层次(磁盘 → CPU → GPU)。CPU继续作为缓存池,但大小可配置,适用于CPU内存受限的设备。
+
+
+
+

+
+
+
+
+

+
+
+**关键特性:**
+- **延迟加载**:模型权重按需从磁盘加载,避免一次性加载全部模型
+- **智能缓存**:CPU内存缓冲区使用FIFO策略管理,可配置大小
+- **多线程预取**:使用多个磁盘工作线程并行加载
+- **异步传输**:使用CUDA流实现计算和数据传输的重叠
+- **Swap轮换**:通过位置轮换实现连续计算,避免重复加载/卸载
+
+**工作步骤**:
+- **磁盘存储**:模型权重按block存储在SSD/NVMe上,每个block一个.safetensors文件
+- **任务调度**:当需要某个block/phase时,优先级任务队列分配磁盘工作线程
+- **异步加载**:多个磁盘线程并行从磁盘读取权重文件到CPU内存缓冲区
+- **智能缓存**:CPU内存缓冲区使用FIFO策略管理缓存,可配置大小
+- **缓存命中**:如果权重已在缓存中,直接传输到GPU,无需磁盘读取
+- **预取传输**:缓存中的权重异步传输到GPU内存(使用GPU加载流)
+- **计算执行**:GPU上的权重进行计算(使用计算流),同时后台继续预取下一个block/phase
+- **Swap轮换**:计算完成后轮换block/phase位置,实现连续计算
+- **内存管理**:当CPU缓存满时,自动淘汰最早使用的权重block/phase
+
+
+
+## ⚙️ 配置参数
+
+### GPU-CPU 卸载配置
+
+```python
+config = {
+ "cpu_offload": True,
+ "offload_ratio": 1.0, # 卸载比例(0.0-1.0)
+ "offload_granularity": "block", # 卸载粒度:"block"或"phase"
+ "lazy_load": False, # 禁用延迟加载
+}
+```
+
+### 磁盘-CPU-GPU 卸载配置
+
+```python
+config = {
+ "cpu_offload": True,
+ "lazy_load": True, # 启用延迟加载
+ "offload_ratio": 1.0, # 卸载比例
+ "offload_granularity": "phase", # 推荐使用phase粒度
+ "num_disk_workers": 2, # 磁盘工作线程数
+ "offload_to_disk": True, # 启用磁盘卸载
+}
+```
+
+**智能缓存关键参数:**
+- `max_memory`:控制CPU缓存大小,影响缓存命中率和内存使用
+- `num_disk_workers`:控制磁盘加载线程数,影响预取速度
+- `offload_granularity`:控制缓存粒度(block或phase),影响缓存效率
+ - `"block"`:以完整的Transformer层为单位进行缓存管理
+ - `"phase"`:以单个计算组件为单位进行缓存管理
+
+**非 DIT 模型组件(T5、CLIP、VAE)的卸载配置:**
+
+这些组件的卸载行为遵循以下规则:
+- **默认行为**:如果没有单独指定,T5、CLIP、VAE 会跟随 `cpu_offload` 的设置
+- **独立配置**:可以为每个组件单独设置卸载策略,实现精细控制
+
+**配置示例**:
+```json
+{
+ "cpu_offload": true, // DIT 模型卸载开关
+ "t5_cpu_offload": false, // T5 编码器独立设置
+ "clip_cpu_offload": false, // CLIP 编码器独立设置
+ "vae_cpu_offload": false // VAE 编码器独立设置
+}
+```
+
+在显存受限的设备上,建议采用渐进式卸载策略:
+
+1. **第一步**:仅开启 `cpu_offload`,关闭 `t5_cpu_offload`、`clip_cpu_offload`、`vae_cpu_offload`
+2. **第二步**:如果显存仍不足,逐步开启 T5、CLIP、VAE 的 CPU 卸载
+3. **第三步**:如果显存仍然不够,考虑使用量化 + CPU 卸载或启用 `lazy_load`
+
+**实践经验**:
+- **RTX 4090 24GB + 14B 模型**:通常只需开启 `cpu_offload`,其他组件卸载需要手动设为 `false`,同时使用 FP8 量化版本
+- **更小显存的 GPU**:需要组合使用量化、CPU 卸载和延迟加载
+- **量化方案**:建议参考[量化技术文档](../method_tutorials/quantization.md)选择合适的量化策略
+
+
+**配置文件参考**:
+- **Wan2.1 系列模型**:参考 [offload 配置文件](https://github.com/ModelTC/lightx2v/tree/main/configs/offload)
+- **Wan2.2 系列模型**:参考 [wan22 配置文件](https://github.com/ModelTC/lightx2v/tree/main/configs/wan22) 中以 `4090` 结尾的配置文件
+
+## 🎯 使用建议
+- 🔄 GPU-CPU分block/phase卸载:适合GPU显存不足(RTX 3090/4090 24G)但系统内存(>64/128G)充足
+
+- 💾 磁盘-CPU-GPU分block/phase卸载:适合GPU显存(RTX 3060/4090 8G)和系统内存(16/32G)都不足
+
+- 🚫 无Offload:适合高端硬件配置,追求最佳性能
+
+
+## 🔍 故障排除
+
+### 常见问题及解决方案
+
+1. **磁盘I/O瓶颈**
+ - 解决方案:使用NVMe SSD,增加num_disk_workers
+
+
+2. **内存缓冲区溢出**
+ - 解决方案:增加max_memory或减少num_disk_workers
+
+3. **加载超时**
+ - 解决方案:检查磁盘性能,优化文件系统
+
+
+**注意**:本卸载机制专为Lightx2v设计,充分利用了现代硬件的异步计算能力,能够显著降低大模型推理的硬件门槛。
diff --git a/docs/ZH_CN/source/method_tutorials/parallel.md b/docs/ZH_CN/source/method_tutorials/parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..a6ff0dfe2c2a65322a2434c5365c4143d4ff6764
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/parallel.md
@@ -0,0 +1,55 @@
+# 并行推理
+
+LightX2V 支持分布式并行推理,能够利用多个 GPU 进行推理。DiT部分支持两种并行注意力机制:**Ulysses** 和 **Ring**,同时还支持 **Cfg 并行推理**。并行推理,显著降低推理耗时和减轻每个GPU的显存开销。
+
+## DiT 并行配置
+
+### 1. Ulysses 并行
+
+**配置方式:**
+```json
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses"
+ }
+```
+
+### 2. Ring 并行
+
+
+**配置方式:**
+```json
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ring"
+ }
+```
+
+## Cfg 并行配置
+
+**配置方式:**
+```json
+ "parallel": {
+ "cfg_p_size": 2
+ }
+```
+
+## 混合并行配置
+
+**配置方式:**
+```json
+ "parallel": {
+ "seq_p_size": 4,
+ "seq_p_attn_type": "ulysses",
+ "cfg_p_size": 2
+ }
+```
+
+
+## 使用方式
+
+并行推理的config文件在[这里](https://github.com/ModelTC/lightx2v/tree/main/configs/dist_infer)
+
+通过指定--config_json到具体的config文件,即可以测试并行推理
+
+[这里](https://github.com/ModelTC/lightx2v/tree/main/scripts/dist_infer)有一些运行脚本供使用。
diff --git a/docs/ZH_CN/source/method_tutorials/quantization.md b/docs/ZH_CN/source/method_tutorials/quantization.md
new file mode 100644
index 0000000000000000000000000000000000000000..311367cc612f29744974dba4345d9d7c10ac10b9
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/quantization.md
@@ -0,0 +1,158 @@
+# 模型量化技术
+
+## 📖 概述
+
+LightX2V 支持对 DIT、T5 和 CLIP 模型进行量化推理,通过降低模型精度来减少显存占用并提升推理速度。
+
+---
+
+## 🔧 量化模式
+
+| 量化模式 | 权重量化 | 激活量化 | 计算内核 | 适用硬件 |
+|--------------|----------|----------|----------|----------|
+| `fp8-vllm` | FP8 通道对称 | FP8 通道动态对称 | [VLLM](https://github.com/vllm-project/vllm) | H100/H200/H800, RTX 40系等 |
+| `int8-vllm` | INT8 通道对称 | INT8 通道动态对称 | [VLLM](https://github.com/vllm-project/vllm) | A100/A800, RTX 30/40系等 |
+| `fp8-sgl` | FP8 通道对称 | FP8 通道动态对称 | [SGL](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) | H100/H200/H800, RTX 40系等 |
+| `int8-sgl` | INT8 通道对称 | INT8 通道动态对称 | [SGL](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) | A100/A800, RTX 30/40系等 |
+| `fp8-q8f` | FP8 通道对称 | FP8 通道动态对称 | [Q8-Kernels](https://github.com/KONAKONA666/q8_kernels) | RTX 40系, L40S等 |
+| `int8-q8f` | INT8 通道对称 | INT8 通道动态对称 | [Q8-Kernels](https://github.com/KONAKONA666/q8_kernels) | RTX 40系, L40S等 |
+| `int8-torchao` | INT8 通道对称 | INT8 通道动态对称 | [TorchAO](https://github.com/pytorch/ao) | A100/A800, RTX 30/40系等 |
+| `int4-g128-marlin` | INT4 分组对称 | FP16 | [Marlin](https://github.com/IST-DASLab/marlin) | H200/H800/A100/A800, RTX 30/40系等 |
+| `fp8-b128-deepgemm` | FP8 分块对称 | FP8 分组对称 | [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) | H100/H200/H800, RTX 40系等|
+
+---
+
+## 🔧 量化模型获取
+
+### 方式一:下载预量化模型
+
+从 LightX2V 模型仓库下载预量化的模型:
+
+**DIT 模型**
+
+从 [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) 下载预量化的 DIT 模型:
+
+```bash
+# 下载 DIT FP8 量化模型
+huggingface-cli download lightx2v/Wan2.1-Distill-Models \
+ --local-dir ./models \
+ --include "wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors"
+```
+
+**Encoder 模型**
+
+从 [Encoders-LightX2V](https://huggingface.co/lightx2v/Encoders-Lightx2v) 下载预量化的 T5 和 CLIP 模型:
+
+```bash
+# 下载 T5 FP8 量化模型
+huggingface-cli download lightx2v/Encoders-Lightx2v \
+ --local-dir ./models \
+ --include "models_t5_umt5-xxl-enc-fp8.pth"
+
+# 下载 CLIP FP8 量化模型
+huggingface-cli download lightx2v/Encoders-Lightx2v \
+ --local-dir ./models \
+ --include "models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8.pth"
+```
+
+### 方式二:自行量化模型
+
+详细量化工具使用方法请参考:[模型转换文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md)
+
+---
+
+## 🚀 量化模型使用
+
+### DIT 模型量化
+
+#### 支持的量化模式
+
+DIT 量化模式(`dit_quant_scheme`)支持:`fp8-vllm`、`int8-vllm`、`fp8-sgl`、`int8-sgl`、`fp8-q8f`、`int8-q8f`、`int8-torchao`、`int4-g128-marlin`、`fp8-b128-deepgemm`
+
+#### 配置示例
+
+```json
+{
+ "dit_quantized": true,
+ "dit_quant_scheme": "fp8-sgl",
+ "dit_quantized_ckpt": "/path/to/dit_quantized_model" // 可选
+}
+```
+
+> 💡 **提示**:当运行脚本的 `model_path` 中只有一个 DIT 模型时,`dit_quantized_ckpt` 可以不用单独指定。
+
+### T5 模型量化
+
+#### 支持的量化模式
+
+T5 量化模式(`t5_quant_scheme`)支持:`int8-vllm`、`fp8-sgl`、`int8-q8f`、`fp8-q8f`、`int8-torchao`
+
+#### 配置示例
+
+```json
+{
+ "t5_quantized": true,
+ "t5_quant_scheme": "fp8-sgl",
+ "t5_quantized_ckpt": "/path/to/t5_quantized_model" // 可选
+}
+```
+
+> 💡 **提示**:当运行脚本指定的 `model_path` 中存在 T5 量化模型(如 `models_t5_umt5-xxl-enc-fp8.pth` 或 `models_t5_umt5-xxl-enc-int8.pth`)时,`t5_quantized_ckpt` 可以不用单独指定。
+
+### CLIP 模型量化
+
+#### 支持的量化模式
+
+CLIP 量化模式(`clip_quant_scheme`)支持:`int8-vllm`、`fp8-sgl`、`int8-q8f`、`fp8-q8f`、`int8-torchao`
+
+#### 配置示例
+
+```json
+{
+ "clip_quantized": true,
+ "clip_quant_scheme": "fp8-sgl",
+ "clip_quantized_ckpt": "/path/to/clip_quantized_model" // 可选
+}
+```
+
+> 💡 **提示**:当运行脚本指定的 `model_path` 中存在 CLIP 量化模型(如 `models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8.pth` 或 `models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8.pth`)时,`clip_quantized_ckpt` 可以不用单独指定。
+
+### 性能优化策略
+
+如果显存不够,可以结合参数卸载来进一步减少显存占用,参考[参数卸载文档](../method_tutorials/offload.md):
+
+> - **Wan2.1 配置**:参考 [offload 配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/offload)
+> - **Wan2.2 配置**:参考 [wan22 配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/wan22) 中以 `4090` 结尾的配置
+
+---
+
+## 📚 相关资源
+
+### 配置文件示例
+- [INT8 量化配置](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v.json)
+- [Q8F 量化配置](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v_q8f.json)
+- [TorchAO 量化配置](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v_torchao.json)
+
+### 运行脚本
+- [量化推理脚本](https://github.com/ModelTC/LightX2V/tree/main/scripts/quantization)
+
+### 工具文档
+- [量化工具文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md)
+- [LightCompress 量化文档](https://github.com/ModelTC/llmc/blob/main/docs/zh_cn/source/backend/lightx2v.md)
+
+### 模型仓库
+- [Wan2.1-LightX2V 量化模型](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- [Wan2.2-LightX2V 量化模型](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+- [Encoders 量化模型](https://huggingface.co/lightx2v/Encoders-Lightx2v)
+
+---
+
+通过本文档,您应该能够:
+
+✅ 理解 LightX2V 支持的量化方案
+✅ 根据硬件选择合适的量化策略
+✅ 正确配置量化参数
+✅ 获取和使用量化模型
+✅ 优化推理性能和显存使用
+
+如有其他问题,欢迎在 [GitHub Issues](https://github.com/ModelTC/LightX2V/issues) 中提问。
diff --git a/docs/ZH_CN/source/method_tutorials/step_distill.md b/docs/ZH_CN/source/method_tutorials/step_distill.md
new file mode 100644
index 0000000000000000000000000000000000000000..2ed48c810dedee47a5b7a657e155ce0f6b4e6a34
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/step_distill.md
@@ -0,0 +1,183 @@
+# 步数蒸馏
+
+步数蒸馏是 LightX2V 中的一项重要优化技术,通过训练蒸馏模型将推理步数从原始的 40-50 步大幅减少到 **4 步**,在保持视频质量的同时显著提升推理速度。LightX2V 在实现步数蒸馏的同时也加入了 CFG 蒸馏,进一步提升推理速度。
+
+## 🔍 技术原理
+
+### DMD 蒸馏
+
+步数蒸馏的核心技术是 [DMD 蒸馏](https://arxiv.org/abs/2311.18828)。DMD 蒸馏的框架如下图所示:
+
+
+

+
+
+DMD蒸馏的核心思想是最小化蒸馏模型与原始模型输出分布的 KL 散度:
+
+$$
+\begin{aligned}
+D_{KL}\left(p_{\text{fake}} \; \| \; p_{\text{real}} \right) &= \mathbb{E}{x\sim p\text{fake}}\left(\log\left(\frac{p_\text{fake}(x)}{p_\text{real}(x)}\right)\right)\\
+&= \mathbb{E}{\substack{
+z \sim \mathcal{N}(0; \mathbf{I}) \\
+x = G_\theta(z)
+}}-\big(\log~p_\text{real}(x) - \log~p_\text{fake}(x)\big).
+\end{aligned}
+$$
+
+由于直接计算概率密度几乎是不可能的,因此 DMD 蒸馏改为计算这个 KL 散度的梯度:
+
+$$
+\begin{aligned}
+\nabla_\theta D_{KL}
+&= \mathbb{E}{\substack{
+z \sim \mathcal{N}(0; \mathbf{I}) \\
+x = G_\theta(z)
+} } \Big[-
+\big(
+s_\text{real}(x) - s_\text{fake}(x)\big)
+\hspace{.5mm} \frac{dG}{d\theta}
+\Big],
+\end{aligned}
+$$
+
+其中 $s_\text{real}(x) =\nabla_{x} \text{log}~p_\text{real}(x)$ 和 $s_\text{fake}(x) =\nabla_{x} \text{log}~p_\text{fake}(x)$ 为得分函数。得分函数可以由模型进行计算。因此,DMD 蒸馏一共维护三个模型:
+
+- `real_score`,计算真实分布的得分;由于真实分布是固定的,因此 DMD 蒸馏使用固定权重的原始模型作为其得分函数;
+- `fake_score`,计算伪分布的得分;由于伪分布是不断更新的,因此 DMD 蒸馏使用原始模型对其初始化,并对其进行微调以学习生成器的输出分布;
+- `generator`,学生模型,通过计算 `real_score` 与 `fake_score` KL 散度的梯度指导其优化方向。
+
+> 参考文献:
+> 1. [DMD (One-step Diffusion with Distribution Matching Distillation)](https://arxiv.org/abs/2311.18828)
+> 2. [DMD2 (Improved Distribution Matching Distillation for Fast Image Synthesis)](https://arxiv.org/abs/2405.14867)
+
+### Self-Forcing
+
+DMD 蒸馏技术是针对图像生成的。Lightx2v 中的步数蒸馏基于 [Self-Forcing](https://github.com/guandeh17/Self-Forcing) 技术实现。Self-Forcing 的整体实现与 DMD 类似,但是仿照 DMD2,去掉了它的回归损失,而是使用了 ODE 初始化。此外,Self-Forcing 针对视频生成任务加入了一个重要优化:
+
+目前基于 DMD 蒸馏的方法难以一步生成视频。Self-Forcing 每次选择一个时间步进行优化,generator 仅仅在这一步计算梯度。这种方法使得 Self-Forcing 的训练速度显著提升,并且提升了中间时间步的去噪质量,其效果亦有所提升。
+
+> 参考文献:
+> 1. [Self-Forcing (Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion)](https://arxiv.org/abs/2506.08009)
+
+### Lightx2v
+
+Self-Forcing 针对 1.3B 的自回归模型进行步数蒸馏、CFG蒸馏。LightX2V 在其基础上,进行了一系列扩展:
+
+1. **更大的模型**:支持 14B 模型的步数蒸馏训练;
+2. **更多的模型**:支持标准的双向模型,以及 I2V 模型的步数蒸馏训练;
+3. **更好的效果**:Lightx2v 使用了约 50,000 条数据的高质量 prompt 进行训练;
+
+具体实现可参考 [Self-Forcing-Plus](https://github.com/GoatWu/Self-Forcing-Plus)。
+
+## 🎯 技术特性
+
+- **推理加速**:推理步数从 40-50 步减少到 4 步且无需 CFG,速度提升约 **20-24x**
+- **质量保持**:通过蒸馏技术保持原有的视频生成质量
+- **兼容性强**:支持 T2V 和 I2V 任务
+- **使用灵活**:支持加载完整步数蒸馏模型,或者在原生模型的基础上加载步数蒸馏LoRA;支持与 int8/fp8 模型量化相兼容
+
+## 🛠️ 配置文件说明
+
+### 基础配置文件
+
+在 [configs/distill/](https://github.com/ModelTC/lightx2v/tree/main/configs/distill) 目录下提供了多种配置选项:
+
+| 配置文件 | 用途 | 模型地址 |
+|----------|------|------------|
+| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | 加载 T2V 4步蒸馏完整模型 | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) |
+| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | 加载 I2V 4步蒸馏完整模型 | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) |
+| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | 加载 Wan-T2V 模型和步数蒸馏 LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) |
+| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | 加载 Wan-I2V 模型和步数蒸馏 LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) |
+
+### 关键配置参数
+
+- 由于 DMD 蒸馏仅训练几个固定的时间步,因此我们推荐使用 `LCM Scheduler` 进行推理。[WanStepDistillScheduler](https://github.com/ModelTC/LightX2V/blob/main/lightx2v/models/schedulers/wan/step_distill/scheduler.py) 中,已经固定使用 `LCM Scheduler`,无需用户进行配置。
+- `infer_steps`, `denoising_step_list` 和 `sample_shift` 设置为与训练时相匹配的参数,一般不建议用户修改。
+- `enable_cfg` 一定设置为 `false`(等价于设置 `sample_guide_scale = 1`),否则可能出现视频完全模糊的现象。
+- `lora_configs` 支持融合不同强度的多个 lora。当 `lora_configs` 不为空时,默认加载原始的 `Wan2.1` 模型。因此使用 `lora_config` 并且想要使用步数蒸馏时,请设置步数蒸馏lora的路径与强度。
+
+```json
+{
+ "infer_steps": 4, // 推理步数
+ "denoising_step_list": [1000, 750, 500, 250], // 去噪时间步列表
+ "sample_shift": 5, // 调度器 timestep shift
+ "enable_cfg": false, // 关闭CFG以提升速度
+ "lora_configs": [ // LoRA权重路径(可选)
+ {
+ "path": "path/to/distill_lora.safetensors",
+ "strength": 1.0
+ }
+ ]
+}
+```
+
+## 📜 使用方法
+
+### 模型准备
+
+**完整模型:**
+将下载好的模型(`distill_model.pt` 或者 `distill_model.safetensors`)放到 Wan 模型根目录的 `distill_models/` 文件夹下即可
+
+- 对于 T2V:`Wan2.1-T2V-14B/distill_models/`
+- 对于 I2V-480P:`Wan2.1-I2V-14B-480P/distill_models/`
+
+**LoRA:**
+
+1. 将下载好的 LoRA 放到任意位置
+2. 修改配置文件中的 `lora_path` 参数为 LoRA 存放路径即可
+
+### 推理脚本
+
+**T2V 完整模型:**
+
+```bash
+bash scripts/wan/run_wan_t2v_distill_4step_cfg.sh
+```
+
+**I2V 完整模型:**
+
+```bash
+bash scripts/wan/run_wan_i2v_distill_4step_cfg.sh
+```
+
+### 步数蒸馏 LoRA 推理脚本
+
+**T2V LoRA:**
+
+```bash
+bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh
+```
+
+**I2V LoRA:**
+
+```bash
+bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh
+```
+
+## 🔧 服务化部署
+
+### 启动蒸馏模型服务
+
+对 [scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh) 中的启动命令进行修改:
+
+```bash
+python -m lightx2v.api_server \
+ --model_cls wan2.1_distill \
+ --task t2v \
+ --model_path $model_path \
+ --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg.json \
+ --port 8000 \
+ --nproc_per_node 1
+```
+
+运行服务启动脚本:
+
+```bash
+scripts/server/start_server.sh
+```
+
+更多详细信息见[服务化部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_service.html)。
+
+### 在 Gradio 界面中使用
+
+见 [Gradio 文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html)
diff --git a/docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md b/docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md
new file mode 100644
index 0000000000000000000000000000000000000000..7df2b6b4edc0cd16d97a57a3d222639c3ca3b2ae
--- /dev/null
+++ b/docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md
@@ -0,0 +1,246 @@
+# 视频帧插值 (VFI)
+
+> **重要说明**: 视频帧插值功能通过配置文件启用,而不是通过命令行参数。请在配置 JSON 文件中添加 `video_frame_interpolation` 配置块来启用此功能。
+
+## 概述
+
+视频帧插值(VFI)是一种在现有帧之间生成中间帧的技术,用于提高帧率并创建更流畅的视频播放效果。LightX2V 集成了 RIFE(Real-Time Intermediate Flow Estimation)模型,提供高质量的帧插值能力。
+
+## 什么是 RIFE?
+
+RIFE 是一种最先进的视频帧插值方法,使用光流估计来生成中间帧。它能够有效地:
+
+- 提高视频帧率(例如,从 16 FPS 提升到 32 FPS)
+- 创建平滑的运动过渡
+- 保持高视觉质量,最少伪影
+- 实时处理视频
+
+## 安装和设置
+
+### 下载 RIFE 模型
+
+首先,使用提供的脚本下载 RIFE 模型权重:
+
+```bash
+python tools/download_rife.py <目标目录>
+```
+
+例如,下载到指定位置:
+```bash
+python tools/download_rife.py /path/to/rife/train_log
+```
+
+此脚本将:
+- 从 HuggingFace 下载 RIFEv4.26 模型
+- 提取并将模型文件放置在正确的目录中
+- 清理临时文件
+
+## 使用方法
+
+### 配置文件设置
+
+视频帧插值功能通过配置文件启用。在你的配置 JSON 文件中添加 `video_frame_interpolation` 配置块:
+
+```json
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "fps": 16,
+ "video_frame_interpolation": {
+ "algo": "rife",
+ "target_fps": 32,
+ "model_path": "/path/to/rife/train_log"
+ }
+}
+```
+
+### 命令行使用
+
+使用包含 VFI 配置的配置文件运行推理:
+
+```bash
+python lightx2v/infer.py \
+ --model_cls wan2.1 \
+ --task t2v \
+ --model_path /path/to/model \
+ --config_json ./configs/video_frame_interpolation/wan_t2v.json \
+ --prompt "美丽的海上日落" \
+ --save_result_path ./output.mp4
+```
+
+### 配置参数说明
+
+在 `video_frame_interpolation` 配置块中:
+
+- `algo`: 帧插值算法,目前支持 "rife"
+- `target_fps`: 输出视频的目标帧率
+- `model_path`: RIFE 模型路径,通常为 "train_log"
+
+其他相关配置:
+- `fps`: 源视频帧率(默认 16)
+
+### 配置优先级
+
+系统会自动处理视频帧率配置,优先级如下:
+1. `video_frame_interpolation.target_fps` - 如果启用视频帧插值,使用此帧率作为输出帧率
+2. `fps`(默认 16)- 如果未启用视频帧插值,使用此帧率;同时总是用作源帧率
+
+
+## 工作原理
+
+### 帧插值过程
+
+1. **源视频生成**: 基础模型以源 FPS 生成视频帧
+2. **帧分析**: RIFE 分析相邻帧以估计光流
+3. **中间帧生成**: 在现有帧之间生成新帧
+4. **时序平滑**: 插值帧创建平滑的运动过渡
+
+### 技术细节
+
+- **输入格式**: ComfyUI 图像张量 [N, H, W, C],范围 [0, 1]
+- **输出格式**: 插值后的 ComfyUI 图像张量 [M, H, W, C],范围 [0, 1]
+- **处理**: 自动填充和分辨率处理
+- **内存优化**: 高效的 GPU 内存管理
+
+## 示例配置
+
+### 基础帧率翻倍
+
+创建配置文件 `wan_t2v_vfi_32fps.json`:
+
+```json
+{
+ "infer_steps": 50,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "seed": 42,
+ "sample_guide_scale": 6,
+ "enable_cfg": true,
+ "fps": 16,
+ "video_frame_interpolation": {
+ "algo": "rife",
+ "target_fps": 32,
+ "model_path": "/path/to/rife/train_log"
+ }
+}
+```
+
+运行命令:
+```bash
+python lightx2v/infer.py \
+ --model_cls wan2.1 \
+ --task t2v \
+ --model_path ./models/wan2.1 \
+ --config_json ./wan_t2v_vfi_32fps.json \
+ --prompt "一只小猫在花园里玩耍" \
+ --save_result_path ./output_32fps.mp4
+```
+
+### 更高帧率增强
+
+创建配置文件 `wan_i2v_vfi_60fps.json`:
+
+```json
+{
+ "infer_steps": 30,
+ "target_video_length": 81,
+ "target_height": 480,
+ "target_width": 832,
+ "seed": 42,
+ "sample_guide_scale": 6,
+ "enable_cfg": true,
+ "fps": 16,
+ "video_frame_interpolation": {
+ "algo": "rife",
+ "target_fps": 60,
+ "model_path": "/path/to/rife/train_log"
+ }
+}
+```
+
+运行命令:
+```bash
+python lightx2v/infer.py \
+ --model_cls wan2.1 \
+ --task i2v \
+ --model_path ./models/wan2.1 \
+ --config_json ./wan_i2v_vfi_60fps.json \
+ --image_path ./input.jpg \
+ --prompt "平滑的相机运动" \
+ --save_result_path ./output_60fps.mp4
+```
+
+## 性能考虑
+
+### 内存使用
+
+- RIFE 处理需要额外的 GPU 内存
+- 内存使用量与视频分辨率和长度成正比
+- 对于较长的视频,考虑使用较低的分辨率
+
+### 处理时间
+
+- 帧插值会增加处理开销
+- 更高的目标帧率需要更多计算
+- 处理时间大致与插值帧数成正比
+
+### 质量与速度权衡
+
+- 更高的插值比率可能引入伪影
+- 最佳范围:2x 到 4x 帧率增加
+- 对于极端插值(>4x),考虑多次处理
+
+## 最佳实践
+
+### 最佳使用场景
+
+- **运动密集视频**: 从帧插值中受益最多
+- **相机运动**: 更平滑的平移和缩放
+- **动作序列**: 减少运动模糊感知
+- **慢动作效果**: 创建流畅的慢动作视频
+
+### 推荐设置
+
+- **源 FPS**: 16-24 FPS(基础模型生成)
+- **目标 FPS**: 32-60 FPS(2x 到 4x 增加)
+- **分辨率**: 最高 720p 以获得最佳性能
+
+### 故障排除
+
+#### 常见问题
+
+1. **内存不足**: 减少视频分辨率或目标 FPS
+2. **输出中有伪影**: 降低插值比率
+3. **处理缓慢**: 检查 GPU 内存并考虑使用 CPU 卸载
+
+#### 解决方案
+
+通过修改配置文件来解决问题:
+
+```json
+{
+ // 内存问题解决:使用较低分辨率
+ "target_height": 480,
+ "target_width": 832,
+
+ // 质量问题解决:使用适中的插值
+ "video_frame_interpolation": {
+ "target_fps": 24 // 而不是 60
+ },
+
+ // 性能问题解决:启用卸载
+ "cpu_offload": true
+}
+```
+
+## 技术实现
+
+LightX2V 中的 RIFE 集成包括:
+
+- **RIFEWrapper**: 与 ComfyUI 兼容的 RIFE 模型包装器
+- **自动模型加载**: 与推理管道的无缝集成
+- **内存优化**: 高效的张量管理和 GPU 内存使用
+- **质量保持**: 在添加帧的同时保持原始视频质量
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8bee25d4f8f8c09f4192273cb99938badadd429e
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,284 @@
+# LightX2V Usage Examples
+
+This document introduces how to use LightX2V for video generation, including basic usage and advanced configurations.
+
+## 📋 Table of Contents
+
+- [Environment Setup](#environment-setup)
+- [Basic Usage Examples](#basic-usage-examples)
+- [Model Path Configuration](#model-path-configuration)
+- [Creating Generator](#creating-generator)
+- [Advanced Configurations](#advanced-configurations)
+ - [Parameter Offloading](#parameter-offloading)
+ - [Model Quantization](#model-quantization)
+ - [Parallel Inference](#parallel-inference)
+ - [Feature Caching](#feature-caching)
+ - [Lightweight VAE](#lightweight-vae)
+
+## 🔧 Environment Setup
+
+Please refer to the main project's [Quick Start Guide](../docs/EN/source/getting_started/quickstart.md) for environment setup.
+
+## 🚀 Basic Usage Examples
+
+A minimal code example can be found in `examples/wan_t2v.py`:
+
+```python
+from lightx2v import LightX2VPipeline
+
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.1-T2V-14B",
+ model_cls="wan2.1",
+ task="t2v",
+)
+
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=50,
+ height=480,
+ width=832,
+ num_frames=81,
+ guidance_scale=5.0,
+ sample_shift=5.0,
+)
+
+seed = 42
+prompt = "Your prompt here"
+negative_prompt = ""
+save_result_path="/path/to/save_results/output.mp4"
+
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
+```
+
+## 📁 Model Path Configuration
+
+### Basic Configuration
+
+Pass the model path to `LightX2VPipeline`:
+
+```python
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.2-I2V-A14B",
+ model_cls="wan2.2_moe", # For wan2.1, use "wan2.1"
+ task="i2v",
+)
+```
+
+### Specifying Multiple Model Weight Versions
+
+When there are multiple versions of bf16 precision DIT model safetensors files in the `model_path` directory, you need to use the following parameters to specify which weights to use:
+
+- **`dit_original_ckpt`**: Used to specify the original DIT weight path for models like wan2.1 and hunyuan15
+- **`low_noise_original_ckpt`**: Used to specify the low noise branch weight path for wan2.2 models
+- **`high_noise_original_ckpt`**: Used to specify the high noise branch weight path for wan2.2 models
+
+**Usage Example:**
+
+```python
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.2-I2V-A14B",
+ model_cls="wan2.2_moe",
+ task="i2v",
+ low_noise_original_ckpt="/path/to/low_noise_model.safetensors",
+ high_noise_original_ckpt="/path/to/high_noise_model.safetensors",
+)
+```
+
+## 🎛️ Creating Generator
+
+### Loading from Configuration File
+
+The generator can be loaded directly from a JSON configuration file. Configuration files are located in the `configs` directory:
+
+```python
+pipe.create_generator(config_json="../configs/wan/wan_t2v.json")
+```
+
+### Creating Generator Manually
+
+You can also create the generator manually and configure multiple parameters:
+
+```python
+pipe.create_generator(
+ attn_mode="flash_attn2", # Options: flash_attn2, flash_attn3, sage_attn2, sage_attn3 (B-architecture GPUs)
+ infer_steps=50, # Number of inference steps
+ num_frames=81, # Number of video frames
+ height=480, # Video height
+ width=832, # Video width
+ guidance_scale=5.0, # CFG guidance strength (CFG disabled when =1)
+ sample_shift=5.0, # Sample shift
+ fps=16, # Frame rate
+ aspect_ratio="16:9", # Aspect ratio
+ boundary=0.900, # Boundary value
+ boundary_step_index=2, # Boundary step index
+ denoising_step_list=[1000, 750, 500, 250], # Denoising step list
+)
+```
+
+**Parameter Description:**
+- **Resolution**: Specified via `height` and `width`
+- **CFG**: Specified via `guidance_scale` (set to 1 to disable CFG)
+- **FPS**: Specified via `fps`
+- **Video Length**: Specified via `num_frames`
+- **Inference Steps**: Specified via `infer_steps`
+- **Sample Shift**: Specified via `sample_shift`
+- **Attention Mode**: Specified via `attn_mode`, options include `flash_attn2`, `flash_attn3`, `sage_attn2`, `sage_attn3` (for B-architecture GPUs)
+
+## ⚙️ Advanced Configurations
+
+**⚠️ Important: When manually creating a generator, you can configure some advanced options. All advanced configurations must be specified before `create_generator()`, otherwise they will not take effect!**
+
+### Parameter Offloading
+
+Significantly reduces memory usage with almost no impact on inference speed. Suitable for RTX 30/40/50 series GPUs.
+
+```python
+pipe.enable_offload(
+ cpu_offload=True, # Enable CPU offloading
+ offload_granularity="block", # Offload granularity: "block" or "phase"
+ text_encoder_offload=False, # Whether to offload text encoder
+ image_encoder_offload=False, # Whether to offload image encoder
+ vae_offload=False, # Whether to offload VAE
+)
+```
+
+**Notes:**
+- For Wan models, `offload_granularity` supports both `"block"` and `"phase"`
+- For HunyuanVideo-1.5, only `"block"` is currently supported
+
+### Model Quantization
+
+Quantization can significantly reduce memory usage and accelerate inference.
+
+```python
+pipe.enable_quantize(
+ dit_quantized=False, # Whether to use quantized DIT model
+ text_encoder_quantized=False, # Whether to use quantized text encoder
+ image_encoder_quantized=False, # Whether to use quantized image encoder
+ dit_quantized_ckpt=None, # DIT quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files)
+ low_noise_quantized_ckpt=None, # Wan2.2 low noise branch quantized weight path
+ high_noise_quantized_ckpt=None, # Wan2.2 high noise branch quantized weight path
+ text_encoder_quantized_ckpt=None, # Text encoder quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files)
+ image_encoder_quantized_ckpt=None, # Image encoder quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files)
+ quant_scheme="fp8-sgl", # Quantization scheme
+)
+```
+
+**Parameter Description:**
+- **`dit_quantized_ckpt`**: When the `model_path` directory doesn't contain quantized weights, or has multiple weight files, you need to specify the specific DIT quantized weight path
+- **`text_encoder_quantized_ckpt`** and **`image_encoder_quantized_ckpt`**: Similarly, used to specify encoder quantized weight paths
+- **`low_noise_quantized_ckpt`** and **`high_noise_quantized_ckpt`**: Used to specify dual-branch quantized weights for Wan2.2 models
+
+**Quantized Model Downloads:**
+
+- **Wan-2.1 Quantized Models**: Download from [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
+- **Wan-2.2 Quantized Models**: Download from [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
+- **HunyuanVideo-1.5 Quantized Models**: Download from [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models)
+ - `hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors` is the quantized weight for the text encoder
+
+**Usage Examples:**
+
+```python
+# HunyuanVideo-1.5 Quantization Example
+pipe.enable_quantize(
+ quant_scheme='fp8-sgl',
+ dit_quantized=True,
+ dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors",
+ text_encoder_quantized=True,
+ image_encoder_quantized=False,
+ text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors",
+)
+
+# Wan2.1 Quantization Example
+pipe.enable_quantize(
+ dit_quantized=True,
+ dit_quantized_ckpt="/path/to/wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors",
+)
+
+# Wan2.2 Quantization Example
+pipe.enable_quantize(
+ dit_quantized=True,
+ low_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors",
+ high_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step_1030.safetensors",
+)
+```
+
+**Quantization Scheme Reference:** For detailed information, please refer to the [Quantization Documentation](../docs/EN/source/method_tutorials/quantization.md)
+
+### Parallel Inference
+
+Supports multi-GPU parallel inference. Requires running with `torchrun`:
+
+```python
+pipe.enable_parallel(
+ seq_p_size=4, # Sequence parallel size
+ seq_p_attn_type="ulysses", # Sequence parallel attention type
+)
+```
+
+**Running Method:**
+```bash
+torchrun --nproc_per_node=4 your_script.py
+```
+
+### Feature Caching
+
+You can specify the cache method as Mag or Tea, using MagCache and TeaCache methods:
+
+```python
+pipe.enable_cache(
+ cache_method='Tea', # Cache method: 'Tea' or 'Mag'
+ coefficients=[-3.08907507e+04, 1.67786188e+04, -3.19178643e+03,
+ 2.60740519e+02, -8.19205881e+00, 1.07913775e-01], # Coefficients
+ teacache_thresh=0.15, # TeaCache threshold
+)
+```
+
+**Coefficient Reference:** Refer to configuration files in `configs/caching` or `configs/hunyuan_video_15/cache` directories
+
+### Lightweight VAE
+
+Using lightweight VAE can accelerate decoding and reduce memory usage.
+
+```python
+pipe.enable_lightvae(
+ use_lightvae=False, # Whether to use LightVAE
+ use_tae=False, # Whether to use LightTAE
+ vae_path=None, # Path to LightVAE
+ tae_path=None, # Path to LightTAE
+)
+```
+
+**Support Status:**
+- **LightVAE**: Currently only supports wan2.1, wan2.2 moe
+- **LightTAE**: Currently only supports wan2.1, wan2.2-ti2v, wan2.2 moe, HunyuanVideo-1.5
+
+**Model Downloads:** Lightweight VAE models can be downloaded from [Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
+
+- LightVAE for Wan-2.1: [lightvaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lightvaew2_1.safetensors)
+- LightTAE for Wan-2.1: [lighttaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_1.safetensors)
+- LightTAE for Wan-2.2-ti2v: [lighttaew2_2.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_2.safetensors)
+- LightTAE for HunyuanVideo-1.5: [lighttaehy1_5.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors)
+
+**Usage Example:**
+
+```python
+# Using LightTAE for HunyuanVideo-1.5
+pipe.enable_lightvae(
+ use_tae=True,
+ tae_path="/path/to/lighttaehy1_5.safetensors",
+ use_lightvae=False,
+ vae_path=None
+)
+```
+
+## 📚 More Resources
+
+- [Full Documentation](https://lightx2v-en.readthedocs.io/en/latest/)
+- [GitHub Repository](https://github.com/ModelTC/LightX2V)
+- [HuggingFace Model Hub](https://huggingface.co/lightx2v)
diff --git a/examples/README_zh.md b/examples/README_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..124b527f98f59bf8af4d897c354cab6e460d3fda
--- /dev/null
+++ b/examples/README_zh.md
@@ -0,0 +1,284 @@
+# LightX2V 使用示例
+
+本文档介绍如何使用 LightX2V 进行视频生成,包括基础使用和进阶配置。
+
+## 📋 目录
+
+- [环境安装](#环境安装)
+- [基础运行示例](#基础运行示例)
+- [模型路径配置](#模型路径配置)
+- [创建生成器](#创建生成器)
+- [进阶配置](#进阶配置)
+ - [参数卸载 (Offload)](#参数卸载-offload)
+ - [模型量化 (Quantization)](#模型量化-quantization)
+ - [并行推理 (Parallel Inference)](#并行推理-parallel-inference)
+ - [特征缓存 (Cache)](#特征缓存-cache)
+ - [轻量 VAE (Light VAE)](#轻量-vae-light-vae)
+
+## 🔧 环境安装
+
+请参考主项目的[快速入门文档](../docs/ZH_CN/source/getting_started/quickstart.md)进行环境安装。
+
+## 🚀 基础运行示例
+
+最小化代码示例可参考 `examples/wan_t2v.py`:
+
+```python
+from lightx2v import LightX2VPipeline
+
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.1-T2V-14B",
+ model_cls="wan2.1",
+ task="t2v",
+)
+
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=50,
+ height=480,
+ width=832,
+ num_frames=81,
+ guidance_scale=5.0,
+ sample_shift=5.0,
+)
+
+seed = 42
+prompt = "Your prompt here"
+negative_prompt = ""
+save_result_path="/path/to/save_results/output.mp4"
+
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
+```
+
+## 📁 模型路径配置
+
+### 基础配置
+
+将模型路径传入 `LightX2VPipeline`:
+
+```python
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.2-I2V-A14B",
+ model_cls="wan2.2_moe", # 对于 wan2.1,使用 "wan2.1"
+ task="i2v",
+)
+```
+
+### 多版本模型权重指定
+
+当 `model_path` 目录下存在多个不同版本的 bf16 精度 DIT 模型 safetensors 文件时,需要使用以下参数指定具体使用哪个权重:
+
+- **`dit_original_ckpt`**: 用于指定 wan2.1 和 hunyuan15 等模型的原始 DIT 权重路径
+- **`low_noise_original_ckpt`**: 用于指定 wan2.2 模型的低噪声分支权重路径
+- **`high_noise_original_ckpt`**: 用于指定 wan2.2 模型的高噪声分支权重路径
+
+**使用示例:**
+
+```python
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.2-I2V-A14B",
+ model_cls="wan2.2_moe",
+ task="i2v",
+ low_noise_original_ckpt="/path/to/low_noise_model.safetensors",
+ high_noise_original_ckpt="/path/to/high_noise_model.safetensors",
+)
+```
+
+## 🎛️ 创建生成器
+
+### 从配置文件加载
+
+生成器可以从 JSON 配置文件直接加载,配置文件位于 `configs` 目录:
+
+```python
+pipe.create_generator(config_json="../configs/wan/wan_t2v.json")
+```
+
+### 手动创建生成器
+
+也可以手动创建生成器,并配置多个参数:
+
+```python
+pipe.create_generator(
+ attn_mode="flash_attn2", # 可选: flash_attn2, flash_attn3, sage_attn2, sage_attn3 (B架构显卡适用)
+ infer_steps=50, # 推理步数
+ num_frames=81, # 视频帧数
+ height=480, # 视频高度
+ width=832, # 视频宽度
+ guidance_scale=5.0, # CFG引导强度 (=1时弃用CFG)
+ sample_shift=5.0, # 采样偏移
+ fps=16, # 帧率
+ aspect_ratio="16:9", # 宽高比
+ boundary=0.900, # 边界值
+ boundary_step_index=2, # 边界步索引
+ denoising_step_list=[1000, 750, 500, 250], # 去噪步列表
+)
+```
+
+**参数说明:**
+- **分辨率**: 通过 `height` 和 `width` 指定
+- **CFG**: 通过 `guidance_scale` 指定(设置为 1 时禁用 CFG)
+- **FPS**: 通过 `fps` 指定帧率
+- **视频长度**: 通过 `num_frames` 指定帧数
+- **推理步数**: 通过 `infer_steps` 指定
+- **采样偏移**: 通过 `sample_shift` 指定
+- **注意力模式**: 通过 `attn_mode` 指定,可选 `flash_attn2`, `flash_attn3`, `sage_attn2`, `sage_attn3`(B架构显卡适用)
+
+## ⚙️ 进阶配置
+
+**⚠️ 重要提示:手动创建生成器时,可以配置一些进阶选项,所有进阶配置必须在 `create_generator()` 之前指定,否则会失效!**
+
+### 参数卸载 (Offload)
+
+显著降低显存占用,几乎不影响推理速度,适用于 RTX 30/40/50 系列显卡。
+
+```python
+pipe.enable_offload(
+ cpu_offload=True, # 启用 CPU 卸载
+ offload_granularity="block", # 卸载粒度: "block" 或 "phase"
+ text_encoder_offload=False, # 文本编码器是否卸载
+ image_encoder_offload=False, # 图像编码器是否卸载
+ vae_offload=False, # VAE 是否卸载
+)
+```
+
+**说明:**
+- 对于 Wan 模型,`offload_granularity` 支持 `"block"` 和 `"phase"`
+- 对于 HunyuanVideo-1.5,目前只支持 `"block"`
+
+### 模型量化 (Quantization)
+
+量化可以显著降低显存占用并加速推理。
+
+```python
+pipe.enable_quantize(
+ dit_quantized=False, # 是否使用量化的 DIT 模型
+ text_encoder_quantized=False, # 是否使用量化的文本编码器
+ image_encoder_quantized=False, # 是否使用量化的图像编码器
+ dit_quantized_ckpt=None, # DIT 量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定)
+ low_noise_quantized_ckpt=None, # Wan2.2 低噪声分支量化权重路径
+ high_noise_quantized_ckpt=None, # Wan2.2 高噪声分支量化权重路径
+ text_encoder_quantized_ckpt=None, # 文本编码器量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定)
+ image_encoder_quantized_ckpt=None, # 图像编码器量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定)
+ quant_scheme="fp8-sgl", # 量化方案
+)
+```
+
+**参数说明:**
+- **`dit_quantized_ckpt`**: 当 `model_path` 目录下没有量化权重,或存在多个权重文件时,需要指定具体的 DIT 量化权重路径
+- **`text_encoder_quantized_ckpt`** 和 **`image_encoder_quantized_ckpt`**: 类似地,用于指定编码器的量化权重路径
+- **`low_noise_quantized_ckpt`** 和 **`high_noise_quantized_ckpt`**: 用于指定 Wan2.2 模型的双分支量化权重
+
+**量化模型下载:**
+
+- **Wan-2.1 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) 下载
+- **Wan-2.2 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) 下载
+- **HunyuanVideo-1.5 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models) 下载
+ - `hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors` 是文本编码器的量化权重
+
+**使用示例:**
+
+```python
+# HunyuanVideo-1.5 量化示例
+pipe.enable_quantize(
+ quant_scheme='fp8-sgl',
+ dit_quantized=True,
+ dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors",
+ text_encoder_quantized=True,
+ image_encoder_quantized=False,
+ text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors",
+)
+
+# Wan2.1 量化示例
+pipe.enable_quantize(
+ dit_quantized=True,
+ dit_quantized_ckpt="/path/to/wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors",
+)
+
+# Wan2.2 量化示例
+pipe.enable_quantize(
+ dit_quantized=True,
+ low_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors",
+ high_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step_1030.safetensors",
+)
+```
+
+**量化方案参考:** 详细说明请参考 [量化文档](../docs/ZH_CN/source/method_tutorials/quantization.md)
+
+### 并行推理 (Parallel Inference)
+
+支持多 GPU 并行推理,需要使用 `torchrun` 运行:
+
+```python
+pipe.enable_parallel(
+ seq_p_size=4, # 序列并行大小
+ seq_p_attn_type="ulysses", # 序列并行注意力类型
+)
+```
+
+**运行方式:**
+```bash
+torchrun --nproc_per_node=4 your_script.py
+```
+
+### 特征缓存 (Cache)
+
+可以指定缓存方法为 Mag 或 Tea,使用 MagCache 和 TeaCache 方法:
+
+```python
+pipe.enable_cache(
+ cache_method='Tea', # 缓存方法: 'Tea' 或 'Mag'
+ coefficients=[-3.08907507e+04, 1.67786188e+04, -3.19178643e+03,
+ 2.60740519e+02, -8.19205881e+00, 1.07913775e-01], # 系数
+ teacache_thresh=0.15, # TeaCache 阈值
+)
+```
+
+**系数参考:** 可参考 `configs/caching` 或 `configs/hunyuan_video_15/cache` 目录下的配置文件
+
+### 轻量 VAE (Light VAE)
+
+使用轻量 VAE 可以加速解码并降低显存占用。
+
+```python
+pipe.enable_lightvae(
+ use_lightvae=False, # 是否使用 LightVAE
+ use_tae=False, # 是否使用 LightTAE
+ vae_path=None, # LightVAE 的路径
+ tae_path=None, # LightTAE 的路径
+)
+```
+
+**支持情况:**
+- **LightVAE**: 目前只支持 wan2.1、wan2.2 moe
+- **LightTAE**: 目前只支持 wan2.1、wan2.2-ti2v、wan2.2 moe、HunyuanVideo-1.5
+
+**模型下载:** 轻量 VAE 模型可从 [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) 下载
+
+- Wan-2.1 的 LightVAE: [lightvaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lightvaew2_1.safetensors)
+- Wan-2.1 的 LightTAE: [lighttaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_1.safetensors)
+- Wan-2.2-ti2v 的 LightTAE: [lighttaew2_2.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_2.safetensors)
+- HunyuanVideo-1.5 的 LightTAE: [lighttaehy1_5.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors)
+
+**使用示例:**
+
+```python
+# 使用 HunyuanVideo-1.5 的 LightTAE
+pipe.enable_lightvae(
+ use_tae=True,
+ tae_path="/path/to/lighttaehy1_5.safetensors",
+ use_lightvae=False,
+ vae_path=None
+)
+```
+
+## 📚 更多资源
+
+- [完整文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/)
+- [GitHub 仓库](https://github.com/ModelTC/LightX2V)
+- [HuggingFace 模型库](https://huggingface.co/lightx2v)
diff --git a/examples/hunyuan_video/hunyuan_i2v.py b/examples/hunyuan_video/hunyuan_i2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..654b1f75333dda87cae3e135a7a5d22ae0301ce2
--- /dev/null
+++ b/examples/hunyuan_video/hunyuan_i2v.py
@@ -0,0 +1,63 @@
+"""
+HunyuanVideo-1.5 image-to-video generation example with quantization.
+This example demonstrates how to use LightX2V with HunyuanVideo-1.5 model for I2V generation,
+including quantized model usage for reduced memory consumption.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for HunyuanVideo-1.5 I2V task
+pipe = LightX2VPipeline(
+ model_path="/path/to/ckpts/hunyuanvideo-1.5/",
+ model_cls="hunyuan_video_1.5",
+ transformer_model_name="720p_i2v",
+ task="i2v",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_i2v_720p.json")
+
+# Enable offloading to significantly reduce VRAM usage with minimal speed impact
+# Suitable for RTX 30/40/50 consumer GPUs
+pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
+ text_encoder_offload=True,
+ image_encoder_offload=False,
+ vae_offload=False,
+)
+
+# Enable quantization for reduced memory usage
+# Quantized models can be downloaded from: https://huggingface.co/lightx2v/Hy1.5-Quantized-Models
+pipe.enable_quantize(
+ quant_scheme="fp8-sgl",
+ dit_quantized=True,
+ dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors",
+ text_encoder_quantized=True,
+ image_encoder_quantized=False,
+ text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors",
+)
+
+# Create generator with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=50,
+ num_frames=121,
+ guidance_scale=6.0,
+ sample_shift=7.0,
+ fps=24,
+)
+
+# Generation parameters
+seed = 42
+prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
+negative_prompt = ""
+save_result_path = "/path/to/save_results/output2.mp4"
+
+# Generate video
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/hunyuan_video/hunyuan_t2v.py b/examples/hunyuan_video/hunyuan_t2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..7283bdede10cf9a7597b7c7551aa63935c283795
--- /dev/null
+++ b/examples/hunyuan_video/hunyuan_t2v.py
@@ -0,0 +1,60 @@
+"""
+HunyuanVideo-1.5 text-to-video generation example.
+This example demonstrates how to use LightX2V with HunyuanVideo-1.5 model for T2V generation.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for HunyuanVideo-1.5
+pipe = LightX2VPipeline(
+ model_path="/path/to/ckpts/hunyuanvideo-1.5/",
+ model_cls="hunyuan_video_1.5",
+ transformer_model_name="720p_t2v",
+ task="t2v",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_t2v_720p.json")
+
+# Enable offloading to significantly reduce VRAM usage with minimal speed impact
+# Suitable for RTX 30/40/50 consumer GPUs
+pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
+ text_encoder_offload=True,
+ image_encoder_offload=False,
+ vae_offload=False,
+)
+
+# Use lighttae
+pipe.enable_lightvae(
+ use_tae=True,
+ tae_path="/path/to/lighttaehy1_5.safetensors",
+ use_lightvae=False,
+ vae_path=None,
+)
+
+# Create generator with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=50,
+ num_frames=121,
+ guidance_scale=6.0,
+ sample_shift=9.0,
+ aspect_ratio="16:9",
+ fps=24,
+)
+
+# Generation parameters
+seed = 123
+prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
+negative_prompt = ""
+save_result_path = "/path/to/save_results/output.mp4"
+
+# Generate video
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/hunyuan_video/hunyuan_t2v_distill.py b/examples/hunyuan_video/hunyuan_t2v_distill.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c7cdd53d73d9e0c20e6eb94d396c4db00b1078
--- /dev/null
+++ b/examples/hunyuan_video/hunyuan_t2v_distill.py
@@ -0,0 +1,55 @@
+"""
+HunyuanVideo-1.5 text-to-video generation example.
+This example demonstrates how to use LightX2V with HunyuanVideo-1.5 4-step distilled model for T2V generation.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for HunyuanVideo-1.5
+pipe = LightX2VPipeline(
+ model_path="/path/to/ckpts/hunyuanvideo-1.5/",
+ model_cls="hunyuan_video_1.5",
+ transformer_model_name="480p_t2v",
+ task="t2v",
+ # 4-step distilled model ckpt
+ dit_original_ckpt="/path/to/hy1.5_t2v_480p_lightx2v_4step.safetensors",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_t2v_720p.json")
+
+# Enable offloading to significantly reduce VRAM usage with minimal speed impact
+# Suitable for RTX 30/40/50 consumer GPUs
+pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
+ text_encoder_offload=True,
+ image_encoder_offload=False,
+ vae_offload=False,
+)
+
+# Use lighttae
+pipe.enable_lightvae(
+ use_tae=True,
+ tae_path="/path/to/lighttaehy1_5.safetensors",
+ use_lightvae=False,
+ vae_path=None,
+)
+
+# Create generator with specified parameters
+pipe.create_generator(attn_mode="sage_attn2", infer_steps=4, num_frames=81, guidance_scale=1, sample_shift=9.0, aspect_ratio="16:9", fps=16, denoising_step_list=[1000, 750, 500, 250])
+
+
+# Generation parameters
+seed = 123
+prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
+negative_prompt = ""
+save_result_path = "/data/nvme0/gushiqiao/LightX2V/save_results/output.mp4"
+
+# Generate video
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/wan/wan_animate.py b/examples/wan/wan_animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..317f34d71d2c10c6490c42279fd3ef60a58a98aa
--- /dev/null
+++ b/examples/wan/wan_animate.py
@@ -0,0 +1,72 @@
+"""
+Wan2.2 animate video generation example.
+This example demonstrates how to use LightX2V with Wan2.2 model for animate video generation.
+
+First, run preprocessing:
+1. Set up environment: pip install -r ../requirements_animate.txt
+2. For animate mode:
+ python ../tools/preprocess/preprocess_data.py \
+ --ckpt_path /path/to/Wan2.1-FLF2V-14B-720P/process_checkpoint \
+ --video_path /path/to/video \
+ --refer_path /path/to/ref_img \
+ --save_path ../save_results/animate/process_results \
+ --resolution_area 1280 720 \
+ --retarget_flag
+3. For replace mode:
+ python ../tools/preprocess/preprocess_data.py \
+ --ckpt_path /path/to/Wan2.1-FLF2V-14B-720P/process_checkpoint \
+ --video_path /path/to/video \
+ --refer_path /path/to/ref_img \
+ --save_path ../save_results/replace/process_results \
+ --resolution_area 1280 720 \
+ --iterations 3 \
+ --k 7 \
+ --w_len 1 \
+ --h_len 1 \
+ --replace_flag
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for animate task
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.1-FLF2V-14B-720P",
+ model_cls="wan2.2_animate",
+ task="animate",
+)
+pipe.replace_flag = True # Set to True for replace mode, False for animate mode
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(
+# config_json="../configs/wan/wan_animate_replace.json"
+# )
+
+# Create generator with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=20,
+ height=480, # Can be set to 720 for higher resolution
+ width=832, # Can be set to 1280 for higher resolution
+ num_frames=77,
+ guidance_scale=1,
+ sample_shift=5.0,
+ fps=30,
+)
+
+seed = 42
+prompt = "视频中的人在做动作"
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+src_pose_path = "../save_results/animate/process_results/src_pose.mp4"
+src_face_path = "../save_results/animate/process_results/src_face.mp4"
+src_ref_images = "../save_results/animate/process_results/src_ref.png"
+save_result_path = "/path/to/save_results/output.mp4"
+
+pipe.generate(
+ seed=seed,
+ src_pose_path=src_pose_path,
+ src_face_path=src_face_path,
+ src_ref_images=src_ref_images,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/wan/wan_flf2v.py b/examples/wan/wan_flf2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f18ccf4a0b440a07f594b1f4345f2ca470d8841
--- /dev/null
+++ b/examples/wan/wan_flf2v.py
@@ -0,0 +1,55 @@
+"""
+Wan2.1 first-last-frame-to-video generation example.
+This example demonstrates how to use LightX2V with Wan2.1 model for FLF2V generation.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for FLF2V task
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.1-FLF2V-14B-720P",
+ model_cls="wan2.1",
+ task="flf2v",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(
+# config_json="../configs/wan/wan_flf2v.json"
+# )
+
+# Optional: enable offloading to significantly reduce VRAM usage
+# Suitable for RTX 30/40/50 consumer GPUs
+# pipe.enable_offload(
+# cpu_offload=True,
+# offload_granularity="block",
+# text_encoder_offload=True,
+# image_encoder_offload=False,
+# vae_offload=False,
+# )
+
+# Create generator with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=40,
+ height=480, # Can be set to 720 for higher resolution
+ width=832, # Can be set to 1280 for higher resolution
+ num_frames=81,
+ guidance_scale=5,
+ sample_shift=5.0,
+)
+
+seed = 42
+prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+image_path = "../assets/inputs/imgs/flf2v_input_first_frame-fs8.png"
+last_frame_path = "../assets/inputs/imgs/flf2v_input_last_frame-fs8.png"
+save_result_path = "/path/to/save_results/output.mp4"
+
+pipe.generate(
+ image_path=image_path,
+ last_frame_path=last_frame_path,
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/wan/wan_i2v.py b/examples/wan/wan_i2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..632566b4be0447ee8f442882b67e40f2a848ec66
--- /dev/null
+++ b/examples/wan/wan_i2v.py
@@ -0,0 +1,56 @@
+"""
+Wan2.2 image-to-video generation example.
+This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for Wan2.2 I2V task
+# For wan2.1, use model_cls="wan2.1"
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.2-I2V-A14B",
+ model_cls="wan2.2_moe",
+ task="i2v",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(
+# config_json="../configs/wan22/wan_moe_i2v.json"
+# )
+
+# Enable offloading to significantly reduce VRAM usage with minimal speed impact
+# Suitable for RTX 30/40/50 consumer GPUs
+pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block", # For Wan models, supports both "block" and "phase"
+ text_encoder_offload=True,
+ image_encoder_offload=False,
+ vae_offload=False,
+)
+
+# Create generator manually with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=40,
+ height=480, # Can be set to 720 for higher resolution
+ width=832, # Can be set to 1280 for higher resolution
+ num_frames=81,
+ guidance_scale=[3.5, 3.5], # For wan2.1, guidance_scale is a scalar (e.g., 5.0)
+ sample_shift=5.0,
+)
+
+# Generation parameters
+seed = 42
+prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+image_path = "/path/to/img_0.jpg"
+save_result_path = "/path/to/save_results/output.mp4"
+
+# Generate video
+pipe.generate(
+ seed=seed,
+ image_path=image_path,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/wan/wan_i2v_distilled.py b/examples/wan/wan_i2v_distilled.py
new file mode 100644
index 0000000000000000000000000000000000000000..be89488b03f3f3507bf734be6c326d801c7e5637
--- /dev/null
+++ b/examples/wan/wan_i2v_distilled.py
@@ -0,0 +1,57 @@
+"""
+Wan2.2 distilled model image-to-video generation example.
+This example demonstrates how to use LightX2V with Wan2.2 distilled model for I2V generation.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for Wan2.2 distilled I2V task
+# For wan2.1, use model_cls="wan2.1_distill"
+pipe = LightX2VPipeline(
+ model_path="/path/to/wan2.2/Wan2.2-I2V-A14B",
+ model_cls="wan2.2_moe_distill",
+ task="i2v",
+ # Distilled weights: For wan2.1, only need to specify dit_original_ckpt="/path/to/wan2.1_i2v_720p_lightx2v_4step.safetensors"
+ low_noise_original_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors",
+ high_noise_original_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(
+# config_json="../configs/wan22/wan_moe_i2v_distill.json"
+# )
+
+# Enable offloading to significantly reduce VRAM usage
+# Suitable for RTX 30/40/50 consumer GPUs
+pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block",
+ text_encoder_offload=True,
+ image_encoder_offload=False,
+ vae_offload=False,
+)
+
+# Create generator manually with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=4,
+ height=480, # Can be set to 720 for higher resolution
+ width=832, # Can be set to 1280 for higher resolution
+ num_frames=81,
+ guidance_scale=1,
+ sample_shift=5.0,
+)
+
+seed = 42
+prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+save_result_path = "/path/to/save_results/output.mp4"
+image_path = "/path/to/img_0.jpg"
+
+pipe.generate(
+ seed=seed,
+ image_path=image_path,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/wan/wan_i2v_with_distill_loras.py b/examples/wan/wan_i2v_with_distill_loras.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d629d52b94cbfb6226a7db36e734b2658c27132
--- /dev/null
+++ b/examples/wan/wan_i2v_with_distill_loras.py
@@ -0,0 +1,62 @@
+"""
+Wan2.2 distilled model with LoRA image-to-video generation example.
+This example demonstrates how to use LightX2V with Wan2.2 distilled model and LoRA for I2V generation.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for Wan2.2 distilled I2V task with LoRA
+# For wan2.1, use model_cls="wan2.1_distill"
+pipe = LightX2VPipeline(
+ model_path="/path/to/wan2.2/Wan2.2-I2V-A14B",
+ model_cls="wan2.2_moe_distill",
+ task="i2v",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(
+# config_json="../configs/wan22/wan_moe_i2v_distill_with_lora.json"
+# )
+
+# Enable offloading to significantly reduce VRAM usage
+# Suitable for RTX 30/40/50 consumer GPUs
+pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block",
+ text_encoder_offload=True,
+ image_encoder_offload=False,
+ vae_offload=False,
+)
+
+# Load distilled LoRA weights
+pipe.enable_lora(
+ [
+ {"name": "high_noise_model", "path": "/path/to/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "strength": 1.0},
+ {"name": "low_noise_model", "path": "/path/to/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", "strength": 1.0},
+ ]
+)
+
+# Create generator with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=4,
+ height=480, # Can be set to 720 for higher resolution
+ width=832, # Can be set to 1280 for higher resolution
+ num_frames=81,
+ guidance_scale=1,
+ sample_shift=5.0,
+)
+
+seed = 42
+prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+save_result_path = "/path/to/save_results/output.mp4"
+image_path = "/path/to/img_0.jpg"
+
+pipe.generate(
+ seed=seed,
+ image_path=image_path,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/wan/wan_t2v.py b/examples/wan/wan_t2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea2e71b27e3658f86dae33eb8d3907a1e831a8db
--- /dev/null
+++ b/examples/wan/wan_t2v.py
@@ -0,0 +1,39 @@
+"""
+Wan2.1 text-to-video generation example.
+This example demonstrates how to use LightX2V with Wan2.1 model for T2V generation.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for Wan2.1 T2V task
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.1-T2V-14B",
+ model_cls="wan2.1",
+ task="t2v",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(config_json="../configs/wan/wan_t2v.json")
+
+# Create generator with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=50,
+ height=480, # Can be set to 720 for higher resolution
+ width=832, # Can be set to 1280 for higher resolution
+ num_frames=81,
+ guidance_scale=5.0,
+ sample_shift=5.0,
+)
+
+seed = 42
+prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+save_result_path = "/path/to/save_results/output.mp4"
+
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/examples/wan/wan_vace.py b/examples/wan/wan_vace.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bbf88c2db6848fc9fab6aaa70487226c3ec4a87
--- /dev/null
+++ b/examples/wan/wan_vace.py
@@ -0,0 +1,52 @@
+"""
+Wan2.1 VACE (Video Animate Character Exchange) generation example.
+This example demonstrates how to use LightX2V with Wan2.1 VACE model for character exchange in videos.
+"""
+
+from lightx2v import LightX2VPipeline
+
+# Initialize pipeline for VACE task
+pipe = LightX2VPipeline(
+ model_path="/path/to/Wan2.1-VACE-1.3B",
+ src_ref_images="../assets/inputs/imgs/girl.png,../assets/inputs/imgs/snake.png",
+ model_cls="wan2.1_vace",
+ task="vace",
+)
+
+# Alternative: create generator from config JSON file
+# pipe.create_generator(
+# config_json="../configs/wan/wan_vace.json"
+# )
+
+# Optional: enable offloading to significantly reduce VRAM usage
+# Suitable for RTX 30/40/50 consumer GPUs
+# pipe.enable_offload(
+# cpu_offload=True,
+# offload_granularity="block",
+# text_encoder_offload=True,
+# image_encoder_offload=False,
+# vae_offload=False,
+# )
+
+# Create generator with specified parameters
+pipe.create_generator(
+ attn_mode="sage_attn2",
+ infer_steps=40,
+ height=480, # Can be set to 720 for higher resolution
+ width=832, # Can be set to 1280 for higher resolution
+ num_frames=81,
+ guidance_scale=5,
+ sample_shift=16,
+)
+
+seed = 42
+prompt = "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+save_result_path = "/path/to/save_results/output.mp4"
+
+pipe.generate(
+ seed=seed,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/lightx2v/__init__.py b/lightx2v/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7298b043cb7a1a768579e0639eb0a2c17c82451b
--- /dev/null
+++ b/lightx2v/__init__.py
@@ -0,0 +1,18 @@
+__version__ = "0.1.0"
+__author__ = "LightX2V Contributors"
+__license__ = "Apache 2.0"
+
+import lightx2v_platform.set_ai_device
+from lightx2v import common, deploy, models, utils
+from lightx2v.pipeline import LightX2VPipeline
+
+__all__ = [
+ "__version__",
+ "__author__",
+ "__license__",
+ "models",
+ "common",
+ "deploy",
+ "utils",
+ "LightX2VPipeline",
+]
diff --git a/lightx2v/common/__init__.py b/lightx2v/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lightx2v/common/modules/__init__.py b/lightx2v/common/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lightx2v/common/modules/weight_module.py b/lightx2v/common/modules/weight_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..59d646cc57df3ccbdb72b687f357c2ced68bceb5
--- /dev/null
+++ b/lightx2v/common/modules/weight_module.py
@@ -0,0 +1,175 @@
+class WeightModule:
+ def __init__(self):
+ self._modules = {}
+ self._parameters = {}
+
+ def is_empty(self):
+ return len(self._modules) == 0 and len(self._parameters) == 0
+
+ def add_module(self, name, module):
+ self._modules[name] = module
+ setattr(self, name, module)
+
+ def register_parameter(self, name, param):
+ self._parameters[name] = param
+ setattr(self, name, param)
+
+ def load(self, weight_dict):
+ for _, module in self._modules.items():
+ if hasattr(module, "load"):
+ module.load(weight_dict)
+
+ for _, parameter in self._parameters.items():
+ if hasattr(parameter, "load"):
+ parameter.load(weight_dict)
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ for _, param in self._parameters.items():
+ if param is not None:
+ param.state_dict(destination)
+ for _, module in self._modules.items():
+ if module is not None:
+ module.state_dict(destination)
+ return destination
+
+ def load_state_dict(self, destination, block_index, adapter_block_index=None):
+ if destination is None:
+ destination = {}
+ for _, param in self._parameters.items():
+ if param is not None:
+ param.load_state_dict(destination, block_index, adapter_block_index)
+ for _, module in self._modules.items():
+ if module is not None:
+ module.load_state_dict(destination, block_index, adapter_block_index)
+ return destination
+
+ def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
+ for _, param in self._parameters.items():
+ if param is not None:
+ param.load_state_dict_from_disk(block_index, adapter_block_index)
+ for _, module in self._modules.items():
+ if module is not None:
+ module.load_state_dict_from_disk(block_index, adapter_block_index)
+
+ def named_parameters(self, prefix=""):
+ for name, param in self._parameters.items():
+ if param is not None:
+ yield prefix + name, param
+ for name, module in self._modules.items():
+ if module is not None:
+ yield from module.named_parameters(prefix + name + ".")
+
+ def to_cpu(self):
+ for name, param in self._parameters.items():
+ if param is not None:
+ if hasattr(param, "cpu"):
+ self._parameters[name] = param.cpu()
+ setattr(self, name, self._parameters[name])
+ elif hasattr(param, "to_cpu"):
+ self._parameters[name].to_cpu()
+ setattr(self, name, self._parameters[name])
+ for module in self._modules.values():
+ if isinstance(module, WeightModuleList):
+ for i in range(len(module)):
+ for m in module[i]._modules.values():
+ if m is not None and hasattr(m, "to_cpu"):
+ m.to_cpu()
+ for m in module[i]._parameters.values():
+ if m is not None and hasattr(m, "to_cpu"):
+ m.to_cpu()
+ else:
+ if module is not None and hasattr(module, "to_cpu"):
+ module.to_cpu()
+
+ def to_cuda(self):
+ for name, param in self._parameters.items():
+ if param is not None:
+ if hasattr(param, "cuda"):
+ self._parameters[name] = param.cuda()
+ elif hasattr(param, "to_cuda"):
+ self._parameters[name].to_cuda()
+ setattr(self, name, self._parameters[name])
+ for module in self._modules.values():
+ if isinstance(module, WeightModuleList):
+ for i in range(len(module)):
+ for m in module[i]._modules.values():
+ if m is not None and hasattr(m, "to_cuda"):
+ m.to_cuda()
+ for m in module[i]._parameters.values():
+ if m is not None and hasattr(m, "to_cuda"):
+ m.to_cuda()
+ else:
+ if module is not None and hasattr(module, "to_cuda"):
+ module.to_cuda()
+
+ def to_cpu_async(self):
+ for name, param in self._parameters.items():
+ if param is not None:
+ if hasattr(param, "cpu"):
+ self._parameters[name] = param.cpu(non_blocking=True)
+ setattr(self, name, self._parameters[name])
+ elif hasattr(param, "to_cpu"):
+ self._parameters[name].to_cpu(non_blocking=True)
+ setattr(self, name, self._parameters[name])
+ for module in self._modules.values():
+ if isinstance(module, WeightModuleList):
+ for i in range(len(module)):
+ for m in module[i]._modules.values():
+ if m is not None and hasattr(m, "to_cpu"):
+ m.to_cpu(non_blocking=True)
+ for m in module[i]._parameters.values():
+ if m is not None and hasattr(m, "to_cpu"):
+ m.to_cpu(non_blocking=True)
+ else:
+ if module is not None and hasattr(module, "to_cpu"):
+ module.to_cpu(non_blocking=True)
+
+ def to_cuda_async(self):
+ for name, param in self._parameters.items():
+ if param is not None:
+ if hasattr(param, "cuda"):
+ self._parameters[name] = param.cuda(non_blocking=True)
+ elif hasattr(param, "to_cuda"):
+ self._parameters[name].to_cuda(non_blocking=True)
+ setattr(self, name, self._parameters[name])
+ for module in self._modules.values():
+ if isinstance(module, WeightModuleList):
+ for i in range(len(module)):
+ for m in module[i]._modules.values():
+ if m is not None and hasattr(m, "to_cuda"):
+ m.to_cuda(non_blocking=True)
+ for m in module[i]._parameters.values():
+ if m is not None and hasattr(m, "to_cuda"):
+ m.to_cuda(non_blocking=True)
+ else:
+ if module is not None and hasattr(module, "to_cuda"):
+ module.to_cuda(non_blocking=True)
+
+
+class WeightModuleList(WeightModule):
+ def __init__(self, modules=None):
+ super().__init__()
+ self._list = []
+ if modules is not None:
+ for idx, module in enumerate(modules):
+ self.append(module)
+
+ def append(self, module):
+ idx = len(self._list)
+ self._list.append(module)
+ self.add_module(str(idx), module)
+
+ def __getitem__(self, idx):
+ return self._list[idx]
+
+ def __setitem__(self, idx, module):
+ self._list[idx] = module
+ self.add_module(str(idx), module)
+
+ def __len__(self):
+ return len(self._list)
+
+ def __iter__(self):
+ return iter(self._list)
diff --git a/lightx2v/common/offload/manager.py b/lightx2v/common/offload/manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1671ebec66cf8a17f4c53c1b4191e836e0c3936
--- /dev/null
+++ b/lightx2v/common/offload/manager.py
@@ -0,0 +1,133 @@
+from concurrent.futures import ThreadPoolExecutor
+
+import torch
+from loguru import logger
+from packaging.version import parse
+from tqdm import tqdm
+
+from lightx2v.utils.profiler import ExcludedProfilingContext
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+torch_device_module = getattr(torch, AI_DEVICE)
+
+
+class WeightAsyncStreamManager(object):
+ def __init__(self, offload_granularity):
+ self.offload_granularity = offload_granularity
+ self.init_stream = torch_device_module.Stream(priority=0)
+ self.need_init_first_buffer = True
+ self.lazy_load = False
+ torch_version = parse(torch.__version__.split("+")[0])
+ if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
+ self.cuda_load_stream = torch_device_module.Stream(priority=1)
+ self.compute_stream = torch_device_module.Stream(priority=1)
+ else:
+ self.cuda_load_stream = torch_device_module.Stream(priority=0)
+ self.compute_stream = torch_device_module.Stream(priority=-1)
+
+ def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None):
+ self.need_init_first_buffer = True
+ if self.offload_granularity == "block":
+ assert blocks_cpu_buffer is not None
+ self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))]
+ elif self.offload_granularity == "phase":
+ assert phases_cpu_buffer is not None
+ self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))]
+ else:
+ raise NotImplementedError
+
+ def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
+ self.need_init_first_buffer = True
+ if self.offload_granularity == "block":
+ assert blocks_cuda_buffer is not None
+ self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
+ elif self.offload_granularity == "phase":
+ assert phases_cuda_buffer is not None
+ self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
+ else:
+ raise NotImplementedError
+
+ def init_first_buffer(self, blocks, adapter_block_idx=None):
+ with torch_device_module.stream(self.init_stream):
+ if hasattr(self, "cpu_buffers"):
+ self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
+ else:
+ if self.offload_granularity == "block":
+ self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
+ else:
+ self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
+ self.init_stream.synchronize()
+ self.need_init_first_buffer = False
+
+ def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
+ with torch_device_module.stream(self.cuda_load_stream):
+ if hasattr(self, "cpu_buffers"):
+ self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
+ self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
+ else:
+ self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
+
+ def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
+ with torch_device_module.stream(self.cuda_load_stream):
+ if hasattr(self, "cpu_buffers"):
+ self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx)
+ else:
+ self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
+
+ def swap_blocks(self):
+ self.cuda_load_stream.synchronize()
+ self.compute_stream.synchronize()
+ self.cuda_buffers[0], self.cuda_buffers[1] = (
+ self.cuda_buffers[1],
+ self.cuda_buffers[0],
+ )
+
+ def swap_phases(self):
+ self.cuda_load_stream.synchronize()
+ self.compute_stream.synchronize()
+
+ @ExcludedProfilingContext("🔥 warm_up_cpu_buffers")
+ def warm_up_cpu_buffers(self, blocks_num):
+ logger.info("🔥 Warming up cpu buffers...")
+ for i in tqdm(range(blocks_num)):
+ for phase in self.cpu_buffers[0]:
+ phase.load_state_dict_from_disk(i, None)
+ for phase in self.cpu_buffers[1]:
+ phase.load_state_dict_from_disk(i, None)
+
+ for phase in self.cpu_buffers[0]:
+ phase.load_state_dict_from_disk(0, None)
+ for phase in self.cpu_buffers[1]:
+ phase.load_state_dict_from_disk(1, None)
+ logger.info("✅ CPU buffers warm-up completed.")
+
+ def init_lazy_load(self, num_workers=6):
+ self.lazy_load = True
+ self.executor = ThreadPoolExecutor(max_workers=num_workers)
+ self.prefetch_futures = []
+ self.prefetch_block_idx = -1
+
+ def start_prefetch_block(self, block_idx, adapter_block_idx=None):
+ self.prefetch_block_idx = block_idx
+ self.prefetch_futures = []
+ for phase in self.cpu_buffers[1]:
+ future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
+ self.prefetch_futures.append(future)
+
+ def swap_cpu_buffers(self):
+ # wait_start = time.time()
+ # already_done = all(f.done() for f in self.prefetch_futures)
+ for f in self.prefetch_futures:
+ f.result()
+ # wait_time = time.time() - wait_start
+ # logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
+ self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]
+
+ def __del__(self):
+ if hasattr(self, "executor") and self.executor is not None:
+ for f in self.prefetch_futures:
+ if not f.done():
+ f.result()
+ self.executor.shutdown(wait=False)
+ self.executor = None
+ logger.debug("ThreadPoolExecutor shut down successfully.")
diff --git a/lightx2v/common/ops/__init__.py b/lightx2v/common/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aab2983ef2b2649da7a004b59ef96993f4fd0d54
--- /dev/null
+++ b/lightx2v/common/ops/__init__.py
@@ -0,0 +1,6 @@
+from .attn import *
+from .conv import *
+from .embedding import *
+from .mm import *
+from .norm import *
+from .tensor import *
diff --git a/lightx2v/common/ops/attn/__init__.py b/lightx2v/common/ops/attn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..65e23ddef56a45665ed795b3e56ef6516e7d09f4
--- /dev/null
+++ b/lightx2v/common/ops/attn/__init__.py
@@ -0,0 +1,10 @@
+from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
+from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer
+from .radial_attn import RadialAttnWeight
+from .ring_attn import RingAttnWeight
+from .sage_attn import SageAttn2Weight, SageAttn3Weight
+from .spassage_attn import SageAttnWeight
+from .svg2_attn import Svg2AttnWeight
+from .svg_attn import SvgAttnWeight
+from .torch_sdpa import TorchSDPAWeight
+from .ulysses_attn import Ulysses4090AttnWeight, UlyssesAttnWeight
diff --git a/lightx2v/common/ops/attn/flash_attn.py b/lightx2v/common/ops/attn/flash_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc8150bb034ec42ebb57d43fa31a1b045eee468e
--- /dev/null
+++ b/lightx2v/common/ops/attn/flash_attn.py
@@ -0,0 +1,89 @@
+from loguru import logger
+
+try:
+ import flash_attn # noqa: F401
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
+except ImportError:
+ logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
+ flash_attn_varlen_func = None
+
+try:
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
+except ImportError:
+ logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
+ flash_attn_varlen_func_v3 = None
+
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .template import AttnWeightTemplate
+
+
+@ATTN_WEIGHT_REGISTER("flash_attn2")
+class FlashAttn2Weight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ if len(q.shape) == 3:
+ bs = 1
+ elif len(q.shape) == 4:
+ bs = q.shape[0]
+ q = q.reshape(-1, q.shape[-2], q.shape[-1])
+ k = k.reshape(-1, k.shape[-2], k.shape[-1])
+ v = v.reshape(-1, v.shape[-2], v.shape[-1])
+ x = flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ ).reshape(bs * max_seqlen_q, -1)
+ return x
+
+
+@ATTN_WEIGHT_REGISTER("flash_attn3")
+class FlashAttn3Weight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ if len(q.shape) == 3:
+ bs = 1
+ elif len(q.shape) == 4:
+ bs = q.shape[0]
+ if model_cls is not None and model_cls in ["hunyuan_video_1.5"]:
+ q = q.reshape(-1, q.shape[-2], q.shape[-1])
+ k = k.reshape(-1, k.shape[-2], k.shape[-1])
+ v = v.reshape(-1, v.shape[-2], v.shape[-1])
+ x = flash_attn_varlen_func_v3(
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ ).reshape(bs * max_seqlen_q, -1)
+ return x
diff --git a/lightx2v/common/ops/attn/nbhd_attn.py b/lightx2v/common/ops/attn/nbhd_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7f32b88aac2dfc5d2a36554721a297c8767d3bc
--- /dev/null
+++ b/lightx2v/common/ops/attn/nbhd_attn.py
@@ -0,0 +1,196 @@
+import torch
+from loguru import logger
+
+try:
+ from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
+except ImportError:
+ magi_ffa_func = None
+
+try:
+ import flashinfer
+except ImportError:
+ flashinfer = None
+
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .template import AttnWeightTemplate
+
+
+def generate_nbhd_mask(a, block_num, attnmap_frame_num, coefficient=[1.0, 0.5, 0.056], min_width=1.0, device="cpu"):
+ """
+ a : block num per frame
+ block_num : block num per col/row
+ attnmap_frame_num : total frame num
+ """
+ i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1]
+ j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num]
+
+ assert len(coefficient) <= attnmap_frame_num, f"coefficient length {len(coefficient)} should <= attnmap_frame_num {attnmap_frame_num}"
+ width_list = [max(min_width, coefficient[i] * a) for i in range(len(coefficient))] + [min_width] * (attnmap_frame_num - len(coefficient))
+ logger.info(f"nbhd_attn width_list: {width_list}, len={len(width_list)}")
+
+ # attention sink frame: j <= a
+ mask_sink = j_indices <= a
+
+ mask_sparse = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
+ for interval in range(0, attnmap_frame_num):
+ n = i_indices // a
+ mask_sparse_base_1 = (j_indices >= (n + interval) * a) & (j_indices <= (n + interval + 1) * a)
+ n = j_indices // a
+ mask_sparse_base_2 = (i_indices >= (n + interval) * a) & (i_indices <= (n + interval + 1) * a)
+
+ width = width_list[interval]
+
+ mask_1 = mask_sparse_base_1 & (i_indices - j_indices + (interval * a + width) >= 0) & (i_indices - j_indices + (interval * a - width) <= 0)
+ mask_2 = mask_sparse_base_2 & (i_indices - j_indices - (interval * a - width) >= 0) & (i_indices - j_indices - (interval * a + width) <= 0)
+
+ mask_sparse = mask_sparse | mask_1 | mask_2
+
+ mask = mask_sink | mask_sparse
+ return mask
+
+
+def generate_qk_ranges(mask, block_size, seqlen):
+ indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2]
+
+ i_indices = indices[:, 0] # [N]
+ j_indices = indices[:, 1] # [N]
+
+ q_start = i_indices * block_size # [N]
+ q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N]
+
+ k_start = j_indices * block_size # [N]
+ k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N]
+
+ q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2]
+ k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2]
+
+ return q_ranges, k_ranges
+
+
+@ATTN_WEIGHT_REGISTER("nbhd_attn")
+class NbhdAttnWeight(AttnWeightTemplate):
+ block_size = 128
+ seqlen = None
+ attnmap_frame_num = None
+ q_ranges = None
+ k_ranges = None
+ attn_type_map = None
+ coefficient = [1.0, 0.5, 0.056]
+ min_width = 1.0
+
+ def __init__(self):
+ self.config = {}
+
+ @classmethod
+ @torch.compiler.disable
+ def prepare_mask(cls, seqlen):
+ if seqlen == cls.seqlen:
+ return
+ block_num = (seqlen + cls.block_size - 1) // cls.block_size
+ block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
+ mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
+ q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
+ attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
+ q_ranges = q_ranges.to(torch.int32).to("cuda")
+ k_ranges = k_ranges.to(torch.int32).to("cuda")
+ cls.seqlen = seqlen
+ cls.q_ranges = q_ranges
+ cls.k_ranges = k_ranges
+ cls.attn_type_map = attn_type_map
+ logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
+ sparsity = 1 - mask.sum().item() / mask.numel()
+ logger.info(f"Attention sparsity: {sparsity}")
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ """
+ q: [seqlen, head_num, head_dim]
+ k: [seqlen, head_num, head_dim]
+ v: [seqlen, head_num, head_dim]
+ """
+ self.prepare_mask(seqlen=q.shape[0])
+ out = magi_ffa_func(
+ q,
+ k,
+ v,
+ q_ranges=self.q_ranges,
+ k_ranges=self.k_ranges,
+ attn_type_map=self.attn_type_map,
+ auto_range_merge=True,
+ )[0]
+ return out.reshape(out.shape[0], -1)
+
+
+@ATTN_WEIGHT_REGISTER("nbhd_attn_flashinfer")
+class NbhdAttnWeightFlashInfer(AttnWeightTemplate):
+ block_size = 128
+ seqlen = None
+ attnmap_frame_num = None
+ coefficient = [1.0, 0.5, 0.056]
+ min_width = 1.0
+ sparse_wrapper = None
+
+ def __init__(self):
+ self.config = {}
+
+ @classmethod
+ @torch.compiler.disable
+ def prepare_mask(cls, seqlen, head_num, head_dim):
+ if seqlen == cls.seqlen:
+ return
+ block_num = (seqlen + cls.block_size - 1) // cls.block_size
+ block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
+ mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
+ mask = mask.unsqueeze(0).repeat(head_num, 1, 1)
+ block_rowcol_size = torch.ones(block_num, dtype=torch.int32) * cls.block_size
+ block_rowcol_size[-1] = seqlen - cls.block_size * (block_num - 1)
+ block_rowcol_size = block_rowcol_size.unsqueeze(0).repeat(head_num, 1)
+ float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
+ cls.sparse_wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="fa2")
+ cls.sparse_wrapper.plan(
+ block_mask_map=mask,
+ block_row_sz=block_rowcol_size,
+ block_col_sz=block_rowcol_size,
+ num_qo_heads=head_num,
+ num_kv_heads=head_num,
+ head_dim=head_dim,
+ q_data_type=torch.bfloat16,
+ )
+ cls.seqlen = seqlen
+ logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
+ sparsity = 1 - mask.sum().item() / mask.numel()
+ logger.info(f"Attention sparsity: {sparsity}")
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ """
+ q: [seqlen, head_num, head_dim]
+ k: [seqlen, head_num, head_dim]
+ v: [seqlen, head_num, head_dim]
+ """
+ self.prepare_mask(seqlen=q.shape[0], head_num=q.shape[1], head_dim=q.shape[2])
+ q = q.transpose(0, 1)
+ k = k.transpose(0, 1)
+ v = v.transpose(0, 1)
+ out = self.sparse_wrapper.run(q, k, v)
+ out = out.transpose(0, 1)
+ return out.reshape(out.shape[0], -1)
diff --git a/lightx2v/common/ops/attn/radial_attn.py b/lightx2v/common/ops/attn/radial_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..002f149b75f17f29ec9af6b5d47de05364e94252
--- /dev/null
+++ b/lightx2v/common/ops/attn/radial_attn.py
@@ -0,0 +1,185 @@
+import torch
+from loguru import logger
+
+try:
+ from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
+except ImportError:
+ magi_ffa_func = None
+
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .template import AttnWeightTemplate
+
+
+def shrinkMaskStrict(mask, block_size=128):
+ seqlen = mask.shape[0]
+ block_num = seqlen // block_size
+ mask = mask[: block_num * block_size, : block_num * block_size].view(block_num, block_size, block_num, block_size)
+ col_densities = mask.sum(dim=1) / block_size
+ # we want the minimum non-zero column density in the block
+ non_zero_densities = col_densities > 0
+ high_density_cols = col_densities > 1 / 3
+ frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
+ block_mask = frac_high_density_cols > 0.6
+ block_mask[0:0] = True
+ block_mask[-1:-1] = True
+ return block_mask
+
+
+def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
+ assert sparse_type in ["radial"]
+ dist = abs(i - j)
+ if model_type == "wan":
+ if dist < 1:
+ return token_per_frame
+ if dist == 1:
+ return token_per_frame // 2
+ elif model_type == "hunyuan":
+ if dist <= 1:
+ return token_per_frame
+ else:
+ raise ValueError(f"Unknown model type: {model_type}")
+ group = dist.bit_length()
+ decay_length = 2 ** token_per_frame.bit_length() / 2**group * decay_factor
+ threshold = block_size
+ if decay_length >= threshold:
+ return decay_length
+ else:
+ return threshold
+
+
+def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device):
+ assert sparse_type in ["radial"]
+ dist = abs(i - j)
+ group = dist.bit_length()
+ threshold = 128 # hardcoded threshold for now, which is equal to block-size
+ decay_length = 2 ** token_per_frame.bit_length() / 2**group
+ if decay_length >= threshold:
+ return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
+
+ split_factor = int(threshold / decay_length)
+ modular = dist % split_factor
+ if modular == 0:
+ return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
+ else:
+ return torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
+
+
+def gen_log_mask_shrinked(device, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
+ """
+ A more memory friendly version, we generate the attention mask of each frame pair at a time,
+ shrinks it, and stores it into the final result
+ """
+ final_log_mask = torch.zeros(((s + block_size - 1) // block_size, (s + block_size - 1) // block_size), device=device, dtype=torch.bool)
+ token_per_frame = video_token_num // num_frame
+ video_text_border = video_token_num // block_size
+
+ col_indices = torch.arange(0, token_per_frame, device=device).view(1, -1)
+ row_indices = torch.arange(0, token_per_frame, device=device).view(-1, 1)
+ final_log_mask[video_text_border:] = True
+ final_log_mask[:, video_text_border:] = True
+ for i in range(num_frame):
+ for j in range(num_frame):
+ local_mask = torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
+ if j == 0 and model_type == "wan": # this is attention sink
+ local_mask = torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
+ else:
+ window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
+ local_mask = torch.abs(col_indices - row_indices) <= window_width
+ split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device)
+ local_mask = torch.logical_and(local_mask, split_mask)
+
+ remainder_row = (i * token_per_frame) % block_size
+ remainder_col = (j * token_per_frame) % block_size
+ # get the padded size
+ all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
+ all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
+ padded_local_mask = torch.zeros((all_length_row, all_length_col), device=device, dtype=torch.bool)
+ padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask
+ # shrink the mask
+ block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
+ # set the block mask to the final log mask
+ block_row_start = (i * token_per_frame) // block_size
+ block_col_start = (j * token_per_frame) // block_size
+ block_row_end = block_row_start + block_mask.shape[0]
+ block_col_end = block_col_start + block_mask.shape[1]
+ final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
+ return final_log_mask
+
+
+def generate_qk_ranges(mask, block_size, seqlen):
+ indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2]
+
+ i_indices = indices[:, 0] # [N]
+ j_indices = indices[:, 1] # [N]
+
+ q_start = i_indices * block_size # [N]
+ q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N]
+
+ k_start = j_indices * block_size # [N]
+ k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N]
+
+ q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2]
+ k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2]
+
+ return q_ranges, k_ranges
+
+
+@ATTN_WEIGHT_REGISTER("radial_attn")
+class RadialAttnWeight(AttnWeightTemplate):
+ block_size = 128
+ seqlen = None
+ attnmap_frame_num = None
+ q_ranges = None
+ k_ranges = None
+ attn_type_map = None
+
+ def __init__(self):
+ self.config = {}
+
+ @classmethod
+ def prepare_mask(cls, seqlen):
+ if seqlen == cls.seqlen:
+ return
+ mask = gen_log_mask_shrinked(
+ device="cuda", s=seqlen, video_token_num=seqlen, num_frame=cls.attnmap_frame_num, block_size=cls.block_size, sparse_type="radial", decay_factor=0.2, model_type="wan"
+ )
+ q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
+ attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
+ q_ranges = q_ranges.to(torch.int32).to("cuda")
+ k_ranges = k_ranges.to(torch.int32).to("cuda")
+ cls.seqlen = seqlen
+ cls.q_ranges = q_ranges
+ cls.k_ranges = k_ranges
+ cls.attn_type_map = attn_type_map
+ logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
+ sparsity = 1 - mask.sum().item() / mask.numel()
+ logger.info(f"Attention sparsity: {sparsity}")
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ """
+ q: [seqlen, head_num, head_dim]
+ k: [seqlen, head_num, head_dim]
+ v: [seqlen, head_num, head_dim]
+ """
+ self.prepare_mask(seqlen=q.shape[0])
+ out = magi_ffa_func(
+ q,
+ k,
+ v,
+ q_ranges=self.q_ranges,
+ k_ranges=self.k_ranges,
+ attn_type_map=self.attn_type_map,
+ auto_range_merge=True,
+ )[0]
+ return out.reshape(out.shape[0], -1)
diff --git a/lightx2v/common/ops/attn/ring_attn.py b/lightx2v/common/ops/attn/ring_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e96117f8eabea7d2ac06c2073140383c4979962
--- /dev/null
+++ b/lightx2v/common/ops/attn/ring_attn.py
@@ -0,0 +1,179 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from loguru import logger
+
+from lightx2v.utils.envs import *
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .template import AttnWeightTemplate
+from .utils.ring_comm import RingComm
+
+try:
+ import flash_attn
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
+except ImportError:
+ logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
+ flash_attn_varlen_func = None
+
+
+@torch.jit.script
+def _update_out_and_lse(
+ out,
+ lse,
+ block_out,
+ block_lse,
+):
+ block_out = block_out.to(torch.float32)
+ block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
+
+ # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
+ # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
+ # For additional context and discussion, please refer to:
+ # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
+ out = out - F.sigmoid(block_lse - lse) * (out - block_out)
+ lse = lse - F.logsigmoid(lse - block_lse)
+ return out, lse
+
+
+@ATTN_WEIGHT_REGISTER("ring")
+class RingAttnWeight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+
+ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
+ """
+ 执行 Ring 注意力机制,结合图像和文本的查询、键和值。
+
+ 参数:
+ q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
+ k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
+ v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
+ img_qkv_len (int): 图像查询、键和值的长度
+ cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
+ attention_type (str): 注意力类型,默认为 "flash_attn2"
+
+ 返回:
+ torch.Tensor: 计算得到的注意力结果
+ """
+ assert not use_fp8_comm, "RingAttn can't support fp8 comm now."
+
+ # 获取当前进程的排名和全局进程数
+ cur_rank = dist.get_rank(seq_p_group)
+ world_size = dist.get_world_size(seq_p_group)
+
+ if len(cu_seqlens_qkv) == 3:
+ txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
+ txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
+ elif len(cu_seqlens_qkv) == 2:
+ txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
+ txt_mask_len = 0
+
+ # if RING_COMM is None:
+ # init_ring_comm()
+
+ RING_COMM = RingComm(seq_p_group)
+
+ # if len(cu_seqlens_qkv) == 3:
+ # txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
+ # txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
+ # elif len(cu_seqlens_qkv) == 2:
+ # txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
+ # txt_mask_len = None
+ q = q.unsqueeze(0)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+
+ img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous()
+ txt_q, txt_k, txt_v = (
+ q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
+ k[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
+ v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
+ )
+
+ out, lse, next_k, next_v = None, None, None, None
+
+ if len(cu_seqlens_qkv) == 3:
+ q = torch.cat((img_q, txt_q), dim=1)
+ k = img_k
+ v = img_v
+
+ for step in range(world_size):
+ if step + 1 != world_size:
+ next_k = RING_COMM.send_recv(k)
+ next_v = RING_COMM.send_recv(v)
+ RING_COMM.commit()
+
+ if step + 1 == world_size:
+ k = torch.cat((k, txt_k), dim=1)
+ v = torch.cat((v, txt_v), dim=1)
+
+ block_out, block_lse = self.ring_attn_sub(q, k, v)
+
+ out, lse = self.update_out_and_lse(out, lse, block_out, block_lse)
+
+ if step + 1 != world_size:
+ RING_COMM.wait()
+ k = next_k
+ v = next_v
+
+ attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
+
+ if txt_mask_len > 0:
+ attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward(
+ q[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
+ k[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
+ v[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
+ dropout_p=0.0,
+ softmax_scale=q.shape[-1] ** (-0.5),
+ causal=False,
+ window_size_left=-1,
+ window_size_right=-1,
+ softcap=0.0,
+ alibi_slopes=None,
+ return_softmax=False,
+ )
+
+ attn2 = attn2.to(GET_DTYPE()).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
+ attn1 = torch.cat([attn1, attn2], dim=0)
+
+ return attn1
+
+ def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+ block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward(
+ q,
+ k,
+ v,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size_left=window_size[0],
+ window_size_right=window_size[1],
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ return_softmax=return_softmax,
+ )
+ return block_out, block_lse
+
+ def update_out_and_lse(
+ self,
+ out,
+ lse,
+ block_out,
+ block_lse,
+ slice_=None,
+ ):
+ if out is None:
+ if slice_ is not None:
+ raise RuntimeError("first update_out_and_lse should not pass slice_ args")
+ out = block_out.to(torch.float32)
+ lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
+ elif slice_ is not None:
+ slice_out, slice_lse = out[slice_], lse[slice_]
+ slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
+ out[slice_], lse[slice_] = slice_out, slice_lse
+ else:
+ out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
+ return out, lse
diff --git a/lightx2v/common/ops/attn/sage_attn.py b/lightx2v/common/ops/attn/sage_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cde1eaf6f8b311863a9cf3850b6d04843523431
--- /dev/null
+++ b/lightx2v/common/ops/attn/sage_attn.py
@@ -0,0 +1,83 @@
+import torch
+from loguru import logger
+
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .template import AttnWeightTemplate
+
+if torch.cuda.is_available() and torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]:
+ try:
+ from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
+ except ImportError:
+ logger.info("sageattn not found, please install sageattention first")
+ sageattn = None
+else:
+ try:
+ from sageattention import sageattn
+ except ImportError:
+ logger.info("sageattn not found, please install sageattention first")
+ sageattn = None
+
+try:
+ from sageattn3 import sageattn3_blackwell
+except ImportError:
+ logger.info("sageattn3 not found, please install sageattention first")
+ sageattn3_blackwell = None
+
+
+@ATTN_WEIGHT_REGISTER("sage_attn2")
+class SageAttn2Weight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
+ if len(q.shape) == 3:
+ bs = 1
+ q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
+ elif len(q.shape) == 4:
+ bs = q.shape[0]
+ x = sageattn(
+ q,
+ k,
+ v,
+ tensor_layout="NHD",
+ ).view(bs * max_seqlen_q, -1)
+ return x
+
+
+@ATTN_WEIGHT_REGISTER("sage_attn3")
+class SageAttn3Weight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
+ if len(q.shape) == 3:
+ bs = 1
+ q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
+ elif len(q.shape) == 4:
+ bs = q.shape[0]
+
+ x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1)
+ return x
diff --git a/lightx2v/common/ops/attn/spassage_attn.py b/lightx2v/common/ops/attn/spassage_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7fbf4644a5a7be634a2be957845f4cbcf3a8a3d
--- /dev/null
+++ b/lightx2v/common/ops/attn/spassage_attn.py
@@ -0,0 +1,76 @@
+import os
+
+import torch
+
+try:
+ import spas_sage_attn
+except ImportError:
+ spas_sage_attn = None
+
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .template import AttnWeightTemplate
+
+
+@ATTN_WEIGHT_REGISTER("spas_sage_attn")
+class SageAttnWeight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+
+ @classmethod
+ def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, tensor_layout="HND"):
+ q = q.unsqueeze(0)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ attn_out = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout)
+ _, H, N, D = attn_out.shape
+ attn_out = attn_out.permute(2, 1, 3, 0).contiguous().view(N, H * D)
+ return attn_out
+
+
+if __name__ == "__main__":
+ import matplotlib.pyplot as plt
+
+ # 1. 构造输入
+ q = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
+ k = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
+ v = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
+
+ # 2. 直接用PyTorch计算注意力
+ q_ = q.float()
+ k_ = k.float()
+ v_ = v.float()
+ attn_weights = torch.matmul(q_, k_.transpose(-2, -1)) / (128**0.5)
+ attn_weights = torch.softmax(attn_weights, dim=-1)
+ output_pt = torch.matmul(attn_weights, v_)
+
+ # 3. 用spas_sage2_attn_meansim_cuda计算注意力
+ q = q.unsqueeze(0) # shape: (1, 32760, 12, 128)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+ q = q.transpose(1, 2) # shape: (1, 12, 32760, 128)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ output_cuda = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout="HND")
+ output_cuda = output_cuda.float()
+
+ # 4. 取左上角[3000, 3000],只取第一个head
+ output_pt_crop = output_pt[0, :3000, :3000].cpu().detach().numpy()
+ output_cuda_crop = output_cuda[0, 0, :3000, :3000].cpu().detach().numpy()
+
+ # 5. 保存图片
+ save_dir = os.path.expanduser("~/Log/10-22/")
+ os.makedirs(save_dir, exist_ok=True)
+
+ plt.imshow(output_pt_crop, aspect="auto")
+ plt.title("PyTorch Attention (left-top 3000x3000)")
+ plt.savefig(os.path.join(save_dir, "attn.png"))
+ plt.close()
+
+ plt.imshow(output_cuda_crop, aspect="auto")
+ plt.title("spas_sage2_attn_meansim_cuda (left-top 3000x3000)")
+ plt.savefig(os.path.join(save_dir, "spas_attn.png"))
+ plt.close()
diff --git a/lightx2v/common/ops/attn/svg2_attn.py b/lightx2v/common/ops/attn/svg2_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fed193a32eb6de338ea30be9ef0578c325f6802
--- /dev/null
+++ b/lightx2v/common/ops/attn/svg2_attn.py
@@ -0,0 +1,355 @@
+from typing import Optional
+
+# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen
+try:
+ import flashinfer
+except ImportError:
+ flashinfer = None
+
+import torch
+import triton
+import triton.language as tl
+
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .svg2_attn_utils import (
+ batch_kmeans_Euclid,
+ identify_dynamic_map,
+)
+from .template import AttnWeightTemplate
+
+
+@triton.jit
+def _permute_kernel(
+ X_ptr,
+ IDX_ptr,
+ Y_ptr,
+ S: tl.constexpr,
+ D: tl.constexpr,
+ BLOCK_S: tl.constexpr,
+):
+ """Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop."""
+
+ pid_bh = tl.program_id(0)
+ tile_s = tl.program_id(1)
+
+ # Offsets along sequence
+ s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
+ token_mask = s_offsets < S
+
+ # Gather source indices for these tokens
+ idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
+ src_row_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)
+
+ # Broadcast to create 2-D pointer matrix (BLOCK_S, D)
+ d_offsets = tl.arange(0, D)
+
+ src_ptrs = X_ptr + (pid_bh * S + src_row_idx[:, None]) * D + d_offsets[None, :]
+ dst_ptrs = Y_ptr + (pid_bh * S + s_offsets[:, None]) * D + d_offsets[None, :]
+
+ full_mask = token_mask[:, None]
+
+ values = tl.load(src_ptrs, mask=full_mask, other=0.0)
+ tl.store(dst_ptrs, values, mask=full_mask)
+
+
+def permute_tensor_by_labels_triton(
+ tensor: torch.Tensor,
+ labels: Optional[torch.Tensor],
+ dim: int,
+ *,
+ sorted_indices: Optional[torch.Tensor] = None,
+):
+ """
+ Permute `tensor` along `dim` according to ascending order of `labels`.
+
+ This is a Triton-accelerated replacement for the original implementation.
+ It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`.
+ If these conditions are not met or the tensors reside on CPU, we fall back
+ to the reference PyTorch implementation.
+ """
+
+ # Assertions – we only support the optimized CUDA path.
+ assert dim == 2, "permute_tensor_by_labels currently only supports dim==2 (sequence dimension)"
+ assert tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
+ assert tensor.is_cuda, "permute_tensor_by_labels requires CUDA tensors"
+
+ B, H, S, D = tensor.shape
+ BH = B * H
+
+ # Determine sorted indices
+ if sorted_indices is not None:
+ sorted_indices = sorted_indices.to(torch.int32).contiguous()
+ else:
+ assert labels is not None, "Either `labels` or `sorted_indices` must be provided."
+ labels = labels.to(tensor.device)
+ sorted_indices = torch.argsort(labels, dim=-1).to(torch.int32).contiguous()
+
+ # Flatten tensor and allocate output
+ inp_flat = tensor.reshape(BH, S, D).contiguous()
+ out_flat = torch.empty_like(inp_flat)
+
+ # Triton kernel tile size
+ BLOCK_S = 64 # number of tokens per program, tunable
+
+ n_s_tiles = triton.cdiv(S, BLOCK_S)
+ grid = (BH, n_s_tiles)
+
+ _permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)
+
+ permuted_tensor = out_flat.reshape(B, H, S, D)
+ return permuted_tensor, sorted_indices
+
+
+@triton.jit
+def _inverse_permute_kernel(
+ X_ptr,
+ IDX_ptr,
+ Y_ptr,
+ S: tl.constexpr,
+ D: tl.constexpr,
+ BLOCK_S: tl.constexpr,
+):
+ """Inverse permutation: scatter BLOCK_S tokens back in one shot."""
+
+ pid_bh = tl.program_id(0)
+ tile_s = tl.program_id(1)
+
+ s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
+ token_mask = s_offsets < S
+
+ idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
+ src_pos_idx = s_offsets.to(tl.int32)
+ dst_pos_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)
+
+ d_offsets = tl.arange(0, D)
+
+ src_ptrs = X_ptr + (pid_bh * S + src_pos_idx[:, None]) * D + d_offsets[None, :]
+ dst_ptrs = Y_ptr + (pid_bh * S + dst_pos_idx[:, None]) * D + d_offsets[None, :]
+
+ full_mask = token_mask[:, None]
+
+ values = tl.load(src_ptrs, mask=full_mask, other=0.0)
+ tl.store(dst_ptrs, values, mask=full_mask)
+
+
+def apply_inverse_permutation_triton(
+ permuted_tensor: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ dim: int,
+):
+ """
+ Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`.
+
+ Args:
+ permuted_tensor: (B, H, S, D).
+ sorted_indices: (B, H, S).
+ dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension.
+
+ Returns:
+ Tensor of shape (B, H, S, D).
+ """
+
+ assert dim == 2, "apply_inverse_permutation currently only supports dim==2"
+ assert permuted_tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
+ assert permuted_tensor.is_cuda, "apply_inverse_permutation requires CUDA tensors"
+
+ B, H, S, D = permuted_tensor.shape
+ BH = B * H
+
+ # Ensure index dtype
+ sorted_indices = sorted_indices.to(torch.int32).contiguous()
+
+ # Flatten inputs
+ inp_flat = permuted_tensor.reshape(BH, S, D).contiguous()
+ out_flat = torch.empty_like(inp_flat)
+
+ BLOCK_S = 64
+ n_s_tiles = triton.cdiv(S, BLOCK_S)
+ grid = (BH, n_s_tiles)
+
+ _inverse_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)
+
+ original_tensor = out_flat.reshape(B, H, S, D)
+ return original_tensor
+
+
+@ATTN_WEIGHT_REGISTER("svg2_attn")
+class Svg2AttnWeight(AttnWeightTemplate):
+ centroids_init = False
+ num_q_centroids = 300
+ num_k_centroids = 1000
+ kmeans_iter_init = 50
+ top_p_kmeans = 0.9
+ min_kc_ratio = 0.10
+ kmeans_iter_step = 2
+
+ def __init__(self):
+ self.config = {}
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ q = q.unsqueeze(0).transpose(1, 2)
+ k = k.unsqueeze(0).transpose(1, 2)
+ v = v.unsqueeze(0).transpose(1, 2)
+ bs, num_heads, seq_len, dim = q.size()
+ q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, q_sorted_indices = self.semantic_aware_permutation(q, k, v)
+
+ output_permuted = self.dynamic_block_sparse_fwd_flashinfer(q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, is_cpu=False)
+
+ attn_output = apply_inverse_permutation_triton(output_permuted, q_sorted_indices, dim=2)
+
+ return attn_output.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
+
+ def dynamic_block_sparse_fwd_flashinfer(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ block_mask_map: torch.Tensor,
+ block_row_sz: torch.Tensor,
+ block_col_sz: torch.Tensor,
+ is_cpu: bool = True,
+ ):
+ """
+ Launcher for the Flashinfer dynamic block sparse attention kernel.
+
+ Args:
+ q (torch.Tensor): Query tensor, shape [B, H, S, D].
+ k (torch.Tensor): Key tensor, shape [B, H, S, D].
+ v (torch.Tensor): Value tensor, shape [B, H, S, D].
+ block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU.
+ block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU.
+ block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU.
+ is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True.
+ """
+ # Input shape check
+ B, H, S, D = q.shape
+ qc_num = block_row_sz.shape[-1]
+ kc_num = block_col_sz.shape[-1]
+ assert block_mask_map.shape == (B, H, qc_num, kc_num)
+
+ assert all(t.device == torch.device("cpu") for t in [block_mask_map, block_row_sz, block_col_sz]) if is_cpu else True
+
+ # Check if block_col_sz and block_row_sz are the same for each head
+ assert torch.all(block_col_sz.sum(dim=2) == block_col_sz.sum(dim=2)[0, 0])
+ assert torch.all(block_row_sz.sum(dim=2) == block_row_sz.sum(dim=2)[0, 0])
+
+ # Prepare flashinfer wrapper
+ float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=q.device)
+ vector_sparse_indices_buffer = torch.empty(1024 * 1024 * 1024, device=q.device)
+ wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="auto")
+ wrapper.reset_workspace_buffer(
+ float_workspace_buffer=wrapper._float_workspace_buffer,
+ int_workspace_buffer=wrapper._int_workspace_buffer,
+ vector_sparse_indices_buffer=vector_sparse_indices_buffer, # Only reset this buffer size
+ vector_sparse_indptr_buffer=wrapper._vector_sparse_indptr_buffer,
+ )
+
+ block_mask_map = block_mask_map.reshape(B * H, qc_num, kc_num)
+ block_row_sz = block_row_sz.reshape(B * H, qc_num)
+ block_col_sz = block_col_sz.reshape(B * H, kc_num)
+
+ wrapper.plan(
+ block_mask_map=block_mask_map,
+ block_row_sz=block_row_sz,
+ block_col_sz=block_col_sz,
+ num_qo_heads=B * H,
+ num_kv_heads=B * H,
+ head_dim=D,
+ q_data_type=q.dtype,
+ kv_data_type=k.dtype,
+ )
+
+ # print_memory_usage("After plan")
+
+ q = q.reshape(B * H, S, D)
+ k = k.reshape(B * H, S, D)
+ v = v.reshape(B * H, S, D)
+ o = wrapper.run(q, k, v) # [num_qo_heads, qo_len, head_dim]
+ o = o.reshape(B, H, S, D)
+ return o
+
+ def semantic_aware_permutation(self, query, key, value):
+ cfg, num_heads, seq_len, dim = query.size()
+
+ # 1. Kmeans clustering
+ qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_clustering(query, key)
+
+ # 2. Identify dynamic map
+ q_cluster_sizes = qcluster_sizes.view(cfg, num_heads, self.num_q_centroids)
+ k_cluster_sizes = kcluster_sizes.view(cfg, num_heads, self.num_k_centroids)
+
+ dynamic_map = identify_dynamic_map(
+ qcentroids.view(cfg, num_heads, self.num_q_centroids, dim),
+ kcentroids.view(cfg, num_heads, self.num_k_centroids, dim),
+ q_cluster_sizes,
+ k_cluster_sizes,
+ self.top_p_kmeans,
+ self.min_kc_ratio,
+ )
+
+ # 3. Permute the query, key, value
+ q_permuted, q_sorted_indices = permute_tensor_by_labels_triton(query, qlabels, dim=2)
+ k_permuted, k_sorted_indices = permute_tensor_by_labels_triton(key, klabels, dim=2)
+ v_permuted, v_sorted_indices = permute_tensor_by_labels_triton(value, klabels, dim=2, sorted_indices=k_sorted_indices)
+
+ return q_permuted, k_permuted, v_permuted, dynamic_map, q_cluster_sizes, k_cluster_sizes, q_sorted_indices
+
+ def kmeans_clustering(self, query, key):
+ if not self.centroids_init:
+ qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_init(query, key)
+ self.centroids_init = True
+ else:
+ qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_step(query, key)
+
+ return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
+
+ def kmeans_init(self, query, key):
+ cfg, num_heads, seq_len, dim = query.size()
+ qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(query.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_q_centroids, max_iters=self.kmeans_iter_init)
+ klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(key.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_k_centroids, max_iters=self.kmeans_iter_init)
+
+ self.q_centroids = qcentroids
+ self.k_centroids = kcentroids
+
+ return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
+
+ def kmeans_step(self, query, key):
+ cfg, num_heads, seq_len, dim = query.size()
+ qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(
+ query.view(cfg * num_heads, seq_len, dim),
+ n_clusters=self.num_q_centroids,
+ max_iters=self.kmeans_iter_step,
+ init_centroids=self.q_centroids,
+ )
+ klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(
+ key.view(cfg * num_heads, seq_len, dim),
+ n_clusters=self.num_k_centroids,
+ max_iters=self.kmeans_iter_step,
+ init_centroids=self.k_centroids,
+ )
+
+ self.q_centroids = qcentroids
+ self.k_centroids = kcentroids
+
+ return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
+
+
+if __name__ == "__main__":
+ q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()
+
+ svg2_attn = Svg2AttnWeight()
+ print("Svg2AttnWeight initialized.")
+
+ out = svg2_attn.apply(q, k, v)
+ print(f"out: {out.shape}, {out.dtype}, {out.device}")
diff --git a/lightx2v/common/ops/attn/svg2_attn_utils.py b/lightx2v/common/ops/attn/svg2_attn_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e65815e455daeec11375af2822ffffb8297c12f
--- /dev/null
+++ b/lightx2v/common/ops/attn/svg2_attn_utils.py
@@ -0,0 +1,1359 @@
+import torch
+import torch.nn.functional as F
+import triton
+import triton.language as tl
+
+try:
+ from cuvs.cluster.kmeans import KMeansParams, fit
+except ImportError:
+ KMeansParams = None
+ fit = None
+
+# --- New functions ---
+
+
+def density_calculation(dynamic_map, q_cluster_sizes, k_cluster_sizes):
+ """
+ Calculate the density of the dynamic map. Currently only batch size = 1 and head size = 1 are supported.
+
+ Input:
+ dynamic_map: [cfg, num_heads, qc_num, kc_num]
+ q_cluster_sizes: [cfg, num_heads, qc_num]
+ k_cluster_sizes: [cfg, num_heads, kc_num]
+ """
+ cfg, num_heads, qc_num, kc_num = dynamic_map.shape
+
+ # Calculate the block size of each block
+ clustered_block_size = q_cluster_sizes[:, :, :, None] * k_cluster_sizes[:, :, None, :]
+ masked_block_size = clustered_block_size * dynamic_map
+
+ # Calculate the density of each block
+ density = torch.sum(masked_block_size, dim=(2, 3)) / torch.sum(clustered_block_size, dim=(2, 3))
+ return density
+
+
+# --- Functions from analyze/kmeans_rapidai.py ---
+
+
+def pairwise_distance(x, y):
+ """
+ Computes pairwise squared Euclidean distance between two sets of points.
+ """
+ x_norm = (x**2).sum(1).view(-1, 1)
+ y_norm = (y**2).sum(1).view(1, -1)
+ dist = torch.clamp(x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)), min=0.0)
+ return dist
+
+
+def kmeans_predict(centroids, input_tensor): # Removed unused params argument
+ """
+ Predict the labels for the input tensor using the centroids.
+ """
+ input_tensor = input_tensor.to(torch.float32)
+ dist = pairwise_distance(input_tensor, centroids)
+ labels = torch.argmin(dist, dim=1)
+ return labels
+
+
+def kmeans_rapidai(tensor, k, max_iter=5, tol=1e-4, init_method="Array", centroids_init=None): # Renamed centroids to centroids_init
+ """
+ Performs K-means clustering using cuVS.
+ """
+
+ assert tensor.dtype == torch.float32, "Tensor must be float32 for cuVS KMeans"
+ assert tensor.ndim == 2, f"Tensor must be 2D, but got {tensor.shape}"
+ # assert init_method == "Array", "init_method must be 'Array' for now"
+
+ L, D = tensor.shape
+
+ # cuVS KMeans in RAPIDS >=23.10 uses 'centroids_init' for initial centroids
+ current_centroids = centroids_init
+ if current_centroids is None:
+ # Default init: cuVS handles KMeansPlusPlus if centroids_init is None and init_method is KMeansPlusPlus
+ # If you need to pass an empty tensor for cuVS to initialize:
+ current_centroids = torch.empty(k, D, device=tensor.device, dtype=torch.float32) # Or pass None
+ else:
+ assert current_centroids.dtype == torch.float32, "Initial centroids must be float32"
+ assert current_centroids.shape == (
+ k,
+ D,
+ ), f"Initial centroids shape mismatch, got {current_centroids.shape}, expected ({k}, {D})"
+ # cuVS uses 'init_method="Array"' when 'centroids_init' is provided.
+
+ # import IPython; IPython.embed()
+
+ params = KMeansParams(n_clusters=k, max_iter=max_iter, tol=tol, init_method=init_method) # Changed init_method to init
+
+ # Call fit with centroids_init (can be None)
+ new_centroids, inertia, n_iter_ = fit(params, tensor, current_centroids) # Added handle=None
+
+ labels = kmeans_predict(new_centroids, tensor)
+ return labels, new_centroids, n_iter_
+
+
+@triton.jit
+def _centroid_update_kernel(
+ x_ptr, # *f16 [B, N, D]
+ cluster_ptr, # *i32 [B, N]
+ sum_ptr, # *f32 [B, K, D]
+ count_ptr, # *i32 [B, K]
+ B: tl.constexpr,
+ N: tl.constexpr,
+ D: tl.constexpr,
+ K: tl.constexpr,
+ BLOCK_D: tl.constexpr, # number of dims processed per program
+):
+ """Each program processes 1 point (token) across BLOCK_D dimensions with atomics."""
+ pid = tl.program_id(axis=0)
+ token_idx = pid # range: [0, B * N)
+
+ # Derive (b, n) indices
+ b = token_idx // N
+ n = token_idx % N
+
+ # Pointers to the token features and its cluster id
+ x_offset = (b * N + n) * D
+ x_ptr = x_ptr + x_offset
+
+ cluster_idx = tl.load(cluster_ptr + b * N + n) # int32
+
+ # Guard for invalid cluster ids (should not happen)
+ cluster_idx = tl.where(cluster_idx < K, cluster_idx, 0)
+
+ # Base pointer for this centroid in the output sum tensor
+ centroid_base = (b * K + cluster_idx) * D
+
+ # Process feature vector in chunks of BLOCK_D
+ offs = tl.arange(0, BLOCK_D)
+ for d_start in range(0, D, BLOCK_D):
+ mask = offs + d_start < D
+ feats = tl.load(x_ptr + d_start + offs, mask=mask, other=0.0)
+ feats = feats.to(tl.float32)
+
+ dest_ptr = sum_ptr + centroid_base + d_start + offs
+ tl.atomic_add(dest_ptr, feats, mask=mask)
+
+ # Update counts (only once per point)
+ tl.atomic_add(count_ptr + b * K + cluster_idx, 1)
+
+
+def triton_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
+ """Compute centroids using custom Triton kernel.
+
+ Args:
+ x_norm (Tensor): (B, N, D) normalized input vectors (float16/float32)
+ cluster_ids (LongTensor): (B, N) cluster assignment per point
+ old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x_norm)
+
+ Returns:
+ Tensor: (B, K, D) updated and L2-normalized centroids (dtype == x_norm.dtype)
+ """
+ assert x_norm.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device"
+ B, N, D = x_norm.shape
+ K = old_centroids.shape[1]
+ assert cluster_ids.shape == (B, N)
+
+ # Allocate accumulation buffers
+ centroid_sums = torch.zeros((B, K, D), device=x_norm.device, dtype=torch.float32)
+ centroid_counts = torch.zeros((B, K), device=x_norm.device, dtype=torch.int32)
+
+ # Launch Triton kernel – one program per token
+ total_tokens = B * N
+ BLOCK_D = 128 # tuneable
+
+ grid = (total_tokens,)
+ _centroid_update_kernel[grid](
+ x_norm,
+ cluster_ids.to(torch.int32),
+ centroid_sums,
+ centroid_counts,
+ B,
+ N,
+ D,
+ K,
+ BLOCK_D=BLOCK_D,
+ )
+
+ # Compute means; keep old centroid if empty cluster
+ counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
+ centroids = centroid_sums / counts_f
+
+ # For clusters with zero count, revert to old centroids
+ zero_mask = (centroid_counts == 0).unsqueeze(-1)
+ centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids)
+
+ centroids = centroids.to(x_norm.dtype)
+ centroids = F.normalize(centroids, p=2, dim=-1)
+ return centroids
+
+
+def torch_loop_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
+ """Reference Python implementation (double for-loop)"""
+ B, N, D = x_norm.shape
+ K = old_centroids.shape[1]
+ new_centroids = torch.zeros_like(old_centroids)
+ for b in range(B):
+ for k in range(K):
+ mask = cluster_ids[b] == k
+ if mask.any():
+ new_centroids[b, k] = F.normalize(x_norm[b][mask].mean(dim=0, dtype=x_norm.dtype), p=2, dim=0)
+ else:
+ new_centroids[b, k] = old_centroids[b, k]
+ return new_centroids
+
+
+def triton_centroid_update_euclid(x: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
+ """Compute centroids for Euclidean KMeans using Triton.
+
+ Args:
+ x (Tensor): (B, N, D) input vectors (float16/float32)
+ cluster_ids (LongTensor): (B, N) cluster assignment per point
+ old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x)
+
+ Returns:
+ Tensor: (B, K, D) updated centroids (dtype == x.dtype)
+ """
+ assert x.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device"
+ B, N, D = x.shape
+ K = old_centroids.shape[1]
+ assert cluster_ids.shape == (B, N)
+
+ # Allocate accumulation buffers
+ centroid_sums = torch.zeros((B, K, D), device=x.device, dtype=torch.float32)
+ centroid_counts = torch.zeros((B, K), device=x.device, dtype=torch.int32)
+
+ total_tokens = B * N
+ BLOCK_D = 128 # tuneable
+ grid = (total_tokens,)
+
+ _centroid_update_kernel[grid](
+ x,
+ cluster_ids.to(torch.int32),
+ centroid_sums,
+ centroid_counts,
+ B,
+ N,
+ D,
+ K,
+ BLOCK_D=BLOCK_D,
+ )
+
+ # Compute means; keep old centroid if empty cluster
+ counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
+ centroids = centroid_sums / counts_f
+
+ # For clusters with zero count, revert to old centroids
+ zero_mask = (centroid_counts == 0).unsqueeze(-1)
+ centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids)
+
+ return centroids.to(x.dtype)
+
+
+# ------------------------------ NEW: chunk-wise centroid update (sorted ids) ------------------------------
+
+
+@triton.jit
+def _centroid_update_chunk_kernel(
+ x_ptr, # *f16 / *f32 [B, N, D] – ORIGINAL ORDER
+ sorted_idx_ptr, # *i32 [B, N] – indices after sort
+ sorted_cluster_ptr, # *i32 [B, N] – cluster ids in sorted order
+ sum_ptr, # *f32 [B, K, D]
+ count_ptr, # *i32 [B, K]
+ B: tl.constexpr,
+ N: tl.constexpr,
+ D: tl.constexpr,
+ K: tl.constexpr,
+ BLOCK_N: tl.constexpr, # how many tokens (points) each program processes
+):
+ """Each program processes **BLOCK_N consecutive, already-sorted tokens**.
+
+ Because the tokens are sorted by cluster id, identical ids appear in
+ contiguous runs. We therefore accumulate a local sum/count for the
+ current run and perform **a single atomic update per run**, instead of
+ per-token.
+ """
+ # program indices – 2-D launch grid: (chunk_id, batch_id)
+ pid_chunk = tl.program_id(axis=0)
+ pid_b = tl.program_id(axis=1)
+
+ b = pid_b
+ chunk_start = pid_chunk * BLOCK_N # position of the first token handled by this program
+
+ # Nothing to do – out of range
+ if chunk_start >= N:
+ return
+
+ # base pointers for this batch
+ idx_batch_base = sorted_idx_ptr + b * N
+ cid_batch_base = sorted_cluster_ptr + b * N
+ x_batch_base = x_ptr + b * N * D # for pointer arithmetic
+
+ # helper aranges
+ offs_token = tl.arange(0, BLOCK_N)
+ offs_dim = tl.arange(0, D)
+
+ # first token index & validity mask
+ token_idx = chunk_start + offs_token
+ valid_tok = token_idx < N
+ first_token_idx = chunk_start
+ last_token_idx = tl.minimum(chunk_start + BLOCK_N, N) - 1
+
+ # Load first cluster id to initialise the running accumulator
+ first_id = tl.load(cid_batch_base + first_token_idx)
+ last_id = tl.load(cid_batch_base + last_token_idx)
+ all_ids = tl.load(cid_batch_base + token_idx, mask=valid_tok, other=-1)
+
+ all_tokens_idxs = tl.load(idx_batch_base + token_idx, mask=valid_tok, other=-1) # [BLOCK_N]
+
+ load_mask = all_tokens_idxs[:, None] * D + offs_dim[None, :]
+
+ for cid in range(first_id, last_id + 1):
+ cluster_mask = all_ids == cid
+ cluster_size = tl.sum(cluster_mask.to(tl.int32))
+ if cluster_size != 0:
+ cluster_feats = tl.load(x_batch_base + load_mask, mask=cluster_mask[:, None], other=0.0) # [BLOCK_N, D]
+ cluster_feats = cluster_feats.to(tl.float32)
+ sum_feats = tl.sum(cluster_feats, axis=0)
+ dest_ptr = sum_ptr + (b * K + cid) * D + offs_dim
+ tl.atomic_add(dest_ptr, sum_feats)
+ tl.atomic_add(count_ptr + b * K + cid, cluster_size)
+
+
+# ---------------------------------------------------------------------------------------------
+
+
+def triton_centroid_update_sorted_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor, *, BLOCK_N: int = 256):
+ """Fast centroid update assuming **cluster_ids are sorted along N**.
+
+ This helper will sort the assignments (together with `x_norm`) and launch the
+ chunk kernel above. Compared to the naive per-token kernel it performs *one
+ atomic add per run of identical ids* instead of per token, providing large
+ speed-ups when clusters are reasonably sized.
+ """
+ assert x_norm.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA"
+ B, N, D = x_norm.shape
+ K = old_centroids.shape[1]
+ assert cluster_ids.shape == (B, N)
+
+ # -------- sort per-batch --------
+ sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids, dim=-1)
+ sorted_idx_int = sorted_idx.to(torch.int32)
+
+ # accumulation buffers
+ centroid_sums = torch.zeros((B, K, D), device=x_norm.device, dtype=torch.float32)
+ centroid_cnts = torch.zeros((B, K), device=x_norm.device, dtype=torch.int32)
+
+ grid = (triton.cdiv(N, BLOCK_N), B)
+ _centroid_update_chunk_kernel[grid](
+ x_norm,
+ sorted_idx_int,
+ sorted_cluster_ids.to(torch.int32),
+ centroid_sums,
+ centroid_cnts,
+ B,
+ N,
+ D,
+ K,
+ BLOCK_N=BLOCK_N,
+ )
+
+ # finalise – convert to means, handle empty clusters, renormalise
+ counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
+ centroids = centroid_sums / counts_f
+ empty_mask = (centroid_cnts == 0).unsqueeze(-1)
+ centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids)
+ centroids = centroids.to(x_norm.dtype)
+ centroids = F.normalize(centroids, p=2, dim=-1)
+ return centroids
+
+
+def triton_centroid_update_sorted_euclid(x: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor, *, BLOCK_N: int = 256):
+ """Fast centroid update for *Euclidean* KMeans assuming cluster IDs are pre-sorted.
+
+ Parameters
+ ----------
+ x : Tensor [B, N, D]
+ Input feature vectors (no normalization assumed).
+ cluster_ids : LongTensor [B, N]
+ Cluster assignment for each point.
+ old_centroids : Tensor [B, K, D]
+ Previous centroids (used to fill empty clusters).
+ BLOCK_N : int, optional
+ Tokens per Triton program (affects occupancy/perf).
+ """
+ assert x.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA device"
+ B, N, D = x.shape
+ K = old_centroids.shape[1]
+
+ # Batch-wise sort of cluster assignments
+ sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids, dim=-1)
+ sorted_idx_int = sorted_idx.to(torch.int32)
+
+ centroid_sums = torch.zeros((B, K, D), device=x.device, dtype=torch.float32)
+ centroid_cnts = torch.zeros((B, K), device=x.device, dtype=torch.int32)
+
+ grid = (triton.cdiv(N, BLOCK_N), B)
+ _centroid_update_chunk_kernel[grid](
+ x, # original features
+ sorted_idx_int, # gather indices
+ sorted_cluster_ids.to(torch.int32),
+ centroid_sums,
+ centroid_cnts,
+ B,
+ N,
+ D,
+ K,
+ BLOCK_N=BLOCK_N,
+ )
+
+ # Convert sums to means; replace empty clusters with old centroids
+ counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
+ centroids = centroid_sums / counts_f
+ empty_mask = (centroid_cnts == 0).unsqueeze(-1)
+ centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids)
+ return centroids.to(x.dtype), centroid_cnts
+
+
+# ===============================================================
+# Triton kernel: compute nearest-centroid IDs (Euclidean distance)
+# Inputs:
+# x : (B, N, D) float16 / float32
+# centroids : (B, K, D) same dtype as x
+# x_sq : (B, N) float32 – pre-computed ||x||^2 per point
+# Output:
+# cluster_ids : (B, N) int32 – nearest centroid index per point
+# ===============================================================
+
+
+def _ceil_div(a: int, b: int) -> int:
+ return (a + b - 1) // b
+
+
+# -----------------------------------------------------------------------------
+# Auto-tuning setup – explore various tile sizes / warp counts
+# -----------------------------------------------------------------------------
+
+_TUNE_CONFIGS = [triton.Config({"BLOCK_N": BN, "BLOCK_K": BK}, num_stages=4, num_warps=wp) for BN in [32, 64, 128] for BK in [32, 64, 128] for wp in [4, 8]]
+
+
+def _cfg_keep(conf):
+ """Basic heuristic to prune unbalanced configs."""
+ BN = conf.kwargs["BLOCK_N"]
+ BK = conf.kwargs["BLOCK_K"]
+ # Avoid tiny tiles on many warps
+ if BN * BK < 32 * 32 and conf.num_warps > 4:
+ return False
+ return True
+
+
+_TUNE_CONFIGS = list(filter(_cfg_keep, _TUNE_CONFIGS))
+
+
+@triton.autotune(_TUNE_CONFIGS, key=["N", "K"])
+@triton.jit
+def _euclid_assign_kernel(
+ x_ptr, # *f16 / *f32 [B, N, D]
+ c_ptr, # *f16 / *f32 [B, K, D]
+ x_sq_ptr, # *f32 [B, N]
+ out_ptr, # *i32 [B, N]
+ B: tl.constexpr,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ D: tl.constexpr,
+ stride_x_b: tl.constexpr,
+ stride_x_n: tl.constexpr,
+ stride_x_d: tl.constexpr,
+ stride_c_b: tl.constexpr,
+ stride_c_k: tl.constexpr,
+ stride_c_d: tl.constexpr,
+ stride_xsq_b: tl.constexpr,
+ stride_xsq_n: tl.constexpr,
+ stride_out_b: tl.constexpr,
+ stride_out_n: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ """Each program handles a tile of BLOCK_N points for a given batch element.
+
+ The kernel iterates over the centroid dimension K in chunks of BLOCK_K and
+ maintains the running minimum distance as well as the corresponding index
+ for every point in the tile.
+ """
+ pid_n = tl.program_id(0) # tile index along N dimension
+ pid_b = tl.program_id(1) # batch index
+
+ n_start = pid_n * BLOCK_N
+ n_offsets = n_start + tl.arange(0, BLOCK_N)
+ n_mask = n_offsets < N
+
+ # ------------------------------------------------------------------
+ # Load x tile (BLOCK_N, D)
+ # ------------------------------------------------------------------
+ offs_d = tl.arange(0, D)
+ # Compute pointer for x block: base + b*stride_x_b + n*stride_x_n + d*stride_x_d
+ x_ptrs = x_ptr + pid_b * stride_x_b + n_offsets[:, None] * stride_x_n + offs_d[None, :] * stride_x_d
+ x_tile = tl.load(x_ptrs, mask=n_mask[:, None], other=0.0)
+ x_tile = x_tile # compute in f32
+
+ # Pre-load x_sq for the tile (BLOCK_N,)
+ xsq_ptrs = x_sq_ptr + pid_b * stride_xsq_b + n_offsets * stride_xsq_n
+ x_sq_tile = tl.load(xsq_ptrs, mask=n_mask, other=0.0).to(tl.float32)
+
+ # Init best distance / index
+ best_dist = tl.full((BLOCK_N,), 3.4e38, tl.float32) # large number
+ best_idx = tl.zeros((BLOCK_N,), tl.int32)
+
+ # ------------------------------------------------------------------
+ # Iterate over centroids in chunks of BLOCK_K
+ # ------------------------------------------------------------------
+ for k_start in range(0, K, BLOCK_K):
+ k_offsets = k_start + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < K
+
+ # Load centroid tile (D, BLOCK_K)
+ c_ptrs = c_ptr + pid_b * stride_c_b + k_offsets[None, :] * stride_c_k + offs_d[:, None] * stride_c_d
+ c_tile = tl.load(c_ptrs, mask=k_mask[None, :], other=0.0)
+ c_tile = c_tile
+
+ # Compute centroid squared norms (BLOCK_K,)
+ cent_sq = tl.sum(c_tile * c_tile, axis=0).to(tl.float32)
+
+ # Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile
+ cross = tl.dot(x_tile, c_tile).to(tl.float32) # float32
+
+ # Squared Euclidean distance
+ dist = x_sq_tile[:, None] + cent_sq[None, :] - 2.0 * cross
+ dist = tl.maximum(dist, 0.0)
+
+ # Mask out invalid centroid columns before reduction
+ dist = tl.where(k_mask[None, :], dist, 3.4e38)
+
+ curr_min = tl.min(dist, axis=1)
+ curr_idx = tl.argmin(dist, axis=1)
+
+ update = curr_min < best_dist
+ best_dist = tl.where(update, curr_min, best_dist)
+ best_idx = tl.where(update, k_start + curr_idx, best_idx)
+
+ # ------------------------------------------------------------------
+ # Write results
+ # ------------------------------------------------------------------
+ out_ptrs = out_ptr + pid_b * stride_out_b + n_offsets * stride_out_n
+ tl.store(out_ptrs, best_idx, mask=n_mask)
+
+
+# ---------------------------------------------------------------
+# Python wrapper
+# ---------------------------------------------------------------
+
+
+def euclid_assign_triton(
+ x: torch.Tensor,
+ centroids: torch.Tensor,
+ x_sq: torch.Tensor,
+ out: torch.Tensor = None,
+ *,
+ BLOCK_N: int = 128,
+ BLOCK_K: int = 128,
+) -> torch.Tensor:
+ """Return nearest-centroid indices using Triton kernel.
+
+ Args:
+ x : (B, N, D) float16 / float32 (on CUDA)
+ centroids : (B, K, D) same dtype/device as x
+ x_sq : (B, N) float32 – ||x||^2 per point (on CUDA)
+
+ Returns:
+ cluster_ids (B, N) int32 (callers can cast to int64 if desired)
+ """
+ assert x.is_cuda and centroids.is_cuda and x_sq.is_cuda, "All tensors must be on CUDA"
+ # assert x.dtype in (torch.float16, torch.float32), "x must be fp16/fp32"
+ assert centroids.dtype == x.dtype, "centroids dtype mismatch"
+
+ B, N, D = x.shape
+ K = centroids.shape[1]
+ assert centroids.shape == (B, K, D), "centroids shape mismatch"
+ assert x_sq.shape == (B, N), "x_sq shape mismatch"
+
+ # x = x.contiguous()
+ # centroids = centroids.contiguous()
+ # x_sq = x_sq.contiguous()
+
+ if out is None:
+ out = torch.empty((B, N), device=x.device, dtype=torch.int64)
+
+ # Strides (in elements)
+ stride_x_b, stride_x_n, stride_x_d = x.stride()
+ stride_c_b, stride_c_k, stride_c_d = centroids.stride()
+ stride_xsq_b, stride_xsq_n = x_sq.stride()
+ stride_out_b, stride_out_n = out.stride()
+
+ grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]), B) # noqa
+
+ _euclid_assign_kernel[grid](
+ x,
+ centroids,
+ x_sq,
+ out,
+ B,
+ N,
+ K,
+ D,
+ stride_x_b,
+ stride_x_n,
+ stride_x_d,
+ stride_c_b,
+ stride_c_k,
+ stride_c_d,
+ stride_xsq_b,
+ stride_xsq_n,
+ stride_out_b,
+ stride_out_n,
+ )
+ return out
+
+
+# 1. Euclidean
+def _euclid_iter(x, x_sq, centroids):
+ # cent_sq = (centroids ** 2).sum(dim=-1)
+ # cross = torch.einsum('bnd,bkd->bnk', x, centroids)
+ # dist_sq = (x_sq[:,:,None] + cent_sq[:,None,:] - 2.0 * cross).clamp_min_(0.0)
+
+ # cluster_ids = dist_sq.argmin(dim=-1)
+ cluster_ids = euclid_assign_triton(x, centroids, x_sq)
+ centroids_new, cluster_sizes = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids)
+ # centroids_new = triton_centroid_update_euclid(x, cluster_ids, centroids)
+
+ # centroids_new = centroids_new.clone() # avoid CUDA graphs aliasing
+
+ shift = (centroids_new - centroids).norm(dim=-1).max()
+ return centroids_new, shift, cluster_ids, cluster_sizes
+
+
+# 2. Cosine
+def _cosine_iter(x_norm, centroids):
+ cos_sim = torch.einsum("bnd,bkd->bnk", x_norm, centroids)
+ cluster_ids = cos_sim.argmax(dim=-1)
+ centroids_new = triton_centroid_update_cosine(x_norm, cluster_ids, centroids)
+ # centroids_new = centroids_new.clone()
+ shift = (centroids_new - centroids).norm(dim=-1).max()
+ return centroids_new, shift, cluster_ids
+
+
+# 3. Dot-product
+def _dot_iter(x, centroids):
+ sim = torch.einsum("bnd,bkd->bnk", x, centroids)
+ cluster_ids = sim.argmax(dim=-1)
+ centroids_new = triton_centroid_update_cosine(x, cluster_ids, centroids)
+ # centroids_new = centroids_new.clone()
+ shift = (centroids_new - centroids).norm(dim=-1).max()
+ return centroids_new, shift, cluster_ids
+
+
+COMPILE_FLAG = False
+
+# Try to compile; if PyTorch < 2.0 or compile is not available, fallback to original function
+try:
+ if COMPILE_FLAG:
+ _euclid_iter_compiled = torch.compile(_euclid_iter, dynamic=True, mode="reduce-overhead")
+ _cosine_iter_compiled = torch.compile(_cosine_iter, dynamic=True, mode="reduce-overhead")
+ _dot_iter_compiled = torch.compile(_dot_iter, dynamic=True, mode="reduce-overhead")
+ else:
+ _euclid_iter_compiled = _euclid_iter
+ _cosine_iter_compiled = _cosine_iter
+ _dot_iter_compiled = _dot_iter
+except Exception: # pragma: no cover
+ _euclid_iter_compiled = _euclid_iter
+ _cosine_iter_compiled = _cosine_iter
+ _dot_iter_compiled = _dot_iter
+
+
+def batch_kmeans_Euclid(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
+ """
+ Batched KMeans clustering in PyTorch using Euclidean distance.
+
+ Args:
+ x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
+ n_clusters: Number of clusters.
+ max_iters: Max number of iterations.
+ tol: Relative tolerance for center movement.
+ verbose: Print loss for each iter.
+ Returns:
+ cluster_ids: (B, N) LongTensor, cluster assignment for each point.
+ centroids: (B, n_clusters, D) final cluster centers.
+ cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
+ n_iters: actual number of iterations executed (int)
+ """
+ B, N, D = x.shape
+
+ # Pre-compute squared L2 norm of all points (constant during iterations)
+ x_sq = (x**2).sum(dim=-1) # (B, N)
+
+ if init_centroids is None:
+ # Randomly select initial centers from x
+ indices = torch.randint(0, N, (B, n_clusters), device=x.device)
+ centroids = torch.gather(x, dim=1, index=indices[..., None].expand(-1, -1, D)) # (B, n_clusters, D)
+ else:
+ # centroids = init_centroids.clone()
+ centroids = init_centroids
+
+ centroids = centroids.view(B, n_clusters, D)
+
+ for it in range(max_iters):
+ # ---- compiled single iteration ----
+ centroids_new, center_shift, cluster_ids, cluster_sizes = _euclid_iter_compiled(x, x_sq, centroids)
+
+ # 4. Check for convergence
+ if verbose:
+ print(f"Iter {it}, center shift: {center_shift.item():.6f}")
+ if center_shift < tol:
+ break
+ # centroids = centroids_new.clone()
+ centroids = centroids_new
+
+ # # --- compute cluster sizes ---
+ # ones = torch.ones_like(cluster_ids, dtype=torch.int64)
+ # cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
+ # cluster_sizes.scatter_add_(1, cluster_ids, ones)
+
+ return cluster_ids, centroids, cluster_sizes, it + 1
+ # return cluster_ids.clone(), centroids.clone(), cluster_sizes.clone(), it + 1
+
+
+# batch_kmeans_Euclid = torch.compile(batch_kmeans_Euclid, dynamic=True, mode="reduce-overhead")
+
+
+def batch_kmeans_Cosine(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
+ """
+ Batched KMeans clustering in PyTorch using Cosine similarity.
+
+ Args:
+ x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
+ n_clusters: Number of clusters.
+ max_iters: Max number of iterations.
+ tol: Relative tolerance for center movement.
+ verbose: Print loss for each iter.
+ Returns:
+ cluster_ids: (B, N) LongTensor, cluster assignment for each point.
+ centroids: (B, n_clusters, D) final cluster centers.
+ cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
+ n_iters: actual number of iterations executed (int)
+ """
+ B, N, D = x.shape
+
+ # Normalize input vectors for cosine similarity
+ x_norm = F.normalize(x, p=2, dim=-1) # (B, N, D)
+
+ if init_centroids is None:
+ # Randomly select initial centers from x_norm
+ indices = torch.randint(0, N, (B, n_clusters), device=x.device)
+ centroids = torch.gather(x_norm, dim=1, index=indices[..., None].expand(-1, -1, D)) # (B, n_clusters, D)
+ else:
+ centroids = init_centroids
+
+ centroids = centroids.view(B, n_clusters, D)
+ centroids = F.normalize(centroids, p=2, dim=-1) # Ensure centroids are normalized
+
+ for it in range(max_iters):
+ # ---- compiled single iteration ----
+ centroids_new, center_shift, cluster_ids = _cosine_iter_compiled(x_norm, centroids)
+
+ # 4. Check for convergence
+ if verbose:
+ print(f"Iter {it}, center shift: {center_shift.item():.6f}")
+ if center_shift < tol:
+ break
+ centroids = centroids_new.clone()
+
+ # --- compute cluster sizes ---
+ ones = torch.ones_like(cluster_ids, dtype=torch.int64)
+ cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
+ cluster_sizes.scatter_add_(1, cluster_ids, ones)
+
+ return cluster_ids, centroids, cluster_sizes, it + 1
+
+
+def batch_kmeans_Dot(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
+ """
+ Batched KMeans clustering in PyTorch using raw dot-product as similarity.
+
+ """
+ B, N, D = x.shape
+
+ if init_centroids is None:
+ # Randomly initialize centroids
+ indices = torch.randint(0, N, (B, n_clusters), device=x.device)
+ centroids = torch.gather(x, dim=1, index=indices[..., None].expand(-1, -1, D))
+ else:
+ centroids = init_centroids
+
+ centroids = centroids.view(B, n_clusters, D)
+
+ for it in range(max_iters):
+ # ---- compiled single iteration ----
+ centroids_new, center_shift, cluster_ids = _dot_iter_compiled(x, centroids)
+
+ # 4. Check for convergence
+ if verbose:
+ print(f"Iter {it} (dot), center shift: {center_shift.item():.6f}")
+ if center_shift < tol:
+ break
+ centroids = centroids_new.clone()
+
+ # --- compute cluster sizes ---
+ ones = torch.ones_like(cluster_ids, dtype=torch.int64)
+ cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
+ cluster_sizes.scatter_add_(1, cluster_ids, ones)
+
+ return cluster_ids, centroids, cluster_sizes, it + 1
+
+
+# --- Functions from analyze/kmeans_block_sparse_attention.py (helpers) ---
+
+
+def permute_tensor_by_labels(tensor, labels, dim):
+ labels = labels.to(tensor.device)
+ sorted_indices = torch.argsort(labels, dim=-1)
+ gather_indices = sorted_indices
+ for i in range(dim + 1, tensor.dim()):
+ gather_indices = gather_indices.unsqueeze(-1)
+ expand_shape = list(tensor.shape)
+ gather_indices = gather_indices.expand(expand_shape)
+ permuted_tensor = torch.gather(tensor, dim, gather_indices)
+ return permuted_tensor, sorted_indices
+
+
+def apply_inverse_permutation(permuted_tensor, sorted_indices, dim):
+ inverse_indices = torch.argsort(sorted_indices, dim=-1)
+ gather_indices = inverse_indices
+ for i in range(dim + 1, permuted_tensor.dim()):
+ gather_indices = gather_indices.unsqueeze(-1)
+ gather_indices = gather_indices.expand(permuted_tensor.shape)
+ original_tensor = torch.gather(permuted_tensor, dim, gather_indices)
+ return original_tensor
+
+
+def weighted_softmax(scores, weights):
+ input_dtype = scores.dtype
+ scores = scores.float()
+ weights = weights.float()
+ max_score = torch.max(scores, dim=-1, keepdim=True)[0]
+ exp_scores = torch.exp(scores - max_score)
+ weighted_exp = weights * exp_scores
+ softmax_out = weighted_exp / torch.sum(weighted_exp, dim=-1, keepdim=True).clamp(min=1e-12)
+ return softmax_out.to(input_dtype)
+
+
+def identify_dynamic_map(
+ query_centroids,
+ key_centroids,
+ q_cluster_sizes,
+ k_cluster_sizes,
+ p,
+ min_kc_ratio=0,
+):
+ B, H, qc_num, D = query_centroids.shape
+ kc_num = key_centroids.shape[2]
+ device = query_centroids.device
+
+ attn_scores = torch.matmul(query_centroids, key_centroids.transpose(-2, -1)) / (D**0.5)
+ k_weights = k_cluster_sizes.unsqueeze(-2).float()
+
+ weighted_attn_probs = weighted_softmax(attn_scores, k_weights)
+ sorted_probs, sorted_indices = torch.sort(weighted_attn_probs, dim=-1, descending=True)
+
+ cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
+ remove_indices = cumsum_probs > p
+ remove_indices[..., 1:] = remove_indices[..., :-1].clone()
+ remove_indices[..., 0] = False
+
+ if min_kc_ratio > 0:
+ preserve_length = int(min_kc_ratio * kc_num)
+ remove_indices[..., :preserve_length] = False
+
+ sorted_clusters_to_keep = ~remove_indices
+
+ dynamic_map = torch.zeros(B, H, qc_num, kc_num, dtype=torch.bool, device=device)
+ dynamic_map.scatter_(-1, sorted_indices, sorted_clusters_to_keep)
+ return dynamic_map
+
+
+# --- Functions from analyze/dynamic_block_sparse_attention.py ---
+
+
+def dynamic_block_sparse_fwd_torch(q, k, v, dynamic_map, qc_size, kc_size):
+ """
+ Computes dynamic block sparse attention using pure PyTorch.
+
+ Args:
+ q (torch.Tensor): Query tensor, shape [B, H, S, D].
+ k (torch.Tensor): Key tensor, shape [B, H, S, D].
+ v (torch.Tensor): Value tensor, shape [B, H, S, D].
+ dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
+ qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
+ kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
+
+ Returns:
+ torch.Tensor: Output tensor, shape [B, H, S, D].
+ """
+ B, H, S, D = q.shape
+ qc_num = qc_size.shape[-1]
+ kc_num = kc_size.shape[-1]
+ device = q.device
+ dtype = q.dtype
+
+ # Ensure sequence lengths match sum of block sizes
+ assert S == torch.sum(qc_size[0, 0, :]), "Sum of qc_size must equal S"
+ assert S == torch.sum(kc_size[0, 0, :]), "Sum of kc_size must equal S"
+
+ # Precompute cumulative sizes for block indexing
+ # Add a 0 at the beginning for easier slicing
+ qc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(qc_size[..., :1]), qc_size], dim=-1), dim=-1)
+ kc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(kc_size[..., :1]), kc_size], dim=-1), dim=-1)
+
+ out = torch.zeros_like(q)
+ scale = D**-0.5
+
+ # Naive implementation: Iterate through batch, head, and blocks
+ for b in range(B):
+ for h in range(H):
+ # Precompute start/end indices for this batch/head
+ q_starts = qc_cum_size[b, h, :-1]
+ q_ends = qc_cum_size[b, h, 1:]
+ k_starts = kc_cum_size[b, h, :-1]
+ k_ends = kc_cum_size[b, h, 1:]
+
+ # Iterate through query blocks
+ for i in range(qc_num):
+ q_start, q_end = q_starts[i], q_ends[i]
+ q_block = q[b, h, q_start:q_end, :] # Shape: [qc_i, D]
+
+ if q_block.shape[0] == 0:
+ continue # Skip empty blocks
+
+ m_i = torch.full((q_block.shape[0], 1), -float("inf"), device=device, dtype=dtype)
+ l_i = torch.zeros((q_block.shape[0], 1), device=device, dtype=dtype)
+ acc_o_i = torch.zeros_like(q_block) # Shape: [qc_i, D]
+
+ # Iterate through key/value blocks for the current query block
+ for j in range(kc_num):
+ # Check if this block needs computation
+ if dynamic_map[b, h, i, j]:
+ k_start, k_end = k_starts[j], k_ends[j]
+ k_block = k[b, h, k_start:k_end, :] # Shape: [kc_j, D]
+ v_block = v[b, h, k_start:k_end, :] # Shape: [kc_j, D]
+
+ if k_block.shape[0] == 0:
+ continue # Skip empty blocks
+
+ # Compute attention scores for the block
+ # QK^T: [qc_i, D] @ [D, kc_j] -> [qc_i, kc_j]
+ s_ij = (q_block @ k_block.transpose(-1, -2)) * scale
+
+ # --- Online Softmax ---
+ # Find max score per query token in this block
+ m_ij = torch.max(s_ij, dim=-1, keepdim=True)[0] # Shape: [qc_i, 1]
+
+ # Update overall max score (m_i)
+ m_new = torch.maximum(m_i, m_ij) # Shape: [qc_i, 1]
+
+ # Calculate scaling factors for previous accumulator and current block
+ p_ij = torch.exp(s_ij - m_new) # Shape: [qc_i, kc_j]
+ exp_m_diff = torch.exp(m_i - m_new) # Shape: [qc_i, 1]
+
+ # Update softmax denominator (l_i)
+ l_i = (l_i * exp_m_diff) + torch.sum(p_ij, dim=-1, keepdim=True) # Shape: [qc_i, 1]
+
+ # Update output accumulator (acc_o_i)
+ # P_ij @ V_j: [qc_i, kc_j] @ [kc_j, D] -> [qc_i, D]
+ acc_o_i = (acc_o_i * exp_m_diff) + (p_ij @ v_block) # Shape: [qc_i, D]
+
+ # Update max score for next iteration
+ m_i = m_new
+
+ # Normalize the accumulated output
+ out[b, h, q_start:q_end, :] = acc_o_i / l_i.clamp(min=1e-12) # Avoid division by zero
+
+ return out
+
+
+# --- Triton Implementation ---
+
+
+@triton.jit
+def _dynamic_block_sparse_fwd_kernel(
+ Q,
+ K,
+ V,
+ Out,
+ dynamic_map,
+ qc_cum_size,
+ kc_cum_size,
+ stride_qb,
+ stride_qh,
+ stride_qs,
+ stride_qd,
+ stride_kb,
+ stride_kh,
+ stride_ks,
+ stride_kd,
+ stride_vb,
+ stride_vh,
+ stride_vs,
+ stride_vd,
+ stride_ob,
+ stride_oh,
+ stride_os,
+ stride_od,
+ stride_dmap_b,
+ stride_dmap_h,
+ stride_dmap_qc,
+ stride_dmap_kc,
+ stride_qcs_b,
+ stride_qcs_h,
+ stride_qcs_qc,
+ stride_kcs_b,
+ stride_kcs_h,
+ stride_kcs_kc,
+ B,
+ H,
+ S,
+ D,
+ scale,
+ QC_NUM: tl.constexpr,
+ KC_NUM: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """
+ Triton kernel for dynamic block sparse attention.
+ Each program computes attention for one query block within a batch/head.
+ Processes query block in chunks of BLOCK_M.
+ Iterates through key blocks, checking dynamic_map.
+ Processes key/value blocks in chunks of BLOCK_N.
+ Uses online softmax.
+ """
+ # --- Grid Calculation ---
+ # Each program instance handles one query block for a specific batch and head
+ pid = tl.program_id(axis=0)
+ B * H * QC_NUM
+
+ # Calculate batch, head, and query block index
+ pid_q_block_global = pid # 0 to B*H*QC_NUM - 1
+ # pid_bh = pid // QC_NUM # Deprecated: Causes issues if QC_NUM is not constant across BH
+ # pid_q_block_idx = pid % QC_NUM
+
+ # Need to map pid (0.. B*H*QC_NUM-1) back to (b, h, q_block_idx)
+ # q_block_idx changes fastest, then h, then b
+ q_block_idx = pid_q_block_global % QC_NUM
+ pid_h_temp = pid_q_block_global // QC_NUM
+ h = pid_h_temp % H
+ b = pid_h_temp // H
+
+ # --- Load Q block info (start/end offsets) ---
+ qcs_offset = b * stride_qcs_b + h * stride_qcs_h
+ q_start_offset = tl.load(qc_cum_size + qcs_offset + q_block_idx * stride_qcs_qc)
+ q_end_offset = tl.load(qc_cum_size + qcs_offset + (q_block_idx + 1) * stride_qcs_qc)
+ q_block_size = q_end_offset - q_start_offset
+
+ # Early exit if the query block is empty
+ if q_block_size == 0:
+ return
+
+ # --- Pointers setup ---
+ q_ptr_base = Q + b * stride_qb + h * stride_qh + q_start_offset * stride_qs
+ k_ptr_base = K + b * stride_kb + h * stride_kh
+ v_ptr_base = V + b * stride_vb + h * stride_vh
+ out_ptr_base = Out + b * stride_ob + h * stride_oh + q_start_offset * stride_os
+ dmap_ptr = dynamic_map + b * stride_dmap_b + h * stride_dmap_h + q_block_idx * stride_dmap_qc
+ kcs_ptr = kc_cum_size + b * stride_kcs_b + h * stride_kcs_h
+
+ # --- Iterate over the query block rows in chunks of BLOCK_M ---
+ offs_qm = tl.arange(0, BLOCK_M) # Query block row offsets [0, 1, ..., BLOCK_M-1]
+ offs_d = tl.arange(0, BLOCK_D) # Dimension offsets [0, 1, ..., BLOCK_D-1]
+
+ for q_chunk_start in range(0, q_block_size, BLOCK_M):
+ q_chunk_rows = offs_qm + q_chunk_start
+ q_rows_mask = q_chunk_rows < q_block_size # Mask for valid rows in this Q chunk [BLOCK_M]
+
+ # --- Initialize accumulators for this Q chunk ---
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # Max score
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # Sum of exp(scores - max)
+ acc_o = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) # Accumulated output
+
+ # --- Load Q chunk ---
+ q_ptr = q_ptr_base + q_chunk_rows[:, None] * stride_qs + offs_d[None, :]
+ # Mask ensures we don't read out of bounds for the query block or dimension D
+ mask_q = q_rows_mask[:, None] & (offs_d[None, :] < D)
+ q_chunk = tl.load(q_ptr, mask=mask_q, other=0.0) # Shape: [BLOCK_M, BLOCK_D]
+
+ # --- Inner loop over K blocks (columns in the block sparse map) ---
+ for k_block_idx in range(KC_NUM):
+ # --- Check dynamic_map: Is this block active? ---
+ is_active = tl.load(dmap_ptr + k_block_idx * stride_dmap_kc)
+ if is_active: # Process block only if it's active
+ # --- Load K block info (start/end offsets) ---
+ k_start_offset = tl.load(kcs_ptr + k_block_idx * stride_kcs_kc)
+ k_end_offset = tl.load(kcs_ptr + (k_block_idx + 1) * stride_kcs_kc)
+ k_block_size = k_end_offset - k_start_offset
+
+ # Skip if the key block is empty (inside the active block check)
+ if k_block_size > 0:
+ k_block_ptr_base = k_ptr_base + k_start_offset * stride_ks
+ v_block_ptr_base = v_ptr_base + k_start_offset * stride_vs
+
+ # --- Loop over K block chunks (size BLOCK_N) ---
+ offs_kn = tl.arange(0, BLOCK_N) # Key block row offsets [0, ..., BLOCK_N-1]
+ for k_chunk_start in range(0, k_block_size, BLOCK_N):
+ k_chunk_rows = offs_kn + k_chunk_start
+ k_rows_mask = k_chunk_rows < k_block_size # Mask for valid rows in this K/V chunk [BLOCK_N]
+
+ # --- Load K, V chunks ---
+ k_ptr = k_block_ptr_base + k_chunk_rows[:, None] * stride_ks + offs_d[None, :]
+ v_ptr = v_block_ptr_base + k_chunk_rows[:, None] * stride_vs + offs_d[None, :]
+
+ # Mask ensures we don't read out of bounds for the key block or dimension D
+ mask_kv = k_rows_mask[:, None] & (offs_d[None, :] < D)
+ k_chunk = tl.load(k_ptr, mask=mask_kv, other=0.0) # Shape: [BLOCK_N, BLOCK_D]
+ v_chunk = tl.load(v_ptr, mask=mask_kv, other=0.0) # Shape: [BLOCK_N, BLOCK_D]
+
+ # --- Compute Scores (Attention) ---
+ # QK^T: [BLOCK_M, BLOCK_D] @ [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
+ s_ij_chunk = tl.dot(q_chunk, k_chunk.T) * scale
+
+ # IMPORTANT: Mask out scores corresponding to padding in K before max/softmax
+ # Set scores for invalid K elements to -inf
+ s_ij_chunk = tl.where(k_rows_mask[None, :], s_ij_chunk, -float("inf"))
+ # Mask out scores for invalid Q elements as well (although q_chunk elements are 0, avoid potential issues)
+ s_ij_chunk = tl.where(q_rows_mask[:, None], s_ij_chunk, -float("inf"))
+
+ # --- Online Softmax Update ---
+ # Current max for this Q-K chunk interaction
+ m_ij_chunk = tl.max(s_ij_chunk, axis=1) # Shape: [BLOCK_M]
+
+ # Update overall max (across K chunks seen so far for this Q chunk)
+ m_new = tl.maximum(m_i, m_ij_chunk) # Shape: [BLOCK_M]
+
+ # Calculate scaled probabilities P_ij = exp(S_ij - m_new)
+ p_ij_chunk = tl.exp(s_ij_chunk - m_new[:, None]) # Shape: [BLOCK_M, BLOCK_N]
+ # Zero out probabilities for masked K elements before summing
+ p_ij_chunk = tl.where(k_rows_mask[None, :], p_ij_chunk, 0.0)
+
+ # Calculate scaling factor for previous accumulator state
+ exp_m_diff = tl.exp(m_i - m_new) # Shape: [BLOCK_M]
+
+ # Update sum accumulator (denominator L)
+ l_i_chunk = tl.sum(p_ij_chunk, axis=1) # Sum probabilities for this chunk, shape [BLOCK_M]
+ l_i = (l_i * exp_m_diff) + l_i_chunk # Shape: [BLOCK_M]
+
+ # Update output accumulator O
+ # P_ij @ V_j: [BLOCK_M, BLOCK_N] @ [BLOCK_N, BLOCK_D] -> [BLOCK_M, BLOCK_D]
+ # Ensure p_ij_chunk is the correct dtype for dot product
+ p_ij_chunk_casted = p_ij_chunk.to(V.dtype.element_ty)
+ o_chunk = tl.dot(p_ij_chunk_casted, v_chunk) # Shape: [BLOCK_M, BLOCK_D]
+
+ acc_o = (acc_o * exp_m_diff[:, None]) + o_chunk # Shape: [BLOCK_M, BLOCK_D]
+
+ # Update max for the next K chunk/block
+ m_i = m_new
+ # End of 'if is_active:' block
+ # --- End of loop over K blocks ---
+
+ # --- Finalize output for this Q chunk ---
+ # Normalize the accumulated output: O = acc_o / l_i
+ # Add epsilon to l_i to avoid division by zero
+ l_i_safe = tl.where(l_i == 0, 1.0, l_i) # Avoid 0/0 -> NaN
+ o_final_chunk = acc_o / (l_i_safe[:, None])
+ o_final_chunk = tl.where(l_i[:, None] == 0, 0.0, o_final_chunk) # Ensure output is 0 if l_i was 0
+
+ # --- Write output chunk to global memory ---
+ out_ptr = out_ptr_base + q_chunk_rows[:, None] * stride_os + offs_d[None, :]
+ # Mask ensures we don't write out of bounds for the query block or dimension D
+ mask_out = q_rows_mask[:, None] & (offs_d[None, :] < D)
+ tl.store(out_ptr, o_final_chunk.to(Out.dtype.element_ty), mask=mask_out)
+
+ # --- (Optional: Write L and M stats if needed) ---
+ # Example:
+ # l_ptr = L + b * stride_lb + h * stride_lh + (q_start_offset + q_chunk_rows) * stride_ls
+ # tl.store(l_ptr, l_i, mask=q_rows_mask)
+ # m_ptr = M + ...
+ # tl.store(m_ptr, m_i, mask=q_rows_mask)
+
+ # --- End of loop over Q chunks ---
+
+
+def dynamic_block_sparse_fwd_triton(q, k, v, dynamic_map, qc_size, kc_size):
+ """
+ Launcher for the Triton dynamic block sparse attention kernel.
+
+ Args:
+ q (torch.Tensor): Query tensor, shape [B, H, S, D].
+ k (torch.Tensor): Key tensor, shape [B, H, S, D].
+ v (torch.Tensor): Value tensor, shape [B, H, S, D].
+ dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
+ qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
+ kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
+
+ Returns:
+ torch.Tensor: Output tensor, shape [B, H, S, D].
+ """
+ B, H, S, D = q.shape
+ qc_num = qc_size.shape[-1]
+ kc_num = kc_size.shape[-1]
+ dtype = q.dtype
+
+ # Assertions and checks
+ assert q.is_cuda and k.is_cuda and v.is_cuda, "Inputs must be CUDA tensors"
+ assert dynamic_map.is_cuda and qc_size.is_cuda and kc_size.is_cuda
+ assert q.dtype == k.dtype == v.dtype, "Input dtypes must match"
+ assert dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
+ assert D in [16, 32, 64, 128], "Head dimension D must be 16, 32, 64, or 128 for efficient Triton dot"
+ # Ensure sequence lengths match sum of block sizes (check on one batch/head for simplicity)
+ assert S == torch.sum(qc_size[0, 0, :]), "Sum of qc_size must equal S"
+ assert S == torch.sum(kc_size[0, 0, :]), "Sum of kc_size must equal S"
+ # Ensure dynamic_map is boolean
+ assert dynamic_map.dtype == torch.bool
+
+ # Calculate scale factor (using float32 for stability)
+ scale = D**-0.5
+
+ # Precompute cumulative sizes (on CPU/GPU, keep on device)
+ qc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(qc_size[..., :1]), qc_size], dim=-1), dim=-1).int()
+ kc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(kc_size[..., :1]), kc_size], dim=-1), dim=-1).int()
+
+ # Output tensor
+ out = torch.empty_like(q)
+
+ # Triton kernel config
+ # BLOCK_M/N can be tuned. Larger blocks may increase occupancy but need more shared memory.
+ # Let's start with reasonably sized blocks.
+ BLOCK_D = D
+ if S <= 512: # Smaller sequence, smaller blocks might be ok
+ BLOCK_M = 64
+ BLOCK_N = 64
+ elif S <= 1024:
+ BLOCK_M = 64
+ BLOCK_N = 64
+ else: # Larger sequence, potentially larger blocks
+ BLOCK_M = 128 # Or keep 64? Test
+ BLOCK_N = 64
+
+ # Adjust block size if sequence length is smaller
+ BLOCK_M = min(BLOCK_M, S)
+ BLOCK_N = min(BLOCK_N, S)
+
+ # Launch grid: One program per query block per batch/head
+ grid = (B * H * qc_num,)
+
+ # Call the kernel
+ _dynamic_block_sparse_fwd_kernel[grid](
+ q,
+ k,
+ v,
+ out,
+ dynamic_map,
+ qc_cum_size,
+ kc_cum_size,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ q.stride(3),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ k.stride(3),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ v.stride(3),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ out.stride(3),
+ dynamic_map.stride(0),
+ dynamic_map.stride(1),
+ dynamic_map.stride(2),
+ dynamic_map.stride(3),
+ qc_cum_size.stride(0),
+ qc_cum_size.stride(1),
+ qc_cum_size.stride(2),
+ kc_cum_size.stride(0),
+ kc_cum_size.stride(1),
+ kc_cum_size.stride(2),
+ B,
+ H,
+ S,
+ D,
+ scale,
+ QC_NUM=qc_num,
+ KC_NUM=kc_num,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_D=BLOCK_D,
+ # num_warps=4 # Can tune this
+ )
+
+ return out
+
+
+# ---------------- Batch wrapper for cuVS KMeans -----------------
+
+
+def batch_kmeans_rapidai(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
+ """Batched K-Means using RAPIDS cuVS implementation.
+
+ Args:
+ x (Tensor): (B, N, D) float32 tensor on CUDA.
+ n_clusters (int): K.
+ max_iters (int): maximum iterations.
+ tol (float): tolerance.
+ init_centroids (Tensor|None): optional initial centroids (B,K,D) float32.
+ verbose (bool): print per-batch info.
+
+ Returns:
+ cluster_ids (B, N) LongTensor
+ centroids (B, K, D) float32
+ cluster_sizes (B, K) LongTensor
+ n_iters_list (List[int]) iterations per batch
+ """
+ B, N, D = x.shape
+ if init_centroids is not None:
+ assert init_centroids.shape == (B, n_clusters, D)
+
+ cluster_ids_list = []
+ centroids_list = []
+ # cluster_sizes_list = []
+ n_iters_list = []
+
+ x_float = x.float()
+ if init_centroids is not None:
+ init_centroids_float = init_centroids.float()
+
+ for b in range(B):
+ xb = x_float[b]
+ if init_centroids is None:
+ centroids_init_b = None
+ init_method = "KMeansPlusPlus"
+ else:
+ centroids_init_b = init_centroids_float[b]
+ init_method = "Array"
+ labels_b, centroids_b, n_iter_b = kmeans_rapidai(xb, n_clusters, max_iter=max_iters, tol=tol, init_method=init_method, centroids_init=centroids_init_b)
+
+ cluster_ids_list.append(labels_b.to(torch.int64)) # (N,)
+ centroids_list.append(centroids_b)
+ # cluster_sizes_b = torch.bincount(labels_b, minlength=n_clusters).to(torch.int64)
+ # cluster_sizes_list.append(cluster_sizes_b)
+ # n_iters_list.append(n_iter_b)
+ # if verbose:
+ # print(f"Batch {b}: iters={n_iter_b}, cluster sizes min={cluster_sizes_b.min().item()} max={cluster_sizes_b.max().item()}")
+
+ cluster_ids = torch.stack(cluster_ids_list, dim=0) # (B,N)
+ centroids = torch.stack(centroids_list, dim=0).to(x.dtype) # (B,K,D)
+ # cluster_sizes = torch.stack(cluster_sizes_list, dim=0) # (B,K)
+ # --- compute cluster sizes ---
+ ones = torch.ones_like(cluster_ids, dtype=torch.int64)
+ cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
+ cluster_sizes.scatter_add_(1, cluster_ids, ones)
+
+ return cluster_ids, centroids, cluster_sizes, n_iters_list
diff --git a/lightx2v/common/ops/attn/svg_attn.py b/lightx2v/common/ops/attn/svg_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..66db7d1b2b34c7e7e7eec4fc5155adebf2b27c8a
--- /dev/null
+++ b/lightx2v/common/ops/attn/svg_attn.py
@@ -0,0 +1,409 @@
+import math
+from functools import lru_cache
+from math import ceil
+
+import torch
+import torch.nn.functional as F
+import triton
+import triton.language as tl
+from loguru import logger
+from torch.nn.attention.flex_attention import create_block_mask, flex_attention
+
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .template import AttnWeightTemplate
+
+
+@triton.jit
+def wan_hidden_states_placement_kernel(
+ hidden_states_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
+ hidden_states_out_ptr, # [cfg, num_heads, seq_len, head_dim]
+ best_mask_idx_ptr, # [cfg, num_heads]
+ hidden_states_stride_b,
+ hidden_states_stride_h,
+ hidden_states_stride_s,
+ hidden_states_stride_d,
+ mask_idx_stride_b,
+ mask_idx_stride_h,
+ seq_len: tl.constexpr,
+ head_dim: tl.constexpr,
+ context_length: tl.constexpr,
+ num_frame: tl.constexpr,
+ frame_size: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ # Copy hidden_states to output
+ # range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
+ cfg = tl.program_id(0)
+ head = tl.program_id(1)
+ block_id = tl.program_id(2)
+
+ start_id = block_id * BLOCK_SIZE
+ end_id = start_id + BLOCK_SIZE
+ end_id = tl.where(end_id > seq_len, seq_len, end_id)
+
+ # Load best mask idx (0 is spatial, 1 is temporal)
+ is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
+
+ offset_token = tl.arange(0, BLOCK_SIZE) + start_id
+ offset_mask = offset_token < seq_len
+ offset_d = tl.arange(0, head_dim)
+
+ if is_temporal:
+ patch_id = offset_token // num_frame
+ frame_id = offset_token - patch_id * num_frame
+ offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, frame_id * frame_size + patch_id)
+
+ offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
+ offset_hidden_states = hidden_states_ptr + offset_load
+
+ offset_store = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_store_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
+ offset_hidden_states_out = hidden_states_out_ptr + offset_store
+
+ # Maybe tune the pipeline here
+ hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
+ tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
+ else:
+ offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
+ offset_hidden_states = hidden_states_ptr + offset_load
+
+ offset_store = offset_load
+ offset_hidden_states_out = hidden_states_out_ptr + offset_store
+
+ # Maybe tune the pipeline here
+ hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
+ tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
+
+
+def wan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size):
+ cfg, num_heads, seq_len, head_dim = hidden_states.shape
+ BLOCK_SIZE = 128
+ assert seq_len == context_length + num_frame * frame_size
+
+ grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
+
+ wan_hidden_states_placement_kernel[grid](
+ hidden_states,
+ hidden_states_out,
+ best_mask_idx,
+ hidden_states.stride(0),
+ hidden_states.stride(1),
+ hidden_states.stride(2),
+ hidden_states.stride(3),
+ best_mask_idx.stride(0),
+ best_mask_idx.stride(1),
+ seq_len,
+ head_dim,
+ context_length,
+ num_frame,
+ frame_size,
+ BLOCK_SIZE,
+ )
+
+ return hidden_states_out
+
+
+@triton.jit
+def wan_sparse_head_placement_kernel(
+ query_ptr,
+ key_ptr,
+ value_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
+ query_out_ptr,
+ key_out_ptr,
+ value_out_ptr, # [cfg, num_heads, seq_len, head_dim]
+ best_mask_idx_ptr, # [cfg, num_heads]
+ query_stride_b,
+ query_stride_h,
+ query_stride_s,
+ query_stride_d,
+ mask_idx_stride_b,
+ mask_idx_stride_h,
+ seq_len: tl.constexpr,
+ head_dim: tl.constexpr,
+ context_length: tl.constexpr,
+ num_frame: tl.constexpr,
+ frame_size: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ # Copy query, key, value to output
+ # range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
+ cfg = tl.program_id(0)
+ head = tl.program_id(1)
+ block_id = tl.program_id(2)
+
+ start_id = block_id * BLOCK_SIZE
+ end_id = start_id + BLOCK_SIZE
+ end_id = tl.where(end_id > seq_len, seq_len, end_id)
+
+ # Load best mask idx (0 is spatial, 1 is temporal)
+ is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
+
+ offset_token = tl.arange(0, BLOCK_SIZE) + start_id
+ offset_mask = offset_token < seq_len
+ offset_d = tl.arange(0, head_dim)
+
+ if is_temporal:
+ frame_id = offset_token // frame_size
+ patch_id = offset_token - frame_id * frame_size
+ offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, patch_id * num_frame + frame_id)
+
+ offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
+ offset_query = query_ptr + offset_load
+ offset_key = key_ptr + offset_load
+ offset_value = value_ptr + offset_load
+
+ offset_store = (cfg * query_stride_b + head * query_stride_h + offset_store_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
+ offset_query_out = query_out_ptr + offset_store
+ offset_key_out = key_out_ptr + offset_store
+ offset_value_out = value_out_ptr + offset_store
+
+ # Maybe tune the pipeline here
+ query = tl.load(offset_query, mask=offset_mask[:, None])
+ tl.store(offset_query_out, query, mask=offset_mask[:, None])
+ key = tl.load(offset_key, mask=offset_mask[:, None])
+ tl.store(offset_key_out, key, mask=offset_mask[:, None])
+ value = tl.load(offset_value, mask=offset_mask[:, None])
+ tl.store(offset_value_out, value, mask=offset_mask[:, None])
+
+ else:
+ offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
+ offset_query = query_ptr + offset_load
+ offset_key = key_ptr + offset_load
+ offset_value = value_ptr + offset_load
+
+ offset_store = offset_load
+ offset_query_out = query_out_ptr + offset_store
+ offset_key_out = key_out_ptr + offset_store
+ offset_value_out = value_out_ptr + offset_store
+
+ # Maybe tune the pipeline here
+ query = tl.load(offset_query, mask=offset_mask[:, None])
+ tl.store(offset_query_out, query, mask=offset_mask[:, None])
+ key = tl.load(offset_key, mask=offset_mask[:, None])
+ tl.store(offset_key_out, key, mask=offset_mask[:, None])
+ value = tl.load(offset_value, mask=offset_mask[:, None])
+ tl.store(offset_value_out, value, mask=offset_mask[:, None])
+
+
+def wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
+ cfg, num_heads, seq_len, head_dim = query.shape
+ BLOCK_SIZE = 128
+ assert seq_len == context_length + num_frame * frame_size
+
+ grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
+
+ wan_sparse_head_placement_kernel[grid](
+ query,
+ key,
+ value,
+ query_out,
+ key_out,
+ value_out,
+ best_mask_idx,
+ query.stride(0),
+ query.stride(1),
+ query.stride(2),
+ query.stride(3),
+ best_mask_idx.stride(0),
+ best_mask_idx.stride(1),
+ seq_len,
+ head_dim,
+ context_length,
+ num_frame,
+ frame_size,
+ BLOCK_SIZE,
+ )
+
+
+def generate_temporal_head_mask_mod(context_length: int = 226, prompt_length: int = 226, num_frames: int = 13, token_per_frame: int = 1350, mul: int = 2):
+ def round_to_multiple(idx):
+ return ceil(idx / 128) * 128
+
+ def temporal_mask_mod(b, h, q_idx, kv_idx):
+ two_frame = round_to_multiple(mul * token_per_frame)
+ temporal_head_mask = torch.abs(q_idx - kv_idx) <= two_frame
+
+ # return temporal_head_mask
+ first_frame_mask = kv_idx < token_per_frame
+ video_mask = first_frame_mask | temporal_head_mask
+ return video_mask
+
+ return temporal_mask_mod
+
+
+@lru_cache
+def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
+ block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
+ return block_mask
+
+
+def prepare_flexattention(cfg_size, num_head, head_dim, dtype, device, context_length, prompt_length, num_frame, frame_size, diag_width=1, multiplier=2):
+ assert diag_width == multiplier, f"{diag_width} is not equivalent to {multiplier}"
+ seq_len = context_length + num_frame * frame_size
+ mask_mod = generate_temporal_head_mask_mod(context_length, prompt_length, num_frame, frame_size, mul=multiplier)
+ block_mask = create_block_mask_cached(mask_mod, None, None, seq_len, seq_len, device=device, _compile=True)
+ return block_mask
+
+
+def sparsity_to_width(sparsity, context_length, num_frame, frame_size):
+ seq_len = context_length + num_frame * frame_size
+ total_elements = seq_len**2
+
+ sparsity = (sparsity * total_elements - 2 * seq_len * context_length) / total_elements
+
+ width = seq_len * (1 - math.sqrt(1 - sparsity))
+ width_frame = width / frame_size
+
+ return width_frame
+
+
+def get_attention_mask(mask_name, sample_mse_max_row, context_length, num_frame, frame_size):
+ attention_mask = torch.zeros((context_length + num_frame * frame_size, context_length + num_frame * frame_size), device="cpu")
+
+ # TODO: fix hard coded mask
+ if mask_name == "spatial":
+ pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")
+
+ pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink
+
+ block_size, block_thres = 128, frame_size * 2
+ num_block = math.ceil(num_frame * frame_size / block_size)
+ for i in range(num_block):
+ for j in range(num_block):
+ if abs(i - j) < block_thres // block_size:
+ pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
+ attention_mask = pixel_attn_mask
+ else:
+ pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")
+
+ pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink
+
+ block_size, block_thres = 128, frame_size * 2
+ num_block = math.ceil(num_frame * frame_size / block_size)
+ for i in range(num_block):
+ for j in range(num_block):
+ if abs(i - j) < block_thres // block_size:
+ pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
+
+ pixel_attn_mask = pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame).permute(1, 0, 3, 2).reshape(frame_size * num_frame, frame_size * num_frame)
+ attention_mask = pixel_attn_mask
+
+ attention_mask = attention_mask[:sample_mse_max_row].cuda()
+ return attention_mask
+
+
+@ATTN_WEIGHT_REGISTER("svg_attn")
+class SvgAttnWeight(AttnWeightTemplate):
+ head_num = None
+ head_dim = None
+ sample_mse_max_row = None
+ num_sampled_rows = None
+ context_length = None
+ attnmap_frame_num = None
+ seqlen = None
+ sparsity = None
+ mask_name_list = ["spatial", "temporal"]
+ attention_masks = None
+ block_mask = None
+
+ @classmethod
+ def prepare(cls, head_num, head_dim, sample_mse_max_row, num_sampled_rows, context_length, sparsity):
+ cls.head_num = head_num
+ cls.head_dim = head_dim
+ cls.sample_mse_max_row = sample_mse_max_row
+ cls.num_sampled_rows = num_sampled_rows
+ cls.context_length = context_length
+ cls.sparsity = sparsity
+ torch._dynamo.config.cache_size_limit = 192 * 3
+ torch._dynamo.config.accumulated_cache_size_limit = 192 * 3
+ logger.info(
+ f"SvgAttnWeight Prepare: head_num={head_num}, head_dim={head_dim}, sample_mse_max_row={sample_mse_max_row}, num_sampled_rows={num_sampled_rows}, context_length={context_length}, sparsity={sparsity}"
+ )
+
+ def __init__(self):
+ self.config = {}
+ self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
+
+ @classmethod
+ def prepare_mask(cls, seqlen):
+ # Use class attributes so updates affect all instances of this class
+ if seqlen == cls.seqlen:
+ return
+ frame_size = seqlen // cls.attnmap_frame_num
+ cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, cls.attnmap_frame_num, frame_size) for mask_name in cls.mask_name_list]
+ multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, cls.attnmap_frame_num, frame_size)
+ cls.block_mask = prepare_flexattention(
+ 1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, cls.attnmap_frame_num, frame_size, diag_width=diag_width, multiplier=multiplier
+ )
+ cls.seqlen = seqlen
+ logger.info(f"SvgAttnWeight Update: seqlen={seqlen}")
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ q = q.unsqueeze(0).transpose(1, 2)
+ k = k.unsqueeze(0).transpose(1, 2)
+ v = v.unsqueeze(0).transpose(1, 2)
+ bs, num_heads, seq_len, dim = q.size()
+
+ self.prepare_mask(seq_len)
+ sampled_mses = self.sample_mse(q, k, v)
+ best_mask_idx = torch.argmin(sampled_mses, dim=0)
+
+ output_hidden_states = torch.zeros_like(q)
+ query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
+
+ query_out, key_out, value_out = self.fast_sparse_head_placement(
+ q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num
+ )
+
+ hidden_states = self.sparse_attention(query_out, key_out, value_out)
+ wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num)
+
+ return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
+
+ def fast_sparse_head_placement(self, query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
+ wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
+ return query_out, key_out, value_out
+
+ def sample_mse(self, query, key, value):
+ cfg, num_heads, seq_len, dim = query.size()
+ num_sampled_rows = min(self.num_sampled_rows, seq_len)
+ sampled_rows = torch.randint(low=0, high=self.sample_mse_max_row, size=(num_sampled_rows,))
+ sampled_q = query[:, :, sampled_rows, :]
+ sampled_qk_scores = torch.matmul(sampled_q, key.transpose(-2, -1)) / (dim**0.5)
+
+ sampled_attn_weights = F.softmax(sampled_qk_scores, dim=-1)
+ sampled_golden_hidden_states = torch.matmul(sampled_attn_weights, value) # (1, seq_len, dim)
+
+ sampled_mses = torch.zeros(len(self.attention_masks), cfg, num_heads, device=query.device, dtype=query.dtype)
+
+ # Only have Tri-diagonal and Striped
+ for mask_idx, attn_mask in enumerate(self.attention_masks):
+ sampled_attention_mask = attn_mask[sampled_rows, :]
+ sampled_attention_scores = sampled_qk_scores.masked_fill(sampled_attention_mask == 0, float("-inf"))
+ sampled_attn_weights = F.softmax(sampled_attention_scores, dim=-1)
+ sampled_hidden_states = torch.matmul(sampled_attn_weights, value)
+ mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3))
+ sampled_mses[mask_idx] = mse
+
+ return sampled_mses
+
+
+if __name__ == "__main__":
+ q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()
+
+ SvgAttnWeight.prepare(head_num=40, head_dim=128, sample_mse_max_row=10000, num_sampled_rows=64, context_length=0, sparsity=0.25)
+ svg_attn = SvgAttnWeight()
+ print("SvgAttnWeight initialized.")
+
+ out = svg_attn.apply(q, k, v)
+ print(f"out: {out.shape}, {out.dtype}, {out.device}")
diff --git a/lightx2v/common/ops/attn/template.py b/lightx2v/common/ops/attn/template.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1f447e9f7b693b530743cad8e3aa72459685ff
--- /dev/null
+++ b/lightx2v/common/ops/attn/template.py
@@ -0,0 +1,35 @@
+from abc import ABCMeta, abstractmethod
+
+
+class AttnWeightTemplate(metaclass=ABCMeta):
+ def __init__(self, weight_name):
+ self.weight_name = weight_name
+ self.config = {}
+
+ def load(self, weight_dict):
+ pass
+
+ @abstractmethod
+ def apply(self, input_tensor):
+ pass
+
+ def set_config(self, config=None):
+ if config is not None:
+ self.config = config
+
+ def to_cpu(self, non_blocking=False):
+ pass
+
+ def to_cuda(self, non_blocking=False):
+ pass
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ return destination
+
+ def load_state_dict(self, destination, block_index, adapter_block_inde=None):
+ return {}
+
+ def load_state_dict_from_disk(self, block_index, adapter_block_inde=None):
+ pass
diff --git a/lightx2v/common/ops/attn/torch_sdpa.py b/lightx2v/common/ops/attn/torch_sdpa.py
new file mode 100644
index 0000000000000000000000000000000000000000..65a002b6ae16f52975e0bf476fa99ab26dfdcce0
--- /dev/null
+++ b/lightx2v/common/ops/attn/torch_sdpa.py
@@ -0,0 +1,39 @@
+import torch
+import torch.nn.functional as F
+
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+
+from .template import AttnWeightTemplate
+
+
+@ATTN_WEIGHT_REGISTER("torch_sdpa")
+class TorchSDPAWeight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+
+ def apply(
+ self,
+ q,
+ k,
+ v,
+ drop_rate=0,
+ attn_mask=None,
+ causal=False,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ model_cls=None,
+ ):
+ if q.ndim == 3:
+ q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(q.dtype)
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
+ x = x.transpose(1, 2)
+ b, s, a, d = x.shape
+ out = x.reshape(b, s, -1)
+ return out.squeeze(0)
diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..197998b6e75055c24b028b43ec891decc8f0825a
--- /dev/null
+++ b/lightx2v/common/ops/attn/ulysses_attn.py
@@ -0,0 +1,415 @@
+import torch
+import torch.distributed as dist
+from loguru import logger
+
+from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+from .template import AttnWeightTemplate
+from .utils.all2all import all2all_head2seq, all2all_seq2head
+
+
+@ATTN_WEIGHT_REGISTER("ulysses")
+class UlyssesAttnWeight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+
+ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
+ """
+ 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
+
+ 参数:
+ q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
+ k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
+ v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
+ img_qkv_len (int): 图像查询、键和值的长度
+ cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
+ attention_type (str): 注意力类型,默认为 "flash_attn2"
+
+ 返回:
+ torch.Tensor: 计算得到的注意力结果
+ """
+ if len(q.shape) == 4:
+ q = q.reshape(-1, q.shape[-2], q.shape[-1])
+ k = k.reshape(-1, k.shape[-2], k.shape[-1])
+ v = v.reshape(-1, v.shape[-2], v.shape[-1])
+
+ # 获取当前进程的排名和全局进程数
+ world_size = dist.get_world_size(seq_p_group)
+ cur_rank = dist.get_rank(seq_p_group)
+
+ # 获取序列长度和文本相关的长度
+ seq_len = q.shape[0]
+ if len(cu_seqlens_qkv) == 3:
+ txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
+ txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
+ elif len(cu_seqlens_qkv) == 2:
+ txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
+ txt_mask_len = None
+
+ # 获取查询张量的头数和隐藏维度
+ _, heads, hidden_dims = q.shape
+ shard_heads = heads // world_size # 每个进程处理的头数
+ shard_seqlen = img_qkv_len # 每个进程处理的序列长度
+
+ # 分割图像和文本的查询、键和值
+ img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
+ txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
+
+ # 将图像的查询、键和值转换为头的格式
+ if use_fp8_comm:
+ original_dtype = img_q.dtype
+ original_shape = img_q.shape
+ img_q_fp8, q_scale = quant_fp8_vllm(img_q.reshape(-1, original_shape[-1]))
+ img_k_fp8, k_scale = quant_fp8_vllm(img_k.reshape(-1, original_shape[-1]))
+ img_v_fp8, v_scale = quant_fp8_vllm(img_v.reshape(-1, original_shape[-1]))
+ img_q_fp8 = all2all_seq2head(img_q_fp8.reshape(original_shape), group=seq_p_group)
+ img_k_fp8 = all2all_seq2head(img_k_fp8.reshape(original_shape), group=seq_p_group)
+ img_v_fp8 = all2all_seq2head(img_v_fp8.reshape(original_shape), group=seq_p_group)
+ q_scale = all2all_seq2head(q_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
+ k_scale = all2all_seq2head(k_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
+ v_scale = all2all_seq2head(v_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
+ img_q = dequant_fp8_vllm(img_q_fp8, q_scale, original_dtype)
+ img_k = dequant_fp8_vllm(img_k_fp8, k_scale, original_dtype)
+ img_v = dequant_fp8_vllm(img_v_fp8, v_scale, original_dtype)
+ else:
+ img_q = all2all_seq2head(img_q, group=seq_p_group)
+ img_k = all2all_seq2head(img_k, group=seq_p_group)
+ img_v = all2all_seq2head(img_v, group=seq_p_group)
+
+ # 处理文本的查询、键和值,选择当前进程的头
+ txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
+ txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
+ txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
+
+ # 合并图像和文本的查询、键和值
+ q = torch.cat((img_q, txt_q), dim=0)
+ k = torch.cat((img_k, txt_k), dim=0)
+ v = torch.cat((img_v, txt_v), dim=0)
+
+ # 初始化累积序列长度张量
+ cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=AI_DEVICE)
+ s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
+ s1 = s # 当前样本的结束位置
+ cu_seqlens_qkv[1] = s1 # 设置累积序列长度
+ if txt_mask_len:
+ s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
+ cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
+ max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
+
+ # 调用注意力函数计算注意力结果
+ # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
+ attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls)
+
+ # 分割图像和文本的注意力结果
+ img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
+
+ # 收集所有进程的文本注意力结果
+ gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
+ dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
+
+ img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
+
+ txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
+
+ # 合并图像和文本的注意力结果
+ attn = torch.cat([img_attn, txt_attn], dim=0)
+
+ return attn # 返回最终的注意力结果
+
+ @torch.compiler.disable
+ def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
+ img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
+
+ # 将头的格式转换回序列格式
+ if use_fp8_comm:
+ original_dtype = img_attn.dtype
+ original_shape = img_attn.shape
+ img_attn_fp8, attn_scale = quant_fp8_vllm(img_attn.reshape(-1, original_shape[-1]))
+ img_attn_fp8 = all2all_head2seq(img_attn_fp8.reshape(original_shape), group=seq_p_group)
+ attn_scale = all2all_head2seq(attn_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
+ img_attn = dequant_fp8_vllm(img_attn_fp8, attn_scale, original_dtype)
+ else:
+ img_attn = all2all_head2seq(img_attn, group=seq_p_group)
+
+ img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
+ return img_attn
+
+
+@ATTN_WEIGHT_REGISTER("ulysses-4090")
+class Ulysses4090AttnWeight(AttnWeightTemplate):
+ def __init__(self):
+ self.config = {}
+ self.rounds = []
+
+ def generate_round_robin_pairs(self, seq_p_group=None):
+ """
+ 生成循环赛配对表,并确保每个配对中的第一个元素小于第二个
+ 这样我们可以用简单的规则确定通信顺序
+ """
+ cur_rank = dist.get_rank(seq_p_group)
+ world_size = dist.get_world_size(seq_p_group)
+ if world_size % 2 != 0:
+ raise ValueError("world_size必须是偶数,奇数情况需要特殊处理")
+
+ teams = list(range(world_size))
+ for _ in range(world_size - 1):
+ round_schedule = {}
+ for i in range(world_size // 2):
+ team1, team2 = teams[i], teams[world_size - 1 - i]
+ smaller, larger = min(team1, team2), max(team1, team2)
+ round_schedule[smaller] = (larger, True)
+ round_schedule[larger] = (smaller, False)
+ self.rounds.append(round_schedule)
+ # 旋转列表(固定第一个元素)
+ teams = [teams[0]] + [teams[-1]] + teams[1:-1]
+
+ # if cur_rank == 0:
+ # self.print_pairing_schedule(seq_p_group)
+
+ def print_pairing_schedule(self, seq_p_group):
+ """打印通信调度表"""
+ world_size = dist.get_world_size(seq_p_group)
+ logger.info("循环赛通信调度表:")
+ logger.info("=" * 50)
+ for i, round_schedule in enumerate(self.rounds):
+ logger.info(f"第 {i + 1} 轮:")
+ for cur_rank in range(world_size):
+ partner, is_smaller_in_pair = round_schedule[cur_rank]
+ logger.info(f" 进程 {cur_rank} ←→ 进程 {partner}")
+ logger.info("=" * 50)
+
+ def load_balanced_all_to_all(self, shards, seq_p_group=None):
+ """
+ 负载均衡all-to-all通信实现
+ """
+ world_size = dist.get_world_size(seq_p_group)
+ cur_rank = dist.get_rank(seq_p_group)
+ global_rank = dist.get_global_rank(seq_p_group, cur_rank)
+ cfg_p_group_index = global_rank // world_size
+
+ # 准备接收缓冲区
+ gathered_shards = [None] * world_size
+ for target_rank in range(world_size):
+ if target_rank != cur_rank:
+ gathered_shards[target_rank] = torch.empty_like(shards[target_rank])
+ else:
+ gathered_shards[cur_rank] = shards[cur_rank]
+
+ for i, round_schedule in enumerate(self.rounds):
+ # 查找当前进程在本轮的配对
+ partner = None
+ is_smaller_in_pair = False
+ if cur_rank in round_schedule:
+ partner, is_smaller_in_pair = round_schedule[cur_rank]
+
+ # 如果没有找到配对,说明本轮当前进程空闲
+ if partner is None:
+ continue
+
+ # 计算全局rank
+ partner_global_rank = cfg_p_group_index * world_size + partner
+
+ if is_smaller_in_pair:
+ # 当前进程是配对中的较小者,先发送后接收
+ send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group)
+ recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group)
+ send_req.wait()
+ recv_req.wait()
+ else:
+ # 当前进程是配对中的较大者,先接收后发送
+ recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group)
+ send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group)
+ recv_req.wait()
+ send_req.wait()
+
+ return gathered_shards
+
+ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
+ """
+ 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
+
+ 参数:
+ q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
+ k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
+ v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
+ img_qkv_len (int): 图像查询、键和值的长度
+ cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
+ attention_type (str): 注意力类型,默认为 "flash_attn2"
+
+ 返回:
+ torch.Tensor: 计算得到的注意力结果
+ """
+ if len(self.rounds) == 0:
+ self.generate_round_robin_pairs(seq_p_group)
+
+ if len(q.shape) == 4:
+ q = q.reshape(-1, q.shape[-2], q.shape[-1])
+ k = k.reshape(-1, k.shape[-2], k.shape[-1])
+ v = v.reshape(-1, v.shape[-2], v.shape[-1])
+ # 获取当前进程的排名和全局进程数
+ world_size = dist.get_world_size(seq_p_group)
+ cur_rank = dist.get_rank(seq_p_group)
+ global_world_size = dist.get_world_size()
+ global_rank = dist.get_global_rank(seq_p_group, cur_rank)
+ cfg_p_group_index = global_rank // world_size
+
+ # 获取序列长度和文本相关的长度
+ seq_len = q.shape[0]
+ if len(cu_seqlens_qkv) == 3:
+ txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
+ txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
+ elif len(cu_seqlens_qkv) == 2:
+ txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
+ txt_mask_len = None
+
+ # 获取查询张量的头数和隐藏维度
+ _, heads, hidden_dims = q.shape
+ shard_heads = heads // world_size # 每个进程处理的头数
+ shard_seqlen = img_qkv_len # 每个进程处理的序列长度
+
+ # 分割图像和文本的查询、键和值
+ img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
+ txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
+
+ # 计算每个进程应该持有的头数分片
+ num_heads = img_q.shape[1]
+ shard_heads = num_heads // world_size
+
+ # 将 image QKV 拼接后,按头维度切分成 N 份,每份大小为 D/N
+ img_qkv = torch.stack([img_q, img_k, img_v], dim=0)
+ qkv_shards = [img_qkv[:, :, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
+ qkv_dtype = img_qkv.dtype
+
+ if use_fp8_comm:
+ qkv_fp8_byte_tensors = []
+ qkv_fp8_bytes = 0
+ qkv_fp8_dtype = None
+ qkv_scale_dtype = None
+ for i in range(world_size):
+ qkv_fp8, qkv_scale = quant_fp8_vllm(qkv_shards[i].reshape(-1, hidden_dims))
+ if i == 0:
+ qkv_fp8_bytes = qkv_fp8.numel() * qkv_fp8.element_size()
+ qkv_fp8_dtype = qkv_fp8.dtype
+ qkv_scale_dtype = qkv_scale.dtype
+ qkv_fp8_byte_tensors.append(torch.cat([qkv_fp8.contiguous().reshape(-1).view(torch.uint8), qkv_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0))
+
+ gathered_qkv_fp8_byte_tensors = self.load_balanced_all_to_all(qkv_fp8_byte_tensors, seq_p_group)
+
+ gathered_q_shards = []
+ gathered_k_shards = []
+ gathered_v_shards = []
+ for i in range(world_size):
+ qkv_fp8_byte_tensor = gathered_qkv_fp8_byte_tensors[i]
+ qkv_fp8 = qkv_fp8_byte_tensor[:qkv_fp8_bytes].view(qkv_fp8_dtype).reshape(3, -1, hidden_dims)
+ qkv_scale = qkv_fp8_byte_tensor[qkv_fp8_bytes:].view(qkv_scale_dtype).reshape(3, -1, 1)
+ q_shards_new = dequant_fp8_vllm(qkv_fp8[0], qkv_scale[0], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
+ k_shards_new = dequant_fp8_vllm(qkv_fp8[1], qkv_scale[1], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
+ v_shards_new = dequant_fp8_vllm(qkv_fp8[2], qkv_scale[2], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
+ gathered_q_shards.append(q_shards_new)
+ gathered_k_shards.append(k_shards_new)
+ gathered_v_shards.append(v_shards_new)
+ else:
+ gathered_qkv_byte_tensors = self.load_balanced_all_to_all(qkv_shards, seq_p_group)
+
+ gathered_q_shards = []
+ gathered_k_shards = []
+ gathered_v_shards = []
+ for i in range(world_size):
+ qkv_tensor = gathered_qkv_byte_tensors[i].view(qkv_dtype).reshape(3, -1, shard_heads, hidden_dims)
+ gathered_q_shards.append(qkv_tensor[0])
+ gathered_k_shards.append(qkv_tensor[1])
+ gathered_v_shards.append(qkv_tensor[2])
+
+ # 拼接所有分片 (在序列维度上)
+ # 每个 gathered_*_shards[i] 的形状是 (seq_len/N, num_heads/N, head_dim)
+ # 拼接后形状是 (seq_len, num_heads/N, head_dim)
+ img_q = torch.cat(gathered_q_shards, dim=0)
+ img_k = torch.cat(gathered_k_shards, dim=0)
+ img_v = torch.cat(gathered_v_shards, dim=0)
+
+ # 处理文本的查询、键和值,选择当前进程的头
+ txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
+ txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
+ txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
+
+ # 合并图像和文本的查询、键和值
+ q = torch.cat((img_q, txt_q), dim=0)
+ k = torch.cat((img_k, txt_k), dim=0)
+ v = torch.cat((img_v, txt_v), dim=0)
+
+ # 初始化累积序列长度张量
+ cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device="cuda")
+ s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
+ s1 = s # 当前样本的结束位置
+ cu_seqlens_qkv[1] = s1 # 设置累积序列长度
+ if txt_mask_len:
+ s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
+ cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
+ max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
+
+ # 调用注意力函数计算注意力结果
+ # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
+ attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls)
+
+ # 分割图像和文本的注意力结果
+ img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
+
+ # 收集所有进程的文本注意力结果
+ gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
+ dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
+
+ img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
+
+ txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
+
+ # 合并图像和文本的注意力结果
+ attn = torch.cat([img_attn, txt_attn], dim=0)
+
+ return attn # 返回最终的注意力结果
+
+ @torch.compiler.disable
+ def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
+ cur_rank = dist.get_rank(seq_p_group)
+ global_world_size = dist.get_world_size()
+ global_rank = dist.get_global_rank(seq_p_group, cur_rank)
+ cfg_p_group_index = global_rank // world_size
+
+ img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
+ attn_dtype = img_attn.dtype
+
+ # 按序列维度切分成 N 份
+ attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)]
+
+ if use_fp8_comm:
+ attn_fp8_byte_tensors = []
+ attn_fp8_bytes = 0
+ attn_fp8_dtype = None
+ attn_scale_dtype = None
+ for i in range(world_size):
+ attn_fp8, attn_scale = quant_fp8_vllm(attn_shards[i].reshape(-1, hidden_dims))
+ if i == 0:
+ attn_fp8_bytes = attn_fp8.numel() * attn_fp8.element_size()
+ attn_fp8_dtype = attn_fp8.dtype
+ attn_scale_dtype = attn_scale.dtype
+ attn_fp8_byte_tensors.append(torch.cat([attn_fp8.contiguous().reshape(-1).view(torch.uint8), attn_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0))
+
+ gathered_attn_fp8_byte_tensors = self.load_balanced_all_to_all(attn_fp8_byte_tensors, seq_p_group)
+
+ gathered_attn_shards = []
+ for i in range(world_size):
+ attn_fp8_byte_tensor = gathered_attn_fp8_byte_tensors[i]
+ attn_fp8 = attn_fp8_byte_tensor[:attn_fp8_bytes].view(attn_fp8_dtype).reshape(-1, hidden_dims)
+ attn_scale = attn_fp8_byte_tensor[attn_fp8_bytes:].view(attn_scale_dtype).reshape(-1, 1)
+ attn_shards_new = dequant_fp8_vllm(attn_fp8, attn_scale, attn_dtype).reshape(-1, shard_heads, hidden_dims)
+ gathered_attn_shards.append(attn_shards_new)
+
+ else:
+ gathered_attn_shards = self.load_balanced_all_to_all(attn_shards, seq_p_group)
+
+ # 拼接所有分片 (在头维度上)
+ img_attn = torch.cat(gathered_attn_shards, dim=1)
+ img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
+
+ return img_attn
diff --git a/lightx2v/common/ops/attn/utils/all2all.py b/lightx2v/common/ops/attn/utils/all2all.py
new file mode 100644
index 0000000000000000000000000000000000000000..757ce74e8abaa1fae6c44d0b7251106ab8823137
--- /dev/null
+++ b/lightx2v/common/ops/attn/utils/all2all.py
@@ -0,0 +1,89 @@
+import torch
+import torch._dynamo as dynamo
+import torch.distributed as dist
+
+
+@dynamo.disable
+def all2all_seq2head(input, group=None):
+ """
+ 将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
+
+ 参数:
+ input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims]
+
+ 返回:
+ torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
+ """
+ # 确保输入是一个3D张量
+ assert input.dim() == 3, f"input must be 3D tensor"
+
+ # 获取当前进程的世界大小
+ world_size = dist.get_world_size(group=group)
+
+ # 获取输入张量的形状
+ shard_seq_len, heads, hidden_dims = input.shape
+ seq_len = shard_seq_len * world_size # 计算总序列长度
+ shard_heads = heads // world_size # 计算每个进程处理的头数
+
+ # 重塑输入张量以便进行 all-to-all 操作
+ input_t = (
+ input.reshape(shard_seq_len, world_size, shard_heads, hidden_dims) # 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims]
+ .transpose(0, 1) # 转置以便进行 all-to-all 操作
+ .contiguous() # 确保内存连续
+ )
+
+ # 创建一个与输入张量相同形状的输出张量
+ output = torch.empty_like(input_t)
+
+ # 执行 all-to-all 操作,将输入张量的内容分发到所有进程
+ dist.all_to_all_single(output, input_t, group=group)
+
+ # 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
+ output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous()
+
+ return output # 返回转换后的输出张量
+
+
+@dynamo.disable
+def all2all_head2seq(input, group=None):
+ """
+ 将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
+
+ 参数:
+ input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims]
+
+ 返回:
+ torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
+ """
+ # 确保输入是一个3D张量
+ assert input.dim() == 3, f"input must be 3D tensor"
+
+ # 获取当前进程的世界大小
+ world_size = dist.get_world_size(group=group)
+
+ # 获取输入张量的形状
+ seq_len, shard_heads, hidden_dims = input.shape
+ heads = shard_heads * world_size # 计算总头数
+ shard_seq_len = seq_len // world_size # 计算每个进程处理的序列长度
+
+ # 重塑输入张量以便进行 all-to-all 操作
+ input_t = (
+ input.reshape(world_size, shard_seq_len, shard_heads, hidden_dims) # 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims]
+ .transpose(1, 2) # 转置以便进行 all-to-all 操作
+ .contiguous() # 确保内存连续
+ .reshape(world_size, shard_heads, shard_seq_len, hidden_dims) # 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims]
+ )
+
+ # 创建一个与输入张量相同形状的输出张量
+ output = torch.empty_like(input_t)
+
+ # 执行 all-to-all 操作,将输入张量的内容分发到所有进程
+ dist.all_to_all_single(output, input_t, group=group)
+
+ # 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
+ output = output.reshape(heads, shard_seq_len, hidden_dims)
+
+ # 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状
+ output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims)
+
+ return output # 返回转换后的输出张量
diff --git a/lightx2v/common/ops/attn/utils/ring_comm.py b/lightx2v/common/ops/attn/utils/ring_comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0f30a463b53eecc7e7d5598bf424256c0a8aa9
--- /dev/null
+++ b/lightx2v/common/ops/attn/utils/ring_comm.py
@@ -0,0 +1,46 @@
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+
+class RingComm:
+ def __init__(self, process_group: dist.ProcessGroup = None):
+ self._process_group = process_group
+ self._ops = []
+ self.rank = dist.get_rank(self._process_group)
+ self.world_size = dist.get_world_size(self._process_group)
+ self._reqs = None
+
+ self.send_rank = (self.rank + 1) % self.world_size
+ self.recv_rank = (self.rank - 1) % self.world_size
+
+ if process_group is not None:
+ self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
+ self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
+
+ def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if recv_tensor is None:
+ res = torch.empty_like(to_send)
+ # logger.info(f"send_recv: empty_like {to_send.shape}")
+ else:
+ res = recv_tensor
+
+ send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
+ recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
+ self._ops.append(send_op)
+ self._ops.append(recv_op)
+ return res
+
+ def commit(self):
+ if self._reqs is not None:
+ raise RuntimeError("commit called twice")
+ self._reqs = dist.batch_isend_irecv(self._ops)
+
+ def wait(self):
+ if self._reqs is None:
+ raise RuntimeError("wait called before commit")
+ for req in self._reqs:
+ req.wait()
+ self._reqs = None
+ self._ops = []
diff --git a/lightx2v/common/ops/conv/__init__.py b/lightx2v/common/ops/conv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4cf739e6027771a21e11d8a629cf444546e1163
--- /dev/null
+++ b/lightx2v/common/ops/conv/__init__.py
@@ -0,0 +1,2 @@
+from .conv2d import *
+from .conv3d import *
diff --git a/lightx2v/common/ops/conv/conv2d.py b/lightx2v/common/ops/conv/conv2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6665087ad03468d532aa1b42810903e5c773562
--- /dev/null
+++ b/lightx2v/common/ops/conv/conv2d.py
@@ -0,0 +1,61 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+from lightx2v.utils.registry_factory import CONV2D_WEIGHT_REGISTER
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+
+class Conv2dWeightTemplate(metaclass=ABCMeta):
+ def __init__(self, weight_name, bias_name, stride, padding, dilation, groups):
+ self.weight_name = weight_name
+ self.bias_name = bias_name
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.config = {}
+
+ @abstractmethod
+ def load(self, weight_dict):
+ pass
+
+ @abstractmethod
+ def apply(self, input_tensor):
+ pass
+
+ def set_config(self, config=None):
+ if config is not None:
+ self.config = config
+
+
+@CONV2D_WEIGHT_REGISTER("Default")
+class Conv2dWeight(Conv2dWeightTemplate):
+ def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
+ super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
+
+ def load(self, weight_dict):
+ self.weight = weight_dict[self.weight_name].to(AI_DEVICE)
+ self.bias = weight_dict[self.bias_name].to(AI_DEVICE) if self.bias_name is not None else None
+
+ def apply(self, input_tensor):
+ input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
+ return input_tensor
+
+ def to_cpu(self, non_blocking=False):
+ self.weight = self.weight.cpu(non_blocking=non_blocking)
+ if self.bias is not None:
+ self.bias = self.bias.cpu(non_blocking=non_blocking)
+
+ def to_cuda(self, non_blocking=False):
+ self.weight = self.weight.to(AI_DEVICE, non_blocking=non_blocking)
+ if self.bias is not None:
+ self.bias = self.bias.to(AI_DEVICE, non_blocking=non_blocking)
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ destination[self.weight_name] = self.weight.cpu().detach().clone()
+ if self.bias is not None:
+ destination[self.bias_name] = self.bias.cpu().detach().clone()
+ return destination
diff --git a/lightx2v/common/ops/conv/conv3d.py b/lightx2v/common/ops/conv/conv3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e45f3391ec48cf67008ba5748c9065677328ec7
--- /dev/null
+++ b/lightx2v/common/ops/conv/conv3d.py
@@ -0,0 +1,94 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+
+class Conv3dWeightTemplate(metaclass=ABCMeta):
+ def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
+ self.weight_name = weight_name
+ self.bias_name = bias_name
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.config = {}
+
+ @abstractmethod
+ def load(self, weight_dict):
+ pass
+
+ @abstractmethod
+ def apply(self, input_tensor):
+ pass
+
+ def set_config(self, config=None):
+ if config is not None:
+ self.config = config
+
+
+@CONV3D_WEIGHT_REGISTER("Default")
+class Conv3dWeight(Conv3dWeightTemplate):
+ def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
+ super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
+
+ def load(self, weight_dict):
+ device = weight_dict[self.weight_name].device
+ if device.type == "cpu":
+ weight_shape = weight_dict[self.weight_name].shape
+ weight_dtype = weight_dict[self.weight_name].dtype
+ self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
+ self.pin_weight.copy_(weight_dict[self.weight_name])
+
+ if self.bias_name is not None:
+ bias_shape = weight_dict[self.bias_name].shape
+ bias_dtype = weight_dict[self.bias_name].dtype
+ self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
+ self.pin_bias.copy_(weight_dict[self.bias_name])
+ else:
+ self.bias = None
+ self.pin_bias = None
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name]
+ if self.bias_name is not None:
+ self.bias = weight_dict[self.bias_name]
+ else:
+ self.bias = None
+
+ def apply(self, input_tensor):
+ input_tensor = torch.nn.functional.conv3d(
+ input_tensor,
+ weight=self.weight,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+ return input_tensor
+
+ def to_cuda(self, non_blocking=False):
+ self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
+ if hasattr(self, "pin_bias") and self.pin_bias is not None:
+ self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
+
+ def to_cpu(self, non_blocking=False):
+ if hasattr(self, "pin_weight"):
+ self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
+ if self.bias is not None:
+ self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
+ else:
+ self.weight = self.weight.to("cpu", non_blocking=non_blocking)
+ if hasattr(self, "bias") and self.bias is not None:
+ self.bias = self.bias.to("cpu", non_blocking=non_blocking)
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight # .cpu().detach().clone().contiguous()
+ if self.bias_name is not None:
+ destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias # .cpu().detach().clone()
+ return destination
diff --git a/lightx2v/common/ops/embedding/__init__.py b/lightx2v/common/ops/embedding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..df915a383076749252e8c22370e58537918c5f36
--- /dev/null
+++ b/lightx2v/common/ops/embedding/__init__.py
@@ -0,0 +1 @@
+from .embedding_weight import *
diff --git a/lightx2v/common/ops/embedding/embedding_weight.py b/lightx2v/common/ops/embedding/embedding_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..8268d18188176f6bd17ed1a31eec324c9075dbb6
--- /dev/null
+++ b/lightx2v/common/ops/embedding/embedding_weight.py
@@ -0,0 +1,72 @@
+import re
+from abc import ABCMeta
+
+import torch
+import torch.nn.functional as F
+
+from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+
+class EmbeddingWeightTemplate(metaclass=ABCMeta):
+ def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ self.weight_name = weight_name
+ self.create_cuda_buffer = create_cuda_buffer
+ self.create_cpu_buffer = create_cpu_buffer
+ self.lazy_load = lazy_load
+ self.lazy_load_file = lazy_load_file
+ self.is_post_adapter = is_post_adapter
+ self.config = {}
+
+ def load(self, weight_dict):
+ if not self.lazy_load:
+ if self.create_cuda_buffer:
+ self.weight_cuda_buffer = weight_dict[self.weight_name].to(AI_DEVICE)
+ else:
+ device = weight_dict[self.weight_name].device
+ if device.type == "cpu":
+ weight_shape = weight_dict[self.weight_name].shape
+ weight_dtype = weight_dict[self.weight_name].dtype
+ self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
+ self.pin_weight.copy_(weight_dict[self.weight_name])
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name]
+
+ def to_cuda(self, non_blocking=False):
+ self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
+
+ def to_cpu(self, non_blocking=False):
+ if hasattr(self, "pin_weight"):
+ self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
+ else:
+ self.weight = self.weight.to("cpu", non_blocking=non_blocking)
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
+ return destination
+
+ def load_state_dict(self, destination, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ else:
+ weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+
+ if weight_name not in destination:
+ self.weight = None
+ return
+ self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
+
+
+@EMBEDDING_WEIGHT_REGISTER("Default")
+class EmbeddingWeight(EmbeddingWeightTemplate):
+ def __init__(self, weight_name=None, lazy_load=False, lazy_load_file=None):
+ super().__init__(weight_name, lazy_load, lazy_load_file)
+
+ def apply(self, input_indices):
+ output = F.embedding(input=input_indices, weight=self.weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
+
+ return output
diff --git a/lightx2v/common/ops/mm/__init__.py b/lightx2v/common/ops/mm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f9898101b00e91590a78d636e2066ff1234a4d8
--- /dev/null
+++ b/lightx2v/common/ops/mm/__init__.py
@@ -0,0 +1 @@
+from .mm_weight import *
diff --git a/lightx2v/common/ops/mm/mm_weight.py b/lightx2v/common/ops/mm/mm_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..330a50e0252432f12c5c6dad9414ad167f1f445f
--- /dev/null
+++ b/lightx2v/common/ops/mm/mm_weight.py
@@ -0,0 +1,1325 @@
+import os
+import re
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+
+import torch
+from safetensors import safe_open
+
+from lightx2v.utils.envs import *
+from lightx2v.utils.ggml_tensor import GGMLTensor
+from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tensor
+from lightx2v.utils.global_paras import CALIB
+from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
+from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+try:
+ from lightx2v_kernel.gemm import (
+ cutlass_scaled_mxfp4_mm,
+ cutlass_scaled_mxfp6_mxfp8_mm,
+ cutlass_scaled_mxfp8_mm,
+ cutlass_scaled_nvfp4_mm,
+ scaled_mxfp4_quant,
+ scaled_mxfp6_quant,
+ scaled_mxfp8_quant,
+ scaled_nvfp4_quant,
+ )
+except ImportError:
+ scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm = None, None
+ scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm = None, None
+ scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm = None, None
+ scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None
+
+try:
+ from vllm import _custom_ops as ops
+except ImportError:
+ ops = None
+
+try:
+ import sgl_kernel
+except ImportError:
+ sgl_kernel = None
+
+try:
+ from q8_kernels.functional.linear import q8_linear
+except ImportError:
+ q8_linear = None
+
+try:
+ from q8_kernels.functional.linear import fp8_linear
+except ImportError:
+ fp8_linear = None
+
+try:
+ import deep_gemm
+except ImportError:
+ deep_gemm = None
+
+try:
+ from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
+except ImportError:
+ quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
+
+try:
+ import gguf
+except ImportError:
+ gguf = None
+
+try:
+ import marlin_cuda_quant
+except ImportError:
+ marlin_cuda_quant = None
+
+
+class MMWeightTemplate(metaclass=ABCMeta):
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ self.weight_name = weight_name
+ self.bias_name = bias_name
+ self.create_cuda_buffer = create_cuda_buffer
+ self.create_cpu_buffer = create_cpu_buffer
+ self.lazy_load = lazy_load
+ self.lazy_load_file = lazy_load_file
+ self.is_post_adapter = is_post_adapter
+ self.config = {}
+
+ @abstractmethod
+ def load(self, weight_dict):
+ pass
+
+ @abstractmethod
+ def apply(self):
+ pass
+
+ def set_config(self, config={}):
+ self.config = config
+
+ def to_cuda(self, non_blocking=False):
+ self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
+ if hasattr(self, "pin_weight_scale"):
+ self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
+ if hasattr(self, "pin_bias") and self.pin_bias is not None:
+ self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
+
+ def to_cpu(self, non_blocking=False):
+ if hasattr(self, "pin_weight"):
+ self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
+ if hasattr(self, "weight_scale_name"):
+ self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
+ if self.bias is not None:
+ self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
+ else:
+ self.weight = self.weight.to("cpu", non_blocking=non_blocking)
+ if hasattr(self, "weight_scale"):
+ self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
+ if hasattr(self, "bias") and self.bias is not None:
+ self.bias = self.bias.to("cpu", non_blocking=non_blocking)
+
+
+@MM_WEIGHT_REGISTER("Default")
+class MMWeight(MMWeightTemplate):
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+
+ def load(self, weight_dict):
+ if self.create_cuda_buffer:
+ self._load_cuda_buffers(weight_dict)
+ elif self.create_cpu_buffer:
+ self._load_cpu_pin_buffers()
+ else:
+ self._load_default_tensors(weight_dict)
+
+ def _get_source_tensor(self, source_name, weight_dict=None):
+ if self.lazy_load:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{source_name.split('.')[1]}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ return lazy_load_file.get_tensor(source_name)
+ return weight_dict[source_name]
+
+ def _create_pin_tensor(self, tensor, transpose=False):
+ pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
+ pin_tensor = pin_tensor.copy_(tensor)
+ if transpose:
+ pin_tensor = pin_tensor.t()
+ del tensor
+ return pin_tensor
+
+ def _load_cuda_buffers(self, weight_dict):
+ self.weight_cuda_buffer = self._get_source_tensor(self.weight_name, weight_dict).t().to(AI_DEVICE)
+ if self.bias_name is not None:
+ self.bias_cuda_buffer = self._get_source_tensor(self.bias_name, weight_dict).to(AI_DEVICE)
+
+ def _load_cpu_pin_buffers(self):
+ if self.lazy_load:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ weight_tensor = lazy_load_file.get_tensor(self.weight_name)
+ self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)
+
+ if self.bias_name is not None:
+ bias_tensor = lazy_load_file.get_tensor(self.bias_name)
+ self.pin_bias = self._create_pin_tensor(bias_tensor)
+ else:
+ self.bias = None
+ self.pin_bias = None
+
+ def _load_default_tensors(self, weight_dict):
+ if not self.lazy_load:
+ device = weight_dict[self.weight_name].device
+ if device.type == "cpu":
+ weight_tensor = weight_dict[self.weight_name]
+ self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)
+
+ if self.bias_name is not None:
+ bias_tensor = weight_dict[self.bias_name]
+ self.pin_bias = self._create_pin_tensor(bias_tensor)
+ else:
+ self.bias = None
+ self.pin_bias = None
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name].t()
+ self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
+
+ def apply(self, input_tensor):
+ shape = (input_tensor.shape[0], self.weight.shape[1])
+ dtype = input_tensor.dtype
+ device = input_tensor.device
+ output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
+ if self.bias is None:
+ return torch.mm(input_tensor, self.weight, out=output_tensor)
+ return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
+ if self.bias_name is not None:
+ destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
+ return destination
+
+ def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ else:
+ self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+
+ if self.bias_name is not None:
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
+ else:
+ self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
+ self.pin_weight = self.pin_weight.copy_(weight_tensor)
+ del weight_tensor
+
+ if self.bias_name is not None:
+ bias_tensor = lazy_load_file.get_tensor(self.bias_name)
+ self.pin_bias.copy_(bias_tensor)
+ del bias_tensor
+
+ def load_state_dict(self, destination, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ else:
+ weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+
+ if weight_name not in destination:
+ self.weight = None
+ return
+
+ self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
+
+ if self.bias_name is not None:
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
+ else:
+ bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
+ self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
+ else:
+ self.bias = None
+
+
+@MM_WEIGHT_REGISTER("Default-Force-FP32")
+class MMWeightForceFP32(MMWeight):
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+
+ def load(self, weight_dict):
+ if not self.lazy_load:
+ super().load(weight_dict)
+ self.weight = self.weight.to(torch.float32)
+ if hasattr(self, "bias") and self.bias is not None:
+ self.bias = self.bias.to(torch.float32)
+
+
+class MMWeightQuantTemplate(MMWeightTemplate):
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
+ self.load_func = None
+ self.weight_need_transpose = True
+ self.act_quant_func = None
+ self.lazy_load = lazy_load
+ self.lazy_load_file = lazy_load_file
+ self.infer_dtype = GET_DTYPE()
+ self.bias_force_fp32 = False
+
+ # =========================
+ # weight load functions
+ # =========================
+ def load(self, weight_dict):
+ self.load_quantized(weight_dict)
+ if self.weight_need_transpose:
+ if hasattr(self, "weight") and self.weight is not None:
+ self.weight = self.weight.t()
+ if hasattr(self, "pin_weight") and self.pin_weight is not None:
+ self.pin_weight = self.pin_weight.t()
+ if hasattr(self, "weight_cuda_buffer") and self.weight_cuda_buffer is not None:
+ self.weight_cuda_buffer = self.weight_cuda_buffer.t()
+
+ def load_quantized(self, weight_dict):
+ if self.create_cuda_buffer:
+ self._load_cuda_buffers(weight_dict)
+ elif self.create_cpu_buffer:
+ self._load_cpu_pin_buffers()
+ else:
+ self._load_default_tensors(weight_dict)
+
+ def _load_cuda_buffers(self, weight_dict):
+ if self.lazy_load:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
+ self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
+ self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
+ else:
+ source = weight_dict
+ self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
+ self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
+
+ def _get_cuda_tensor_pair(self, source, is_lazy):
+ if is_lazy:
+ weight = source.get_tensor(self.weight_name).to(AI_DEVICE)
+ scale = source.get_tensor(self.weight_scale_name).float().to(AI_DEVICE)
+ else:
+ weight = source[self.weight_name].to(AI_DEVICE)
+ scale = source[self.weight_scale_name].float().to(AI_DEVICE)
+ return weight, scale
+
+ def _get_cuda_bias_tensor(self, source, is_lazy):
+ if self.bias_name is None:
+ return None
+ if is_lazy:
+ bias = source.get_tensor(self.bias_name)
+ dtype = self.infer_dtype
+ else:
+ bias = source[self.bias_name]
+ dtype = bias.dtype
+ if self.bias_force_fp32:
+ bias = bias.to(torch.float32)
+ else:
+ bias = bias.to(dtype)
+ return bias.to(AI_DEVICE)
+
+ def _load_cpu_pin_buffers(self):
+ self.pin_weight, self.pin_weight_scale = self._get_cpu_pin_tensor_pair(self.lazy_load_file, is_lazy=True)
+ self.pin_bias = self._get_cpu_pin_bias_tensor(self.lazy_load_file, is_lazy=True)
+ self.bias = None
+
+ def _get_cpu_pin_tensor_pair(self, source, is_lazy):
+ if is_lazy:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
+ weight_tensor = source.get_tensor(self.weight_name)
+ scale_tensor = source.get_tensor(self.weight_scale_name)
+ scale_dtype = torch.float
+ pin_weight = self._create_pin_tensor(weight_tensor)
+ pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
+ else:
+ weight_tensor = source[self.weight_name]
+ scale_tensor = source[self.weight_scale_name]
+ scale_dtype = torch.float
+ pin_weight = self._create_pin_tensor(weight_tensor)
+ pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
+ return pin_weight, pin_scale
+
+ def _get_cpu_pin_bias_tensor(self, source, is_lazy):
+ if self.bias_name is None:
+ return None
+ if is_lazy:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
+ bias_tensor = source.get_tensor(self.bias_name)
+ if not self.bias_force_fp32:
+ bias_tensor = bias_tensor.to(self.infer_dtype)
+ if self.bias_force_fp32:
+ bias_tensor = bias_tensor.to(torch.float32)
+ return self._create_pin_tensor(bias_tensor)
+ else:
+ bias_tensor = source[self.bias_name]
+ if self.bias_force_fp32:
+ bias_tensor = bias_tensor.to(torch.float32)
+ return self._create_pin_tensor(bias_tensor)
+
+ def _create_pin_tensor(self, tensor, dtype=None):
+ dtype = dtype or tensor.dtype
+ pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
+ pin_tensor.copy_(tensor)
+ del tensor
+ return pin_tensor
+
+ def _load_default_tensors(self, weight_dict):
+ if not self.lazy_load:
+ self.weight, self.weight_scale, self.pin_weight, self.pin_weight_scale = self._get_device_tensor_pair(weight_dict)
+ self._load_default_bias(weight_dict)
+ else:
+ self.bias = None
+ self.pin_bias = None
+
+ def _get_device_tensor_pair(self, source):
+ device = source[self.weight_name].device
+ if device.type == "cpu":
+ pin_weight, pin_scale = self._get_cpu_pin_tensor_pair(source, is_lazy=False)
+ return None, None, pin_weight, pin_scale
+ else:
+ return source[self.weight_name], source[self.weight_scale_name].float(), None, None
+
+ def _load_default_bias(self, source):
+ if self.bias_name is None:
+ self.bias = None
+ self.pin_bias = None
+ self.bias_cuda_buffer = None
+ return
+
+ if self.create_cuda_buffer:
+ self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, is_lazy=False)
+ self.bias = None
+ self.pin_bias = None
+ else:
+ bias_tensor = source[self.bias_name].float() if self.bias_force_fp32 else source[self.bias_name]
+ device = bias_tensor.device
+ if device.type == "cpu":
+ self.pin_bias = self._get_cpu_pin_bias_tensor(source, is_lazy=False)
+ self.bias = None
+ else:
+ self.bias = bias_tensor
+ self.pin_bias = None
+
+ def load_fp8_perchannel_sym(self, weight_dict):
+ if self.config.get("weight_auto_quant", False):
+ self.weight = weight_dict[self.weight_name].to(torch.float32)
+ w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
+ self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
+ self.weight = self.weight.to(torch.float8_e4m3fn)
+ self.weight_scale = self.weight_scale.to(torch.float32)
+ else:
+ self.load_quantized(weight_dict)
+
+ def load_int8_perchannel_sym(self, weight_dict):
+ if self.config.get("weight_auto_quant", False):
+ self.weight = weight_dict[self.weight_name].to(torch.float32)
+ w_quantizer = IntegerQuantizer(8, True, "per_channel")
+ self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
+ self.weight = self.weight.to(torch.int8)
+ self.weight_scale = self.weight_scale.to(torch.float32)
+ else:
+ self.load_quantized(weight_dict)
+
+ def load_mxfp4(self, weight_dict):
+ if self.config.get("weight_auto_quant", False):
+ device = weight_dict[self.weight_name].device
+ self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
+ self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight)
+ self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
+ else:
+ device = weight_dict[self.weight_name].device
+ if device.type == "cpu":
+ weight_shape = weight_dict[self.weight_name].shape
+ weight_dtype = weight_dict[self.weight_name].dtype
+ self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
+ self.pin_weight.copy_(weight_dict[self.weight_name])
+
+ weight_scale_shape = weight_dict[self.weight_scale_name].shape
+ weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
+ self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
+ self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name]
+ self.weight_scale = weight_dict[self.weight_scale_name]
+
+ def load_mxfp6(self, weight_dict):
+ if self.config.get("weight_auto_quant", False):
+ device = weight_dict[self.weight_name].device
+ self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
+ self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight)
+ self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
+ else:
+ device = weight_dict[self.weight_name].device
+ if device.type == "cpu":
+ weight_shape = weight_dict[self.weight_name].shape
+ weight_dtype = weight_dict[self.weight_name].dtype
+ self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
+ self.pin_weight.copy_(weight_dict[self.weight_name])
+
+ weight_scale_shape = weight_dict[self.weight_scale_name].shape
+ weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
+ self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
+ self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name]
+ self.weight_scale = weight_dict[self.weight_scale_name]
+
+ def load_mxfp8(self, weight_dict):
+ if self.config.get("weight_auto_quant", False):
+ device = weight_dict[self.weight_name].device
+ self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
+ self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight)
+ self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
+ else:
+ device = weight_dict[self.weight_name].device
+ if device.type == "cpu":
+ weight_shape = weight_dict[self.weight_name].shape
+ weight_dtype = weight_dict[self.weight_name].dtype
+ self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
+ self.pin_weight.copy_(weight_dict[self.weight_name])
+
+ weight_scale_shape = weight_dict[self.weight_scale_name].shape
+ weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
+ self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
+ self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name]
+ self.weight_scale = weight_dict[self.weight_scale_name]
+
+ def load_nvfp4(self, weight_dict):
+ device = weight_dict[self.weight_name].device
+
+ input_absmax = weight_dict[self.weight_name.replace(".weight", ".input_absmax")]
+ input_global_scale = (2688.0 / input_absmax).to(torch.float32)
+ weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"]
+ alpha = 1.0 / (input_global_scale * weight_global_scale)
+
+ if device.type == "cpu":
+ weight_shape = weight_dict[self.weight_name].shape
+ weight_dtype = weight_dict[self.weight_name].dtype
+ self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
+ self.pin_weight.copy_(weight_dict[self.weight_name])
+
+ weight_scale_shape = weight_dict[self.weight_scale_name].shape
+ weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
+ self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
+ self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
+
+ input_global_scale_shape = input_global_scale.shape
+ input_global_scale_dtype = input_global_scale.dtype
+ self.pin_input_global_scale = torch.empty(input_global_scale_shape, pin_memory=True, dtype=input_global_scale_dtype)
+ self.pin_input_global_scale.copy_(input_global_scale)
+
+ alpha_shape = alpha.shape
+ alpha_dtype = alpha.dtype
+ self.pin_alpha = torch.empty(alpha_shape, pin_memory=True, dtype=alpha_dtype)
+ self.pin_alpha.copy_(alpha)
+
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name]
+ self.weight_scale = weight_dict[self.weight_scale_name]
+ self.input_global_scale = input_global_scale
+ self.alpha = alpha
+
+ if self.bias_name is not None:
+ if self.create_cuda_buffer:
+ self.bias_cuda_buffer = weight_dict[self.bias_name].to(AI_DEVICE)
+ else:
+ device = weight_dict[self.bias_name].device
+ if device.type == "cpu":
+ bias_shape = weight_dict[self.bias_name].shape
+ bias_dtype = weight_dict[self.bias_name].dtype
+ self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
+ self.pin_bias.copy_(weight_dict[self.bias_name])
+ else:
+ self.bias = weight_dict[self.bias_name]
+ else:
+ self.bias = None
+ self.pin_bias = None
+
+ def load_fp8_perblock128_sym(self, weight_dict):
+ if self.config.get("weight_auto_quant", False):
+ self.weight = weight_dict[self.weight_name]
+ self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
+ else:
+ self.load_quantized(weight_dict)
+
+ def per_block_cast_to_fp8(self, x):
+ assert x.dim() == 2
+ m, n = x.shape
+ x_padded = torch.zeros(
+ (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ x_padded[:m, :n] = x
+ x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
+ x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
+ x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
+ return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
+
+ # =========================
+ # act quant kernels
+ # =========================
+ def act_quant_int8_perchannel_sym_torchao(self, x):
+ input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
+ return input_tensor_quant, input_tensor_scale
+
+ def act_quant_fp8_perchannel_sym_vllm(self, x):
+ input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
+ return input_tensor_quant, input_tensor_scale
+
+ def act_quant_fp8_perchannel_sym_sgl(self, x):
+ m, k = x.shape
+ input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
+ input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
+ sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
+ return input_tensor_quant, input_tensor_scale
+
+ def act_quant_int8_perchannel_sym_vllm(self, x):
+ input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
+ return input_tensor_quant, input_tensor_scale
+
+ def act_quant_nvfp4(self, x):
+ input_tensor_quant, input_tensor_scale = scaled_nvfp4_quant(x, self.input_global_scale)
+ return input_tensor_quant, input_tensor_scale
+
+ def act_quant_mxfp4(self, x):
+ input_tensor_quant, input_tensor_scale = scaled_mxfp4_quant(x)
+ return input_tensor_quant, input_tensor_scale
+
+ def act_quant_mxfp8(self, x):
+ input_tensor_quant, input_tensor_scale = scaled_mxfp8_quant(x)
+ return input_tensor_quant, input_tensor_scale
+
+ def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
+ assert x.dim() == 2 and x.size(1) % 128 == 0
+ m, n = x.shape
+ x_view = x.view(m, -1, 128)
+ x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
+ return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
+
+ def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
+ m, k = x.shape
+ input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
+ input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
+ sgl_kernel.sgl_per_token_group_quant_fp8(
+ x,
+ input_tensor_quant,
+ input_tensor_scale,
+ group_size=128,
+ eps=1e-10,
+ fp8_min=-448.0,
+ fp8_max=448.0,
+ )
+ return input_tensor_quant, input_tensor_scale
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
+ if self.bias_name is not None:
+ destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
+ destination[self.weight_scale_name] = self.pin_weight_scale if hasattr(self, "pin_weight_scale") else self.weight_scale
+ return destination
+
+ def load_state_dict(self, destination, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1)
+ else:
+ weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+ weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)
+
+ if weight_name not in destination:
+ self.weight = None
+ return
+
+ self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
+ self.weight_scale = self.weight_scale_cuda_buffer.copy_(destination[weight_scale_name], non_blocking=True)
+
+ if self.bias_name is not None:
+ bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
+ self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
+ else:
+ self.bias = None
+
+ def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1)
+ else:
+ self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+ self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)
+
+ if self.bias_name is not None:
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
+ else:
+ self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ if self.weight_need_transpose:
+ weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
+ else:
+ weight_tensor = lazy_load_file.get_tensor(self.weight_name)
+
+ self.pin_weight = self.pin_weight.copy_(weight_tensor)
+ del weight_tensor
+
+ weight_scale_tensor = lazy_load_file.get_tensor(self.weight_scale_name)
+ self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
+ del weight_scale_tensor
+
+ if self.bias_name is not None:
+ bias_tensor = lazy_load_file.get_tensor(self.bias_name)
+ self.pin_bias.copy_(bias_tensor)
+ del bias_tensor
+
+
+@MM_WEIGHT_REGISTER("fp8-vllm")
+class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
+ """
+ Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
+
+ Quant MM:
+ Weight: fp8 perchannel sym
+ Act: fp8 perchannel dynamic sym
+ Kernel: vllm
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_fp8_perchannel_sym
+ self.weight_need_transpose = True
+ self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
+
+ def apply(self, input_tensor):
+ shape = (input_tensor.shape[0], self.weight.shape[1])
+ dtype = input_tensor.dtype
+ device = input_tensor.device
+ output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
+
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ torch.ops._C.cutlass_scaled_mm(
+ output_tensor,
+ input_tensor_quant,
+ self.weight,
+ input_tensor_scale,
+ self.weight_scale,
+ self.bias if self.bias is not None else None,
+ )
+ return output_tensor
+
+
+@MM_WEIGHT_REGISTER("int8-vllm")
+class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
+ """
+ Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
+
+ Quant MM:
+ Weight: int8 perchannel sym
+ Act: int8 perchannel dynamic sym
+ Kernel: vllm
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_int8_perchannel_sym
+ self.weight_need_transpose = True
+ self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
+
+ def apply(self, input_tensor):
+ shape = (input_tensor.shape[0], self.weight.shape[1])
+ dtype = input_tensor.dtype
+ device = input_tensor.device
+ output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
+
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ torch.ops._C.cutlass_scaled_mm(
+ output_tensor,
+ input_tensor_quant,
+ self.weight,
+ input_tensor_scale,
+ self.weight_scale,
+ self.bias if self.bias is not None else None,
+ )
+ return output_tensor
+
+
+@MM_WEIGHT_REGISTER("mxfp4")
+class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
+ """
+ Name: W-mxfp4-A-mxfp4-dynamic
+
+ Quant MM:
+ Weight: mxfp4
+ Act: mxfp4
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_mxfp4
+ self.weight_need_transpose = False
+ self.act_quant_func = self.act_quant_mxfp4
+ self.set_alpha()
+
+ def set_alpha(self):
+ self.alpha = torch.tensor(1.0, dtype=torch.float32)
+
+ def apply(self, input_tensor):
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ self.alpha = self.alpha.to(self.weight.device)
+ output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
+ return output_tensor
+
+
+@MM_WEIGHT_REGISTER("mxfp6-mxfp8")
+class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
+ """
+ Name: W-mxfp6-A-nvfp8-dynamic
+
+ Quant MM:
+ Weight: mxfp6
+ Act: mxfp8
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_mxfp6
+ self.weight_need_transpose = False
+ self.act_quant_func = self.act_quant_mxfp8
+ self.set_alpha()
+
+ def set_alpha(self):
+ self.alpha = torch.tensor(1.0, dtype=torch.float32)
+
+ def apply(self, input_tensor):
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ self.alpha = self.alpha.to(self.weight.device)
+ output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
+ return output_tensor
+
+
+@MM_WEIGHT_REGISTER("mxfp8")
+class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
+ """
+ Name: W-mxfp8-A-nvfp8-dynamic
+
+ Quant MM:
+ Weight: mxfp8
+ Act: mxfp8
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_mxfp8
+ self.weight_need_transpose = False
+ self.act_quant_func = self.act_quant_mxfp8
+ self.set_alpha()
+
+ def set_alpha(self):
+ self.alpha = torch.tensor(1.0, dtype=torch.float32)
+
+ def apply(self, input_tensor):
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ self.alpha = self.alpha.to(self.weight.device)
+ output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
+ return output_tensor
+
+
+@MM_WEIGHT_REGISTER("nvfp4")
+class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
+ """
+ Name: W-nvfp4-A-nvfp4-dynamic
+
+ Quant MM:
+ Weight: nvfp4
+ Act: nvfp4
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_nvfp4
+ self.weight_need_transpose = False
+ self.act_quant_func = self.act_quant_nvfp4
+
+ def apply(self, input_tensor):
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
+ return output_tensor
+
+ def to_cuda(self, non_blocking=False):
+ self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
+ if hasattr(self, "pin_weight_scale"):
+ self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
+ self.input_global_scale = self.pin_input_global_scale.to(AI_DEVICE, non_blocking=non_blocking)
+ self.alpha = self.pin_alpha.to(AI_DEVICE, non_blocking=non_blocking)
+ if hasattr(self, "pin_bias") and self.pin_bias is not None:
+ self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
+
+ def to_cpu(self, non_blocking=False):
+ if hasattr(self, "pin_weight"):
+ self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
+ if hasattr(self, "weight_scale_name"):
+ self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
+ self.input_global_scale = self.pin_input_global_scale.copy_(self.input_global_scale, non_blocking=non_blocking).cpu()
+ self.alpha = self.pin_alpha.copy_(self.alpha, non_blocking=non_blocking).cpu()
+ if self.bias is not None:
+ self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
+ else:
+ self.weight = self.weight.to("cpu", non_blocking=non_blocking)
+ if hasattr(self, "weight_scale"):
+ self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
+ self.input_global_scale = self.input_global_scale.to("cpu", non_blocking=non_blocking)
+ self.alpha = self.alpha.to("cpu", non_blocking=non_blocking)
+ if hasattr(self, "bias") and self.bias is not None:
+ self.bias = self.bias.to("cpu", non_blocking=non_blocking)
+
+
+@MM_WEIGHT_REGISTER("Calib")
+class MMCalibNvfp4(MMWeight):
+ """
+ Name: calib
+
+ Calib:
+ absmax: torch.max(torch.abs(input_tensor))
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.running_absmax = None
+ self.count = 0
+ self.decay = 0.9
+
+ def apply(self, input_tensor):
+ shape = (input_tensor.shape[0], self.weight.shape[1])
+ dtype, device = input_tensor.dtype, input_tensor.device
+
+ current_absmax = torch.max(torch.abs(input_tensor)).to("cpu")
+ if self.count % 2 == 0:
+ if self.running_absmax is None:
+ self.running_absmax = current_absmax
+ else:
+ self.running_absmax = self.decay * self.running_absmax + (1 - self.decay) * current_absmax
+ CALIB["absmax"][self.weight_name] = self.running_absmax
+ self.count = self.count + 1
+
+ output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
+ if self.bias is None:
+ return torch.mm(input_tensor, self.weight, out=output_tensor)
+ return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
+
+
+@MM_WEIGHT_REGISTER("fp8-q8f")
+class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
+ """
+ Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
+
+ Quant MM:
+ Weight: fp8 perchannel sym
+ Act: fp8 perchannel dynamic sym
+ Kernel: Q8F
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_fp8_perchannel_sym
+ self.weight_need_transpose = False
+ self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
+ self.bias_force_fp32 = True
+
+ def apply(self, input_tensor):
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ output_tensor = fp8_linear(
+ input_tensor_quant,
+ self.weight,
+ self.bias.float() if self.bias is not None else None,
+ input_tensor_scale,
+ self.weight_scale,
+ out_dtype=self.infer_dtype,
+ )
+ return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
+
+
+@MM_WEIGHT_REGISTER("int8-q8f")
+class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
+ """
+ Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
+
+ Quant MM:
+ Weight: int8 perchannel sym
+ Act: int8 perchannel dynamic sym
+ Kernel: Q8F
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_int8_perchannel_sym
+ self.weight_need_transpose = False
+ self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
+
+ def apply(self, input_tensor):
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ output_tensor = q8_linear(
+ input_tensor_quant,
+ self.weight,
+ self.bias.float() if self.bias is not None else None,
+ input_tensor_scale,
+ self.weight_scale,
+ fuse_gelu=False,
+ out_dtype=self.infer_dtype,
+ )
+ return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
+
+
+@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
+class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
+ """
+ Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl
+
+ Quant MM:
+ Weight: fp8 perblock 128x128 sym
+ Act: fp8 pertoken-pergroup group=128 dynamic sym
+ Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_fp8_perblock128_sym
+ self.weight_need_transpose = False
+ self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl
+
+ def apply(self, input_tensor):
+ shape = (input_tensor.shape[0], self.weight.shape[0])
+ dtype = input_tensor.dtype
+ device = input_tensor.device
+ output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
+
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ deep_gemm.gemm_fp8_fp8_bf16_nt(
+ (input_tensor_quant, input_tensor_scale),
+ (self.weight, self.weight_scale),
+ output_tensor,
+ )
+ if hasattr(self, "bias") and self.bias is not None:
+ output_tensor.add_(self.bias)
+ return output_tensor
+
+
+@MM_WEIGHT_REGISTER("fp8-sgl")
+class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
+ """
+ Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl
+
+ Quant MM:
+ Weight: fp8 perchannel sym
+ Act: fp8 perchannel dynamic sym
+ Kernel: Sgl-kernel
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_fp8_perchannel_sym
+ self.weight_need_transpose = True
+ self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
+
+ def apply(self, input_tensor):
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ output_tensor = sgl_kernel.fp8_scaled_mm(
+ input_tensor_quant,
+ self.weight,
+ input_tensor_scale,
+ self.weight_scale,
+ self.infer_dtype,
+ self.bias if self.bias is not None else None,
+ )
+ return output_tensor
+
+
+@MM_WEIGHT_REGISTER("int8-sgl")
+class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
+ """
+ Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
+
+ Quant MM:
+ Weight: int8 perchannel sym
+ Act: int8 perchannel dynamic sym
+ Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_int8_perchannel_sym
+ self.weight_need_transpose = True
+ self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
+
+ def apply(self, input_tensor):
+ shape = (input_tensor.shape[0], self.weight.shape[1])
+ dtype = input_tensor.dtype
+ device = input_tensor.device
+ output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
+
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ output_tensor = sgl_kernel.int8_scaled_mm(
+ input_tensor_quant,
+ self.weight,
+ input_tensor_scale,
+ self.weight_scale,
+ self.infer_dtype,
+ self.bias if self.bias is not None else None,
+ )
+ return output_tensor
+
+
+@MM_WEIGHT_REGISTER("int8-torchao")
+class MMWeightWint8channelAint8channeldynamicTorchao(MMWeightQuantTemplate):
+ """
+ Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
+
+ Quant MM:
+ Weight: int8 perchannel sym
+ Act: int8 perchannel dynamic sym
+ Kernel: Torchao
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_int8_perchannel_sym
+ self.weight_need_transpose = True
+ self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao
+
+ def apply(self, input_tensor):
+ input_tensor = input_tensor
+ input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
+ output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
+ if self.bias is not None:
+ output_tensor = output_tensor + self.bias
+
+ return output_tensor
+
+
+class MMWeightGGUFTemplate(MMWeightTemplate):
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+
+ def load(self, weight_dict):
+ if not self.lazy_load:
+ assert not self.create_cuda_buffer, "GGUF Unsupported offload block"
+ self.weight = weight_dict[self.weight_name]
+
+ weight_shape = self.weight.shape
+ weight_dtype = self.weight.dtype
+
+ if isinstance(self.weight, GGMLTensor):
+ self.pin_weight = GGMLTensor.empty_pinned(weight_shape, orig_shape=self.weight.orig_shape, dtype=weight_dtype, gguf_type=self.weight.gguf_type)
+ self.pin_weight.copy_from(self.weight)
+ else:
+ self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
+ self.pin_weight.copy_(weight_dict[self.weight_name])
+
+ if self.bias_name is not None:
+ self.bias = weight_dict[self.bias_name]
+ if isinstance(self.bias, GGMLTensor):
+ self.pin_bias = GGMLTensor.empty_pinned(self.bias.shape, orig_shape=self.bias.orig_shape, dtype=self.bias.dtype, gguf_type=self.bias.gguf_type)
+ self.pin_bias.copy_from(self.bias)
+ else:
+ self.pin_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
+ self.pin_bias.copy_(weight_dict[self.bias_name])
+ else:
+ self.bias = None
+
+ def load_state_dict(self, destination, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ else:
+ weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+
+ if weight_name not in destination:
+ self.weight = None
+ return
+
+ self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
+
+ if self.bias_name is not None:
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
+ else:
+ bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
+ self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
+ else:
+ self.bias = None
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
+ if self.bias_name is not None:
+ destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
+
+ return destination
+
+ def get_weight(self, tensor, dtype):
+ if tensor is None:
+ return
+
+ weight = gguf_dequantize_tensor(tensor, dtype)
+ if isinstance(weight, GGMLTensor):
+ weight = torch.Tensor(weight)
+
+ return weight
+
+ def cast_bias_weight(self, input_tensor=None, dtype=None, device=None, bias_dtype=None):
+ if input_tensor is not None:
+ if dtype is None:
+ dtype = getattr(input_tensor, "dtype", torch.float32)
+
+ bias = None
+ if self.bias is not None:
+ bias = self.get_weight(self.bias, dtype)
+
+ weight = self.get_weight(self.weight, dtype)
+ return weight, bias
+
+ def apply(self, input_tensor):
+ weight, bias = self.cast_bias_weight(input_tensor)
+ return torch.nn.functional.linear(input_tensor, weight, bias)
+
+
+@MM_WEIGHT_REGISTER("gguf-BF16")
+class MMWeightGGUFBF16(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.BF16
+
+
+@MM_WEIGHT_REGISTER("gguf-Q8_0")
+class MMWeightGGUFQ80(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q8_0
+
+
+@MM_WEIGHT_REGISTER("gguf-Q6_K")
+class MMWeightGGUFQ6K(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q6_K
+
+
+@MM_WEIGHT_REGISTER("gguf-Q5_K_S")
+class MMWeightGGUFQ5KS(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q6_K
+
+
+@MM_WEIGHT_REGISTER("gguf-Q5_K_M")
+class MMWeightGGUFQ5KM(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q6_K
+
+
+@MM_WEIGHT_REGISTER("gguf-Q5_1")
+class MMWeightGGUFQ51(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q5_1
+
+
+@MM_WEIGHT_REGISTER("gguf-Q5_0")
+class MMWeightGGUFQ50(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q5_0
+
+
+@MM_WEIGHT_REGISTER("gguf-Q4_K_M")
+class MMWeightGGUFQ4KM(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q5_0
+
+
+@MM_WEIGHT_REGISTER("gguf-Q4_K_S")
+class MMWeightGGUFQ4KS(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q4_K
+
+
+@MM_WEIGHT_REGISTER("gguf-Q4_1")
+class MMWeightGGUFQ41(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q4_1
+
+
+@MM_WEIGHT_REGISTER("gguf-Q4_0")
+class MMWeightGGUFQ40(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q4_0
+
+
+@MM_WEIGHT_REGISTER("gguf-Q3_K_M")
+class MMWeightGGUFQ3KM(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q3_K
+
+
+@MM_WEIGHT_REGISTER("gguf-Q3_K_S")
+class MMWeightGGUFQ3KS(MMWeightGGUFTemplate):
+ qtype = gguf.GGMLQuantizationType.Q2_K
+
+
+@MM_WEIGHT_REGISTER("int4-g128-marlin")
+class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
+ """
+ Name: "W-int4-group128-sym-Marlin
+
+ Quant int4 x FP16:
+ Weight: int4 pergroup sym
+ Kernel: Marlin
+ """
+
+ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
+ self.load_func = self.load_quantized
+
+ def load(self, weight_dict):
+ assert not self.lazy_load
+ self.load_func(weight_dict)
+ self.workspace = weight_dict[f"{self.weight_name}_workspace"]
+
+ if self.bias_name is not None:
+ bias_shape = weight_dict[self.bias_name].shape
+ bias_dtype = weight_dict[self.bias_name].dtype
+ self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
+ self.bias.copy_(weight_dict[self.bias_name])
+ else:
+ self.bias = None
+
+ def apply(self, input_tensor):
+ output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
+ marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
+ if hasattr(self, "bias") and self.bias is not None:
+ output_tensor.add_(self.bias)
+ return output_tensor
diff --git a/lightx2v/common/ops/norm/__init__.py b/lightx2v/common/ops/norm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e724797374c92b0ca99f38778067216867174e0
--- /dev/null
+++ b/lightx2v/common/ops/norm/__init__.py
@@ -0,0 +1,2 @@
+from .layer_norm_weight import *
+from .rms_norm_weight import *
diff --git a/lightx2v/common/ops/norm/layer_norm_weight.py b/lightx2v/common/ops/norm/layer_norm_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..38db511751673035761c339d10b151dc555e1305
--- /dev/null
+++ b/lightx2v/common/ops/norm/layer_norm_weight.py
@@ -0,0 +1,220 @@
+import os
+import re
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+
+import torch
+from safetensors import safe_open
+
+from lightx2v.utils.envs import *
+from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+from .triton_ops import norm_infer
+
+
+class LNWeightTemplate(metaclass=ABCMeta):
+ def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
+ self.weight_name = weight_name
+ self.bias_name = bias_name
+ self.eps = eps
+ self.create_cuda_buffer = create_cuda_buffer
+ self.create_cpu_buffer = create_cpu_buffer
+ self.lazy_load = lazy_load
+ self.lazy_load_file = lazy_load_file
+ self.is_post_adapter = is_post_adapter
+ self.config = {}
+ self.infer_dtype = GET_DTYPE()
+ self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
+
+ def load(self, weight_dict):
+ if self.create_cuda_buffer:
+ self._load_cuda_buffers(weight_dict)
+ elif self.create_cpu_buffer:
+ self._load_cpu_pin_buffers()
+ else:
+ self._load_default_tensors(weight_dict)
+
+ def _load_default_tensors(self, weight_dict):
+ if not self.lazy_load and self.weight_name is not None:
+ device = weight_dict[self.weight_name].device
+ if device.type == "cpu":
+ weight_tensor = weight_dict[self.weight_name]
+ self.pin_weight = self._create_cpu_pin_tensor(weight_tensor)
+ bias_tensor = weight_dict[self.bias_name] if self.bias_name is not None else None
+ self.pin_bias = self._create_cpu_pin_tensor(bias_tensor) if bias_tensor is not None else None
+ self.bias = None
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name]
+ self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
+ else:
+ self.weight = None
+ self.bias = None
+
+ def _get_tensor(self, name, weight_dict=None, use_infer_dtype=False):
+ if name is None:
+ return None
+ if self.lazy_load:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{name.split('.')[1]}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ tensor = lazy_load_file.get_tensor(name)
+ if use_infer_dtype:
+ tensor = tensor.to(self.infer_dtype)
+ else:
+ tensor = weight_dict[name]
+ return tensor
+
+ def _create_cpu_pin_tensor(self, tensor):
+ if tensor is None:
+ return None
+ pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
+ pin_tensor.copy_(tensor)
+ del tensor
+ return pin_tensor
+
+ def _load_cuda_buffers(self, weight_dict):
+ weight_tensor = self._get_tensor(self.weight_name, weight_dict, use_infer_dtype=self.lazy_load)
+ if weight_tensor is not None:
+ self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE)
+
+ bias_tensor = self._get_tensor(self.bias_name, weight_dict, use_infer_dtype=self.lazy_load)
+ if bias_tensor is not None:
+ self.bias_cuda_buffer = bias_tensor.to(AI_DEVICE)
+
+ def _load_cpu_pin_buffers(self):
+ weight_tensor = self._get_tensor(self.weight_name, use_infer_dtype=True)
+ if weight_tensor is not None:
+ self.pin_weight = self._create_cpu_pin_tensor(weight_tensor)
+ else:
+ self.weight = None
+
+ bias_tensor = self._get_tensor(self.bias_name, use_infer_dtype=True)
+ if bias_tensor is not None:
+ self.pin_bias = self._create_cpu_pin_tensor(bias_tensor)
+ else:
+ self.bias = None
+ self.pin_bias = None
+
+ @abstractmethod
+ def apply(self, input_tensor):
+ pass
+
+ def set_config(self, config=None):
+ if config is not None:
+ self.config = config
+
+ def to_cuda(self, non_blocking=False):
+ if hasattr(self, "pin_weight") and self.pin_weight is not None:
+ self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
+ else:
+ self.weight = None
+ if hasattr(self, "pin_bias") and self.pin_bias is not None:
+ self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
+ else:
+ self.bias = None
+
+ def to_cpu(self, non_blocking=False):
+ if hasattr(self, "pin_weight") and self.pin_weight is not None:
+ self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
+ if self.bias is not None:
+ self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
+ elif hasattr(self, "weight") and self.weight is not None:
+ self.weight = self.weight.to("cpu", non_blocking=non_blocking)
+ if hasattr(self, "bias") and self.bias is not None:
+ self.bias = self.bias.to("cpu", non_blocking=non_blocking)
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ if self.weight_name is not None:
+ destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
+ if self.bias_name is not None:
+ destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
+ return destination
+
+ def load_state_dict(self, destination, block_index, adapter_block_index=None):
+ if self.weight_name is not None:
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ else:
+ weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+
+ if weight_name not in destination:
+ self.weight = None
+ return
+ self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
+ else:
+ self.weight = None
+
+ if self.bias_name is not None:
+ bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
+ self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
+ else:
+ self.bias = None
+
+ def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
+ if self.weight_name is not None:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
+ if self.is_post_adapter:
+ self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ else:
+ self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
+ self.pin_weight = self.pin_weight.copy_(weight_tensor)
+ del weight_tensor
+
+ if self.bias_name is not None:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
+ else:
+ self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
+
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ bias_tensor = lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
+ self.pin_bias.copy_(bias_tensor)
+ del bias_tensor
+
+
+@LN_WEIGHT_REGISTER("Default")
+class LNWeight(LNWeightTemplate):
+ def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
+
+ def apply(self, input_tensor):
+ if self.sensitive_layer_dtype != self.infer_dtype:
+ input_tensor = torch.nn.functional.layer_norm(
+ input_tensor.float(),
+ (input_tensor.shape[-1],),
+ self.weight,
+ self.bias,
+ self.eps,
+ ).to(self.infer_dtype)
+ else:
+ input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
+
+ return input_tensor
+
+
+@LN_WEIGHT_REGISTER("Triton")
+class LNWeight(LNWeightTemplate):
+ def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
+ super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
+
+ def apply(self, input_tensor):
+ input_tensor = norm_infer(input_tensor, self.weight, self.bias, self.eps)
+ return input_tensor
diff --git a/lightx2v/common/ops/norm/rms_norm_weight.py b/lightx2v/common/ops/norm/rms_norm_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..98d60ef2cba0aab9495e7f690f17d301eb82a81d
--- /dev/null
+++ b/lightx2v/common/ops/norm/rms_norm_weight.py
@@ -0,0 +1,204 @@
+import os
+import re
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+
+import torch
+from safetensors import safe_open
+
+from lightx2v.utils.envs import *
+from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+try:
+ import sgl_kernel
+except ImportError:
+ sgl_kernel = None
+
+
+class RMSWeightTemplate(metaclass=ABCMeta):
+ def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
+ self.weight_name = weight_name
+ self.eps = eps
+ self.create_cuda_buffer = create_cuda_buffer
+ self.create_cpu_buffer = create_cpu_buffer
+ self.lazy_load = lazy_load
+ self.lazy_load_file = lazy_load_file
+ self.is_post_adapter = is_post_adapter
+ self.infer_dtype = GET_DTYPE()
+ self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
+ self.config = {}
+
+ def load(self, weight_dict):
+ if self.create_cuda_buffer:
+ self._load_cuda_buffer(weight_dict)
+ elif self.create_cpu_buffer:
+ self._load_cpu_pin_buffer()
+ else:
+ self._load_default_tensors(weight_dict)
+
+ def _load_default_tensors(self, weight_dict):
+ if not self.lazy_load:
+ device = weight_dict[self.weight_name].device
+ if device.type == "cpu":
+ weight_tensor = weight_dict[self.weight_name]
+ self.pin_weight = self._create_cpu_pin_weight(weight_tensor)
+ del weight_dict[self.weight_name]
+ else:
+ self.weight = weight_dict[self.weight_name]
+
+ def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False):
+ if self.lazy_load:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ tensor = lazy_load_file.get_tensor(self.weight_name)
+ if use_infer_dtype:
+ tensor = tensor.to(self.infer_dtype)
+ else:
+ tensor = weight_dict[self.weight_name]
+ return tensor
+
+ def _create_cpu_pin_weight(self, tensor):
+ pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
+ pin_tensor.copy_(tensor)
+ del tensor
+ return pin_tensor
+
+ def _load_cuda_buffer(self, weight_dict):
+ weight_tensor = self._get_weight_tensor(weight_dict, use_infer_dtype=self.lazy_load)
+ self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE)
+
+ def _load_cpu_pin_buffer(self):
+ weight_tensor = self._get_weight_tensor(use_infer_dtype=True)
+ self.pin_weight = self._create_cpu_pin_weight(weight_tensor)
+
+ @abstractmethod
+ def apply(self, input_tensor):
+ pass
+
+ def set_config(self, config=None):
+ if config is not None:
+ self.config = config
+
+ def to_cuda(self, non_blocking=False):
+ self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
+
+ def to_cpu(self, non_blocking=False):
+ if hasattr(self, "pin_weight"):
+ self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
+ else:
+ self.weight = self.weight.to("cpu", non_blocking=non_blocking)
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
+ return destination
+
+ def load_state_dict(self, destination, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ else:
+ weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+
+ if weight_name not in destination:
+ self.weight = None
+ return
+ self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
+
+ def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
+ else:
+ self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
+ self.pin_weight = self.pin_weight.copy_(weight_tensor)
+ del weight_tensor
+
+
+@RMS_WEIGHT_REGISTER("Default")
+class RMSWeight(RMSWeightTemplate):
+ def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
+ super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+ def apply(self, input_tensor):
+ if GET_SENSITIVE_DTYPE() != GET_DTYPE():
+ input_tensor = self._norm(input_tensor).type_as(input_tensor) * self.weight
+ else:
+ input_tensor = self._norm(input_tensor.float()).type_as(input_tensor) * self.weight
+ return input_tensor
+
+
+@RMS_WEIGHT_REGISTER("sgl-kernel")
+class RMSWeightSgl(RMSWeight):
+ def __init__(
+ self,
+ weight_name,
+ create_cuda_buffer=False,
+ create_cpu_buffer=False,
+ lazy_load=False,
+ lazy_load_file=None,
+ is_post_adapter=False,
+ eps=1e-6,
+ ):
+ super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
+
+ def apply(self, input_tensor):
+ if sgl_kernel is not None and self.sensitive_layer_dtype == self.infer_dtype:
+ input_tensor = input_tensor.contiguous()
+ orig_shape = input_tensor.shape
+ input_tensor = input_tensor.view(-1, orig_shape[-1])
+ input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape)
+ else:
+ # sgl_kernel is not available or dtype!=torch.bfloat16/float16, fallback to default implementation
+ if self.sensitive_layer_dtype != self.infer_dtype:
+ input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).to(self.infer_dtype)
+ input_tensor = (input_tensor * self.weight).to(self.infer_dtype)
+ else:
+ input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
+ input_tensor = input_tensor * self.weight
+
+ return input_tensor
+
+
+@RMS_WEIGHT_REGISTER("fp32_variance")
+class RMSWeightFP32(RMSWeight):
+ def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
+ super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
+
+ def apply(self, input_tensor):
+ input_dtype = input_tensor.dtype
+ variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = input_tensor * torch.rsqrt(variance + self.eps)
+
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ if self.weight is not None:
+ hidden_states = hidden_states * self.weight
+ hidden_states = hidden_states.to(input_dtype)
+
+ return hidden_states
+
+
+@RMS_WEIGHT_REGISTER("self_forcing")
+class RMSWeightSF(RMSWeight):
+ def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
+ super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+ def apply(self, x):
+ return self._norm(x.float()).type_as(x) * self.weight
diff --git a/lightx2v/common/ops/norm/triton_ops.py b/lightx2v/common/ops/norm/triton_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e0ad507bc467600547dd491a012e5ca22fc6cb4
--- /dev/null
+++ b/lightx2v/common/ops/norm/triton_ops.py
@@ -0,0 +1,900 @@
+# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang
+
+# TODO: for temporary usage, expecting a refactor
+from typing import Optional
+
+import torch
+import triton # type: ignore
+import triton.language as tl # type: ignore
+from torch import Tensor
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_N": 64}, num_warps=2),
+ triton.Config({"BLOCK_N": 128}, num_warps=4),
+ triton.Config({"BLOCK_N": 256}, num_warps=4),
+ triton.Config({"BLOCK_N": 512}, num_warps=4),
+ triton.Config({"BLOCK_N": 1024}, num_warps=8),
+ ],
+ key=["inner_dim"],
+)
+@triton.jit
+def _fused_scale_shift_4d_kernel(
+ output_ptr,
+ normalized_ptr,
+ scale_ptr,
+ shift_ptr,
+ rows,
+ inner_dim,
+ seq_len,
+ num_frames,
+ frame_seqlen,
+ BLOCK_N: tl.constexpr,
+):
+ pid_row = tl.program_id(0)
+ pid_col = tl.program_id(1)
+
+ col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask = col_offsets < inner_dim
+
+ # Pointers for normalized and output
+ row_base = pid_row * inner_dim
+ norm_ptrs = normalized_ptr + row_base + col_offsets
+ out_ptrs = output_ptr + row_base + col_offsets
+
+ # Pointers for scale and shift for 4D
+ b_idx = pid_row // seq_len
+ t_idx = pid_row % seq_len
+ frame_idx_in_batch = t_idx // frame_seqlen
+
+ scale_row_idx = b_idx * num_frames + frame_idx_in_batch
+ scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets
+ shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets
+
+ normalized = tl.load(norm_ptrs, mask=mask, other=0.0)
+ scale = tl.load(scale_ptrs, mask=mask, other=0.0)
+ shift = tl.load(shift_ptrs, mask=mask, other=0.0)
+
+ one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype)
+ output = normalized * (one + scale) + shift
+
+ tl.store(out_ptrs, output, mask=mask)
+
+
+@triton.jit
+def fuse_scale_shift_kernel_blc_opt(
+ x_ptr,
+ shift_ptr,
+ scale_ptr,
+ y_ptr,
+ B,
+ L,
+ C,
+ stride_x_b,
+ stride_x_l,
+ stride_x_c,
+ stride_s_b,
+ stride_s_l,
+ stride_s_c,
+ stride_sc_b,
+ stride_sc_l,
+ stride_sc_c,
+ SCALE_IS_SCALAR: tl.constexpr,
+ SHIFT_IS_SCALAR: tl.constexpr,
+ BLOCK_L: tl.constexpr,
+ BLOCK_C: tl.constexpr,
+):
+ pid_l = tl.program_id(0)
+ pid_c = tl.program_id(1)
+ pid_b = tl.program_id(2)
+
+ l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
+ c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
+
+ mask_l = l_offsets < L
+ mask_c = c_offsets < C
+ mask = mask_l[:, None] & mask_c[None, :]
+
+ x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c
+ x = tl.load(x_ptr + x_off, mask=mask, other=0)
+
+ if SHIFT_IS_SCALAR:
+ shift_val = tl.load(shift_ptr)
+ shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype)
+ else:
+ s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c
+ shift = tl.load(shift_ptr + s_off, mask=mask, other=0)
+
+ if SCALE_IS_SCALAR:
+ scale_val = tl.load(scale_ptr)
+ scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype)
+ else:
+ sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c
+ scale = tl.load(scale_ptr + sc_off, mask=mask, other=0)
+
+ y = x * (1 + scale) + shift
+ tl.store(y_ptr + x_off, y, mask=mask)
+
+
+def fuse_scale_shift_kernel(
+ x: torch.Tensor,
+ scale: torch.Tensor,
+ shift: torch.Tensor,
+ block_l: int = 128,
+ block_c: int = 128,
+):
+ assert x.is_cuda and scale.is_cuda
+ assert x.is_contiguous()
+
+ B, L, C = x.shape
+ output = torch.empty_like(x)
+
+ if scale.dim() == 4:
+ # scale/shift: [B, F, 1, C]
+ rows = B * L
+ x_2d = x.view(rows, C)
+ output_2d = output.view(rows, C)
+ grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa
+ num_frames = scale.shape[1]
+ assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift"
+ frame_seqlen = L // num_frames
+
+ # Compact [B, F, C] without the singleton dim into [B*F, C]
+ scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous()
+ shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous()
+
+ _fused_scale_shift_4d_kernel[grid](
+ output_2d,
+ x_2d,
+ scale_reshaped,
+ shift_reshaped,
+ rows,
+ C,
+ L,
+ num_frames,
+ frame_seqlen,
+ )
+ else:
+ # 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L
+ # 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C])
+ # Also support scalar (0D or 1-element)
+ if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1):
+ scale_blc = scale.reshape(1)
+ elif scale.dim() == 2:
+ scale_blc = scale[:, None, :]
+ elif scale.dim() == 3:
+ scale_blc = scale
+ else:
+ raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D")
+
+ if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1):
+ shift_blc = shift.reshape(1)
+ elif shift.dim() == 2:
+ shift_blc = shift[:, None, :]
+ elif shift.dim() == 3:
+ shift_blc = shift
+ else:
+ # broadcast later via expand if possible
+ shift_blc = shift
+
+ need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1
+ need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1
+
+ if not need_scale_scalar:
+ scale_exp = scale_blc.expand(B, L, C)
+ s_sb, s_sl, s_sc = scale_exp.stride()
+ else:
+ s_sb = s_sl = s_sc = 0
+
+ if not need_shift_scalar:
+ shift_exp = shift_blc.expand(B, L, C)
+ sh_sb, sh_sl, sh_sc = shift_exp.stride()
+ else:
+ sh_sb = sh_sl = sh_sc = 0
+
+ # If both scalars and both zero, copy fast-path
+ if need_scale_scalar and need_shift_scalar:
+ if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0):
+ output.copy_(x)
+ return output
+
+ grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B)
+ fuse_scale_shift_kernel_blc_opt[grid](
+ x,
+ shift_blc if need_shift_scalar else shift_exp,
+ scale_blc if need_scale_scalar else scale_exp,
+ output,
+ B,
+ L,
+ C,
+ x.stride(0),
+ x.stride(1),
+ x.stride(2),
+ sh_sb,
+ sh_sl,
+ sh_sc,
+ s_sb,
+ s_sl,
+ s_sc,
+ SCALE_IS_SCALAR=need_scale_scalar,
+ SHIFT_IS_SCALAR=need_shift_scalar,
+ BLOCK_L=block_l,
+ BLOCK_C=block_c,
+ num_warps=4,
+ num_stages=2,
+ )
+ return output
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
+ triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
+ triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
+ triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
+ ],
+ key=["head_size", "interleaved"],
+)
+@triton.jit
+def _rotary_embedding_kernel(
+ output_ptr,
+ x_ptr,
+ cos_ptr,
+ sin_ptr,
+ num_heads,
+ head_size,
+ num_tokens,
+ stride_x_row,
+ stride_cos_row,
+ stride_sin_row,
+ interleaved: tl.constexpr,
+ BLOCK_HS_HALF: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ token_idx = (row_idx // num_heads) % num_tokens
+
+ x_row_ptr = x_ptr + row_idx * stride_x_row
+ cos_row_ptr = cos_ptr + token_idx * stride_cos_row
+ sin_row_ptr = sin_ptr + token_idx * stride_sin_row
+ output_row_ptr = output_ptr + row_idx * stride_x_row
+
+ # half size for x1 and x2
+ head_size_half = head_size // 2
+
+ for block_start in range(0, head_size_half, BLOCK_HS_HALF):
+ offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF)
+ mask = offsets_half < head_size_half
+
+ cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0)
+ sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0)
+
+ offsets_x1 = 2 * offsets_half
+ offsets_x2 = 2 * offsets_half + 1
+
+ x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0)
+ x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0)
+
+ x1_fp32 = x1_vals.to(tl.float32)
+ x2_fp32 = x2_vals.to(tl.float32)
+ cos_fp32 = cos_vals.to(tl.float32)
+ sin_fp32 = sin_vals.to(tl.float32)
+ o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)
+ o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32)
+
+ tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask)
+ tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask)
+
+
+def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
+ output = torch.empty_like(x)
+
+ if x.dim() > 3:
+ bsz, num_tokens, num_heads, head_size = x.shape
+ else:
+ num_tokens, num_heads, head_size = x.shape
+ bsz = 1
+
+ assert head_size % 2 == 0, "head_size must be divisible by 2"
+
+ x_reshaped = x.view(-1, head_size)
+ output_reshaped = output.view(-1, head_size)
+
+ # num_tokens per head, 1 token per block
+ grid = (bsz * num_tokens * num_heads,)
+
+ if interleaved and cos.shape[-1] == head_size:
+ cos = cos[..., ::2].contiguous()
+ sin = sin[..., ::2].contiguous()
+ else:
+ cos = cos.contiguous()
+ sin = sin.contiguous()
+
+ _rotary_embedding_kernel[grid](
+ output_reshaped,
+ x_reshaped,
+ cos,
+ sin,
+ num_heads,
+ head_size,
+ num_tokens,
+ x_reshaped.stride(0),
+ cos.stride(0),
+ sin.stride(0),
+ interleaved,
+ )
+
+ return output
+
+
+# RMSNorm-fp32
+def maybe_contiguous_lastdim(x):
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
+
+
+def maybe_contiguous(x):
+ return x.contiguous() if x is not None else None
+
+
+def triton_autotune_configs():
+ if not torch.cuda.is_available():
+ return []
+ # Return configs with a valid warp count for the current device
+ configs = []
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
+ max_threads_per_block = 1024
+ # Default to warp size 32 if not defined by device
+ warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
+ return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block]
+ # return [triton.Config({}, num_warps=8)]
+
+
+# Copied from flash-attn
+@triton.autotune(
+ configs=triton_autotune_configs(),
+ key=[
+ "N",
+ "HAS_RESIDUAL",
+ "STORE_RESIDUAL_OUT",
+ "IS_RMS_NORM",
+ "HAS_BIAS",
+ "HAS_WEIGHT",
+ "HAS_X1",
+ "HAS_W1",
+ "HAS_B1",
+ ],
+)
+# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
+# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
+# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
+# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
+# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
+# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
+@triton.jit
+def _layer_norm_fwd_1pass_kernel(
+ X, # pointer to the input
+ Y, # pointer to the output
+ W, # pointer to the weights
+ B, # pointer to the biases
+ RESIDUAL, # pointer to the residual
+ X1,
+ W1,
+ B1,
+ Y1,
+ RESIDUAL_OUT, # pointer to the residual
+ ROWSCALE,
+ SEEDS, # Dropout seeds for each row
+ DROPOUT_MASK,
+ DROPOUT_MASK1,
+ Mean, # pointer to the mean
+ Rstd, # pointer to the 1/std
+ stride_x_row, # how much to increase the pointer when moving by 1 row
+ stride_y_row,
+ stride_res_row,
+ stride_res_out_row,
+ stride_x1_row,
+ stride_y1_row,
+ M, # number of rows in X
+ N, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ dropout_p, # Dropout probability
+ zero_centered_weight, # If true, add 1.0 to the weight
+ IS_RMS_NORM: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ HAS_RESIDUAL: tl.constexpr,
+ STORE_RESIDUAL_OUT: tl.constexpr,
+ HAS_WEIGHT: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ HAS_DROPOUT: tl.constexpr,
+ STORE_DROPOUT_MASK: tl.constexpr,
+ HAS_ROWSCALE: tl.constexpr,
+ HAS_X1: tl.constexpr,
+ HAS_W1: tl.constexpr,
+ HAS_B1: tl.constexpr,
+):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ X += row * stride_x_row
+ Y += row * stride_y_row
+ if HAS_RESIDUAL:
+ RESIDUAL += row * stride_res_row
+ if STORE_RESIDUAL_OUT:
+ RESIDUAL_OUT += row * stride_res_out_row
+ if HAS_X1:
+ X1 += row * stride_x1_row
+ if HAS_W1:
+ Y1 += row * stride_y1_row
+ # Compute mean and variance
+ cols = tl.arange(0, BLOCK_N)
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
+ if HAS_ROWSCALE:
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
+ x *= rowscale
+ if HAS_DROPOUT:
+ # Compute dropout mask
+ # 7 rounds is good enough, and reduces register pressure
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
+ if STORE_DROPOUT_MASK:
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
+ if HAS_X1:
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
+ if HAS_ROWSCALE:
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
+ x1 *= rowscale
+ if HAS_DROPOUT:
+ # Compute dropout mask
+ # 7 rounds is good enough, and reduces register pressure
+ keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
+ if STORE_DROPOUT_MASK:
+ tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
+ x += x1
+ if HAS_RESIDUAL:
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
+ x += residual
+ if STORE_RESIDUAL_OUT:
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
+ if not IS_RMS_NORM:
+ mean = tl.sum(x, axis=0) / N
+ tl.store(Mean + row, mean)
+ xbar = tl.where(cols < N, x - mean, 0.0)
+ var = tl.sum(xbar * xbar, axis=0) / N
+ else:
+ xbar = tl.where(cols < N, x, 0.0)
+ var = tl.sum(xbar * xbar, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ tl.store(Rstd + row, rstd)
+ # Normalize and apply linear transformation
+ mask = cols < N
+ if HAS_WEIGHT:
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
+ if zero_centered_weight:
+ w += 1.0
+ if HAS_BIAS:
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
+ if HAS_WEIGHT:
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
+ else:
+ y = x_hat + b if HAS_BIAS else x_hat
+ # Write output
+ tl.store(Y + cols, y, mask=mask)
+ if HAS_W1:
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
+ if zero_centered_weight:
+ w1 += 1.0
+ if HAS_B1:
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
+ tl.store(Y1 + cols, y1, mask=mask)
+
+
+def _layer_norm_fwd(
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ eps: float,
+ residual: Optional[Tensor] = None,
+ x1: Optional[Tensor] = None,
+ weight1: Optional[Tensor] = None,
+ bias1: Optional[Tensor] = None,
+ dropout_p: float = 0.0,
+ rowscale: Optional[Tensor] = None,
+ out_dtype: Optional[torch.dtype] = None,
+ residual_dtype: Optional[torch.dtype] = None,
+ zero_centered_weight: bool = False,
+ is_rms_norm: bool = False,
+ return_dropout_mask: bool = False,
+ out: Optional[Tensor] = None,
+ residual_out: Optional[Tensor] = None,
+) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
+ # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
+ # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
+ # so that _layer_norm_fwd_impl doesn't have to return them.
+ if out is None:
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
+ if residual is not None:
+ residual_dtype = residual.dtype
+ if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None):
+ residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype)
+ else:
+ residual_out = None
+ y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
+ x,
+ weight,
+ bias,
+ eps,
+ out,
+ residual=residual,
+ x1=x1,
+ weight1=weight1,
+ bias1=bias1,
+ dropout_p=dropout_p,
+ rowscale=rowscale,
+ zero_centered_weight=zero_centered_weight,
+ is_rms_norm=is_rms_norm,
+ return_dropout_mask=return_dropout_mask,
+ residual_out=residual_out,
+ )
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
+ if residual_out is None:
+ residual_out = x
+ return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
+
+
+# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
+# since we're returning a tuple of tensors
+def _layer_norm_fwd_impl(
+ x: Tensor,
+ weight: Optional[Tensor],
+ bias: Tensor,
+ eps: float,
+ out: Tensor,
+ residual: Optional[Tensor] = None,
+ x1: Optional[Tensor] = None,
+ weight1: Optional[Tensor] = None,
+ bias1: Optional[Tensor] = None,
+ dropout_p: float = 0.0,
+ rowscale: Optional[Tensor] = None,
+ zero_centered_weight: bool = False,
+ is_rms_norm: bool = False,
+ return_dropout_mask: bool = False,
+ residual_out: Optional[Tensor] = None,
+) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
+ M, N = x.shape
+ assert x.stride(-1) == 1
+ if residual is not None:
+ assert residual.stride(-1) == 1
+ assert residual.shape == (M, N)
+ if weight is not None:
+ assert weight.shape == (N,)
+ assert weight.stride(-1) == 1
+ if bias is not None:
+ assert bias.stride(-1) == 1
+ assert bias.shape == (N,)
+ if x1 is not None:
+ assert x1.shape == x.shape
+ assert rowscale is None
+ assert x1.stride(-1) == 1
+ if weight1 is not None:
+ assert weight1.shape == (N,)
+ assert weight1.stride(-1) == 1
+ if bias1 is not None:
+ assert bias1.shape == (N,)
+ assert bias1.stride(-1) == 1
+ if rowscale is not None:
+ assert rowscale.is_contiguous()
+ assert rowscale.shape == (M,)
+ assert out.shape == x.shape
+ assert out.stride(-1) == 1
+ if residual_out is not None:
+ assert residual_out.shape == x.shape
+ assert residual_out.stride(-1) == 1
+ if weight1 is not None:
+ y1 = torch.empty_like(out)
+ assert y1.stride(-1) == 1
+ else:
+ y1 = None
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
+ if dropout_p > 0.0:
+ seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64)
+ else:
+ seeds = None
+ if return_dropout_mask and dropout_p > 0.0:
+ dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
+ if x1 is not None:
+ dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
+ else:
+ dropout_mask1 = None
+ else:
+ dropout_mask, dropout_mask1 = None, None
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ if N > BLOCK_N:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ with torch.cuda.device(x.device.index):
+ torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
+ x,
+ out,
+ weight if weight is not None else x, # unused when HAS_WEIGHT == False
+ bias,
+ residual,
+ x1,
+ weight1,
+ bias1,
+ y1,
+ residual_out,
+ rowscale,
+ seeds,
+ dropout_mask,
+ dropout_mask1,
+ mean,
+ rstd,
+ x.stride(0),
+ out.stride(0),
+ residual.stride(0) if residual is not None else 0,
+ residual_out.stride(0) if residual_out is not None else 0,
+ x1.stride(0) if x1 is not None else 0,
+ y1.stride(0) if y1 is not None else 0,
+ M,
+ N,
+ eps,
+ dropout_p,
+ # Passing bool make torch inductor very unhappy since it then tries to compare to int_max
+ int(zero_centered_weight),
+ is_rms_norm,
+ BLOCK_N,
+ residual is not None,
+ residual_out is not None,
+ weight is not None,
+ bias is not None,
+ dropout_p > 0.0,
+ dropout_mask is not None,
+ rowscale is not None,
+ HAS_X1=x1 is not None,
+ HAS_W1=weight1 is not None,
+ HAS_B1=bias1 is not None,
+ )
+ return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
+
+
+class LayerNormFn:
+ @staticmethod
+ def forward(
+ x,
+ weight,
+ bias,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ eps=1e-6,
+ dropout_p=0.0,
+ rowscale=None,
+ prenorm=False,
+ residual_in_fp32=False,
+ zero_centered_weight=False,
+ is_rms_norm=False,
+ return_dropout_mask=False,
+ out_dtype=None,
+ out=None,
+ residual_out=None,
+ ):
+ x_shape_og = x.shape
+ # reshape input data into 2D tensor
+ x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
+ if residual is not None:
+ assert residual.shape == x_shape_og
+ residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
+ if x1 is not None:
+ assert x1.shape == x_shape_og
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
+ x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
+ # weight can be None when elementwise_affine=False for LayerNorm
+ if weight is not None:
+ weight = weight.contiguous()
+ bias = maybe_contiguous(bias)
+ weight1 = maybe_contiguous(weight1)
+ bias1 = maybe_contiguous(bias1)
+ if rowscale is not None:
+ rowscale = rowscale.reshape(-1).contiguous()
+ residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None)
+ if out is not None:
+ out = out.reshape(-1, out.shape[-1])
+ if residual_out is not None:
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
+ x,
+ weight,
+ bias,
+ eps,
+ residual,
+ x1,
+ weight1,
+ bias1,
+ dropout_p=dropout_p,
+ rowscale=rowscale,
+ out_dtype=out_dtype,
+ residual_dtype=residual_dtype,
+ zero_centered_weight=zero_centered_weight,
+ is_rms_norm=is_rms_norm,
+ return_dropout_mask=return_dropout_mask,
+ out=out,
+ residual_out=residual_out,
+ )
+ y = y.reshape(x_shape_og)
+ return y
+
+
+def layer_norm_fn(
+ x,
+ weight,
+ bias,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ eps=1e-6,
+ dropout_p=0.0,
+ rowscale=None,
+ prenorm=False,
+ residual_in_fp32=False,
+ zero_centered_weight=False,
+ is_rms_norm=False,
+ return_dropout_mask=False,
+ out_dtype=None,
+ out=None,
+ residual_out=None,
+):
+ return LayerNormFn.forward(
+ x,
+ weight,
+ bias,
+ residual,
+ x1,
+ weight1,
+ bias1,
+ eps,
+ dropout_p,
+ rowscale,
+ prenorm,
+ residual_in_fp32,
+ zero_centered_weight,
+ is_rms_norm,
+ return_dropout_mask,
+ out_dtype,
+ out,
+ residual_out,
+ )
+
+
+@triton.jit
+def _norm_infer_kernel(
+ X,
+ Y,
+ W,
+ B,
+ stride_x_row,
+ stride_y_row,
+ M,
+ N,
+ eps,
+ IS_RMS_NORM: tl.constexpr,
+ HAS_WEIGHT: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ row = tl.program_id(0)
+ X += row * stride_x_row
+ Y += row * stride_y_row
+ if HAS_WEIGHT:
+ W += 0
+ if HAS_BIAS:
+ B += 0
+ cols = tl.arange(0, BLOCK_N)
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
+ if not IS_RMS_NORM:
+ mean = tl.sum(x, axis=0) / N
+ xbar = tl.where(cols < N, x - mean, 0.0)
+ var = tl.sum(xbar * xbar, axis=0) / N
+ else:
+ xbar = tl.where(cols < N, x, 0.0)
+ var = tl.sum(xbar * xbar, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
+ if HAS_WEIGHT:
+ w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)
+ y = x_hat * w
+ else:
+ y = x_hat
+ if HAS_BIAS:
+ b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)
+ y += b
+ tl.store(Y + cols, y, mask=cols < N)
+
+
+def norm_infer(
+ x: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ eps: float,
+ is_rms_norm: bool = False,
+ out: Optional[Tensor] = None,
+):
+ M, N = x.shape
+ assert x.stride(-1) == 1
+ if weight is not None:
+ assert weight.shape == (N,)
+ assert weight.stride(-1) == 1
+ if bias is not None:
+ assert bias.shape == (N,)
+ assert bias.stride(-1) == 1
+ if out is None:
+ out = torch.empty_like(x)
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ if N > BLOCK_N:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
+ _norm_infer_kernel[(M,)](
+ x,
+ out,
+ weight if weight is not None else x, # dummy when HAS_WEIGHT=False
+ bias if bias is not None else x, # dummy when HAS_BIAS=False
+ x.stride(0),
+ out.stride(0),
+ M,
+ N,
+ eps,
+ IS_RMS_NORM=is_rms_norm,
+ HAS_WEIGHT=weight is not None,
+ HAS_BIAS=bias is not None,
+ BLOCK_N=BLOCK_N,
+ num_warps=num_warps,
+ )
+ return out
+
+
+def rms_norm_fn(
+ x,
+ weight,
+ bias,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ eps=1e-6,
+ dropout_p=0.0,
+ rowscale=None,
+ prenorm=False,
+ residual_in_fp32=False,
+ zero_centered_weight=False,
+ return_dropout_mask=False,
+ out_dtype=None,
+ out=None,
+ residual_out=None,
+):
+ return LayerNormFn.forward(
+ x,
+ weight,
+ bias,
+ residual,
+ x1,
+ weight1,
+ bias1,
+ eps,
+ dropout_p,
+ rowscale,
+ prenorm,
+ residual_in_fp32,
+ zero_centered_weight,
+ True,
+ return_dropout_mask,
+ out_dtype,
+ out,
+ residual_out,
+ )
diff --git a/lightx2v/common/ops/tensor/__init__.py b/lightx2v/common/ops/tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa13faea725555a9f71d9a73a5a18bc4d1a97b0d
--- /dev/null
+++ b/lightx2v/common/ops/tensor/__init__.py
@@ -0,0 +1 @@
+from .tensor import DefaultTensor
diff --git a/lightx2v/common/ops/tensor/tensor.py b/lightx2v/common/ops/tensor/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..069d5e3ac9b13360b381b7029d40e0b3c8984d75
--- /dev/null
+++ b/lightx2v/common/ops/tensor/tensor.py
@@ -0,0 +1,110 @@
+import os
+import re
+from pathlib import Path
+
+import torch
+from safetensors import safe_open
+
+from lightx2v.utils.envs import *
+from lightx2v.utils.registry_factory import TENSOR_REGISTER
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+
+@TENSOR_REGISTER("Default")
+class DefaultTensor:
+ def __init__(self, tensor_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
+ self.tensor_name = tensor_name
+ self.lazy_load = lazy_load
+ self.lazy_load_file = lazy_load_file
+ self.is_post_adapter = is_post_adapter
+ self.create_cuda_buffer = create_cuda_buffer
+ self.create_cpu_buffer = create_cpu_buffer
+ self.infer_dtype = GET_DTYPE()
+ self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
+
+ def load(self, weight_dict):
+ if self.create_cuda_buffer:
+ self._load_cuda_buffer(weight_dict)
+ elif self.create_cpu_buffer:
+ self._load_cpu_pin_buffer()
+ else:
+ self._load_default_tensors(weight_dict)
+
+ def _load_default_tensors(self, weight_dict):
+ if not self.lazy_load:
+ device = weight_dict[self.tensor_name].device
+ if device.type == "cpu":
+ tensor = weight_dict[self.tensor_name]
+ self.pin_tensor = self._create_cpu_pin_tensor(tensor)
+ del weight_dict[self.tensor_name]
+ else:
+ self.tensor = weight_dict[self.tensor_name]
+
+ def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
+ if self.lazy_load:
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ tensor = lazy_load_file.get_tensor(self.tensor_name)
+ if use_infer_dtype:
+ tensor = tensor.to(self.infer_dtype)
+ else:
+ tensor = weight_dict[self.tensor_name]
+ return tensor
+
+ def _create_cpu_pin_tensor(self, tensor):
+ pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
+ pin_tensor.copy_(tensor)
+ del tensor
+ return pin_tensor
+
+ def _load_cuda_buffer(self, weight_dict):
+ tensor = self._get_tensor(weight_dict, use_infer_dtype=self.lazy_load)
+ self.tensor_cuda_buffer = tensor.to(AI_DEVICE)
+
+ def _load_cpu_pin_buffer(self):
+ tensor = self._get_tensor(use_infer_dtype=True)
+ self.pin_tensor = self._create_cpu_pin_tensor(tensor)
+
+ def to_cuda(self, non_blocking=False):
+ self.tensor = self.pin_tensor.to(AI_DEVICE, non_blocking=non_blocking)
+
+ def to_cpu(self, non_blocking=False):
+ if hasattr(self, "pin_tensor"):
+ self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
+ else:
+ self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
+
+ def state_dict(self, destination=None):
+ if destination is None:
+ destination = {}
+ destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor
+ return destination
+
+ def load_state_dict(self, destination, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
+ else:
+ tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
+ if tensor_name not in destination:
+ self.tensor = None
+ return
+ self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)
+
+ def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
+ if self.is_post_adapter:
+ assert adapter_block_index is not None
+ self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
+ else:
+ self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
+ if Path(self.lazy_load_file).is_file():
+ lazy_load_file_path = self.lazy_load_file
+ else:
+ lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
+ with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
+ tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
+ self.pin_tensor = self.pin_tensor.copy_(tensor)
+ del tensor
diff --git a/lightx2v/common/transformer_infer/transformer_infer.py b/lightx2v/common/transformer_infer/transformer_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..424f79b00568c9f19ccf532a11fff01b55bd97b1
--- /dev/null
+++ b/lightx2v/common/transformer_infer/transformer_infer.py
@@ -0,0 +1,46 @@
+import math
+from abc import ABC, abstractmethod
+
+
+class BaseTransformerInfer(ABC):
+ @abstractmethod
+ def infer(self):
+ pass
+
+ def set_scheduler(self, scheduler):
+ self.scheduler = scheduler
+ self.scheduler.transformer_infer = self
+
+
+class BaseTaylorCachingTransformerInfer(BaseTransformerInfer):
+ @abstractmethod
+ def infer_calculating(self):
+ pass
+
+ @abstractmethod
+ def infer_using_cache(self):
+ pass
+
+ @abstractmethod
+ def get_taylor_step_diff(self):
+ pass
+
+ # 1. when fully calcualted, stored in cache
+ def derivative_approximation(self, block_cache, module_name, out):
+ if module_name not in block_cache:
+ block_cache[module_name] = {0: out}
+ else:
+ step_diff = self.get_taylor_step_diff()
+
+ previous_out = block_cache[module_name][0]
+ block_cache[module_name][0] = out
+ block_cache[module_name][1] = (out - previous_out) / step_diff
+
+ def taylor_formula(self, tensor_dict):
+ x = self.get_taylor_step_diff()
+
+ output = 0
+ for i in range(len(tensor_dict)):
+ output += (1 / math.factorial(i)) * tensor_dict[i] * (x**i)
+
+ return output
diff --git a/lightx2v/deploy/__init__.py b/lightx2v/deploy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lightx2v/deploy/common/__init__.py b/lightx2v/deploy/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lightx2v/deploy/common/aliyun.py b/lightx2v/deploy/common/aliyun.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aaff6b47d27964d0cc05d2daea0d7c60fa20772
--- /dev/null
+++ b/lightx2v/deploy/common/aliyun.py
@@ -0,0 +1,81 @@
+import asyncio
+import json
+import os
+import sys
+
+from alibabacloud_dypnsapi20170525 import models as dypnsapi_models
+from alibabacloud_dypnsapi20170525.client import Client
+from alibabacloud_tea_openapi import models as openapi_models
+from alibabacloud_tea_util import models as util_models
+from loguru import logger
+
+
+class AlibabaCloudClient:
+ def __init__(self):
+ config = openapi_models.Config(
+ access_key_id=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_ID"),
+ access_key_secret=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_SECRET"),
+ https_proxy=os.getenv("auth_https_proxy", None),
+ )
+ self.client = Client(config)
+ self.runtime = util_models.RuntimeOptions()
+
+ def check_ok(self, res, prefix):
+ logger.info(f"{prefix}: {res}")
+ if not isinstance(res, dict) or "statusCode" not in res or res["statusCode"] != 200:
+ logger.warning(f"{prefix}: error response: {res}")
+ return False
+ if "body" not in res or "Code" not in res["body"] or "Success" not in res["body"]:
+ logger.warning(f"{prefix}: error body: {res}")
+ return False
+ if res["body"]["Code"] != "OK" or res["body"]["Success"] is not True:
+ logger.warning(f"{prefix}: sms error: {res}")
+ return False
+ return True
+
+ async def send_sms(self, phone_number):
+ try:
+ req = dypnsapi_models.SendSmsVerifyCodeRequest(
+ phone_number=phone_number,
+ sign_name="速通互联验证服务",
+ template_code="100001",
+ template_param=json.dumps({"code": "##code##", "min": "5"}),
+ valid_time=300,
+ )
+ res = await self.client.send_sms_verify_code_with_options_async(req, self.runtime)
+ ok = self.check_ok(res.to_map(), "AlibabaCloudClient send sms")
+ logger.info(f"AlibabaCloudClient send sms for {phone_number}: {ok}")
+ return ok
+
+ except Exception as e:
+ logger.warning(f"AlibabaCloudClient send sms for {phone_number}: {e}")
+ return False
+
+ async def check_sms(self, phone_number, verify_code):
+ try:
+ req = dypnsapi_models.CheckSmsVerifyCodeRequest(
+ phone_number=phone_number,
+ verify_code=verify_code,
+ )
+ res = await self.client.check_sms_verify_code_with_options_async(req, self.runtime)
+ ok = self.check_ok(res.to_map(), "AlibabaCloudClient check sms")
+ logger.info(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {ok}")
+ return ok
+
+ except Exception as e:
+ logger.warning(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {e}")
+ return False
+
+
+async def test(args):
+ assert len(args) in [1, 2], "Usage: python aliyun_sms.py [verify_code]"
+ phone_number = args[0]
+ client = AlibabaCloudClient()
+ if len(args) == 1:
+ await client.send_sms(phone_number)
+ else:
+ await client.check_sms(phone_number, args[1])
+
+
+if __name__ == "__main__":
+ asyncio.run(test(sys.argv[1:]))
diff --git a/lightx2v/deploy/common/audio_separator.py b/lightx2v/deploy/common/audio_separator.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f06d6ef8257259b517549185256aaee44b0d25
--- /dev/null
+++ b/lightx2v/deploy/common/audio_separator.py
@@ -0,0 +1,376 @@
+# -*- coding: utf-8 -*-
+"""
+Audio Source Separation Module
+Separates different voice tracks in audio, supports multi-person audio separation
+"""
+
+import base64
+import io
+import os
+import tempfile
+import traceback
+from collections import defaultdict
+from typing import Dict, Optional, Union
+
+import torch
+import torchaudio
+from loguru import logger
+
+# Import pyannote.audio for speaker diarization
+from pyannote.audio import Audio, Pipeline
+
+_origin_torch_load = torch.load
+
+
+def our_torch_load(checkpoint_file, *args, **kwargs):
+ kwargs["weights_only"] = False
+ return _origin_torch_load(checkpoint_file, *args, **kwargs)
+
+
+class AudioSeparator:
+ """
+ Audio separator for separating different voice tracks in audio using pyannote.audio
+ Supports multi-person conversation separation, maintains duration (other speakers' tracks are empty)
+ """
+
+ def __init__(
+ self,
+ model_path: str = None,
+ device: str = None,
+ sample_rate: int = 16000,
+ ):
+ """
+ Initialize audio separator
+
+ Args:
+ model_path: Model path (if using custom model), default uses pyannote/speaker-diarization-community-1
+ device: Device ('cpu', 'cuda', etc.), None for auto selection
+ sample_rate: Target sample rate, default 16000
+ """
+ self.sample_rate = sample_rate
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
+ self._init_pyannote(model_path)
+
+ def _init_pyannote(self, model_path: str = None):
+ """Initialize pyannote.audio pipeline"""
+ try:
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
+ model_name = model_path or "pyannote/speaker-diarization-community-1"
+
+ try:
+ torch.load = our_torch_load
+ # Try loading with token if available
+ if huggingface_token:
+ self.pipeline = Pipeline.from_pretrained(model_name, token=huggingface_token)
+ else:
+ # Try without token (may work for public models)
+ self.pipeline = Pipeline.from_pretrained(model_name)
+ except Exception as e:
+ if "gated" in str(e).lower() or "token" in str(e).lower():
+ raise RuntimeError(f"Model requires authentication. Set HUGGINGFACE_TOKEN or HF_TOKEN environment variable: {e}")
+ raise RuntimeError(f"Failed to load pyannote model: {e}")
+ finally:
+ torch.load = _origin_torch_load
+
+ # Move pipeline to specified device
+ if self.device:
+ self.pipeline.to(torch.device(self.device))
+
+ # Initialize Audio helper for waveform loading
+ self.pyannote_audio = Audio()
+
+ logger.info("Initialized pyannote.audio speaker diarization pipeline")
+ except Exception as e:
+ logger.error(f"Failed to initialize pyannote: {e}")
+ raise RuntimeError(f"Failed to initialize pyannote.audio pipeline: {e}")
+
+ def separate_speakers(
+ self,
+ audio_path: Union[str, bytes],
+ num_speakers: Optional[int] = None,
+ min_speakers: int = 1,
+ max_speakers: int = 5,
+ ) -> Dict:
+ """
+ Separate different speakers in audio
+
+ Args:
+ audio_path: Audio file path or bytes data
+ num_speakers: Specified number of speakers, None for auto detection
+ min_speakers: Minimum number of speakers
+ max_speakers: Maximum number of speakers
+
+ Returns:
+ Dict containing:
+ - speakers: List of speaker audio segments, each containing:
+ - speaker_id: Speaker ID (0, 1, 2, ...)
+ - audio: torch.Tensor audio data [channels, samples]
+ - segments: List of (start_time, end_time) tuples
+ - sample_rate: Sample rate
+ """
+ try:
+ # Load audio
+ if isinstance(audio_path, bytes):
+ # 尝试从字节数据推断音频格式
+ # 检查是否是 WAV 格式(RIFF 头)
+ is_wav = audio_path[:4] == b"RIFF" and audio_path[8:12] == b"WAVE"
+ # 检查是否是 MP3 格式(ID3 或 MPEG 头)
+ is_mp3 = audio_path[:3] == b"ID3" or audio_path[:2] == b"\xff\xfb" or audio_path[:2] == b"\xff\xf3"
+
+ # 根据格式选择后缀
+ if is_wav:
+ suffix = ".wav"
+ elif is_mp3:
+ suffix = ".mp3"
+ else:
+ # 默认尝试 WAV,如果失败会抛出错误
+ suffix = ".wav"
+
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_file:
+ tmp_file.write(audio_path)
+ tmp_audio_path = tmp_file.name
+ try:
+ result = self._separate_speakers_internal(tmp_audio_path, num_speakers, min_speakers, max_speakers)
+ finally:
+ # 确保临时文件被删除
+ try:
+ os.unlink(tmp_audio_path)
+ except Exception as e:
+ logger.warning(f"Failed to delete temp file {tmp_audio_path}: {e}")
+ return result
+ else:
+ return self._separate_speakers_internal(audio_path, num_speakers, min_speakers, max_speakers)
+
+ except Exception as e:
+ logger.error(f"Speaker separation failed: {traceback.format_exc()}")
+ raise RuntimeError(f"Audio separation error: {e}")
+
+ def _separate_speakers_internal(
+ self,
+ audio_path: str,
+ num_speakers: Optional[int] = None,
+ min_speakers: int = 1,
+ max_speakers: int = 5,
+ ) -> Dict:
+ """Internal method: execute speaker separation"""
+
+ # Load audio
+ waveform, original_sr = torchaudio.load(audio_path)
+ if original_sr != self.sample_rate:
+ resampler = torchaudio.transforms.Resample(original_sr, self.sample_rate)
+ waveform = resampler(waveform)
+
+ # Convert to mono if stereo
+ if waveform.shape[0] > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+
+ # Ensure waveform is float32 and normalized (pyannote expects this format)
+ if waveform.dtype != torch.float32:
+ waveform = waveform.float()
+
+ # Ensure waveform is in range [-1, 1] (normalize if needed)
+ if waveform.abs().max() > 1.0:
+ waveform = waveform / waveform.abs().max()
+
+ if self.pipeline is None:
+ raise RuntimeError("Pyannote pipeline not initialized")
+
+ return self._separate_with_pyannote(audio_path, waveform, num_speakers, min_speakers, max_speakers)
+
+ def _separate_with_pyannote(
+ self,
+ audio_path: str,
+ waveform: torch.Tensor,
+ num_speakers: Optional[int],
+ min_speakers: int,
+ max_speakers: int,
+ ) -> Dict:
+ """Use pyannote.audio for speaker diarization"""
+ try:
+ # Use waveform dict to avoid AudioDecoder dependency issues
+ # Pipeline can accept either file path or waveform dict
+ # Using waveform dict is more reliable when torchcodec is not properly installed
+ audio_input = {
+ "waveform": waveform,
+ "sample_rate": self.sample_rate,
+ }
+
+ # Run speaker diarization
+ output = self.pipeline(
+ audio_input,
+ min_speakers=min_speakers if num_speakers is None else num_speakers,
+ max_speakers=max_speakers if num_speakers is None else num_speakers,
+ )
+
+ # Extract audio segments for each speaker
+ speakers_dict = defaultdict(list)
+ for turn, speaker in output.speaker_diarization:
+ print(f"Speaker: {speaker}, Start time: {turn.start}, End time: {turn.end}")
+ start_time = turn.start
+ end_time = turn.end
+ start_sample = int(start_time * self.sample_rate)
+ end_sample = int(end_time * self.sample_rate)
+
+ # Extract audio segment for this time period
+ segment_audio = waveform[:, start_sample:end_sample]
+ speakers_dict[speaker].append((start_time, end_time, segment_audio))
+
+ # Generate complete audio for each speaker (other speakers' segments are empty)
+ speakers = []
+ audio_duration = waveform.shape[1] / self.sample_rate
+ num_samples = waveform.shape[1]
+
+ for speaker_id, segments in speakers_dict.items():
+ # Create zero-filled audio
+ speaker_audio = torch.zeros_like(waveform)
+
+ # Fill in this speaker's segments
+ for start_time, end_time, segment_audio in segments:
+ start_sample = int(start_time * self.sample_rate)
+ end_sample = int(end_time * self.sample_rate)
+ # Ensure no out-of-bounds
+ end_sample = min(end_sample, num_samples)
+ segment_len = end_sample - start_sample
+ if segment_len > 0 and segment_audio.shape[1] > 0:
+ actual_len = min(segment_len, segment_audio.shape[1])
+ speaker_audio[:, start_sample : start_sample + actual_len] = segment_audio[:, :actual_len]
+
+ speakers.append(
+ {
+ "speaker_id": speaker_id,
+ "audio": speaker_audio,
+ "segments": [(s[0], s[1]) for s in segments],
+ "sample_rate": self.sample_rate,
+ }
+ )
+
+ logger.info(f"Separated audio into {len(speakers)} speakers using pyannote")
+ return {"speakers": speakers, "method": "pyannote"}
+
+ except Exception as e:
+ logger.error(f"Pyannote separation failed: {e}")
+ raise RuntimeError(f"Audio separation failed: {e}")
+
+ def save_speaker_audio(self, speaker_audio: torch.Tensor, output_path: str, sample_rate: int = None):
+ """
+ Save speaker audio to file
+
+ Args:
+ speaker_audio: Audio tensor [channels, samples]
+ output_path: Output path
+ sample_rate: Sample rate, if None uses self.sample_rate
+ """
+ sr = sample_rate if sample_rate else self.sample_rate
+ torchaudio.save(output_path, speaker_audio, sr)
+ logger.info(f"Saved speaker audio to {output_path}")
+
+ def speaker_audio_to_base64(self, speaker_audio: torch.Tensor, sample_rate: int = None, format: str = "wav") -> str:
+ """
+ Convert speaker audio tensor to base64 encoded string without saving to file
+
+ Args:
+ speaker_audio: Audio tensor [channels, samples]
+ sample_rate: Sample rate, if None uses self.sample_rate
+ format: Audio format (default: "wav")
+
+ Returns:
+ Base64 encoded audio string
+ """
+ sr = sample_rate if sample_rate else self.sample_rate
+
+ # Use BytesIO to save audio to memory
+ buffer = io.BytesIO()
+ torchaudio.save(buffer, speaker_audio, sr, format=format)
+
+ # Get the audio bytes
+ audio_bytes = buffer.getvalue()
+
+ # Encode to base64
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
+
+ logger.debug(f"Converted speaker audio to base64, size: {len(audio_bytes)} bytes")
+ return audio_base64
+
+ def separate_and_save(
+ self,
+ audio_path: Union[str, bytes],
+ output_dir: str,
+ num_speakers: Optional[int] = None,
+ min_speakers: int = 1,
+ max_speakers: int = 5,
+ ) -> Dict:
+ """
+ Separate audio and save to files
+
+ Args:
+ audio_path: Input audio path or bytes data
+ output_dir: Output directory
+ num_speakers: Specified number of speakers
+ min_speakers: Minimum number of speakers
+ max_speakers: Maximum number of speakers
+
+ Returns:
+ Separation result dictionary, containing output file paths
+ """
+ os.makedirs(output_dir, exist_ok=True)
+
+ result = self.separate_speakers(audio_path, num_speakers, min_speakers, max_speakers)
+
+ output_paths = []
+ for speaker in result["speakers"]:
+ speaker_id = speaker["speaker_id"]
+ output_path = os.path.join(output_dir, f"{speaker_id}.wav")
+ self.save_speaker_audio(speaker["audio"], output_path, speaker["sample_rate"])
+ output_paths.append(output_path)
+ speaker["output_path"] = output_path
+
+ result["output_paths"] = output_paths
+ return result
+
+
+def separate_audio_tracks(
+ audio_path: str,
+ output_dir: str = None,
+ num_speakers: int = None,
+ model_path: str = None,
+) -> Dict:
+ """
+ Convenience function: separate different audio tracks
+
+ Args:
+ audio_path: Audio file path
+ output_dir: Output directory, if None does not save files
+ num_speakers: Number of speakers
+ model_path: Model path (optional)
+
+ Returns:
+ Separation result dictionary
+ """
+ separator = AudioSeparator(model_path=model_path)
+
+ if output_dir:
+ return separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
+ else:
+ return separator.separate_speakers(audio_path, num_speakers=num_speakers)
+
+
+if __name__ == "__main__":
+ # Test code
+ import sys
+
+ if len(sys.argv) < 2:
+ print("Usage: python audio_separator.py [output_dir] [num_speakers]")
+ sys.exit(1)
+
+ audio_path = sys.argv[1]
+ output_dir = sys.argv[2] if len(sys.argv) > 2 else "./separated_audio"
+ num_speakers = int(sys.argv[3]) if len(sys.argv) > 3 else None
+
+ separator = AudioSeparator()
+ result = separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
+
+ print(f"Separated audio into {len(result['speakers'])} speakers:")
+ for speaker in result["speakers"]:
+ print(f" Speaker {speaker['speaker_id']}: {len(speaker['segments'])} segments")
+ if "output_path" in speaker:
+ print(f" Saved to: {speaker['output_path']}")
diff --git a/lightx2v/deploy/common/face_detector.py b/lightx2v/deploy/common/face_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..24240e330979994711df70767cf8c1a828ab0502
--- /dev/null
+++ b/lightx2v/deploy/common/face_detector.py
@@ -0,0 +1,277 @@
+# -*- coding: utf-8 -*-
+"""
+Face Detection Module using YOLO
+Supports detecting faces in images, including human faces, animal faces, anime faces, sketches, etc.
+"""
+
+import io
+import traceback
+from typing import Dict, List, Union
+
+import numpy as np
+from PIL import Image, ImageDraw
+from loguru import logger
+from ultralytics import YOLO
+
+
+class FaceDetector:
+ """
+ Face detection using YOLO models
+ Supports detecting: human faces, animal faces, anime faces, sketch faces, etc.
+ """
+
+ def __init__(self, model_path: str = None, conf_threshold: float = 0.25, device: str = None):
+ """
+ Initialize face detector
+
+ Args:
+ model_path: YOLO model path, if None uses default pretrained model
+ conf_threshold: Confidence threshold, default 0.25
+ device: Device ('cpu', 'cuda', '0', '1', etc.), None for auto selection
+ """
+
+ self.conf_threshold = conf_threshold
+ self.device = device
+
+ if model_path is None:
+ # Use YOLO11 pretrained model, can detect COCO dataset classes (including person)
+ # Or use dedicated face detection model
+ logger.info("Loading default YOLO11n model for face detection")
+ try:
+ self.model = YOLO("yolo11n.pt") # Lightweight model
+ except Exception as e:
+ logger.warning(f"Failed to load default model, trying yolov8n: {e}")
+ self.model = YOLO("yolov8n.pt")
+ else:
+ logger.info(f"Loading YOLO model from {model_path}")
+ self.model = YOLO(model_path)
+
+ # Person class ID in COCO dataset is 0
+ # YOLO can detect person, for more precise face detection, recommend using dedicated face detection models
+ # Such as YOLOv8-face or RetinaFace, can be specified via model_path parameter
+ # First use YOLO to detect person region, then can further detect faces within
+ self.target_classes = {
+ "person": 0, # Face (by detecting person class)
+ # Can be extended to detect animal faces (cat, dog, etc.) and other classes
+ }
+
+ def detect_faces(
+ self,
+ image: Union[str, Image.Image, bytes, np.ndarray],
+ return_image: bool = False,
+ ) -> Dict:
+ """
+ Detect faces in image
+
+ Args:
+ image: Input image, can be path, PIL Image, bytes or numpy array
+ return_image: Whether to return annotated image with detection boxes
+ return_boxes: Whether to return detection box information
+
+ Returns:
+ Dict containing:
+ - faces: List of face detection results, each containing:
+ - bbox: [x1, y1, x2, y2] bounding box coordinates (absolute pixel coordinates)
+ - confidence: Confidence score (0.0-1.0)
+ - class_id: Class ID
+ - class_name: Class name
+ - image (optional): PIL Image with detection boxes drawn (if return_image=True)
+ """
+ try:
+ # Load image
+ if isinstance(image, str):
+ img = Image.open(image).convert("RGB")
+ elif isinstance(image, bytes):
+ img = Image.open(io.BytesIO(image)).convert("RGB")
+ elif isinstance(image, np.ndarray):
+ img = Image.fromarray(image).convert("RGB")
+ elif isinstance(image, Image.Image):
+ img = image.convert("RGB")
+ else:
+ raise ValueError(f"Unsupported image type: {type(image)}")
+
+ # Use YOLO for detection
+ # Note: YOLO by default detects person, we focus on person detection
+ # For more precise face detection, can train or use dedicated face detection models
+ results = self.model.predict(
+ source=img,
+ conf=self.conf_threshold,
+ device=self.device,
+ verbose=False,
+ )
+
+ faces = []
+ annotated_img = img.copy() if return_image else None
+
+ if len(results) > 0:
+ result = results[0]
+ boxes = result.boxes
+
+ if boxes is not None and len(boxes) > 0:
+ for i in range(len(boxes)):
+ # Get bounding box coordinates (xyxy format)
+ bbox = boxes.xyxy[i].cpu().numpy().tolist()
+ confidence = float(boxes.conf[i].cpu().numpy())
+ class_id = int(boxes.cls[i].cpu().numpy())
+
+ # Get class name
+ class_name = result.names.get(class_id, "unknown")
+
+ # Process target classes (person, etc.)
+ # For person, the entire body box contains face region
+ # For more precise face detection, can:
+ # 1. Use dedicated face detection models (RetinaFace, YOLOv8-face)
+ # 2. Further use face detection model within current person box
+ # 3. Use specifically trained multi-class detection models (faces, animal faces, anime faces, etc.)
+ if class_id in self.target_classes.values():
+ face_info = {
+ "bbox": bbox, # [x1, y1, x2, y2] - absolute pixel coordinates
+ "confidence": confidence,
+ "class_id": class_id,
+ "class_name": class_name,
+ }
+ faces.append(face_info)
+
+ # Draw annotations on image if needed
+ if return_image and annotated_img is not None:
+ draw = ImageDraw.Draw(annotated_img)
+ x1, y1, x2, y2 = bbox
+ # Draw bounding box
+ draw.rectangle(
+ [x1, y1, x2, y2],
+ outline="red",
+ width=2,
+ )
+ # Draw label
+ label = f"{class_name} {confidence:.2f}"
+ draw.text((x1, y1 - 15), label, fill="red")
+
+ result_dict = {"faces": faces}
+
+ if return_image and annotated_img is not None:
+ result_dict["image"] = annotated_img
+
+ logger.info(f"Detected {len(faces)} faces in image")
+ return result_dict
+
+ except Exception as e:
+ logger.error(f"Face detection failed: {traceback.format_exc()}")
+ raise RuntimeError(f"Face detection error: {e}")
+
+ def detect_faces_from_bytes(self, image_bytes: bytes, **kwargs) -> Dict:
+ """
+ Detect faces from byte data
+
+ Args:
+ image_bytes: Image byte data
+ **kwargs: Additional parameters passed to detect_faces
+
+ Returns:
+ Detection result dictionary
+ """
+ return self.detect_faces(image_bytes, **kwargs)
+
+ def extract_face_regions(self, image: Union[str, Image.Image, bytes], expand_ratio: float = 0.1) -> List[Image.Image]:
+ """
+ Extract detected face regions
+
+ Args:
+ image: Input image
+ expand_ratio: Bounding box expansion ratio to include more context
+
+ Returns:
+ List of extracted face region images
+ """
+ result = self.detect_faces(image)
+ faces = result["faces"]
+
+ # Load original image
+ if isinstance(image, str):
+ img = Image.open(image).convert("RGB")
+ elif isinstance(image, bytes):
+ img = Image.open(io.BytesIO(image)).convert("RGB")
+ elif isinstance(image, Image.Image):
+ img = image.convert("RGB")
+ else:
+ raise ValueError(f"Unsupported image type: {type(image)}")
+
+ face_regions = []
+ img_width, img_height = img.size
+
+ for face in faces:
+ x1, y1, x2, y2 = face["bbox"]
+
+ # Expand bounding box
+ width = x2 - x1
+ height = y2 - y1
+ expand_x = width * expand_ratio
+ expand_y = height * expand_ratio
+
+ x1 = max(0, int(x1 - expand_x))
+ y1 = max(0, int(y1 - expand_y))
+ x2 = min(img_width, int(x2 + expand_x))
+ y2 = min(img_height, int(y2 + expand_y))
+
+ # Crop region
+ face_region = img.crop((x1, y1, x2, y2))
+ face_regions.append(face_region)
+
+ return face_regions
+
+ def count_faces(self, image: Union[str, Image.Image, bytes]) -> int:
+ """
+ Count number of faces in image
+
+ Args:
+ image: Input image
+
+ Returns:
+ Number of detected faces
+ """
+ result = self.detect_faces(image, return_image=False)
+ return len(result["faces"])
+
+
+def detect_faces_in_image(
+ image_path: str,
+ model_path: str = None,
+ conf_threshold: float = 0.25,
+ return_image: bool = False,
+) -> Dict:
+ """
+ Convenience function: detect faces in image
+
+ Args:
+ image_path: Image path
+ model_path: YOLO model path
+ conf_threshold: Confidence threshold
+ return_image: Whether to return annotated image
+
+ Returns:
+ Detection result dictionary containing:
+ - faces: List of face detection results with bbox coordinates [x1, y1, x2, y2]
+ - image (optional): Annotated image with detection boxes
+ """
+ detector = FaceDetector(model_path=model_path, conf_threshold=conf_threshold)
+ return detector.detect_faces(image_path, return_image=return_image)
+
+
+if __name__ == "__main__":
+ # Test code
+ import sys
+
+ if len(sys.argv) < 2:
+ print("Usage: python face_detector.py ")
+ sys.exit(1)
+
+ image_path = sys.argv[1]
+ detector = FaceDetector()
+ result = detector.detect_faces(image_path, return_image=True)
+
+ print(f"Detected {len(result['faces'])} faces:")
+ for i, face in enumerate(result["faces"]):
+ print(f" Face {i + 1}: {face}")
+
+ output_path = "detected_faces.png"
+ result["image"].save(output_path)
+ print(f"Annotated image saved to: {output_path}")
diff --git a/lightx2v/deploy/common/pipeline.py b/lightx2v/deploy/common/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d4868ddbe8ab719b722a430a5762d7cc8c26b6c
--- /dev/null
+++ b/lightx2v/deploy/common/pipeline.py
@@ -0,0 +1,167 @@
+import json
+import sys
+
+from loguru import logger
+
+
+class Pipeline:
+ def __init__(self, pipeline_json_file):
+ self.pipeline_json_file = pipeline_json_file
+ x = json.load(open(pipeline_json_file))
+ self.data = x["data"]
+ self.meta = x["meta"]
+ self.inputs = {}
+ self.outputs = {}
+ self.temps = {}
+ self.model_lists = []
+ self.types = {}
+ self.queues = set()
+ self.model_name_inner_to_outer = self.meta.get("model_name_inner_to_outer", {})
+ self.model_name_outer_to_inner = self.meta.get("model_name_outer_to_inner", {})
+ self.tidy_pipeline()
+
+ def init_dict(self, base, task, model_cls):
+ if task not in base:
+ base[task] = {}
+ if model_cls not in base[task]:
+ base[task][model_cls] = {}
+
+ # tidy each task item eg, ['t2v', 'wan2.1', 'multi_stage']
+ def tidy_task(self, task, model_cls, stage, v3):
+ out2worker = {}
+ out2num = {}
+ cur_inps = set()
+ cur_temps = set()
+ cur_types = {}
+ for worker_name, worker_item in v3.items():
+ prevs = []
+ for inp in worker_item["inputs"]:
+ cur_types[inp] = self.get_type(inp)
+ if inp in out2worker:
+ prevs.append(out2worker[inp])
+ out2num[inp] -= 1
+ if out2num[inp] <= 0:
+ cur_temps.add(inp)
+ else:
+ cur_inps.add(inp)
+ worker_item["previous"] = prevs
+
+ for out in worker_item["outputs"]:
+ cur_types[out] = self.get_type(out)
+ out2worker[out] = worker_name
+ if out not in out2num:
+ out2num[out] = 0
+ out2num[out] += 1
+
+ if "queue" not in worker_item:
+ worker_item["queue"] = "-".join([task, model_cls, stage, worker_name])
+ self.queues.add(worker_item["queue"])
+
+ cur_outs = [out for out, num in out2num.items() if num > 0]
+ self.inputs[task][model_cls][stage] = list(cur_inps)
+ self.outputs[task][model_cls][stage] = cur_outs
+ self.temps[task][model_cls][stage] = list(cur_temps)
+ self.types[task][model_cls][stage] = cur_types
+
+ # tidy previous dependence workers and queue name
+ def tidy_pipeline(self):
+ for task, v1 in self.data.items():
+ for model_cls, v2 in v1.items():
+ for stage, v3 in v2.items():
+ self.init_dict(self.inputs, task, model_cls)
+ self.init_dict(self.outputs, task, model_cls)
+ self.init_dict(self.temps, task, model_cls)
+ self.init_dict(self.types, task, model_cls)
+ self.tidy_task(task, model_cls, stage, v3)
+ self.model_lists.append({"task": task, "model_cls": model_cls, "stage": stage})
+ logger.info(f"pipelines: {json.dumps(self.data, indent=4)}")
+ logger.info(f"inputs: {self.inputs}")
+ logger.info(f"outputs: {self.outputs}")
+ logger.info(f"temps: {self.temps}")
+ logger.info(f"types: {self.types}")
+ logger.info(f"model_lists: {self.model_lists}")
+ logger.info(f"queues: {self.queues}")
+
+ def get_item_by_keys(self, keys):
+ item = self.data
+ for k in keys:
+ if k not in item:
+ raise Exception(f"{keys} are not in {self.pipeline_json_file}!")
+ item = item[k]
+ return item
+
+ # eg. keys: ['t2v', 'wan2.1', 'multi_stage', 'text_encoder']
+ def get_worker(self, keys):
+ return self.get_item_by_keys(keys)
+
+ # eg. keys: ['t2v', 'wan2.1', 'multi_stage']
+ def get_workers(self, keys):
+ return self.get_item_by_keys(keys)
+
+ # eg. keys: ['t2v', 'wan2.1', 'multi_stage']
+ def get_inputs(self, keys):
+ item = self.inputs
+ for k in keys:
+ if k not in item:
+ raise Exception(f"{keys} are not in inputs!")
+ item = item[k]
+ return item
+
+ # eg. keys: ['t2v', 'wan2.1', 'multi_stage']
+ def get_outputs(self, keys):
+ item = self.outputs
+ for k in keys:
+ if k not in item:
+ raise Exception(f"{keys} are not in outputs!")
+ item = item[k]
+ return item
+
+ # eg. keys: ['t2v', 'wan2.1', 'multi_stage']
+ def get_temps(self, keys):
+ item = self.temps
+ for k in keys:
+ if k not in item:
+ raise Exception(f"{keys} are not in temps!")
+ item = item[k]
+ return item
+
+ # eg. keys: ['t2v', 'wan2.1', 'multi_stage']
+ def get_types(self, keys):
+ item = self.types
+ for k in keys:
+ if k not in item:
+ raise Exception(f"{keys} are not in types!")
+ item = item[k]
+ return item
+
+ def check_item_by_keys(self, keys):
+ item = self.data
+ for k in keys:
+ if k not in item:
+ return False
+ item = item[k]
+ return True
+
+ def get_model_lists(self):
+ return self.model_lists
+
+ def get_type(self, name):
+ return self.meta["special_types"].get(name, "OBJECT")
+
+ def get_monitor_config(self):
+ return self.meta["monitor"]
+
+ def get_queues(self):
+ return self.queues
+
+ def inner_model_name(self, name):
+ return self.model_name_outer_to_inner.get(name, name)
+
+ def outer_model_name(self, name):
+ return self.model_name_inner_to_outer.get(name, name)
+
+
+if __name__ == "__main__":
+ pipeline = Pipeline(sys.argv[1])
+ print(pipeline.get_workers(["t2v", "wan2.1", "multi_stage"]))
+ print(pipeline.get_worker(["i2v", "wan2.1", "multi_stage", "dit"]))
diff --git a/lightx2v/deploy/common/podcasts.py b/lightx2v/deploy/common/podcasts.py
new file mode 100644
index 0000000000000000000000000000000000000000..52d255dc23afa864b90cc496faeb175c4a8be998
--- /dev/null
+++ b/lightx2v/deploy/common/podcasts.py
@@ -0,0 +1,696 @@
+# -*- coding: utf-8 -*-
+
+import asyncio
+import io
+import json
+import os
+import struct
+import uuid
+from dataclasses import dataclass
+from enum import IntEnum
+from typing import Callable, List, Optional
+
+import websockets
+from loguru import logger
+from pydub import AudioSegment
+
+
+# Protocol definitions (from podcasts_protocols)
+class MsgType(IntEnum):
+ """Message type enumeration"""
+
+ Invalid = 0
+ FullClientRequest = 0b1
+ AudioOnlyClient = 0b10
+ FullServerResponse = 0b1001
+ AudioOnlyServer = 0b1011
+ FrontEndResultServer = 0b1100
+ Error = 0b1111
+ ServerACK = AudioOnlyServer
+
+
+class MsgTypeFlagBits(IntEnum):
+ """Message type flag bits"""
+
+ NoSeq = 0
+ PositiveSeq = 0b1
+ LastNoSeq = 0b10
+ NegativeSeq = 0b11
+ WithEvent = 0b100
+
+
+class VersionBits(IntEnum):
+ """Version bits"""
+
+ Version1 = 1
+
+
+class HeaderSizeBits(IntEnum):
+ """Header size bits"""
+
+ HeaderSize4 = 1
+ HeaderSize8 = 2
+ HeaderSize12 = 3
+ HeaderSize16 = 4
+
+
+class SerializationBits(IntEnum):
+ """Serialization method bits"""
+
+ Raw = 0
+ JSON = 0b1
+ Thrift = 0b11
+ Custom = 0b1111
+
+
+class CompressionBits(IntEnum):
+ """Compression method bits"""
+
+ None_ = 0
+ Gzip = 0b1
+ Custom = 0b1111
+
+
+class EventType(IntEnum):
+ """Event type enumeration"""
+
+ None_ = 0
+ StartConnection = 1
+ StartTask = 1
+ FinishConnection = 2
+ FinishTask = 2
+ ConnectionStarted = 50
+ TaskStarted = 50
+ ConnectionFailed = 51
+ TaskFailed = 51
+ ConnectionFinished = 52
+ TaskFinished = 52
+ StartSession = 100
+ CancelSession = 101
+ FinishSession = 102
+ SessionStarted = 150
+ SessionCanceled = 151
+ SessionFinished = 152
+ SessionFailed = 153
+ UsageResponse = 154
+ ChargeData = 154
+ TaskRequest = 200
+ UpdateConfig = 201
+ AudioMuted = 250
+ SayHello = 300
+ TTSSentenceStart = 350
+ TTSSentenceEnd = 351
+ TTSResponse = 352
+ TTSEnded = 359
+ PodcastRoundStart = 360
+ PodcastRoundResponse = 361
+ PodcastRoundEnd = 362
+ PodcastEnd = 363
+
+
+@dataclass
+class Message:
+ """Message object"""
+
+ version: VersionBits = VersionBits.Version1
+ header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4
+ type: MsgType = MsgType.Invalid
+ flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq
+ serialization: SerializationBits = SerializationBits.JSON
+ compression: CompressionBits = CompressionBits.None_
+ event: EventType = EventType.None_
+ session_id: str = ""
+ connect_id: str = ""
+ sequence: int = 0
+ error_code: int = 0
+ payload: bytes = b""
+
+ @classmethod
+ def from_bytes(cls, data: bytes) -> "Message":
+ """Create message object from bytes"""
+ if len(data) < 3:
+ raise ValueError(f"Data too short: expected at least 3 bytes, got {len(data)}")
+ type_and_flag = data[1]
+ msg_type = MsgType(type_and_flag >> 4)
+ flag = MsgTypeFlagBits(type_and_flag & 0b00001111)
+ msg = cls(type=msg_type, flag=flag)
+ msg.unmarshal(data)
+ return msg
+
+ def marshal(self) -> bytes:
+ """Serialize message to bytes"""
+ buffer = io.BytesIO()
+ header = [
+ (self.version << 4) | self.header_size,
+ (self.type << 4) | self.flag,
+ (self.serialization << 4) | self.compression,
+ ]
+ header_size = 4 * self.header_size
+ if padding := header_size - len(header):
+ header.extend([0] * padding)
+ buffer.write(bytes(header))
+ writers = self._get_writers()
+ for writer in writers:
+ writer(buffer)
+ return buffer.getvalue()
+
+ def unmarshal(self, data: bytes) -> None:
+ """Deserialize message from bytes"""
+ buffer = io.BytesIO(data)
+ version_and_header_size = buffer.read(1)[0]
+ self.version = VersionBits(version_and_header_size >> 4)
+ self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111)
+ buffer.read(1)
+ serialization_compression = buffer.read(1)[0]
+ self.serialization = SerializationBits(serialization_compression >> 4)
+ self.compression = CompressionBits(serialization_compression & 0b00001111)
+ header_size = 4 * self.header_size
+ read_size = 3
+ if padding_size := header_size - read_size:
+ buffer.read(padding_size)
+ readers = self._get_readers()
+ for reader in readers:
+ reader(buffer)
+ remaining = buffer.read()
+ if remaining:
+ raise ValueError(f"Unexpected data after message: {remaining}")
+
+ def _get_writers(self) -> List[Callable[[io.BytesIO], None]]:
+ """Get list of writer functions"""
+ writers = []
+ if self.flag == MsgTypeFlagBits.WithEvent:
+ writers.extend([self._write_event, self._write_session_id])
+ if self.type in [MsgType.FullClientRequest, MsgType.FullServerResponse, MsgType.FrontEndResultServer, MsgType.AudioOnlyClient, MsgType.AudioOnlyServer]:
+ if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
+ writers.append(self._write_sequence)
+ elif self.type == MsgType.Error:
+ writers.append(self._write_error_code)
+ else:
+ raise ValueError(f"Unsupported message type: {self.type}")
+ writers.append(self._write_payload)
+ return writers
+
+ def _get_readers(self) -> List[Callable[[io.BytesIO], None]]:
+ """Get list of reader functions"""
+ readers = []
+ if self.type in [MsgType.FullClientRequest, MsgType.FullServerResponse, MsgType.FrontEndResultServer, MsgType.AudioOnlyClient, MsgType.AudioOnlyServer]:
+ if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
+ readers.append(self._read_sequence)
+ elif self.type == MsgType.Error:
+ readers.append(self._read_error_code)
+ if self.flag == MsgTypeFlagBits.WithEvent:
+ readers.extend([self._read_event, self._read_session_id, self._read_connect_id])
+ readers.append(self._read_payload)
+ return readers
+
+ def _write_event(self, buffer: io.BytesIO) -> None:
+ buffer.write(struct.pack(">i", self.event))
+
+ def _write_session_id(self, buffer: io.BytesIO) -> None:
+ if self.event in [EventType.StartConnection, EventType.FinishConnection, EventType.ConnectionStarted, EventType.ConnectionFailed]:
+ return
+ session_id_bytes = self.session_id.encode("utf-8")
+ size = len(session_id_bytes)
+ if size > 0xFFFFFFFF:
+ raise ValueError(f"Session ID size ({size}) exceeds max(uint32)")
+ buffer.write(struct.pack(">I", size))
+ if size > 0:
+ buffer.write(session_id_bytes)
+
+ def _write_sequence(self, buffer: io.BytesIO) -> None:
+ buffer.write(struct.pack(">i", self.sequence))
+
+ def _write_error_code(self, buffer: io.BytesIO) -> None:
+ buffer.write(struct.pack(">I", self.error_code))
+
+ def _write_payload(self, buffer: io.BytesIO) -> None:
+ size = len(self.payload)
+ if size > 0xFFFFFFFF:
+ raise ValueError(f"Payload size ({size}) exceeds max(uint32)")
+ buffer.write(struct.pack(">I", size))
+ buffer.write(self.payload)
+
+ def _read_event(self, buffer: io.BytesIO) -> None:
+ event_bytes = buffer.read(4)
+ if event_bytes:
+ self.event = EventType(struct.unpack(">i", event_bytes)[0])
+
+ def _read_session_id(self, buffer: io.BytesIO) -> None:
+ if self.event in [EventType.StartConnection, EventType.FinishConnection, EventType.ConnectionStarted, EventType.ConnectionFailed, EventType.ConnectionFinished]:
+ return
+ size_bytes = buffer.read(4)
+ if size_bytes:
+ size = struct.unpack(">I", size_bytes)[0]
+ if size > 0:
+ session_id_bytes = buffer.read(size)
+ if len(session_id_bytes) == size:
+ self.session_id = session_id_bytes.decode("utf-8")
+
+ def _read_connect_id(self, buffer: io.BytesIO) -> None:
+ if self.event in [EventType.ConnectionStarted, EventType.ConnectionFailed, EventType.ConnectionFinished]:
+ size_bytes = buffer.read(4)
+ if size_bytes:
+ size = struct.unpack(">I", size_bytes)[0]
+ if size > 0:
+ self.connect_id = buffer.read(size).decode("utf-8")
+
+ def _read_sequence(self, buffer: io.BytesIO) -> None:
+ sequence_bytes = buffer.read(4)
+ if sequence_bytes:
+ self.sequence = struct.unpack(">i", sequence_bytes)[0]
+
+ def _read_error_code(self, buffer: io.BytesIO) -> None:
+ error_code_bytes = buffer.read(4)
+ if error_code_bytes:
+ self.error_code = struct.unpack(">I", error_code_bytes)[0]
+
+ def _read_payload(self, buffer: io.BytesIO) -> None:
+ size_bytes = buffer.read(4)
+ if size_bytes:
+ size = struct.unpack(">I", size_bytes)[0]
+ if size > 0:
+ self.payload = buffer.read(size)
+
+
+async def receive_message(websocket: websockets.WebSocketClientProtocol) -> Message:
+ """Receive message from websocket"""
+ try:
+ data = await websocket.recv()
+ if isinstance(data, str):
+ raise ValueError(f"Unexpected text message: {data}")
+ elif isinstance(data, bytes):
+ msg = Message.from_bytes(data)
+ # logger.debug(f"Received: {msg}")
+ return msg
+ else:
+ raise ValueError(f"Unexpected message type: {type(data)}")
+ except Exception as e:
+ logger.error(f"Failed to receive message: {e}")
+ raise
+
+
+async def wait_for_event(websocket: websockets.WebSocketClientProtocol, msg_type: MsgType, event_type: EventType) -> Message:
+ """Wait for specific event"""
+ while True:
+ msg = await receive_message(websocket)
+ if msg.type != msg_type or msg.event != event_type:
+ raise ValueError(f"Unexpected message: {msg}")
+ if msg.type == msg_type and msg.event == event_type:
+ return msg
+
+
+async def start_connection(websocket: websockets.WebSocketClientProtocol) -> None:
+ """Start connection"""
+ msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
+ msg.event = EventType.StartConnection
+ msg.payload = b"{}"
+ logger.debug(f"Sending: {msg}")
+ await websocket.send(msg.marshal())
+
+
+async def finish_connection(websocket: websockets.WebSocketClientProtocol) -> None:
+ """Finish connection"""
+ msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
+ msg.event = EventType.FinishConnection
+ msg.payload = b"{}"
+ logger.debug(f"Sending: {msg}")
+ await websocket.send(msg.marshal())
+
+
+async def start_session(websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str) -> None:
+ """Start session"""
+ msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
+ msg.event = EventType.StartSession
+ msg.session_id = session_id
+ msg.payload = payload
+ logger.debug(f"Sending: {msg}")
+ await websocket.send(msg.marshal())
+
+
+async def finish_session(websocket: websockets.WebSocketClientProtocol, session_id: str) -> None:
+ """Finish session"""
+ msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
+ msg.event = EventType.FinishSession
+ msg.session_id = session_id
+ msg.payload = b"{}"
+ logger.debug(f"Sending: {msg}")
+ await websocket.send(msg.marshal())
+
+
+class PodcastRoundPostProcessor:
+ def __init__(self, session_id, data_manager):
+ self.session_id = session_id
+ self.data_manager = data_manager
+
+ self.temp_merged_audio_name = "merged_audio.mp3"
+ self.output_merged_audio_name = f"{session_id}-merged_audio.mp3"
+ self.subtitle_timestamps = [] # 记录字幕时间戳
+ self.current_audio_duration = 0.0 # 当前音频时长
+ self.merged_audio = None # 用于存储合并的音频对象
+ self.merged_audio_bytes = None
+
+ async def init(self):
+ if self.data_manager:
+ await self.data_manager.create_podcast_temp_session_dir(self.session_id)
+
+ async def postprocess_round(self, current_round, voice, audio, podcast_texts):
+ text = ""
+ if podcast_texts:
+ text = podcast_texts[-1].get("text", "")
+ logger.debug(f"Processing round: {current_round}, voice: {voice}, text: {text}, audio: {len(audio)} bytes")
+
+ new_segment = AudioSegment.from_mp3(io.BytesIO(bytes(audio)))
+ round_duration = len(new_segment) / 1000.0
+
+ if self.merged_audio is None:
+ self.merged_audio = new_segment
+ else:
+ self.merged_audio = self.merged_audio + new_segment
+
+ # 保存合并后的音频到临时文件(用于前端实时访问)
+ merged_io = io.BytesIO()
+ self.merged_audio.export(merged_io, format="mp3")
+ self.merged_audio_bytes = merged_io.getvalue()
+ if self.data_manager:
+ await self.data_manager.save_podcast_temp_session_file(self.session_id, self.temp_merged_audio_name, self.merged_audio_bytes)
+ merged_file_size = len(self.merged_audio_bytes)
+
+ # 记录字幕时间戳
+ self.subtitle_timestamps.append(
+ {
+ "start": self.current_audio_duration,
+ "end": self.current_audio_duration + round_duration,
+ "text": text,
+ "speaker": voice,
+ }
+ )
+ self.current_audio_duration += round_duration
+ logger.debug(f"Merged audio updated: {merged_file_size} bytes, duration: {self.current_audio_duration:.2f}s")
+
+ return {
+ "url": f"/api/v1/podcast/audio?session_id={self.session_id}&filename={self.temp_merged_audio_name}",
+ "size": merged_file_size,
+ "duration": self.current_audio_duration,
+ "round": current_round,
+ "text": text,
+ "speaker": voice,
+ }
+
+ async def postprocess_final(self):
+ if self.data_manager:
+ await self.data_manager.save_podcast_output_file(self.output_merged_audio_name, self.merged_audio_bytes)
+ return {
+ "subtitles": self.subtitle_timestamps,
+ "audio_name": self.output_merged_audio_name,
+ }
+
+ async def cleanup(self):
+ if self.data_manager:
+ await self.data_manager.clear_podcast_temp_session_dir(self.session_id)
+ self.data_manager = None
+
+
+class VolcEnginePodcastClient:
+ """
+ VolcEngine Podcast客户端
+
+ 支持多种播客类型:
+ - action=0: 文本转播客
+ - action=3: NLP文本转播客
+ - action=4: 提示词生成播客
+ """
+
+ def __init__(self):
+ self.endpoint = "wss://openspeech.bytedance.com/api/v3/sami/podcasttts"
+ self.appid = os.getenv("VOLCENGINE_PODCAST_APPID")
+ self.access_token = os.getenv("VOLCENGINE_PODCAST_ACCESS_TOKEN")
+ self.app_key = "aGjiRDfUWi"
+ self.proxy = os.getenv("HTTPS_PROXY", None)
+ if self.proxy:
+ logger.info(f"volcengine podcast use proxy: {self.proxy}")
+
+ async def podcast_request(
+ self,
+ session_id: str,
+ data_manager=None,
+ text: str = "",
+ input_url: str = "",
+ prompt_text: str = "",
+ nlp_texts: str = "",
+ action: int = 0,
+ resource_id: str = "volc.service_type.10050",
+ encoding: str = "mp3",
+ input_id: str = "test_podcast",
+ speaker_info: str = '{"random_order":false}',
+ use_head_music: bool = False,
+ use_tail_music: bool = False,
+ only_nlp_text: bool = False,
+ return_audio_url: bool = False,
+ skip_round_audio_save: bool = False,
+ on_round_complete: Optional[Callable] = None,
+ ):
+ """
+ 执行播客请求
+
+ Args:
+ text: 输入文本 (action=0时使用)
+ input_url: Web URL或文件URL (action=0时使用)
+ prompt_text: 提示词文本 (action=4时必须)
+ nlp_texts: NLP文本 (action=3时必须)
+ action: 播客类型 (0/3/4)
+ resource_id: 音频资源ID
+ encoding: 音频格式 (mp3/wav)
+ input_id: 唯一输入标识
+ speaker_info: 播客说话人信息
+ use_head_music: 是否使用开头音乐
+ use_tail_music: 是否使用结尾音乐
+ only_nlp_text: 是否只返回播客文本
+ return_audio_url: 是否返回音频URL
+ skip_round_audio_save: 是否跳过单轮音频保存
+ output_dir: 输出目录
+ on_round_complete: 轮次完成回调函数
+ """
+ if not self.appid or not self.access_token:
+ logger.error("APP ID or Access Key is required")
+ return None, None
+
+ headers = {
+ "X-Api-App-Id": self.appid,
+ "X-Api-App-Key": self.app_key,
+ "X-Api-Access-Key": self.access_token,
+ "X-Api-Resource-Id": resource_id,
+ "X-Api-Connect-Id": str(uuid.uuid4()),
+ }
+
+ is_podcast_round_end = True
+ audio_received = False
+ last_round_id = -1
+ task_id = ""
+ websocket = None
+ retry_num = 5
+ audio = bytearray()
+ voice = ""
+ current_round = 0
+ podcast_texts = []
+
+ post_processor = PodcastRoundPostProcessor(session_id, data_manager)
+ await post_processor.init()
+
+ try:
+ while retry_num > 0:
+ # 建立WebSocket连接
+ websocket = await websockets.connect(self.endpoint, additional_headers=headers)
+ logger.debug(f"WebSocket connected: {websocket.response.headers}")
+
+ # 构建请求参数
+ if input_url:
+ req_params = {
+ "input_id": input_id,
+ "nlp_texts": json.loads(nlp_texts) if nlp_texts else None,
+ "prompt_text": prompt_text,
+ "action": action,
+ "use_head_music": use_head_music,
+ "use_tail_music": use_tail_music,
+ "input_info": {
+ "input_url": input_url,
+ "return_audio_url": return_audio_url,
+ "only_nlp_text": only_nlp_text,
+ },
+ "speaker_info": json.loads(speaker_info) if speaker_info else None,
+ "audio_config": {"format": encoding, "sample_rate": 24000, "speech_rate": 0},
+ }
+ else:
+ req_params = {
+ "input_id": input_id,
+ "input_text": text,
+ "nlp_texts": json.loads(nlp_texts) if nlp_texts else None,
+ "prompt_text": prompt_text,
+ "action": action,
+ "use_head_music": use_head_music,
+ "use_tail_music": use_tail_music,
+ "input_info": {
+ "input_url": input_url,
+ "return_audio_url": return_audio_url,
+ "only_nlp_text": only_nlp_text,
+ },
+ "speaker_info": json.loads(speaker_info) if speaker_info else None,
+ "audio_config": {"format": encoding, "sample_rate": 24000, "speech_rate": 0},
+ }
+
+ logger.debug(f"Request params: {json.dumps(req_params, indent=2, ensure_ascii=False)}")
+
+ if not is_podcast_round_end:
+ req_params["retry_info"] = {"retry_task_id": task_id, "last_finished_round_id": last_round_id}
+
+ # Start connection
+ await start_connection(websocket)
+ await wait_for_event(websocket, MsgType.FullServerResponse, EventType.ConnectionStarted)
+
+ session_id = str(uuid.uuid4())
+ if not task_id:
+ task_id = session_id
+
+ # Start session
+ await start_session(websocket, json.dumps(req_params).encode(), session_id)
+ await wait_for_event(websocket, MsgType.FullServerResponse, EventType.SessionStarted)
+
+ # Finish session
+ await finish_session(websocket, session_id)
+
+ while True:
+ msg = await receive_message(websocket)
+
+ # 音频数据块
+ if msg.type == MsgType.AudioOnlyServer and msg.event == EventType.PodcastRoundResponse:
+ if not audio_received and audio:
+ audio_received = True
+ audio.extend(msg.payload)
+
+ # 错误信息
+ elif msg.type == MsgType.Error:
+ raise RuntimeError(f"Server error: {msg.payload.decode()}")
+
+ elif msg.type == MsgType.FullServerResponse:
+ # 播客 round 开始
+ if msg.event == EventType.PodcastRoundStart:
+ data = json.loads(msg.payload.decode())
+ if data.get("text"):
+ filtered_payload = {"text": data.get("text"), "speaker": data.get("speaker")}
+ podcast_texts.append(filtered_payload)
+ voice = data.get("speaker")
+ current_round = data.get("round_id")
+ if current_round == -1:
+ voice = "head_music"
+ if current_round == 9999:
+ voice = "tail_music"
+ is_podcast_round_end = False
+ logger.debug(f"New round started: {data}")
+
+ # 播客 round 结束
+ if msg.event == EventType.PodcastRoundEnd:
+ data = json.loads(msg.payload.decode())
+ logger.debug(f"Podcast round end: {data}")
+ if data.get("is_error"):
+ break
+ is_podcast_round_end = True
+ last_round_id = current_round
+ if audio:
+ round_info = await post_processor.postprocess_round(current_round, voice, audio, podcast_texts)
+ if on_round_complete:
+ await on_round_complete(round_info)
+ audio.clear()
+
+ # 播客结束
+ if msg.event == EventType.PodcastEnd:
+ data = json.loads(msg.payload.decode())
+ logger.info(f"Podcast end: {data}")
+
+ # 会话结束
+ if msg.event == EventType.SessionFinished:
+ break
+
+ if not audio_received and not only_nlp_text:
+ raise RuntimeError("No audio data received")
+
+ # 保持连接
+ await finish_connection(websocket)
+ await wait_for_event(websocket, MsgType.FullServerResponse, EventType.ConnectionFinished)
+
+ # 播客结束, 保存最终音频文件
+ if is_podcast_round_end:
+ podcast_info = await post_processor.postprocess_final()
+ return podcast_info
+ else:
+ logger.error(f"Current podcast not finished, resuming from round {last_round_id}")
+ retry_num -= 1
+ await asyncio.sleep(1)
+ if websocket:
+ await websocket.close()
+
+ finally:
+ await post_processor.cleanup()
+ if websocket:
+ await websocket.close()
+ return None
+
+
+async def test(args):
+ """
+ Podcast测试函数
+
+ Args:
+ args: dict, 包含所有podcast参数
+ """
+ client = VolcEnginePodcastClient()
+
+ # 设置默认参数
+ params = {
+ "text": "",
+ "input_url": "https://zhuanlan.zhihu.com/p/607822576",
+ "prompt_text": "",
+ "nlp_texts": "",
+ "action": 0,
+ "resource_id": "volc.service_type.10050",
+ "encoding": "mp3",
+ "input_id": "test_podcast",
+ "speaker_info": '{"random_order":false}',
+ "use_head_music": False,
+ "use_tail_music": False,
+ "only_nlp_text": False,
+ "return_audio_url": True,
+ "skip_round_audio_save": False,
+ "output_dir": "output",
+ }
+
+ # 覆盖默认参数
+ if args:
+ params.update(args)
+
+ await client.podcast_request(**params)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--text", default="", help="Input text Use when action in [0]")
+ parser.add_argument("--input_url", default="", help="Web url or file url Use when action in [0]")
+ parser.add_argument("--prompt_text", default="", help="Input Prompt Text must not empty when action in [4]")
+ parser.add_argument("--nlp_texts", default="", help="Input NLP Texts must not empty when action in [3]")
+ parser.add_argument("--resource_id", default="volc.service_type.10050", help="Audio Resource ID")
+ parser.add_argument("--encoding", default="mp3", choices=["mp3", "wav"], help="Audio format")
+ parser.add_argument("--input_id", default="test_podcast", help="Unique input identifier")
+ parser.add_argument("--speaker_info", default='{"random_order":false}', help="Podcast Speaker Info")
+ parser.add_argument("--use_head_music", default=False, action="store_true", help="Enable head music")
+ parser.add_argument("--use_tail_music", default=False, action="store_true", help="Enable tail music")
+ parser.add_argument("--only_nlp_text", default=False, action="store_true", help="Enable only podcast text when action in [0, 4]")
+ parser.add_argument("--return_audio_url", default=False, action="store_true", help="Enable return audio url that can download")
+ parser.add_argument("--action", default=0, type=int, choices=[0, 3, 4], help="different podcast type")
+ parser.add_argument("--skip_round_audio_save", default=False, action="store_true", help="skip round audio save")
+ parser.add_argument("--output_dir", default="output", help="Output directory")
+
+ args = parser.parse_args()
+
+ kwargs = {k: v for k, v in vars(args).items() if v is not None and not (isinstance(v, bool) and not v)}
+
+ asyncio.run(test(kwargs))
diff --git a/lightx2v/deploy/common/utils.py b/lightx2v/deploy/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6258cc9bc1f4eb53f47793358fb2a552241b91b6
--- /dev/null
+++ b/lightx2v/deploy/common/utils.py
@@ -0,0 +1,253 @@
+import asyncio
+import base64
+import io
+import os
+import subprocess
+import tempfile
+import time
+import traceback
+from datetime import datetime
+
+import httpx
+import torchaudio
+from PIL import Image
+from loguru import logger
+
+FMT = "%Y-%m-%d %H:%M:%S"
+
+
+def current_time():
+ return datetime.now().timestamp()
+
+
+def time2str(t):
+ d = datetime.fromtimestamp(t)
+ return d.strftime(FMT)
+
+
+def str2time(s):
+ d = datetime.strptime(s, FMT)
+ return d.timestamp()
+
+
+def try_catch(func):
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except Exception:
+ logger.error(f"Error in {func.__name__}:")
+ traceback.print_exc()
+ return None
+
+ return wrapper
+
+
+def class_try_catch(func):
+ def wrapper(self, *args, **kwargs):
+ try:
+ return func(self, *args, **kwargs)
+ except Exception:
+ logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
+ traceback.print_exc()
+ return None
+
+ return wrapper
+
+
+def class_try_catch_async(func):
+ async def wrapper(self, *args, **kwargs):
+ try:
+ return await func(self, *args, **kwargs)
+ except Exception:
+ logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
+ traceback.print_exc()
+ return None
+
+ return wrapper
+
+
+def data_name(x, task_id):
+ if x == "input_image":
+ x = x + ".png"
+ elif x == "input_video":
+ x = x + ".mp4"
+ elif x == "output_video":
+ x = x + ".mp4"
+ return f"{task_id}-{x}"
+
+
+async def fetch_resource(url, timeout):
+ logger.info(f"Begin to download resource from url: {url}")
+ t0 = time.time()
+ async with httpx.AsyncClient() as client:
+ async with client.stream("GET", url, timeout=timeout) as response:
+ response.raise_for_status()
+ ans_bytes = []
+ async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
+ ans_bytes.append(chunk)
+ if len(ans_bytes) > 128:
+ raise Exception(f"url {url} recv data is too big")
+ content = b"".join(ans_bytes)
+ logger.info(f"Download url {url} resource cost time: {time.time() - t0} seconds")
+ return content
+
+
+# check, resize, read rotate meta info
+def format_image_data(data, max_size=1280):
+ image = Image.open(io.BytesIO(data)).convert("RGB")
+ exif = image.getexif()
+ changed = False
+ w, h = image.size
+ assert w > 0 and h > 0, "image is empty"
+ logger.info(f"load image: {w}x{h}, exif: {exif}")
+
+ if w > max_size or h > max_size:
+ ratio = max_size / max(w, h)
+ w = int(w * ratio)
+ h = int(h * ratio)
+ image = image.resize((w, h))
+ logger.info(f"resize image to: {image.size}")
+ changed = True
+
+ orientation_key = 274
+ if orientation_key and orientation_key in exif:
+ orientation = exif[orientation_key]
+ if orientation == 2:
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
+ elif orientation == 3:
+ image = image.rotate(180, expand=True)
+ elif orientation == 4:
+ image = image.transpose(Image.FLIP_TOP_BOTTOM)
+ elif orientation == 5:
+ image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(90, expand=True)
+ elif orientation == 6:
+ image = image.rotate(270, expand=True)
+ elif orientation == 7:
+ image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(270, expand=True)
+ elif orientation == 8:
+ image = image.rotate(90, expand=True)
+
+ # reset orientation to 1
+ if orientation != 1:
+ logger.info(f"reset orientation from {orientation} to 1")
+ exif[orientation_key] = 1
+ changed = True
+
+ if not changed:
+ return data
+ output = io.BytesIO()
+ image.save(output, format=image.format or "JPEG", exif=exif.tobytes())
+ return output.getvalue()
+
+
+def media_to_wav(data):
+ with tempfile.NamedTemporaryFile() as fin:
+ fin.write(data)
+ fin.flush()
+ cmd = ["ffmpeg", "-i", fin.name, "-f", "wav", "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2", "pipe:1"]
+ p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ assert p.returncode == 0, f"media to wav failed: {p.stderr.decode()}"
+ return p.stdout
+
+
+def format_audio_data(data):
+ if len(data) < 4:
+ raise ValueError("Audio file too short")
+ data = media_to_wav(data)
+ waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
+ logger.info(f"load audio: {waveform.size()}, {sample_rate}")
+ assert waveform.numel() > 0, "audio is empty"
+ assert sample_rate > 0, "audio sample rate is not valid"
+ return data
+
+
+async def preload_data(inp, inp_type, typ, val):
+ try:
+ if typ == "url":
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
+ data = await fetch_resource(val, timeout=timeout)
+ elif typ == "base64":
+ # Decode base64 in background thread to avoid blocking event loop
+ data = await asyncio.to_thread(base64.b64decode, val)
+ # For multi-person audio directory, val should be a dict with file structure
+ elif typ == "directory":
+ data = {}
+ for fname, b64_data in val.items():
+ data[fname] = await asyncio.to_thread(base64.b64decode, b64_data)
+ return {"type": "directory", "data": data}
+ elif typ == "stream":
+ # no bytes data need to be saved by data_manager
+ data = None
+ else:
+ raise ValueError(f"cannot read {inp}[{inp_type}] which type is {typ}!")
+
+ # check if valid image bytes
+ if inp_type == "IMAGE":
+ data = await asyncio.to_thread(format_image_data, data)
+ elif inp_type == "AUDIO":
+ if typ != "stream" and typ != "directory":
+ data = await asyncio.to_thread(format_audio_data, data)
+ elif inp_type == "VIDEO":
+ # Video data doesn't need special formatting, just validate it's not empty
+ if len(data) == 0:
+ raise ValueError("Video file is empty")
+ logger.info(f"load video: {len(data)} bytes")
+ else:
+ raise Exception(f"cannot parse inp_type={inp_type} data")
+ return data
+
+ except Exception as e:
+ raise ValueError(f"Failed to read {inp}, type={typ}, val={val[:100]}: {e}!")
+
+
+async def load_inputs(params, raw_inputs, types):
+ inputs_data = {}
+ for inp in raw_inputs:
+ item = params.pop(inp)
+ bytes_data = await preload_data(inp, types[inp], item["type"], item["data"])
+
+ # Handle multi-person audio directory
+ if bytes_data is not None and isinstance(bytes_data, dict) and bytes_data.get("type") == "directory":
+ fs = []
+ for fname, fdata in bytes_data["data"].items():
+ inputs_data[f"{inp}/{fname}"] = fdata
+ fs.append(f"{inp}/{fname}")
+ params["extra_inputs"] = {inp: fs}
+ elif bytes_data is not None:
+ inputs_data[inp] = bytes_data
+ else:
+ params[inp] = item
+ return inputs_data
+
+
+def check_params(params, raw_inputs, raw_outputs, types):
+ stream_audio = os.getenv("STREAM_AUDIO", "0") == "1"
+ stream_video = os.getenv("STREAM_VIDEO", "0") == "1"
+ for x in raw_inputs + raw_outputs:
+ if x in params and "type" in params[x]:
+ if params[x]["type"] == "stream":
+ if types[x] == "AUDIO":
+ assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1"
+ elif types[x] == "VIDEO":
+ assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1"
+ elif params[x]["type"] == "directory":
+ # Multi-person audio directory is only supported for AUDIO type
+ assert types[x] == "AUDIO", f"directory type is only supported for AUDIO input, got {types[x]}"
+
+
+if __name__ == "__main__":
+ # https://github.com/recurser/exif-orientation-examples
+ exif_dir = "/data/nvme0/liuliang1/exif-orientation-examples"
+ out_dir = "/data/nvme0/liuliang1/exif-orientation-examples/outs"
+ os.makedirs(out_dir, exist_ok=True)
+
+ for base_name in ["Landscape", "Portrait"]:
+ for i in range(9):
+ fin_name = os.path.join(exif_dir, f"{base_name}_{i}.jpg")
+ fout_name = os.path.join(out_dir, f"{base_name}_{i}_formatted.jpg")
+ logger.info(f"format image: {fin_name} -> {fout_name}")
+ with open(fin_name, "rb") as f:
+ data = f.read()
+ data = format_image_data(data)
+ with open(fout_name, "wb") as f:
+ f.write(data)
diff --git a/lightx2v/deploy/common/va_controller.py b/lightx2v/deploy/common/va_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..688bce66fa4de2cab9f15121414d53fa66b3c027
--- /dev/null
+++ b/lightx2v/deploy/common/va_controller.py
@@ -0,0 +1,202 @@
+import math
+import os
+
+import torch
+import torch.distributed as dist
+from loguru import logger
+
+from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
+from lightx2v_platform.base.global_var import AI_DEVICE
+
+
+class NextControl:
+ def __init__(self, action: str, data: any = None):
+ # action: switch, data: prev_video tensor
+ # action: wait, data: None
+ # action: fetch, data: None
+ self.action = action
+ self.data = data
+
+
+class VAController:
+ def __init__(self, model_runner):
+ self.reader = None
+ self.recorder = None
+ self.rank = 0
+ self.world_size = 1
+ if dist.is_initialized():
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size
+ self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
+ self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None)
+ self.init_recorder()
+ self.init_reader(model_runner)
+
+ def init_base(self, config, input_info, has_vfi_model, has_vsr_model):
+ self.audio_path = input_info.audio_path
+ self.output_video_path = input_info.save_result_path
+ if isinstance(self.output_video_path, dict):
+ self.output_video_path = self.output_video_path["data"]
+
+ self.audio_sr = config.get("audio_sr", 16000)
+ self.target_fps = config.get("target_fps", 16)
+ self.max_num_frames = config.get("target_video_length", 81)
+ self.prev_frame_length = config.get("prev_frame_length", 5)
+
+ self.record_fps = config.get("target_fps", 16)
+ if "video_frame_interpolation" in config and has_vfi_model:
+ self.record_fps = config["video_frame_interpolation"]["target_fps"]
+ self.record_fps = config.get("record_fps", self.record_fps)
+
+ self.tgt_h = input_info.target_shape[0]
+ self.tgt_w = input_info.target_shape[1]
+ self.record_h, self.record_w = self.tgt_h, self.tgt_w
+ if "video_super_resolution" in config and has_vsr_model:
+ _, _, self.record_w, self.record_h = compute_scaled_and_target_dims(
+ self.record_w,
+ self.record_h,
+ scale=config["video_super_resolution"]["scale"],
+ multiple=128,
+ )
+
+ # how many frames to publish stream as a batch
+ self.slice_frame = config.get("slice_frame", self.prev_frame_length)
+ # estimate the max infer seconds, for immediate switch with local omni
+ slice_interval = self.slice_frame / self.record_fps
+ est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
+ self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval)
+ self.min_stay_queue_num = self.est_infer_end_idx * 2 + 1
+
+ def init_recorder(self):
+ if not self.output_video_path or self.rank != self.target_recorder_rank:
+ return
+ logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
+ whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
+ if whip_shared_path and self.output_video_path.startswith("http"):
+ from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
+
+ self.recorder = X264VARecorder(
+ whip_shared_path=whip_shared_path,
+ livestream_url=self.output_video_path,
+ fps=self.record_fps,
+ sample_rate=self.audio_sr,
+ slice_frame=self.slice_frame,
+ prev_frame=self.prev_frame_length,
+ )
+ else:
+ from lightx2v.deploy.common.va_recorder import VARecorder
+
+ self.recorder = VARecorder(
+ livestream_url=self.output_video_path,
+ fps=self.record_fps,
+ sample_rate=self.audio_sr,
+ slice_frame=self.slice_frame,
+ prev_frame=self.prev_frame_length,
+ )
+
+ def init_reader(self, model_runner=None):
+ if not isinstance(self.audio_path, dict):
+ return
+ assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}"
+ segment_duration = self.max_num_frames / self.target_fps
+ prev_duration = self.prev_frame_length / self.target_fps
+ omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
+ if omni_work_dir:
+ from lightx2v.deploy.common.va_reader_omni import OmniVAReader
+
+ self.reader = OmniVAReader(
+ rank=self.rank,
+ world_size=self.world_size,
+ stream_url=self.audio_path["data"],
+ sample_rate=self.audio_sr,
+ segment_duration=segment_duration,
+ prev_duration=prev_duration,
+ target_rank=self.target_reader_rank,
+ model_runner=model_runner,
+ huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
+ )
+ else:
+ from lightx2v.deploy.common.va_reader import VAReader
+
+ self.reader = VAReader(
+ rank=self.rank,
+ world_size=self.world_size,
+ stream_url=self.audio_path["data"],
+ sample_rate=self.audio_sr,
+ segment_duration=segment_duration,
+ prev_duration=prev_duration,
+ target_rank=self.target_reader_rank,
+ )
+
+ def start(self):
+ self.reader.start()
+ if self.rank == self.target_recorder_rank:
+ assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}"
+ self.recorder.start(self.record_w, self.record_h)
+ if self.world_size > 1:
+ dist.barrier()
+
+ def next_control(self):
+ from lightx2v.deploy.common.va_reader_omni import OmniVAReader
+
+ if isinstance(self.reader, OmniVAReader):
+ return self.omni_reader_next_control()
+ return NextControl(action="fetch")
+
+ def before_control(self):
+ from lightx2v.deploy.common.va_reader_omni import OmniVAReader
+
+ if isinstance(self.reader, OmniVAReader):
+ self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
+ self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
+ self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE)
+
+ def omni_reader_next_control(self):
+ immediate_switch = self.reader.get_immediate_switch()
+ if immediate_switch == 1:
+ # truncate the stream buffer to keep the max infer time length
+ # and broadcast the prev video tensor to all ranks
+ if self.rank == self.target_recorder_rank:
+ logger.warning(f"runner recv immediate switch, truncate stream buffer")
+ video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx)
+ if video_tensor is not None:
+ self.flag_tensor.fill_(1)
+ self.prev_tensor.copy_(video_tensor)
+ else:
+ self.flag_tensor.fill_(0)
+ dist.broadcast(self.flag_tensor, src=self.target_recorder_rank)
+ if self.flag_tensor.item() == 1:
+ dist.broadcast(self.prev_tensor, src=self.target_recorder_rank)
+ return NextControl(action="switch", data=self.prev_tensor)
+ else:
+ # get the length of stream buffer, broadcast to all ranks
+ if self.rank == self.target_recorder_rank:
+ stream_buffer_length = self.recorder.get_buffer_stream_size()
+ self.len_tensor.copy_(stream_buffer_length)
+ dist.broadcast(self.len_tensor, src=self.target_recorder_rank)
+ buffer_length = self.len_tensor.item()
+ # stream buffer is enough, skip infer
+ if buffer_length >= self.min_stay_queue_num:
+ return NextControl(action="wait")
+ return NextControl(action="fetch")
+
+ def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
+ if self.recorder.realtime:
+ self.recorder.buffer_stream(images, audios, gen_video)
+ else:
+ self.recorder.pub_livestream(images, audios)
+
+ def clear(self):
+ self.len_tensor = None
+ self.flag_tensor = None
+ self.prev_tensor = None
+ if self.reader is not None:
+ self.reader.stop()
+ self.reader = None
+ if self.recorder is not None:
+ self.recorder.stop()
+ self.recorder = None
+
+ def __del__(self):
+ self.clear()
diff --git a/lightx2v/deploy/common/va_reader.py b/lightx2v/deploy/common/va_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6579d51f0bc08272be9a5957734a5685222786d1
--- /dev/null
+++ b/lightx2v/deploy/common/va_reader.py
@@ -0,0 +1,274 @@
+import os
+import queue
+import signal
+import subprocess
+import threading
+import time
+import traceback
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from loguru import logger
+
+
+class VAReader:
+ def __init__(
+ self,
+ rank: int,
+ world_size: int,
+ stream_url: str,
+ segment_duration: float = 5.0,
+ sample_rate: int = 16000,
+ audio_channels: int = 1,
+ buffer_size: int = 1,
+ prev_duration: float = 0.3125,
+ target_rank: int = 0,
+ ):
+ self.rank = rank
+ self.world_size = world_size
+ self.stream_url = stream_url
+ self.segment_duration = segment_duration
+ self.sample_rate = sample_rate
+ self.audio_channels = audio_channels
+ self.prev_duration = prev_duration
+ # int16 = 2 bytes
+ self.chunk_size = int(self.segment_duration * self.sample_rate) * 2
+ self.prev_size = int(self.prev_duration * self.sample_rate) * 2
+ self.prev_chunk = None
+ self.buffer_size = buffer_size
+
+ self.audio_queue = queue.Queue(maxsize=self.buffer_size)
+ self.audio_thread = None
+ self.ffmpeg_process = None
+ self.bytes_buffer = bytearray()
+
+ self.target_rank = target_rank % self.world_size
+
+ self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
+ self.audio_tensor = torch.zeros(self.chunk_size, dtype=torch.uint8, device="cuda")
+
+ logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
+ logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
+
+ def start(self):
+ if self.rank == self.target_rank:
+ if self.stream_url.startswith("rtmp://"):
+ self.start_ffmpeg_process_rtmp()
+ elif self.stream_url.startswith("http"):
+ self.start_ffmpeg_process_whep()
+ else:
+ raise Exception(f"Unsupported stream URL: {self.stream_url}")
+ self.audio_thread = threading.Thread(target=self.audio_worker, daemon=True)
+ self.audio_thread.start()
+ logger.info(f"VAReader {self.rank}/{self.world_size} started successfully")
+ else:
+ logger.info(f"VAReader {self.rank}/{self.world_size} wait only")
+ if self.world_size > 1:
+ logger.info(f"VAReader {self.rank}/{self.world_size} wait barrier")
+ dist.barrier()
+ logger.info(f"VAReader {self.rank}/{self.world_size} end barrier")
+
+ def start_ffmpeg_process_rtmp(self):
+ """Start ffmpeg process read audio from stream"""
+ ffmpeg_cmd = [
+ "ffmpeg",
+ "-i",
+ self.stream_url,
+ "-vn",
+ # "-acodec",
+ # "pcm_s16le",
+ "-ar",
+ str(self.sample_rate),
+ "-ac",
+ str(self.audio_channels),
+ "-f",
+ "s16le",
+ "-",
+ ]
+ try:
+ self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0)
+ logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
+ logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
+ except Exception as e:
+ logger.error(f"Failed to start FFmpeg process: {e}")
+ raise
+
+ def start_ffmpeg_process_whep(self):
+ """Start gstream process read audio from stream"""
+ ffmpeg_cmd = [
+ "gst-launch-1.0",
+ "-q",
+ "whepsrc",
+ f"whep-endpoint={self.stream_url}",
+ "video-caps=none",
+ "!rtpopusdepay",
+ "!opusdec",
+ "plc=false",
+ "!audioconvert",
+ "!audioresample",
+ f"!audio/x-raw,format=S16LE,channels={self.audio_channels},rate={self.sample_rate}",
+ "!fdsink",
+ "fd=1",
+ ]
+ try:
+ self.ffmpeg_process = subprocess.Popen(
+ ffmpeg_cmd,
+ stdout=subprocess.PIPE,
+ # stderr=subprocess.PIPE,
+ bufsize=0,
+ )
+ logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
+ logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
+ except Exception as e:
+ logger.error(f"Failed to start FFmpeg process: {e}")
+ raise
+
+ def audio_worker(self):
+ logger.info("Audio pull worker thread started")
+ try:
+ while True:
+ if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None:
+ logger.warning("FFmpeg process exited, audio worker thread stopped")
+ break
+ self.fetch_audio_data()
+ time.sleep(0.01)
+ except: # noqa
+ logger.error(f"Audio pull worker error: {traceback.format_exc()}")
+ finally:
+ logger.warning("Audio pull worker thread stopped")
+
+ def fetch_audio_data(self):
+ """Fetch audio data from ffmpeg process"""
+ try:
+ audio_bytes = self.ffmpeg_process.stdout.read(self.chunk_size)
+ if not audio_bytes:
+ return
+ self.bytes_buffer.extend(audio_bytes)
+ # logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes")
+
+ if len(self.bytes_buffer) >= self.chunk_size:
+ audio_data = self.bytes_buffer[: self.chunk_size]
+ self.bytes_buffer = self.bytes_buffer[self.chunk_size :]
+
+ # first chunk, read original 81 frames
+ # for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames
+ if self.prev_chunk is None:
+ logger.info(f"change chunk_size: from {self.chunk_size} to {self.chunk_size - self.prev_size}")
+ self.chunk_size -= self.prev_size
+ else:
+ audio_data = self.prev_chunk + audio_data
+ self.prev_chunk = audio_data[-self.prev_size :]
+
+ try:
+ self.audio_queue.put_nowait(audio_data)
+ except queue.Full:
+ logger.warning(f"Audio queue full:{self.audio_queue.qsize()}, discarded oldest chunk")
+ self.audio_queue.get_nowait()
+ self.audio_queue.put_nowait(audio_data)
+ logger.info(f"Put audio data: {len(audio_data)} bytes, audio_queue: {self.audio_queue.qsize()}, chunk_size:{self.chunk_size}")
+
+ except: # noqa
+ logger.error(f"Fetch audio data error: {traceback.format_exc()}")
+
+ def braodcast_audio_data(self, audio_data):
+ if self.rank == self.target_rank:
+ if audio_data is None:
+ self.flag_tensor.fill_(0)
+ else:
+ self.flag_tensor.fill_(1)
+ self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
+ logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
+
+ dist.broadcast(self.flag_tensor, src=self.target_rank)
+ if self.flag_tensor.item() == 0:
+ return None
+
+ dist.broadcast(self.audio_tensor, src=self.target_rank)
+ if self.rank != self.target_rank:
+ logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
+ audio_data = self.audio_tensor.cpu().numpy().tobytes()
+ return audio_data
+
+ def bytes_to_ndarray(self, audio_data):
+ if audio_data is None:
+ return None
+ audio_data = np.frombuffer(audio_data, dtype=np.int16)
+ audio_data = audio_data.astype(np.float32) / 32768.0
+ logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
+ return audio_data
+
+ def get_audio_segment(self, timeout: float = 1.0):
+ audio_data = None
+ if self.rank == self.target_rank:
+ try:
+ audio_data = self.audio_queue.get(timeout=timeout)
+ except: # noqa
+ logger.warning(f"Failed to get audio segment: {traceback.format_exc()}")
+ if self.world_size > 1:
+ audio_data = self.braodcast_audio_data(audio_data)
+ audio_data = self.bytes_to_ndarray(audio_data)
+ return audio_data
+
+ def stop(self):
+ # Stop ffmpeg process
+ if self.ffmpeg_process:
+ self.ffmpeg_process.send_signal(signal.SIGINT)
+ try:
+ self.ffmpeg_process.wait(timeout=5)
+ except subprocess.TimeoutExpired:
+ self.ffmpeg_process.kill()
+ logger.warning("FFmpeg reader process stopped")
+
+ # Wait for threads to finish
+ if self.audio_thread and self.audio_thread.is_alive():
+ self.audio_thread.join(timeout=5)
+ if self.audio_thread.is_alive():
+ logger.error("Audio pull thread did not stop gracefully")
+
+ while self.audio_queue and self.audio_queue.qsize() > 0:
+ self.audio_queue.get_nowait()
+ self.audio_queue = None
+ logger.warning("Audio pull queue cleaned")
+
+ def __del__(self):
+ self.stop()
+
+
+if __name__ == "__main__":
+ WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
+ RANK = int(os.environ.get("RANK", 0))
+ if WORLD_SIZE > 1:
+ dist.init_process_group(backend="nccl")
+ torch.cuda.set_device(dist.get_rank())
+ logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
+
+ reader = VAReader(
+ RANK,
+ WORLD_SIZE,
+ # "rtmp://localhost/live/test_audio",
+ "https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000",
+ segment_duration=1.0,
+ sample_rate=16000,
+ audio_channels=1,
+ prev_duration=1 / 16,
+ )
+ reader.start()
+ fail_count = 0
+ max_fail_count = 2
+
+ try:
+ while True:
+ audio_data = reader.get_audio_segment(timeout=2)
+ if audio_data is not None:
+ # logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
+ fail_count = 0
+ else:
+ fail_count += 1
+ if fail_count > max_fail_count:
+ logger.warning("Failed to get audio chunk, stop reader")
+ reader.stop()
+ break
+ time.sleep(0.95)
+ finally:
+ reader.stop()
diff --git a/lightx2v/deploy/common/va_reader_omni.py b/lightx2v/deploy/common/va_reader_omni.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdec87abd0226df993e65d716b3cdf3cae18bcc9
--- /dev/null
+++ b/lightx2v/deploy/common/va_reader_omni.py
@@ -0,0 +1,508 @@
+import datetime
+import json
+import os
+import random
+import subprocess
+import threading
+import time
+import traceback
+from collections import deque
+from copy import deepcopy
+
+import jsonschema
+import numpy as np
+import torch
+import torch.distributed as dist
+import zmq
+from loguru import logger
+
+try:
+ from bson import BSON
+except ImportError:
+ BSON = None
+ logger.warning("BSON is not installed")
+from scipy.signal import resample
+
+
+class AudioInfo:
+ def __init__(self, info: dict):
+ self.sample_count = info["sample_count"]
+ self.sample_rate = info["sample_rate"]
+ self.channel_count = info["channel_count"]
+ self.sample_fmt = info["sample_fmt"]
+ self.pts = info["pts"]
+
+ def is_spec_equal(self, other: "AudioInfo") -> bool:
+ return self.sample_fmt == other.sample_fmt and self.sample_rate == other.sample_rate and self.channel_count == other.channel_count
+
+ def duration(self) -> datetime.timedelta:
+ return datetime.timedelta(seconds=self.sample_count / self.sample_rate)
+
+ def __str__(self):
+ return "AudioInfo(sample_count={}, sample_rate={}, channel_count={}, sample_fmt={}, pts={})".format(self.sample_count, self.sample_rate, self.channel_count, self.sample_fmt, self.pts)
+
+
+class ByteBuffer:
+ def __init__(self):
+ self.buffer = deque()
+ self.current_size = 0
+ # is the audio belonging to current turn finished
+ self.audio_finished = False
+
+ def add(self, byte_data: bytes):
+ self.buffer.append(byte_data)
+ self.current_size += len(byte_data)
+
+ def get(self, size=1024):
+ data = bytearray()
+
+ while size > 0 and len(self.buffer) > 0:
+ chunk = self.buffer.popleft()
+ if len(chunk) <= size:
+ # 如果当前数据小于size,则将当前数据全部添加到data中
+ data.extend(chunk)
+ self.current_size -= len(chunk)
+ size -= len(chunk)
+ else:
+ # 如果当前数据大于size,则将当前数据的一部分添加到data中,剩余部分留在缓冲区
+ data.extend(chunk[:size])
+ self.buffer.appendleft(chunk[size:]) # 剩余部分留在缓冲区
+ self.current_size -= size
+ size = 0
+
+ return bytes(data)
+
+ def mark_finished(self):
+ self.audio_finished = True
+
+ def has_more_voice(self):
+ return not self.audio_finished
+
+ def __len__(self):
+ return self.current_size
+
+
+class ChatAdapter:
+ def __init__(
+ self,
+ omni_work_dir: str,
+ whep_url: str,
+ session_id: str,
+ account: str,
+ config_files: list[str],
+ config_schema_path: str,
+ seg_duration: float,
+ model_runner,
+ huoshan_tts_voice_type,
+ ):
+ assert os.path.exists(omni_work_dir), f"OMNI work directory {omni_work_dir} does not exist"
+ self.omni_work_dir = omni_work_dir
+ self.context = zmq.Context()
+ self.w2f_socket = self.context.socket(zmq.PULL)
+ self.w2f_url = ChatAdapter.select_and_bind(self.w2f_socket)
+ self.f2w_socket = self.context.socket(zmq.PUSH)
+ self.f2w_url = ChatAdapter.select_and_bind(self.f2w_socket)
+ self.recv_thread = None
+ self.audio_buffer = ByteBuffer()
+ self.audio_info = None
+ self.chat_server_cmd = [
+ os.path.join(self.omni_work_dir, "bin", "seko-chatter"),
+ "--session-id",
+ session_id,
+ "--account",
+ account,
+ "--whep-server-url",
+ whep_url,
+ "--w2f-endpoint",
+ self.w2f_url,
+ "--f2w-endpoint",
+ self.f2w_url,
+ "--config-files",
+ *config_files,
+ ]
+ override_config = {}
+ if huoshan_tts_voice_type is not None:
+ logger.info(f"Use Huoshan TTS voice type: {huoshan_tts_voice_type}")
+ override_config["TTS"] = {
+ "default_voice_info": {
+ "voice_type": huoshan_tts_voice_type,
+ "provider": "huoshan_stream_tts",
+ }
+ }
+ with open(config_schema_path, "r") as f:
+ schema = json.load(f)
+ jsonschema.validate(instance=override_config, schema=schema)
+ if override_config is not None:
+ self.chat_server_cmd.extend(["--override-config", json.dumps(override_config)])
+ self.chatter_proc = None
+
+ self.seg_duration = seg_duration
+ self.reset_prev = False
+ self.status = "blank"
+ self.immediate_switch = 0
+ self.model_runner = model_runner
+
+ def launch_chat_server(self):
+ env = {
+ "RUST_LOG": "info,duplex_server=debug,backend_5o=debug",
+ "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", "") + ":" + os.path.join(self.omni_work_dir, "lib/"),
+ "PATH": os.environ["PATH"] + ":" + os.path.join(self.omni_work_dir, "bin/"),
+ }
+ self.chatter_proc = subprocess.Popen(self.chat_server_cmd, env=env, cwd=self.omni_work_dir)
+
+ @staticmethod
+ def select_and_bind(socket: zmq.Socket) -> str:
+ # randomly select a port between 1024 and 6553
+ retry_count = 20
+ err = None
+ while retry_count > 0:
+ try:
+ port = random.randint(1024, 65535)
+ # port = 5555
+ url = f"tcp://localhost:{port}"
+ socket.bind(url)
+ return url
+ except zmq.error.ZMQError as e:
+ retry_count -= 1
+ err = e
+ raise err
+
+ # immediate switch to status, discard prev_bytes, set immediate_switch to 1
+ def immediate_switch_to(self, status):
+ logger.warning(f"VA reader immediate switch to {status}")
+ self.reset_prev = True
+ self.status = status
+ self.immediate_switch = 1
+ if self.model_runner is not None:
+ self.model_runner.pause_signal = True
+ logger.warning(f"Model runner pause signal set to True")
+
+ def recv_loop(self):
+ while True:
+ try:
+ message = self.w2f_socket.recv()
+ except Exception:
+ logger.error(f"Error receiving message: {traceback.format_exc()}")
+ break
+ try:
+ message = BSON.decode(message)
+ msg_type = message["type"]
+ logger.debug("Received message type: {}".format(msg_type))
+ if msg_type == "AgentAudio":
+ audio = message["audio"]
+ if audio["type"] != "Pcm":
+ logger.error("Unsupported audio type: {}".format(audio["type"]))
+ continue
+ pcm_data = audio["data"]
+ audio_info = AudioInfo(audio["info"])
+ logger.debug("Received audio with duration: {}".format(audio_info.duration()))
+ if self.audio_info is None:
+ self.audio_info = audio_info
+ else:
+ # check if the audio info is the same
+ if not self.audio_info.is_spec_equal(audio_info):
+ raise ValueError("Audio info mismatch")
+ self.audio_buffer.add(pcm_data)
+ # if status is blank and has voice, set immediate switch to 1
+ if self.status == "blank" and self.has_voice(self.seg_duration):
+ self.immediate_switch_to("voice")
+ elif msg_type == "AgentStartPlay":
+ logger.debug("Received AgentStartPlay, create new audio buffer")
+ self.audio_buffer = ByteBuffer()
+ elif msg_type == "AgentEndPlay":
+ logger.debug("Received AgentEndPlay, mark audio finished")
+ self.audio_buffer.mark_finished()
+ elif msg_type == "ClearAgentAudio":
+ logger.warning("Received ClearAgentAudio, clear audio buffer")
+ self.audio_buffer = None
+ self.audio_info = None
+ if self.status == "voice":
+ self.status = "blank"
+ # self.immediate_switch_to("blank")
+ except Exception as e:
+ logger.error("Error decoding message: {}, continue".format(e))
+ continue
+ logger.warning("recv loop interrupted")
+
+ def start(self):
+ self.launch_chat_server()
+ self.recv_thread = threading.Thread(target=self.recv_loop)
+ self.recv_thread.start()
+
+ def has_voice(self, duration) -> bool:
+ if self.audio_info is None or self.audio_buffer.current_size == 0:
+ return False
+ bytes_count = round(duration * self.audio_info.sample_rate) * self.audio_info.channel_count * 2 # S16LE assumed
+ # if not has enough bytes and maybe has more voice, return False
+ if self.audio_buffer.current_size < bytes_count and self.audio_buffer.has_more_voice():
+ logger.warning(f"Not enough bytes and maybe has more voice, content_size: {self.audio_buffer.current_size}, bytes_count: {bytes_count}")
+ return False
+ return bytes_count
+
+ def get_audio(self, fetch_duration) -> (bytes, AudioInfo):
+ bytes_count = self.has_voice(fetch_duration)
+ if bytes_count is False:
+ return None
+ pcm_data = self.audio_buffer.get(bytes_count)
+
+ # the actual sample count fetched
+ sample_count = len(pcm_data) // (self.audio_info.channel_count * 2)
+ logger.debug("Fetched {} bytes audio".format(sample_count))
+ logger.debug("After fetch, there are {} bytes left".format(self.audio_buffer.current_size))
+ audio_info = deepcopy(self.audio_info)
+ audio_info.sample_count = sample_count
+ return (pcm_data, audio_info)
+
+ def stop(self):
+ self.model_runner = None
+ if self.chatter_proc is not None:
+ self.chatter_proc.terminate()
+ self.chatter_proc.wait()
+ self.chatter_proc = None
+ self.w2f_socket.close()
+ self.f2w_socket.close()
+
+ def __del__(self):
+ self.stop()
+
+
+class OmniVAReader:
+ def __init__(
+ self,
+ rank: int,
+ world_size: int,
+ stream_url: str,
+ segment_duration: float = 5.0625,
+ sample_rate: int = 16000,
+ audio_channels: int = 1,
+ buffer_size: int = 1,
+ prev_duration: float = 0.3125,
+ target_rank: int = 0,
+ model_runner=None,
+ huoshan_tts_voice_type=None,
+ ):
+ self.rank = rank
+ self.world_size = world_size
+ self.stream_url = stream_url
+ self.segment_duration = segment_duration
+ self.sample_rate = sample_rate
+
+ self.audio_channels = audio_channels
+ self.prev_duration = prev_duration
+ self.all_seg_sample_count = int(self.segment_duration * self.sample_rate)
+ self.prev_seg_sample_count = int(self.prev_duration * self.sample_rate)
+ self.prev_seg_chunk = None
+
+ self.target_rank = target_rank % self.world_size
+ self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
+ self.immediate_switch_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
+ chunk_size = int(self.segment_duration * self.sample_rate) * 2
+ self.audio_tensor = torch.zeros(chunk_size, dtype=torch.uint8, device="cuda")
+ self.chat_adapter = None
+ self.model_runner = model_runner
+ self.huoshan_tts_voice_type = huoshan_tts_voice_type
+
+ assert self.audio_channels == 1, "Only mono audio is supported for OmniVAReader"
+ logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
+ logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
+
+ def init_omni_env(self):
+ self.omni_work_dir = os.getenv("OMNI_WORK_DIR", "/path/of/seko_chatter/")
+ self.session_id = os.getenv("OMNI_SESSION_ID", "")
+ self.account = os.getenv("OMNI_ACCOUNT", "")
+ self.config_files = os.getenv("OMNI_CONFIG_FILES", "").split(",")
+ self.config_schema_path = os.getenv("OMNI_CONFIG_SCHEMA_PATH", None)
+ assert os.path.exists(self.omni_work_dir), f"OMNI work directory {self.omni_work_dir} does not exist"
+ assert self.session_id and self.account, "OMNI_SESSION_ID and OMNI_ACCOUNT are required"
+ logger.info(
+ f"OMNI work directory: {self.omni_work_dir}, session_id: {self.session_id}, account: {self.account}, config_files: {self.config_files}, config_schema_path: {self.config_schema_path}"
+ )
+
+ def start(self):
+ if self.rank == self.target_rank:
+ self.init_omni_env()
+ assert self.stream_url.startswith("http"), "Only HTTP stream is supported for OmniVAReader"
+ self.chat_adapter = ChatAdapter(
+ omni_work_dir=self.omni_work_dir,
+ whep_url=self.stream_url,
+ session_id=self.session_id,
+ account=self.account,
+ config_files=self.config_files,
+ config_schema_path=self.config_schema_path,
+ seg_duration=self.segment_duration,
+ model_runner=self.model_runner,
+ huoshan_tts_voice_type=self.huoshan_tts_voice_type,
+ )
+ self.chat_adapter.start()
+ logger.info(f"OmniVAReader {self.rank}/{self.world_size} started successfully")
+ else:
+ logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait only")
+ if self.world_size > 1:
+ logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait barrier")
+ dist.barrier()
+ logger.info(f"OmniVAReader {self.rank}/{self.world_size} end barrier")
+
+ def braodcast_audio_data(self, audio_data):
+ if self.rank == self.target_rank:
+ if audio_data is None:
+ self.flag_tensor.fill_(0)
+ else:
+ self.flag_tensor.fill_(1)
+ self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
+ # logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
+
+ dist.broadcast(self.flag_tensor, src=self.target_rank)
+ if self.flag_tensor.item() == 0:
+ return None
+
+ dist.broadcast(self.audio_tensor, src=self.target_rank)
+ if self.rank != self.target_rank:
+ # logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
+ audio_data = self.audio_tensor.cpu().numpy().tobytes()
+ return audio_data
+
+ def bytes_to_ndarray(self, audio_data):
+ if audio_data is None:
+ return None
+ audio_data = np.frombuffer(audio_data, dtype=np.int16)
+ audio_data = audio_data.astype(np.float32) / 32768.0
+ # logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
+ return audio_data
+
+ def convert_pcm_s16le_to_mono_resampled(self, audio_data, audio_info):
+ audio = np.frombuffer(audio_data, dtype=np.int16)
+ sample_count = audio_info.sample_count
+ assert len(audio) == sample_count * audio_info.channel_count, f"audio length {len(audio)} != sample_count * channel_count {sample_count * audio_info.channel_count}"
+ # convert to mono
+ if audio_info.channel_count > 1:
+ audio = audio.reshape(-1, audio_info.channel_count).mean(axis=1)
+
+ # logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()}")
+ if audio_info.sample_rate != self.sample_rate:
+ sample_count = int(len(audio) * self.sample_rate / audio_info.sample_rate)
+ audio = resample(audio, sample_count).astype(np.int16)
+ # logger.info(f"resampled audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
+ logger.warning(f"valid audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
+ return audio, sample_count
+
+ def prepare_audio_data(self, chat_audio_result):
+ sample_count = 0
+ audio = np.array([], dtype=np.int16)
+
+ # convert chat audio result to mono and target sample rate
+ if chat_audio_result is not None:
+ audio_data, audio_info = chat_audio_result
+ audio, sample_count = self.convert_pcm_s16le_to_mono_resampled(audio_data, audio_info)
+
+ # if is not the first segment, concat with previous segment
+ if self.prev_seg_chunk is not None:
+ audio = np.concatenate([self.prev_seg_chunk, audio])
+ sample_count = len(audio)
+ assert sample_count <= self.all_seg_sample_count, f"audio length {sample_count} > all_seg_sample_count {self.all_seg_sample_count}"
+
+ # pad 0 to the audio to make it the same length as all_seg_sample_count
+ if sample_count < self.all_seg_sample_count:
+ pad_count = self.all_seg_sample_count - sample_count
+ # logger.info(f"pad {pad_count} samples to audio")
+ audio = np.pad(audio, (0, pad_count), mode="constant", constant_values=0)
+ sample_count = len(audio)
+
+ # update prev seg chunk
+ self.prev_seg_chunk = audio[-self.prev_seg_sample_count :]
+ # logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}, prev seg chunk: {self.prev_seg_chunk.shape}")
+ return audio.tobytes()
+
+ def get_fetch_duration(self):
+ fetch_duration = self.segment_duration
+ # after immediate switch, reset prev seg chunk
+ if self.chat_adapter.reset_prev:
+ self.prev_seg_chunk = None
+ self.chat_adapter.reset_prev = False
+ logger.warning(f"Reset prev seg chunk")
+ # first segment, fetch segment_duration, else fetch segment_duration - prev_duration
+ if self.prev_seg_chunk is not None:
+ fetch_duration -= self.prev_duration
+ return fetch_duration
+
+ def get_audio_segment(self):
+ audio_data = None
+ if self.rank == self.target_rank:
+ try:
+ fetch_duration = self.get_fetch_duration()
+ # logger.info(f"Get segment, fetch_duration: {fetch_duration}")
+ if self.chat_adapter.status == "voice":
+ audio_result = self.chat_adapter.get_audio(fetch_duration)
+ audio_data = self.prepare_audio_data(audio_result)
+ # think all voice segments inferred, naturally switch to blank
+ if audio_result is None:
+ logger.info(f"Think all voice segments inferred, naturally switch to blank")
+ self.chat_adapter.status = "blank"
+ else:
+ audio_data = self.prepare_audio_data(None)
+ except Exception as e:
+ logger.warning(f"Failed to get voice segment: {e}")
+ return None
+ if self.world_size > 1:
+ audio_data = self.braodcast_audio_data(audio_data)
+ audio_data = self.bytes_to_ndarray(audio_data)
+ return audio_data
+
+ def get_immediate_switch(self):
+ if self.rank == self.target_rank:
+ if self.chat_adapter.immediate_switch == 1:
+ self.immediate_switch_tensor.fill_(1)
+ # reset immediate switch
+ self.chat_adapter.immediate_switch = 0
+ else:
+ self.immediate_switch_tensor.fill_(0)
+ dist.broadcast(self.immediate_switch_tensor, src=self.target_rank)
+ immediate_switch = self.immediate_switch_tensor.item()
+ return immediate_switch
+
+ def stop(self):
+ self.model_runner = None
+ if self.chat_adapter is not None:
+ self.chat_adapter.stop()
+ self.chat_adapter = None
+ logger.warning("OmniVAReader stopped")
+
+ def __del__(self):
+ self.stop()
+
+
+if __name__ == "__main__":
+ WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
+ RANK = int(os.environ.get("RANK", 0))
+ if WORLD_SIZE > 1:
+ dist.init_process_group(backend="nccl")
+ torch.cuda.set_device(dist.get_rank())
+ logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
+
+ reader = OmniVAReader(
+ RANK,
+ WORLD_SIZE,
+ "https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=publish&stream=test_stream_ll&eip=10.120.114.82:8000",
+ segment_duration=17 / 16,
+ sample_rate=16000,
+ audio_channels=1,
+ prev_duration=1 / 16,
+ )
+ reader.start()
+ fail_count = 0
+ max_fail_count = 100000000
+
+ try:
+ while True:
+ audio_data = reader.get_audio_segment(timeout=1)
+ if audio_data is not None:
+ logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
+ fail_count = 0
+ else:
+ fail_count += 1
+ if fail_count > max_fail_count:
+ logger.warning("Failed to get audio chunk, stop reader")
+ reader.stop()
+ break
+ time.sleep(0.95)
+ finally:
+ reader.stop()
diff --git a/lightx2v/deploy/common/va_recorder.py b/lightx2v/deploy/common/va_recorder.py
new file mode 100644
index 0000000000000000000000000000000000000000..de551d338098ea40d649c00b1441450baa27894a
--- /dev/null
+++ b/lightx2v/deploy/common/va_recorder.py
@@ -0,0 +1,657 @@
+import os
+import queue
+import socket
+import subprocess
+import threading
+import time
+import traceback
+
+import numpy as np
+import torch
+import torchaudio as ta
+from loguru import logger
+
+
+def pseudo_random(a, b):
+ x = str(time.time()).split(".")[1]
+ y = int(float("0." + x) * 1000000)
+ return a + (y % (b - a + 1))
+
+
+class VARecorder:
+ def __init__(
+ self,
+ livestream_url: str,
+ fps: float = 16.0,
+ sample_rate: int = 16000,
+ slice_frame: int = 1,
+ prev_frame: int = 1,
+ ):
+ self.livestream_url = livestream_url
+ self.fps = fps
+ self.sample_rate = sample_rate
+ self.audio_port = pseudo_random(32000, 40000)
+ self.video_port = self.audio_port + 1
+ self.ffmpeg_log_level = os.getenv("FFMPEG_LOG_LEVEL", "error")
+ logger.info(f"VARecorder audio port: {self.audio_port}, video port: {self.video_port}, ffmpeg_log_level: {self.ffmpeg_log_level}")
+
+ self.width = None
+ self.height = None
+ self.stoppable_t = None
+ self.realtime = False
+ if self.livestream_url.startswith("rtmp://") or self.livestream_url.startswith("http"):
+ self.realtime = True
+
+ # ffmpeg process for mix video and audio data and push to livestream
+ self.ffmpeg_process = None
+
+ # TCP connection objects
+ self.audio_socket = None
+ self.video_socket = None
+ self.audio_conn = None
+ self.video_conn = None
+ self.audio_thread = None
+ self.video_thread = None
+
+ # queue for send data to ffmpeg process
+ self.audio_queue = queue.Queue()
+ self.video_queue = queue.Queue()
+
+ # buffer for stream data
+ self.audio_samples_per_frame = round(self.sample_rate / self.fps)
+ self.stream_buffer = []
+ self.stream_buffer_lock = threading.Lock()
+ self.stop_schedule = False
+ self.schedule_thread = None
+ self.slice_frame = slice_frame
+ self.prev_frame = prev_frame
+ assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
+
+ def init_sockets(self):
+ # TCP socket for send and recv video and audio data
+ self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.video_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ self.video_socket.bind(("127.0.0.1", self.video_port))
+ self.video_socket.listen(1)
+
+ self.audio_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.audio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.audio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ self.audio_socket.bind(("127.0.0.1", self.audio_port))
+ self.audio_socket.listen(1)
+
+ def audio_worker(self):
+ try:
+ logger.info("Waiting for ffmpeg to connect to audio socket...")
+ self.audio_conn, _ = self.audio_socket.accept()
+ logger.info(f"Audio connection established from {self.audio_conn.getpeername()}")
+ fail_time, max_fail_time = 0, 10
+ while True:
+ try:
+ if self.audio_queue is None:
+ break
+ data = self.audio_queue.get()
+ if data is None:
+ logger.info("Audio thread received stop signal")
+ break
+ # Convert audio data to 16-bit integer format
+ audios = torch.clamp(torch.round(data * 32767), -32768, 32767).to(torch.int16)
+ try:
+ self.audio_conn.send(audios[None].cpu().numpy().tobytes())
+ except (BrokenPipeError, OSError, ConnectionResetError) as e:
+ logger.info(f"Audio connection closed, stopping worker: {type(e).__name__}")
+ return
+ fail_time = 0
+ except (BrokenPipeError, OSError, ConnectionResetError):
+ logger.info("Audio connection closed during queue processing")
+ break
+ except Exception:
+ logger.error(f"Send audio data error: {traceback.format_exc()}")
+ fail_time += 1
+ if fail_time > max_fail_time:
+ logger.error(f"Audio push worker thread failed {fail_time} times, stopping...")
+ break
+ except Exception:
+ logger.error(f"Audio push worker thread error: {traceback.format_exc()}")
+ finally:
+ logger.info("Audio push worker thread stopped")
+
+ def video_worker(self):
+ try:
+ logger.info("Waiting for ffmpeg to connect to video socket...")
+ self.video_conn, _ = self.video_socket.accept()
+ logger.info(f"Video connection established from {self.video_conn.getpeername()}")
+ fail_time, max_fail_time = 0, 10
+ packet_secs = 1.0 / self.fps
+ while True:
+ try:
+ if self.video_queue is None:
+ break
+ data = self.video_queue.get()
+ if data is None:
+ logger.info("Video thread received stop signal")
+ break
+
+ # Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
+ for i in range(data.shape[0]):
+ t0 = time.time()
+ frame = (data[i] * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
+ try:
+ self.video_conn.send(frame.tobytes())
+ except (BrokenPipeError, OSError, ConnectionResetError) as e:
+ logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
+ return
+ if self.realtime and i < data.shape[0] - 1:
+ time.sleep(max(0, packet_secs - (time.time() - t0)))
+
+ fail_time = 0
+ except (BrokenPipeError, OSError, ConnectionResetError):
+ logger.info("Video connection closed during queue processing")
+ break
+ except Exception:
+ logger.error(f"Send video data error: {traceback.format_exc()}")
+ fail_time += 1
+ if fail_time > max_fail_time:
+ logger.error(f"Video push worker thread failed {fail_time} times, stopping...")
+ break
+ except Exception:
+ logger.error(f"Video push worker thread error: {traceback.format_exc()}")
+ finally:
+ logger.info("Video push worker thread stopped")
+
+ def start_ffmpeg_process_local(self):
+ """Start ffmpeg process that connects to our TCP sockets"""
+ ffmpeg_cmd = [
+ "ffmpeg",
+ "-fflags",
+ "nobuffer",
+ "-analyzeduration",
+ "0",
+ "-probesize",
+ "32",
+ "-flush_packets",
+ "1",
+ "-f",
+ "s16le",
+ "-ar",
+ str(self.sample_rate),
+ "-ac",
+ "1",
+ "-i",
+ f"tcp://127.0.0.1:{self.audio_port}",
+ "-f",
+ "rawvideo",
+ "-pix_fmt",
+ "rgb24",
+ "-color_range",
+ "pc",
+ "-colorspace",
+ "rgb",
+ "-color_primaries",
+ "bt709",
+ "-color_trc",
+ "iec61966-2-1",
+ "-r",
+ str(self.fps),
+ "-s",
+ f"{self.width}x{self.height}",
+ "-i",
+ f"tcp://127.0.0.1:{self.video_port}",
+ "-ar",
+ "44100",
+ "-b:v",
+ "4M",
+ "-c:v",
+ "libx264",
+ "-preset",
+ "ultrafast",
+ "-tune",
+ "zerolatency",
+ "-g",
+ f"{self.fps}",
+ "-pix_fmt",
+ "yuv420p",
+ "-f",
+ "mp4",
+ self.livestream_url,
+ "-y",
+ "-loglevel",
+ self.ffmpeg_log_level,
+ ]
+ try:
+ self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
+ logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
+ logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
+ except Exception as e:
+ logger.error(f"Failed to start FFmpeg: {e}")
+
+ def start_ffmpeg_process_rtmp(self):
+ """Start ffmpeg process that connects to our TCP sockets"""
+ ffmpeg_cmd = [
+ "ffmpeg",
+ "-re",
+ "-f",
+ "s16le",
+ "-ar",
+ str(self.sample_rate),
+ "-ac",
+ "1",
+ "-i",
+ f"tcp://127.0.0.1:{self.audio_port}",
+ "-f",
+ "rawvideo",
+ "-re",
+ "-pix_fmt",
+ "rgb24",
+ "-r",
+ str(self.fps),
+ "-s",
+ f"{self.width}x{self.height}",
+ "-i",
+ f"tcp://127.0.0.1:{self.video_port}",
+ "-ar",
+ "44100",
+ "-b:v",
+ "2M",
+ "-c:v",
+ "libx264",
+ "-preset",
+ "ultrafast",
+ "-tune",
+ "zerolatency",
+ "-g",
+ f"{self.fps}",
+ "-pix_fmt",
+ "yuv420p",
+ "-f",
+ "flv",
+ self.livestream_url,
+ "-y",
+ "-loglevel",
+ self.ffmpeg_log_level,
+ ]
+ try:
+ self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
+ logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
+ logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
+ except Exception as e:
+ logger.error(f"Failed to start FFmpeg: {e}")
+
+ def start_ffmpeg_process_whip(self):
+ """Start ffmpeg process that connects to our TCP sockets"""
+ ffmpeg_cmd = [
+ "ffmpeg",
+ "-re",
+ "-fflags",
+ "nobuffer",
+ "-analyzeduration",
+ "0",
+ "-probesize",
+ "32",
+ "-flush_packets",
+ "1",
+ "-f",
+ "s16le",
+ "-ar",
+ str(self.sample_rate),
+ "-ac",
+ "1",
+ "-ch_layout",
+ "mono",
+ "-i",
+ f"tcp://127.0.0.1:{self.audio_port}",
+ "-f",
+ "rawvideo",
+ "-re",
+ "-pix_fmt",
+ "rgb24",
+ "-r",
+ str(self.fps),
+ "-s",
+ f"{self.width}x{self.height}",
+ "-i",
+ f"tcp://127.0.0.1:{self.video_port}",
+ "-ar",
+ "48000",
+ "-c:a",
+ "libopus",
+ "-ac",
+ "2",
+ "-b:v",
+ "2M",
+ "-c:v",
+ "libx264",
+ "-preset",
+ "ultrafast",
+ "-tune",
+ "zerolatency",
+ "-g",
+ f"{self.fps}",
+ "-pix_fmt",
+ "yuv420p",
+ "-threads",
+ "1",
+ "-bf",
+ "0",
+ "-f",
+ "whip",
+ self.livestream_url,
+ "-y",
+ "-loglevel",
+ self.ffmpeg_log_level,
+ ]
+ try:
+ self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
+ logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
+ logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
+ except Exception as e:
+ logger.error(f"Failed to start FFmpeg: {e}")
+
+ def start(self, width: int, height: int):
+ self.set_video_size(width, height)
+ duration = 1.0
+ frames = int(self.fps * duration)
+ samples = int(self.sample_rate * (frames / self.fps))
+ self.pub_livestream(torch.zeros((frames, height, width, 3), dtype=torch.float16), torch.zeros(samples, dtype=torch.float16))
+ time.sleep(duration)
+
+ def set_video_size(self, width: int, height: int):
+ if self.width is not None and self.height is not None:
+ assert self.width == width and self.height == height, "Video size already set"
+ return
+ self.width = width
+ self.height = height
+ self.init_sockets()
+ if self.livestream_url.startswith("rtmp://"):
+ self.start_ffmpeg_process_rtmp()
+ elif self.livestream_url.startswith("http"):
+ self.start_ffmpeg_process_whip()
+ else:
+ self.start_ffmpeg_process_local()
+ self.audio_thread = threading.Thread(target=self.audio_worker)
+ self.video_thread = threading.Thread(target=self.video_worker)
+ self.audio_thread.start()
+ self.video_thread.start()
+ if self.realtime:
+ self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
+ self.schedule_thread.start()
+
+ # Publish ComfyUI Image tensor and audio tensor to livestream
+ def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
+ N, height, width, C = images.shape
+ M = audios.reshape(-1).shape[0]
+ assert C == 3, "Input must be [N, H, W, C] with C=3"
+
+ logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]")
+ audio_frames = round(M * self.fps / self.sample_rate)
+ if audio_frames != N:
+ logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
+
+ self.set_video_size(width, height)
+ self.audio_queue.put(audios)
+ self.video_queue.put(images)
+ logger.info(f"Published {N} frames and {M} audio samples")
+
+ self.stoppable_t = time.time() + M / self.sample_rate + 3
+
+ def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
+ N, height, width, C = images.shape
+ M = audios.reshape(-1).shape[0]
+ assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
+ assert C == 3, "Input must be [N, H, W, C] with C=3"
+
+ audio_frames = round(M * self.fps / self.sample_rate)
+ if audio_frames != N:
+ logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
+ self.set_video_size(width, height)
+
+ # logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
+ rets = []
+ for i in range(0, N, self.slice_frame):
+ end_frame = i + self.slice_frame
+ img = images[i:end_frame]
+ aud = audios[i * self.audio_samples_per_frame : end_frame * self.audio_samples_per_frame]
+ gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
+ rets.append((img, aud, gen))
+
+ with self.stream_buffer_lock:
+ origin_size = len(self.stream_buffer)
+ self.stream_buffer.extend(rets)
+ logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
+
+ def get_buffer_stream_size(self):
+ return len(self.stream_buffer)
+
+ def truncate_stream_buffer(self, size: int):
+ with self.stream_buffer_lock:
+ self.stream_buffer = self.stream_buffer[:size]
+ logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
+ if len(self.stream_buffer) > 0:
+ return self.stream_buffer[-1][2] # return the last video tensor
+ else:
+ return None
+
+ def schedule_stream_buffer(self):
+ schedule_interval = self.slice_frame / self.fps
+ logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
+ t = None
+ while True:
+ try:
+ if self.stop_schedule:
+ break
+ img, aud, gen = None, None, None
+ with self.stream_buffer_lock:
+ if len(self.stream_buffer) > 0:
+ img, aud, gen = self.stream_buffer.pop(0)
+
+ if t is not None:
+ wait_secs = schedule_interval - (time.time() - t)
+ if wait_secs > 0:
+ time.sleep(wait_secs)
+ t = time.time()
+
+ if img is not None and aud is not None:
+ self.audio_queue.put(aud)
+ self.video_queue.put(img)
+ # logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
+ del gen
+ self.stoppable_t = time.time() + aud.shape[0] / self.sample_rate + 3
+ else:
+ logger.warning(f"No stream buffer to schedule")
+ except Exception:
+ logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
+ break
+ logger.info("Schedule stream buffer thread stopped")
+
+ def stop(self, wait=True):
+ if wait and self.stoppable_t:
+ t = self.stoppable_t - time.time()
+ if t > 0:
+ logger.warning(f"Waiting for {t} seconds to stop ...")
+ time.sleep(t)
+ self.stoppable_t = None
+
+ if self.schedule_thread:
+ self.stop_schedule = True
+ self.schedule_thread.join(timeout=5)
+ if self.schedule_thread and self.schedule_thread.is_alive():
+ logger.error(f"Schedule thread did not stop after 5s")
+
+ # Send stop signals to queues
+ if self.audio_queue:
+ self.audio_queue.put(None)
+ if self.video_queue:
+ self.video_queue.put(None)
+
+ # Wait for threads to finish processing queued data (increased timeout)
+ queue_timeout = 30 # Increased from 5s to 30s to allow sufficient time for large video frames
+ if self.audio_thread and self.audio_thread.is_alive():
+ self.audio_thread.join(timeout=queue_timeout)
+ if self.audio_thread.is_alive():
+ logger.error(f"Audio push thread did not stop after {queue_timeout}s")
+ if self.video_thread and self.video_thread.is_alive():
+ self.video_thread.join(timeout=queue_timeout)
+ if self.video_thread.is_alive():
+ logger.error(f"Video push thread did not stop after {queue_timeout}s")
+
+ # Shutdown connections to signal EOF to FFmpeg
+ # shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
+ if self.audio_conn:
+ try:
+ self.audio_conn.getpeername()
+ self.audio_conn.shutdown(socket.SHUT_WR)
+ logger.info("Audio connection shutdown initiated")
+ except OSError:
+ # Connection already closed, skip shutdown
+ pass
+
+ if self.video_conn:
+ try:
+ self.video_conn.getpeername()
+ self.video_conn.shutdown(socket.SHUT_WR)
+ logger.info("Video connection shutdown initiated")
+ except OSError:
+ # Connection already closed, skip shutdown
+ pass
+
+ if self.ffmpeg_process:
+ is_local_file = not self.livestream_url.startswith(("rtmp://", "http"))
+ # Local MP4 files need time to write moov atom and finalize the container
+ timeout_seconds = 30 if is_local_file else 10
+ logger.info(f"Waiting for FFmpeg to finalize file (timeout={timeout_seconds}s, local_file={is_local_file})")
+ logger.info(f"FFmpeg output: {self.livestream_url}")
+
+ try:
+ returncode = self.ffmpeg_process.wait(timeout=timeout_seconds)
+ if returncode == 0:
+ logger.info(f"FFmpeg process exited successfully (exit code: {returncode})")
+ else:
+ logger.warning(f"FFmpeg process exited with non-zero code: {returncode}")
+ except subprocess.TimeoutExpired:
+ logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...")
+ try:
+ self.ffmpeg_process.terminate() # SIGTERM
+ returncode = self.ffmpeg_process.wait(timeout=5)
+ logger.warning(f"FFmpeg process terminated with SIGTERM (exit code: {returncode})")
+ except subprocess.TimeoutExpired:
+ logger.error("FFmpeg process still running after SIGTERM, killing with SIGKILL...")
+ self.ffmpeg_process.kill()
+ self.ffmpeg_process.wait() # Wait for kill to complete
+ logger.error("FFmpeg process killed with SIGKILL")
+ finally:
+ self.ffmpeg_process = None
+
+ if self.audio_conn:
+ try:
+ self.audio_conn.close()
+ except Exception as e:
+ logger.debug(f"Error closing audio connection: {e}")
+ finally:
+ self.audio_conn = None
+
+ if self.video_conn:
+ try:
+ self.video_conn.close()
+ except Exception as e:
+ logger.debug(f"Error closing video connection: {e}")
+ finally:
+ self.video_conn = None
+
+ if self.audio_socket:
+ try:
+ self.audio_socket.close()
+ except Exception as e:
+ logger.debug(f"Error closing audio socket: {e}")
+ finally:
+ self.audio_socket = None
+
+ if self.video_socket:
+ try:
+ self.video_socket.close()
+ except Exception as e:
+ logger.debug(f"Error closing video socket: {e}")
+ finally:
+ self.video_socket = None
+
+ if self.audio_queue:
+ while self.audio_queue.qsize() > 0:
+ try:
+ self.audio_queue.get_nowait()
+ except: # noqa
+ break
+ if self.video_queue:
+ while self.video_queue.qsize() > 0:
+ try:
+ self.video_queue.get_nowait()
+ except: # noqa
+ break
+ self.audio_queue = None
+ self.video_queue = None
+ logger.info("VARecorder stopped and resources cleaned up")
+
+ def __del__(self):
+ self.stop(wait=False)
+
+
+def create_simple_video(frames=10, height=480, width=640):
+ video_data = []
+ for i in range(frames):
+ frame = np.zeros((height, width, 3), dtype=np.float32)
+ stripe_height = height // 8
+ colors = [
+ [1.0, 0.0, 0.0], # 红色
+ [0.0, 1.0, 0.0], # 绿色
+ [0.0, 0.0, 1.0], # 蓝色
+ [1.0, 1.0, 0.0], # 黄色
+ [1.0, 0.0, 1.0], # 洋红
+ [0.0, 1.0, 1.0], # 青色
+ [1.0, 1.0, 1.0], # 白色
+ [0.5, 0.5, 0.5], # 灰色
+ ]
+ for j, color in enumerate(colors):
+ start_y = j * stripe_height
+ end_y = min((j + 1) * stripe_height, height)
+ frame[start_y:end_y, :] = color
+ offset = int((i / frames) * width)
+ frame = np.roll(frame, offset, axis=1)
+ frame = torch.tensor(frame, dtype=torch.float32)
+ video_data.append(frame)
+ return torch.stack(video_data, dim=0)
+
+
+if __name__ == "__main__":
+ sample_rate = 16000
+ fps = 16
+ width = 640
+ height = 480
+
+ recorder = VARecorder(
+ # livestream_url="rtmp://localhost/live/test",
+ # livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
+ livestream_url="/path/to/output_video.mp4",
+ fps=fps,
+ sample_rate=sample_rate,
+ )
+
+ audio_path = "/path/to/test_b_2min.wav"
+ audio_array, ori_sr = ta.load(audio_path)
+ audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
+ audio_array = audio_array.reshape(-1)
+ secs = audio_array.shape[0] // sample_rate
+ interval = 1
+
+ for i in range(0, secs, interval):
+ logger.info(f"{i} / {secs} s")
+ start = i * sample_rate
+ end = (i + interval) * sample_rate
+ cur_audio_array = audio_array[start:end]
+ logger.info(f"audio: {cur_audio_array.shape} {cur_audio_array.dtype} {cur_audio_array.min()} {cur_audio_array.max()}")
+
+ num_frames = int(interval * fps)
+ images = create_simple_video(num_frames, height, width)
+ logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}")
+
+ recorder.pub_livestream(images, cur_audio_array)
+ time.sleep(interval)
+ recorder.stop()
diff --git a/lightx2v/deploy/common/va_recorder_x264.py b/lightx2v/deploy/common/va_recorder_x264.py
new file mode 100644
index 0000000000000000000000000000000000000000..93a82a7e6619af0d18cfe44c66dd4a5d44b51500
--- /dev/null
+++ b/lightx2v/deploy/common/va_recorder_x264.py
@@ -0,0 +1,321 @@
+import ctypes
+import queue
+import threading
+import time
+import traceback
+
+import numpy as np
+import torch
+import torchaudio as ta
+from loguru import logger
+from scipy.signal import resample
+
+
+class X264VARecorder:
+ def __init__(
+ self,
+ whip_shared_path: str,
+ livestream_url: str,
+ fps: float = 16.0,
+ sample_rate: int = 16000,
+ slice_frame: int = 1,
+ prev_frame: int = 1,
+ ):
+ assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
+ self.livestream_url = livestream_url
+ self.fps = fps
+ self.sample_rate = sample_rate
+
+ self.width = None
+ self.height = None
+ self.stoppable_t = None
+
+ # only enable whip shared api for whip http livestream
+ self.whip_shared_path = whip_shared_path
+ self.whip_shared_lib = None
+ self.whip_shared_handle = None
+
+ assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
+ self.realtime = True
+
+ # queue for send data to whip shared api
+ self.queue = queue.Queue()
+ self.worker_thread = None
+
+ # buffer for stream data
+ self.target_sample_rate = 48000
+ self.target_samples_per_frame = round(self.target_sample_rate / self.fps)
+ self.target_chunks_per_frame = self.target_samples_per_frame * 2
+ self.stream_buffer = []
+ self.stream_buffer_lock = threading.Lock()
+ self.stop_schedule = False
+ self.schedule_thread = None
+ self.slice_frame = slice_frame
+ self.prev_frame = prev_frame
+ assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
+
+ def worker(self):
+ try:
+ fail_time, max_fail_time = 0, 10
+ packet_secs = 1.0 / self.fps
+ while True:
+ try:
+ if self.queue is None:
+ break
+ data = self.queue.get()
+ if data is None:
+ logger.info("Worker thread received stop signal")
+ break
+ audios, images = data
+
+ for i in range(images.shape[0]):
+ t0 = time.time()
+ cur_audio = audios[i * self.target_chunks_per_frame : (i + 1) * self.target_chunks_per_frame].flatten()
+ audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
+ self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, self.target_samples_per_frame)
+
+ cur_video = images[i].flatten()
+ video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
+ self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
+
+ if self.realtime and i < images.shape[0] - 1:
+ time.sleep(max(0, packet_secs - (time.time() - t0)))
+
+ fail_time = 0
+ except: # noqa
+ logger.error(f"Send audio data error: {traceback.format_exc()}")
+ fail_time += 1
+ if fail_time > max_fail_time:
+ logger.error(f"Audio push worker thread failed {fail_time} times, stopping...")
+ break
+ except: # noqa
+ logger.error(f"Audio push worker thread error: {traceback.format_exc()}")
+ finally:
+ logger.info("Audio push worker thread stopped")
+
+ def start_libx264_whip_shared_api(self, width: int, height: int):
+ self.whip_shared_lib = ctypes.CDLL(self.whip_shared_path)
+
+ # define function argtypes and restype
+ self.whip_shared_lib.initWhipStream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
+ self.whip_shared_lib.initWhipStream.restype = ctypes.c_void_p
+
+ self.whip_shared_lib.pushWhipRawAudioFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int16), ctypes.c_int]
+ self.whip_shared_lib.pushWhipRawVideoFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int, ctypes.c_int]
+
+ self.whip_shared_lib.destroyWhipStream.argtypes = [ctypes.c_void_p]
+
+ whip_url = ctypes.c_char_p(self.livestream_url.encode("utf-8"))
+ self.whip_shared_handle = ctypes.c_void_p(self.whip_shared_lib.initWhipStream(whip_url, 1, 1, 0, width, height))
+ logger.info(f"WHIP shared API initialized with handle: {self.whip_shared_handle}")
+
+ def convert_data(self, audios, images):
+ # Convert audio data to 16-bit integer format
+ audio_datas = torch.clamp(torch.round(audios * 32767), -32768, 32767).to(torch.int16).cpu().numpy().reshape(-1)
+ # Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
+ image_datas = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
+
+ logger.info(f"image_datas: {image_datas.shape} {image_datas.dtype} {image_datas.min()} {image_datas.max()}")
+ reample_audios = resample(audio_datas, int(len(audio_datas) * 48000 / self.sample_rate))
+ stereo_audios = np.stack([reample_audios, reample_audios], axis=-1).astype(np.int16).reshape(-1)
+ return stereo_audios, image_datas
+
+ def start(self, width: int, height: int):
+ self.set_video_size(width, height)
+
+ def set_video_size(self, width: int, height: int):
+ if self.width is not None and self.height is not None:
+ assert self.width == width and self.height == height, "Video size already set"
+ return
+ self.width = width
+ self.height = height
+ self.start_libx264_whip_shared_api(width, height)
+ self.worker_thread = threading.Thread(target=self.worker)
+ self.worker_thread.start()
+ if self.realtime:
+ self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
+ self.schedule_thread.start()
+
+ def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
+ N, height, width, C = images.shape
+ M = audios.reshape(-1).shape[0]
+ assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
+ assert C == 3, "Input must be [N, H, W, C] with C=3"
+
+ audio_frames = round(M * self.fps / self.sample_rate)
+ if audio_frames != N:
+ logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
+ self.set_video_size(width, height)
+ audio_datas, image_datas = self.convert_data(audios, images)
+
+ # logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
+ rets = []
+ for i in range(0, N, self.slice_frame):
+ end_frame = i + self.slice_frame
+ img = image_datas[i:end_frame]
+ aud = audio_datas[i * self.target_chunks_per_frame : end_frame * self.target_chunks_per_frame]
+ gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
+ rets.append((img, aud, gen))
+
+ with self.stream_buffer_lock:
+ origin_size = len(self.stream_buffer)
+ self.stream_buffer.extend(rets)
+ logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
+
+ def get_buffer_stream_size(self):
+ return len(self.stream_buffer)
+
+ def truncate_stream_buffer(self, size: int):
+ with self.stream_buffer_lock:
+ self.stream_buffer = self.stream_buffer[:size]
+ logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
+ if len(self.stream_buffer) > 0:
+ return self.stream_buffer[-1][2] # return the last video tensor
+ else:
+ return None
+
+ def schedule_stream_buffer(self):
+ schedule_interval = self.slice_frame / self.fps
+ logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
+ t = None
+ while True:
+ try:
+ if self.stop_schedule:
+ break
+ img, aud, gen = None, None, None
+ with self.stream_buffer_lock:
+ if len(self.stream_buffer) > 0:
+ img, aud, gen = self.stream_buffer.pop(0)
+
+ if t is not None:
+ wait_secs = schedule_interval - (time.time() - t)
+ if wait_secs > 0:
+ time.sleep(wait_secs)
+ t = time.time()
+
+ if img is not None and aud is not None:
+ self.queue.put((aud, img))
+ # logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
+ del gen
+ self.stoppable_t = time.time() + img.shape[0] / self.fps + 3
+ else:
+ logger.warning(f"No stream buffer to schedule")
+ except Exception:
+ logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
+ break
+ logger.info("Schedule stream buffer thread stopped")
+
+ def stop(self, wait=True):
+ if wait and self.stoppable_t:
+ t = self.stoppable_t - time.time()
+ if t > 0:
+ logger.warning(f"Waiting for {t} seconds to stop ...")
+ time.sleep(t)
+ self.stoppable_t = None
+
+ if self.schedule_thread:
+ self.stop_schedule = True
+ self.schedule_thread.join(timeout=5)
+ if self.schedule_thread and self.schedule_thread.is_alive():
+ logger.error(f"Schedule thread did not stop after 5s")
+
+ # Send stop signals to queues
+ if self.queue:
+ self.queue.put(None)
+
+ # Wait for threads to finish
+ if self.worker_thread and self.worker_thread.is_alive():
+ self.worker_thread.join(timeout=5)
+ if self.worker_thread.is_alive():
+ logger.warning("Worker thread did not stop gracefully")
+
+ # Destroy WHIP shared API
+ if self.whip_shared_lib and self.whip_shared_handle:
+ self.whip_shared_lib.destroyWhipStream(self.whip_shared_handle)
+ self.whip_shared_handle = None
+ self.whip_shared_lib = None
+ logger.warning("WHIP shared API destroyed")
+
+ def __del__(self):
+ self.stop()
+
+
+def create_simple_video(frames=10, height=480, width=640):
+ video_data = []
+ for i in range(frames):
+ frame = np.zeros((height, width, 3), dtype=np.float32)
+ stripe_height = height // 8
+ colors = [
+ [1.0, 0.0, 0.0], # 红色
+ [0.0, 1.0, 0.0], # 绿色
+ [0.0, 0.0, 1.0], # 蓝色
+ [1.0, 1.0, 0.0], # 黄色
+ [1.0, 0.0, 1.0], # 洋红
+ [0.0, 1.0, 1.0], # 青色
+ [1.0, 1.0, 1.0], # 白色
+ [0.5, 0.5, 0.5], # 灰色
+ ]
+ for j, color in enumerate(colors):
+ start_y = j * stripe_height
+ end_y = min((j + 1) * stripe_height, height)
+ frame[start_y:end_y, :] = color
+ offset = int((i / frames) * width)
+ frame = np.roll(frame, offset, axis=1)
+ frame = torch.tensor(frame, dtype=torch.float32)
+ video_data.append(frame)
+ return torch.stack(video_data, dim=0)
+
+
+if __name__ == "__main__":
+ sample_rate = 16000
+ fps = 16
+ width = 452
+ height = 352
+
+ recorder = X264VARecorder(
+ whip_shared_path="/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/0.1.1/go_whxp.so",
+ livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll2&eip=10.120.114.82:8000",
+ fps=fps,
+ sample_rate=sample_rate,
+ )
+ recorder.start(width, height)
+
+ # time.sleep(5)
+ audio_path = "/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav"
+ audio_array, ori_sr = ta.load(audio_path)
+ audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
+ audio_array = audio_array.numpy().reshape(-1)
+ secs = audio_array.shape[0] // sample_rate
+ interval = 1
+ space = 10
+
+ i = 0
+ while i < space:
+ t0 = time.time()
+ logger.info(f"space {i} / {space} s")
+ cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32)
+ num_frames = int(interval * fps)
+ images = create_simple_video(num_frames, height, width)
+ recorder.buffer_stream(images, torch.tensor(cur_audio_array, dtype=torch.float32), images)
+ i += interval
+ time.sleep(interval - (time.time() - t0))
+
+ started = True
+
+ i = 0
+ while i < secs:
+ t0 = time.time()
+ start = int(i * sample_rate)
+ end = int((i + interval) * sample_rate)
+ cur_audio_array = torch.tensor(audio_array[start:end], dtype=torch.float32)
+ num_frames = int(interval * fps)
+ images = create_simple_video(num_frames, height, width)
+ logger.info(f"{i} / {secs} s")
+ if started:
+ logger.warning(f"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!")
+ started = False
+ recorder.buffer_stream(images, cur_audio_array, images)
+ i += interval
+ time.sleep(interval - (time.time() - t0))
+
+ recorder.stop()
diff --git a/lightx2v/deploy/common/video_recorder.py b/lightx2v/deploy/common/video_recorder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e58a18ede880d5faafc33af4bb00be6355f7ed81
--- /dev/null
+++ b/lightx2v/deploy/common/video_recorder.py
@@ -0,0 +1,422 @@
+import os
+import queue
+import socket
+import subprocess
+import threading
+import time
+import traceback
+
+import numpy as np
+import torch
+from loguru import logger
+
+
+def pseudo_random(a, b):
+ x = str(time.time()).split(".")[1]
+ y = int(float("0." + x) * 1000000)
+ return a + (y % (b - a + 1))
+
+
+class VideoRecorder:
+ def __init__(
+ self,
+ livestream_url: str,
+ fps: float = 16.0,
+ ):
+ self.livestream_url = livestream_url
+ self.fps = fps
+ self.video_port = pseudo_random(32000, 40000)
+ self.ffmpeg_log_level = os.getenv("FFMPEG_LOG_LEVEL", "error")
+ logger.info(f"VideoRecorder video port: {self.video_port}, ffmpeg_log_level: {self.ffmpeg_log_level}")
+
+ self.width = None
+ self.height = None
+ self.stoppable_t = None
+ self.realtime = True
+
+ # ffmpeg process for video data and push to livestream
+ self.ffmpeg_process = None
+
+ # TCP connection objects
+ self.video_socket = None
+ self.video_conn = None
+ self.video_thread = None
+
+ # queue for send data to ffmpeg process
+ self.video_queue = queue.Queue()
+
+ def init_sockets(self):
+ # TCP socket for send and recv video data
+ self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.video_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ self.video_socket.bind(("127.0.0.1", self.video_port))
+ self.video_socket.listen(1)
+
+ def video_worker(self):
+ try:
+ logger.info("Waiting for ffmpeg to connect to video socket...")
+ self.video_conn, _ = self.video_socket.accept()
+ logger.info(f"Video connection established from {self.video_conn.getpeername()}")
+ fail_time, max_fail_time = 0, 10
+ packet_secs = 1.0 / self.fps
+ while True:
+ try:
+ if self.video_queue is None:
+ break
+ data = self.video_queue.get()
+ if data is None:
+ logger.info("Video thread received stop signal")
+ break
+
+ # Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
+ for i in range(data.shape[0]):
+ t0 = time.time()
+ frame = (data[i] * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
+ try:
+ self.video_conn.send(frame.tobytes())
+ except (BrokenPipeError, OSError, ConnectionResetError) as e:
+ logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
+ return
+ if self.realtime:
+ time.sleep(max(0, packet_secs - (time.time() - t0)))
+
+ fail_time = 0
+ except (BrokenPipeError, OSError, ConnectionResetError):
+ logger.info("Video connection closed during queue processing")
+ break
+ except Exception:
+ logger.error(f"Send video data error: {traceback.format_exc()}")
+ fail_time += 1
+ if fail_time > max_fail_time:
+ logger.error(f"Video push worker thread failed {fail_time} times, stopping...")
+ break
+ except Exception:
+ logger.error(f"Video push worker thread error: {traceback.format_exc()}")
+ finally:
+ logger.info("Video push worker thread stopped")
+
+ def start_ffmpeg_process_local(self):
+ """Start ffmpeg process that connects to our TCP sockets"""
+ ffmpeg_cmd = [
+ "ffmpeg",
+ "-fflags",
+ "nobuffer",
+ "-analyzeduration",
+ "0",
+ "-probesize",
+ "32",
+ "-flush_packets",
+ "1",
+ "-f",
+ "rawvideo",
+ "-pix_fmt",
+ "rgb24",
+ "-color_range",
+ "pc",
+ "-colorspace",
+ "rgb",
+ "-color_primaries",
+ "bt709",
+ "-color_trc",
+ "iec61966-2-1",
+ "-r",
+ str(self.fps),
+ "-s",
+ f"{self.width}x{self.height}",
+ "-i",
+ f"tcp://127.0.0.1:{self.video_port}",
+ "-b:v",
+ "4M",
+ "-c:v",
+ "libx264",
+ "-preset",
+ "ultrafast",
+ "-tune",
+ "zerolatency",
+ "-g",
+ f"{self.fps}",
+ "-pix_fmt",
+ "yuv420p",
+ "-f",
+ "mp4",
+ self.livestream_url,
+ "-y",
+ "-loglevel",
+ self.ffmpeg_log_level,
+ ]
+ try:
+ self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
+ logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
+ logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
+ except Exception as e:
+ logger.error(f"Failed to start FFmpeg: {e}")
+
+ def start_ffmpeg_process_rtmp(self):
+ """Start ffmpeg process that connects to our TCP sockets"""
+ ffmpeg_cmd = [
+ "ffmpeg",
+ "-f",
+ "rawvideo",
+ "-re",
+ "-pix_fmt",
+ "rgb24",
+ "-r",
+ str(self.fps),
+ "-s",
+ f"{self.width}x{self.height}",
+ "-i",
+ f"tcp://127.0.0.1:{self.video_port}",
+ "-b:v",
+ "2M",
+ "-c:v",
+ "libx264",
+ "-preset",
+ "ultrafast",
+ "-tune",
+ "zerolatency",
+ "-g",
+ f"{self.fps}",
+ "-pix_fmt",
+ "yuv420p",
+ "-f",
+ "flv",
+ self.livestream_url,
+ "-y",
+ "-loglevel",
+ self.ffmpeg_log_level,
+ ]
+ try:
+ self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
+ logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
+ logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
+ except Exception as e:
+ logger.error(f"Failed to start FFmpeg: {e}")
+
+ def start_ffmpeg_process_whip(self):
+ """Start ffmpeg process that connects to our TCP sockets"""
+ ffmpeg_cmd = [
+ "ffmpeg",
+ "-re",
+ "-fflags",
+ "nobuffer",
+ "-analyzeduration",
+ "0",
+ "-probesize",
+ "32",
+ "-flush_packets",
+ "1",
+ "-f",
+ "rawvideo",
+ "-re",
+ "-pix_fmt",
+ "rgb24",
+ "-r",
+ str(self.fps),
+ "-s",
+ f"{self.width}x{self.height}",
+ "-i",
+ f"tcp://127.0.0.1:{self.video_port}",
+ "-b:v",
+ "2M",
+ "-c:v",
+ "libx264",
+ "-preset",
+ "ultrafast",
+ "-tune",
+ "zerolatency",
+ "-g",
+ f"{self.fps}",
+ "-pix_fmt",
+ "yuv420p",
+ "-threads",
+ "1",
+ "-bf",
+ "0",
+ "-f",
+ "whip",
+ self.livestream_url,
+ "-y",
+ "-loglevel",
+ self.ffmpeg_log_level,
+ ]
+ try:
+ self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
+ logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
+ logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
+ except Exception as e:
+ logger.error(f"Failed to start FFmpeg: {e}")
+
+ def start(self, width: int, height: int):
+ self.set_video_size(width, height)
+ duration = 1.0
+ self.pub_video(torch.zeros((int(self.fps * duration), height, width, 3), dtype=torch.float16))
+ time.sleep(duration)
+
+ def set_video_size(self, width: int, height: int):
+ if self.width is not None and self.height is not None:
+ assert self.width == width and self.height == height, "Video size already set"
+ return
+ self.width = width
+ self.height = height
+ self.init_sockets()
+ if self.livestream_url.startswith("rtmp://"):
+ self.start_ffmpeg_process_rtmp()
+ elif self.livestream_url.startswith("http"):
+ self.start_ffmpeg_process_whip()
+ else:
+ self.start_ffmpeg_process_local()
+ self.realtime = False
+ self.video_thread = threading.Thread(target=self.video_worker)
+ self.video_thread.start()
+
+ # Publish ComfyUI Image tensor to livestream
+ def pub_video(self, images: torch.Tensor):
+ N, height, width, C = images.shape
+ assert C == 3, "Input must be [N, H, W, C] with C=3"
+
+ logger.info(f"Publishing video [{N}x{width}x{height}]")
+
+ self.set_video_size(width, height)
+ self.video_queue.put(images)
+ logger.info(f"Published {N} frames")
+
+ self.stoppable_t = time.time() + N / self.fps + 3
+
+ def stop(self, wait=True):
+ if wait and self.stoppable_t:
+ t = self.stoppable_t - time.time()
+ if t > 0:
+ logger.warning(f"Waiting for {t} seconds to stop ...")
+ time.sleep(t)
+ self.stoppable_t = None
+
+ # Send stop signals to queues
+ if self.video_queue:
+ self.video_queue.put(None)
+
+ # Wait for threads to finish processing queued data (increased timeout)
+ queue_timeout = 30 # Increased from 5s to 30s to allow sufficient time for large video frames
+ if self.video_thread and self.video_thread.is_alive():
+ self.video_thread.join(timeout=queue_timeout)
+ if self.video_thread.is_alive():
+ logger.error(f"Video push thread did not stop after {queue_timeout}s")
+
+ # Shutdown connections to signal EOF to FFmpeg
+ # shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
+ if self.video_conn:
+ try:
+ self.video_conn.getpeername()
+ self.video_conn.shutdown(socket.SHUT_WR)
+ logger.info("Video connection shutdown initiated")
+ except OSError:
+ # Connection already closed, skip shutdown
+ pass
+
+ if self.ffmpeg_process:
+ is_local_file = not self.livestream_url.startswith(("rtmp://", "http"))
+ # Local MP4 files need time to write moov atom and finalize the container
+ timeout_seconds = 30 if is_local_file else 10
+ logger.info(f"Waiting for FFmpeg to finalize file (timeout={timeout_seconds}s, local_file={is_local_file})")
+ logger.info(f"FFmpeg output: {self.livestream_url}")
+
+ try:
+ returncode = self.ffmpeg_process.wait(timeout=timeout_seconds)
+ if returncode == 0:
+ logger.info(f"FFmpeg process exited successfully (exit code: {returncode})")
+ else:
+ logger.warning(f"FFmpeg process exited with non-zero code: {returncode}")
+ except subprocess.TimeoutExpired:
+ logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...")
+ try:
+ self.ffmpeg_process.terminate() # SIGTERM
+ returncode = self.ffmpeg_process.wait(timeout=5)
+ logger.warning(f"FFmpeg process terminated with SIGTERM (exit code: {returncode})")
+ except subprocess.TimeoutExpired:
+ logger.error("FFmpeg process still running after SIGTERM, killing with SIGKILL...")
+ self.ffmpeg_process.kill()
+ self.ffmpeg_process.wait() # Wait for kill to complete
+ logger.error("FFmpeg process killed with SIGKILL")
+ finally:
+ self.ffmpeg_process = None
+
+ if self.video_conn:
+ try:
+ self.video_conn.close()
+ except Exception as e:
+ logger.debug(f"Error closing video connection: {e}")
+ finally:
+ self.video_conn = None
+
+ if self.video_socket:
+ try:
+ self.video_socket.close()
+ except Exception as e:
+ logger.debug(f"Error closing video socket: {e}")
+ finally:
+ self.video_socket = None
+
+ if self.video_queue:
+ while self.video_queue.qsize() > 0:
+ try:
+ self.video_queue.get_nowait()
+ except: # noqa
+ break
+ self.video_queue = None
+ logger.info("VideoRecorder stopped and resources cleaned up")
+
+ def __del__(self):
+ self.stop(wait=False)
+
+
+def create_simple_video(frames=10, height=480, width=640):
+ video_data = []
+ for i in range(frames):
+ frame = np.zeros((height, width, 3), dtype=np.float32)
+ stripe_height = height // 8
+ colors = [
+ [1.0, 0.0, 0.0], # 红色
+ [0.0, 1.0, 0.0], # 绿色
+ [0.0, 0.0, 1.0], # 蓝色
+ [1.0, 1.0, 0.0], # 黄色
+ [1.0, 0.0, 1.0], # 洋红
+ [0.0, 1.0, 1.0], # 青色
+ [1.0, 1.0, 1.0], # 白色
+ [0.5, 0.5, 0.5], # 灰色
+ ]
+ for j, color in enumerate(colors):
+ start_y = j * stripe_height
+ end_y = min((j + 1) * stripe_height, height)
+ frame[start_y:end_y, :] = color
+ offset = int((i / frames) * width)
+ frame = np.roll(frame, offset, axis=1)
+ frame = torch.tensor(frame, dtype=torch.float32)
+ video_data.append(frame)
+ return torch.stack(video_data, dim=0)
+
+
+if __name__ == "__main__":
+ fps = 16
+ width = 640
+ height = 480
+
+ recorder = VideoRecorder(
+ # livestream_url="rtmp://localhost/live/test",
+ # livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
+ livestream_url="/path/to/output_video.mp4",
+ fps=fps,
+ )
+
+ secs = 10 # 10秒视频
+ interval = 1
+
+ for i in range(0, secs, interval):
+ logger.info(f"{i} / {secs} s")
+
+ num_frames = int(interval * fps)
+ images = create_simple_video(num_frames, height, width)
+ logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}")
+
+ recorder.pub_video(images)
+ time.sleep(interval)
+ recorder.stop()
diff --git a/lightx2v/deploy/common/volcengine_tts.py b/lightx2v/deploy/common/volcengine_tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0e60d3670215973bb3e1506baff1b40e26d0975
--- /dev/null
+++ b/lightx2v/deploy/common/volcengine_tts.py
@@ -0,0 +1,241 @@
+# -*- coding: utf-8 -*-
+
+import asyncio
+import base64
+import json
+import os
+import sys
+
+import aiohttp
+from loguru import logger
+
+
+class VolcEngineTTSClient:
+ """
+ VolcEngine TTS客户端
+
+ 参数范围说明:
+ - speech_rate: -50~100 (100代表2倍速, -50代表0.5倍速, 0为正常语速)
+ - loudness_rate: -50~100 (100代表2倍音量, -50代表0.5倍音量, 0为正常音量)
+ - emotion_scale: 1-5
+ """
+
+ def __init__(self, voices_list_file=None):
+ self.url = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
+ self.appid = os.getenv("VOLCENGINE_TTS_APPID")
+ self.access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN")
+ self.proxy = os.getenv("HTTPS_PROXY", None)
+ if self.proxy:
+ logger.info(f"volcengine tts use proxy: {self.proxy}")
+ if voices_list_file is not None:
+ with open(voices_list_file, "r", encoding="utf-8") as f:
+ self.voices_list = json.load(f)
+ else:
+ self.voices_list = None
+
+ def get_voice_list(self):
+ return self.voices_list
+
+ async def tts_http_stream(self, headers, params, audio_save_path):
+ """执行TTS流式请求"""
+ try:
+ logger.info(f"volcengine tts params: {params}")
+ audio_data = bytearray()
+ total_audio_size = 0
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(self.url, json=params, headers=headers, proxy=self.proxy) as response:
+ response.raise_for_status()
+ async for chunk in response.content:
+ if not chunk:
+ continue
+ try:
+ data = json.loads(chunk.decode("utf-8").strip())
+ if data.get("code", 0) == 0 and "data" in data and data["data"]:
+ chunk_audio = base64.b64decode(data["data"])
+ audio_size = len(chunk_audio)
+ total_audio_size += audio_size
+ audio_data.extend(chunk_audio)
+ continue
+ if data.get("code", 0) == 0 and "sentence" in data and data["sentence"]:
+ continue
+ if data.get("code", 0) == 20000000:
+ break
+ if data.get("code", 0) > 0:
+ logger.warning(f"volcengine tts error response: {data}")
+ break
+ except Exception as e:
+ logger.warning(f"Failed to parse volcengine tts chunk: {e}")
+
+ # save audio file
+ if audio_data:
+ with open(audio_save_path, "wb") as f:
+ f.write(audio_data)
+ logger.info(f"audio saved to {audio_save_path}, audio size: {len(audio_data) / 1024:.2f} KB")
+ # set correct permissions
+ os.chmod(audio_save_path, 0o644)
+ return True
+ else:
+ logger.warning("No tts audio data received")
+ return False
+
+ except Exception as e:
+ logger.warning(f"VolcEngineTTSClient tts request failed: {e}")
+ return False
+
+ async def tts_request(
+ self,
+ text,
+ voice_type="zh_female_vv_uranus_bigtts",
+ context_texts="",
+ emotion="",
+ emotion_scale=4,
+ speech_rate=0,
+ loudness_rate=0,
+ pitch=0,
+ output="tts_output.mp3",
+ resource_id="seed-tts-2.0",
+ app_key="aGjiRDfUWi",
+ uid="123123",
+ format="mp3",
+ sample_rate=24000,
+ enable_timestamp=True,
+ ):
+ """
+ 执行TTS请求
+
+ Args:
+ text: 要转换的文本
+ voice_type: 声音类型
+ emotion: 情感类型
+ emotion_scale: 情感强度 (1-5)
+ speech_rate: 语速调节 (-50~100, 100代表2倍速, -50代表0.5倍速, 0为正常语速)
+ loudness_rate: 音量调节 (-50~100, 100代表2倍音量, -50代表0.5倍音量, 0为正常音量)
+ pitch: 音调调节 (-12~12, 12代表高音调, -12代表低音调, 0为正常音调)
+ output: 输出文件路径
+ resource_id: 资源ID
+ app_key: 应用密钥
+ uid: 用户ID
+ format: 音频格式
+ sample_rate: 采样率
+ enable_timestamp: 是否启用时间戳
+ """
+ # 验证参数范围
+ if not (-50 <= speech_rate <= 100):
+ logger.warning(f"speech_rate {speech_rate} 超出有效范围 [-50, 100],将使用默认值 0")
+ speech_rate = 0
+
+ if not (-50 <= loudness_rate <= 100):
+ logger.warning(f"loudness_rate {loudness_rate} 超出有效范围 [-50, 100],将使用默认值 0")
+ loudness_rate = 0
+
+ if not (1 <= emotion_scale <= 5):
+ logger.warning(f"emotion_scale {emotion_scale} 超出有效范围 [1, 5],将使用默认值 3")
+ emotion_scale = 3
+
+ if not (-12 <= pitch <= 12):
+ logger.warning(f"pitch {pitch} 超出有效范围 [-12, 12],将使用默认值 0")
+ pitch = 0
+
+ headers = {
+ "X-Api-App-Id": self.appid,
+ "X-Api-Access-Key": self.access_token,
+ "X-Api-Resource-Id": resource_id,
+ "X-Api-App-Key": app_key,
+ "Content-Type": "application/json",
+ "Connection": "keep-alive",
+ }
+ additions = json.dumps(
+ {"explicit_language": "zh", "disable_markdown_filter": True, "enable_timestamp": True, "context_texts": [context_texts] if context_texts else None, "post_process": {"pitch": pitch}}
+ )
+ payload = {
+ "user": {"uid": uid},
+ "req_params": {
+ "text": text,
+ "speaker": voice_type,
+ "audio_params": {
+ "format": format,
+ "sample_rate": sample_rate,
+ "enable_timestamp": enable_timestamp,
+ "emotion": emotion,
+ "emotion_scale": emotion_scale,
+ "speech_rate": speech_rate,
+ "loudness_rate": loudness_rate,
+ },
+ "additions": additions,
+ },
+ }
+ success = await self.tts_http_stream(headers=headers, params=payload, audio_save_path=output)
+ if success:
+ logger.info(f"VolcEngineTTSClient tts request for '{text}': success")
+ else:
+ logger.warning(f"VolcEngineTTSClient tts request for '{text}': failed")
+ return success
+
+
+async def test(args):
+ """
+ TTS测试函数
+
+ Args:
+ args: list, e.g. [text, voice_type, emotion, emotion_scale, speech_rate, loudness_rate, output, resource_id, app_key, uid, format, sample_rate, enable_timestamp]
+ Provide as many as needed, from left to right.
+
+ Parameter ranges:
+ - speech_rate: -50~100 (100代表2倍速, -50代表0.5倍速, 0为正常语速)
+ - loudness_rate: -50~100 (100代表2倍音量, -50代表0.5倍音量, 0为正常音量)
+ - emotion_scale: 1-5
+ - pitch: -12~12 (12代表高音调, -12代表低音调, 0为正常音调)
+ """
+ client = VolcEngineTTSClient()
+ # 设置默认参数
+ params = {
+ "text": "",
+ "voice_type": "zh_female_vv_uranus_bigtts",
+ "context_texts": "",
+ "emotion": "",
+ "emotion_scale": 4,
+ "speech_rate": 0,
+ "loudness_rate": 0,
+ "pitch": 12,
+ "output": "tts_output.mp3",
+ "resource_id": "seed-tts-2.0",
+ "app_key": "aGjiRDfUWi",
+ "uid": "123123",
+ "format": "mp3",
+ "sample_rate": 24000,
+ "enable_timestamp": True,
+ }
+ keys = list(params.keys())
+ # 覆盖默认参数
+ for i, arg in enumerate(args):
+ # 类型转换
+ if keys[i] == "sample_rate":
+ params[keys[i]] = int(arg)
+ elif keys[i] == "enable_timestamp":
+ # 支持多种布尔输入
+ params[keys[i]] = str(arg).lower() in ("1", "true", "yes", "on")
+ else:
+ params[keys[i]] = arg
+
+ await client.tts_request(
+ params["text"],
+ params["voice_type"],
+ params["context_texts"],
+ params["emotion"],
+ params["emotion_scale"],
+ params["speech_rate"],
+ params["loudness_rate"],
+ params["pitch"],
+ params["output"],
+ params["resource_id"],
+ params["app_key"],
+ params["uid"],
+ params["format"],
+ params["sample_rate"],
+ params["enable_timestamp"],
+ )
+
+
+if __name__ == "__main__":
+ asyncio.run(test(sys.argv[1:]))
diff --git a/lightx2v/deploy/data_manager/__init__.py b/lightx2v/deploy/data_manager/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..14ee9d3388f24d9e763c02600972371273ab9933
--- /dev/null
+++ b/lightx2v/deploy/data_manager/__init__.py
@@ -0,0 +1,248 @@
+import io
+import json
+import os
+
+import torch
+from PIL import Image
+
+from lightx2v.deploy.common.utils import class_try_catch_async
+
+
+class BaseDataManager:
+ def __init__(self):
+ self.template_images_dir = None
+ self.template_audios_dir = None
+ self.template_videos_dir = None
+ self.template_tasks_dir = None
+ self.podcast_temp_session_dir = None
+ self.podcast_output_dir = None
+
+ async def init(self):
+ pass
+
+ async def close(self):
+ pass
+
+ def fmt_path(self, base, filename, abs_path=None):
+ if abs_path:
+ return abs_path
+ else:
+ return os.path.join(base, filename)
+
+ def to_device(self, data, device):
+ if isinstance(data, dict):
+ return {key: self.to_device(value, device) for key, value in data.items()}
+ elif isinstance(data, list):
+ return [self.to_device(item, device) for item in data]
+ elif isinstance(data, torch.Tensor):
+ return data.to(device)
+ else:
+ return data
+
+ async def save_bytes(self, bytes_data, filename, abs_path=None):
+ raise NotImplementedError
+
+ async def load_bytes(self, filename, abs_path=None):
+ raise NotImplementedError
+
+ async def delete_bytes(self, filename, abs_path=None):
+ raise NotImplementedError
+
+ async def presign_url(self, filename, abs_path=None):
+ return None
+
+ async def recurrent_save(self, data, prefix):
+ if isinstance(data, dict):
+ return {k: await self.recurrent_save(v, f"{prefix}-{k}") for k, v in data.items()}
+ elif isinstance(data, list):
+ return [await self.recurrent_save(v, f"{prefix}-{idx}") for idx, v in enumerate(data)]
+ elif isinstance(data, torch.Tensor):
+ save_path = prefix + ".pt"
+ await self.save_tensor(data, save_path)
+ return save_path
+ elif isinstance(data, Image.Image):
+ save_path = prefix + ".png"
+ await self.save_image(data, save_path)
+ return save_path
+ else:
+ return data
+
+ async def recurrent_load(self, data, device, prefix):
+ if isinstance(data, dict):
+ return {k: await self.recurrent_load(v, device, f"{prefix}-{k}") for k, v in data.items()}
+ elif isinstance(data, list):
+ return [await self.recurrent_load(v, device, f"{prefix}-{idx}") for idx, v in enumerate(data)]
+ elif isinstance(data, str) and data == prefix + ".pt":
+ return await self.load_tensor(data, device)
+ elif isinstance(data, str) and data == prefix + ".png":
+ return await self.load_image(data)
+ else:
+ return data
+
+ async def recurrent_delete(self, data, prefix):
+ if isinstance(data, dict):
+ return {k: await self.recurrent_delete(v, f"{prefix}-{k}") for k, v in data.items()}
+ elif isinstance(data, list):
+ return [await self.recurrent_delete(v, f"{prefix}-{idx}") for idx, v in enumerate(data)]
+ elif isinstance(data, str) and data == prefix + ".pt":
+ await self.delete_bytes(data)
+ elif isinstance(data, str) and data == prefix + ".png":
+ await self.delete_bytes(data)
+
+ @class_try_catch_async
+ async def save_object(self, data, filename):
+ data = await self.recurrent_save(data, filename)
+ bytes_data = json.dumps(data, ensure_ascii=False).encode("utf-8")
+ await self.save_bytes(bytes_data, filename)
+
+ @class_try_catch_async
+ async def load_object(self, filename, device):
+ bytes_data = await self.load_bytes(filename)
+ data = json.loads(bytes_data.decode("utf-8"))
+ data = await self.recurrent_load(data, device, filename)
+ return data
+
+ @class_try_catch_async
+ async def delete_object(self, filename):
+ bytes_data = await self.load_bytes(filename)
+ data = json.loads(bytes_data.decode("utf-8"))
+ await self.recurrent_delete(data, filename)
+ await self.delete_bytes(filename)
+
+ @class_try_catch_async
+ async def save_tensor(self, data: torch.Tensor, filename):
+ buffer = io.BytesIO()
+ torch.save(data.to("cpu"), buffer)
+ await self.save_bytes(buffer.getvalue(), filename)
+
+ @class_try_catch_async
+ async def load_tensor(self, filename, device):
+ bytes_data = await self.load_bytes(filename)
+ buffer = io.BytesIO(bytes_data)
+ t = torch.load(io.BytesIO(bytes_data))
+ t = t.to(device)
+ return t
+
+ @class_try_catch_async
+ async def save_image(self, data: Image.Image, filename):
+ buffer = io.BytesIO()
+ data.save(buffer, format="PNG")
+ await self.save_bytes(buffer.getvalue(), filename)
+
+ @class_try_catch_async
+ async def load_image(self, filename):
+ bytes_data = await self.load_bytes(filename)
+ buffer = io.BytesIO(bytes_data)
+ img = Image.open(buffer).convert("RGB")
+ return img
+
+ def get_delete_func(self, type):
+ maps = {
+ "TENSOR": self.delete_bytes,
+ "IMAGE": self.delete_bytes,
+ "OBJECT": self.delete_object,
+ "VIDEO": self.delete_bytes,
+ }
+ return maps[type]
+
+ def get_template_dir(self, template_type):
+ if template_type == "audios":
+ return self.template_audios_dir
+ elif template_type == "images":
+ return self.template_images_dir
+ elif template_type == "videos":
+ return self.template_videos_dir
+ elif template_type == "tasks":
+ return self.template_tasks_dir
+ else:
+ raise ValueError(f"Invalid template type: {template_type}")
+
+ @class_try_catch_async
+ async def list_template_files(self, template_type):
+ template_dir = self.get_template_dir(template_type)
+ if template_dir is None:
+ return []
+ return await self.list_files(base_dir=template_dir)
+
+ @class_try_catch_async
+ async def load_template_file(self, template_type, filename):
+ template_dir = self.get_template_dir(template_type)
+ if template_dir is None:
+ return None
+ return await self.load_bytes(None, abs_path=os.path.join(template_dir, filename))
+
+ @class_try_catch_async
+ async def template_file_exists(self, template_type, filename):
+ template_dir = self.get_template_dir(template_type)
+ if template_dir is None:
+ return None
+ return await self.file_exists(None, abs_path=os.path.join(template_dir, filename))
+
+ @class_try_catch_async
+ async def delete_template_file(self, template_type, filename):
+ template_dir = self.get_template_dir(template_type)
+ if template_dir is None:
+ return None
+ return await self.delete_bytes(None, abs_path=os.path.join(template_dir, filename))
+
+ @class_try_catch_async
+ async def save_template_file(self, template_type, filename, bytes_data):
+ template_dir = self.get_template_dir(template_type)
+ if template_dir is None:
+ return None
+ abs_path = os.path.join(template_dir, filename)
+ return await self.save_bytes(bytes_data, None, abs_path=abs_path)
+
+ @class_try_catch_async
+ async def presign_template_url(self, template_type, filename):
+ template_dir = self.get_template_dir(template_type)
+ if template_dir is None:
+ return None
+ return await self.presign_url(None, abs_path=os.path.join(template_dir, filename))
+
+ @class_try_catch_async
+ async def list_podcast_temp_session_files(self, session_id):
+ session_dir = os.path.join(self.podcast_temp_session_dir, session_id)
+ return await self.list_files(base_dir=session_dir)
+
+ @class_try_catch_async
+ async def save_podcast_temp_session_file(self, session_id, filename, bytes_data):
+ fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename)
+ await self.save_bytes(bytes_data, None, abs_path=fpath)
+
+ @class_try_catch_async
+ async def load_podcast_temp_session_file(self, session_id, filename):
+ fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename)
+ return await self.load_bytes(None, abs_path=fpath)
+
+ @class_try_catch_async
+ async def delete_podcast_temp_session_file(self, session_id, filename):
+ fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename)
+ return await self.delete_bytes(None, abs_path=fpath)
+
+ @class_try_catch_async
+ async def save_podcast_output_file(self, filename, bytes_data):
+ fpath = os.path.join(self.podcast_output_dir, filename)
+ await self.save_bytes(bytes_data, None, abs_path=fpath)
+
+ @class_try_catch_async
+ async def load_podcast_output_file(self, filename):
+ fpath = os.path.join(self.podcast_output_dir, filename)
+ return await self.load_bytes(None, abs_path=fpath)
+
+ @class_try_catch_async
+ async def delete_podcast_output_file(self, filename):
+ fpath = os.path.join(self.podcast_output_dir, filename)
+ return await self.delete_bytes(None, abs_path=fpath)
+
+ @class_try_catch_async
+ async def presign_podcast_output_url(self, filename):
+ fpath = os.path.join(self.podcast_output_dir, filename)
+ return await self.presign_url(None, abs_path=fpath)
+
+
+# Import data manager implementations
+from .local_data_manager import LocalDataManager # noqa
+from .s3_data_manager import S3DataManager # noqa
+
+__all__ = ["BaseDataManager", "LocalDataManager", "S3DataManager"]
diff --git a/lightx2v/deploy/data_manager/local_data_manager.py b/lightx2v/deploy/data_manager/local_data_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e11e6cbadef5b7eccd6b9e65de657798b7d2924
--- /dev/null
+++ b/lightx2v/deploy/data_manager/local_data_manager.py
@@ -0,0 +1,120 @@
+import asyncio
+import os
+import shutil
+
+from loguru import logger
+
+from lightx2v.deploy.common.utils import class_try_catch_async
+from lightx2v.deploy.data_manager import BaseDataManager
+
+
+class LocalDataManager(BaseDataManager):
+ def __init__(self, local_dir, template_dir):
+ super().__init__()
+ self.local_dir = local_dir
+ self.name = "local"
+ if not os.path.exists(self.local_dir):
+ os.makedirs(self.local_dir)
+ if template_dir:
+ self.template_images_dir = os.path.join(template_dir, "images")
+ self.template_audios_dir = os.path.join(template_dir, "audios")
+ self.template_videos_dir = os.path.join(template_dir, "videos")
+ self.template_tasks_dir = os.path.join(template_dir, "tasks")
+ assert os.path.exists(self.template_images_dir), f"{self.template_images_dir} not exists!"
+ assert os.path.exists(self.template_audios_dir), f"{self.template_audios_dir} not exists!"
+ assert os.path.exists(self.template_videos_dir), f"{self.template_videos_dir} not exists!"
+ assert os.path.exists(self.template_tasks_dir), f"{self.template_tasks_dir} not exists!"
+
+ # podcast temp session dir and output dir
+ self.podcast_temp_session_dir = os.path.join(self.local_dir, "podcast_temp_session")
+ self.podcast_output_dir = os.path.join(self.local_dir, "podcast_output")
+ os.makedirs(self.podcast_temp_session_dir, exist_ok=True)
+ os.makedirs(self.podcast_output_dir, exist_ok=True)
+
+ @class_try_catch_async
+ async def save_bytes(self, bytes_data, filename, abs_path=None):
+ out_path = self.fmt_path(self.local_dir, filename, abs_path)
+ with open(out_path, "wb") as fout:
+ fout.write(bytes_data)
+ return True
+
+ @class_try_catch_async
+ async def load_bytes(self, filename, abs_path=None):
+ inp_path = self.fmt_path(self.local_dir, filename, abs_path)
+ with open(inp_path, "rb") as fin:
+ return fin.read()
+
+ @class_try_catch_async
+ async def delete_bytes(self, filename, abs_path=None):
+ inp_path = self.fmt_path(self.local_dir, filename, abs_path)
+ os.remove(inp_path)
+ logger.info(f"deleted local file {filename}")
+ return True
+
+ @class_try_catch_async
+ async def file_exists(self, filename, abs_path=None):
+ filename = self.fmt_path(self.local_dir, filename, abs_path)
+ return os.path.exists(filename)
+
+ @class_try_catch_async
+ async def list_files(self, base_dir=None):
+ prefix = base_dir if base_dir else self.local_dir
+ return os.listdir(prefix)
+
+ @class_try_catch_async
+ async def create_podcast_temp_session_dir(self, session_id):
+ dir_path = os.path.join(self.podcast_temp_session_dir, session_id)
+ os.makedirs(dir_path, exist_ok=True)
+ return dir_path
+
+ @class_try_catch_async
+ async def clear_podcast_temp_session_dir(self, session_id):
+ session_dir = os.path.join(self.podcast_temp_session_dir, session_id)
+ if os.path.isdir(session_dir):
+ shutil.rmtree(session_dir)
+ logger.info(f"cleared podcast temp session dir {session_dir}")
+ return True
+
+
+async def test():
+ import torch
+ from PIL import Image
+
+ m = LocalDataManager("/data/nvme1/liuliang1/lightx2v/local_data", None)
+ await m.init()
+
+ img = Image.open("/data/nvme1/liuliang1/lightx2v/assets/img_lightx2v.png")
+ tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0")
+
+ await m.save_image(img, "test_img.png")
+ print(await m.load_image("test_img.png"))
+
+ await m.save_tensor(tensor, "test_tensor.pt")
+ print(await m.load_tensor("test_tensor.pt", "cuda:0"))
+
+ await m.save_object(
+ {
+ "images": [img, img],
+ "tensor": tensor,
+ "list": [
+ [2, 0, 5, 5],
+ {
+ "1": "hello world",
+ "2": "world",
+ "3": img,
+ "t": tensor,
+ },
+ "0609",
+ ],
+ },
+ "test_object.json",
+ )
+ print(await m.load_object("test_object.json", "cuda:0"))
+
+ await m.get_delete_func("OBJECT")("test_object.json")
+ await m.get_delete_func("TENSOR")("test_tensor.pt")
+ await m.get_delete_func("IMAGE")("test_img.png")
+
+
+if __name__ == "__main__":
+ asyncio.run(test())
diff --git a/lightx2v/deploy/data_manager/s3_data_manager.py b/lightx2v/deploy/data_manager/s3_data_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d928fc197bb5d5bb871e7e4199d8ebd31bd8fc9
--- /dev/null
+++ b/lightx2v/deploy/data_manager/s3_data_manager.py
@@ -0,0 +1,254 @@
+import asyncio
+import hashlib
+import json
+import os
+
+import aioboto3
+import tos
+from botocore.client import Config
+from loguru import logger
+
+from lightx2v.deploy.common.utils import class_try_catch_async
+from lightx2v.deploy.data_manager import BaseDataManager
+
+
+class S3DataManager(BaseDataManager):
+ def __init__(self, config_string, template_dir, max_retries=3):
+ super().__init__()
+ self.name = "s3"
+ self.config = json.loads(config_string)
+ self.max_retries = max_retries
+ self.bucket_name = self.config["bucket_name"]
+ self.aws_access_key_id = self.config["aws_access_key_id"]
+ self.aws_secret_access_key = self.config["aws_secret_access_key"]
+ self.endpoint_url = self.config["endpoint_url"]
+ self.base_path = self.config["base_path"]
+ self.connect_timeout = self.config.get("connect_timeout", 60)
+ self.read_timeout = self.config.get("read_timeout", 60)
+ self.write_timeout = self.config.get("write_timeout", 10)
+ self.addressing_style = self.config.get("addressing_style", None)
+ self.region = self.config.get("region", None)
+ self.cdn_url = self.config.get("cdn_url", "")
+ self.session = None
+ self.s3_client = None
+ self.presign_client = None
+ if template_dir:
+ self.template_images_dir = os.path.join(template_dir, "images")
+ self.template_audios_dir = os.path.join(template_dir, "audios")
+ self.template_videos_dir = os.path.join(template_dir, "videos")
+ self.template_tasks_dir = os.path.join(template_dir, "tasks")
+
+ # podcast temp session dir and output dir
+ self.podcast_temp_session_dir = os.path.join(self.base_path, "podcast_temp_session")
+ self.podcast_output_dir = os.path.join(self.base_path, "podcast_output")
+
+ async def init_presign_client(self):
+ # init tos client for volces.com
+ if "volces.com" in self.endpoint_url:
+ self.presign_client = tos.TosClientV2(
+ self.aws_access_key_id,
+ self.aws_secret_access_key,
+ self.endpoint_url.replace("tos-s3-", "tos-"),
+ self.region,
+ )
+
+ async def init(self):
+ for i in range(self.max_retries):
+ try:
+ logger.info(f"S3DataManager init with config: {self.config} (attempt {i + 1}/{self.max_retries}) ...")
+ s3_config = {"payload_signing_enabled": True}
+ if self.addressing_style:
+ s3_config["addressing_style"] = self.addressing_style
+ self.session = aioboto3.Session()
+ self.s3_client = await self.session.client(
+ "s3",
+ aws_access_key_id=self.aws_access_key_id,
+ aws_secret_access_key=self.aws_secret_access_key,
+ endpoint_url=self.endpoint_url,
+ config=Config(
+ signature_version="s3v4",
+ s3=s3_config,
+ connect_timeout=self.connect_timeout,
+ read_timeout=self.read_timeout,
+ parameter_validation=False,
+ max_pool_connections=50,
+ ),
+ ).__aenter__()
+
+ try:
+ await self.s3_client.head_bucket(Bucket=self.bucket_name)
+ logger.info(f"check bucket {self.bucket_name} success")
+ except Exception as e:
+ logger.info(f"check bucket {self.bucket_name} error: {e}, try to create it...")
+ await self.s3_client.create_bucket(Bucket=self.bucket_name)
+
+ await self.init_presign_client()
+ logger.info(f"Successfully init S3 bucket: {self.bucket_name} with timeouts - connect: {self.connect_timeout}s, read: {self.read_timeout}s, write: {self.write_timeout}s")
+ return
+ except Exception as e:
+ logger.warning(f"Failed to connect to S3: {e}")
+ await asyncio.sleep(1)
+
+ async def close(self):
+ if self.s3_client:
+ await self.s3_client.__aexit__(None, None, None)
+ if self.session:
+ self.session = None
+
+ @class_try_catch_async
+ async def save_bytes(self, bytes_data, filename, abs_path=None):
+ filename = self.fmt_path(self.base_path, filename, abs_path)
+ content_sha256 = hashlib.sha256(bytes_data).hexdigest()
+ await self.s3_client.put_object(
+ Bucket=self.bucket_name,
+ Key=filename,
+ Body=bytes_data,
+ ChecksumSHA256=content_sha256,
+ ContentType="application/octet-stream",
+ )
+ return True
+
+ @class_try_catch_async
+ async def load_bytes(self, filename, abs_path=None):
+ filename = self.fmt_path(self.base_path, filename, abs_path)
+ response = await self.s3_client.get_object(Bucket=self.bucket_name, Key=filename)
+ return await response["Body"].read()
+
+ @class_try_catch_async
+ async def delete_bytes(self, filename, abs_path=None):
+ filename = self.fmt_path(self.base_path, filename, abs_path)
+ await self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
+ logger.info(f"deleted s3 file {filename}")
+ return True
+
+ @class_try_catch_async
+ async def file_exists(self, filename, abs_path=None):
+ filename = self.fmt_path(self.base_path, filename, abs_path)
+ try:
+ await self.s3_client.head_object(Bucket=self.bucket_name, Key=filename)
+ return True
+ except Exception:
+ return False
+
+ @class_try_catch_async
+ async def list_files(self, base_dir=None):
+ if base_dir:
+ prefix = self.fmt_path(self.base_path, None, abs_path=base_dir)
+ else:
+ prefix = self.base_path
+ prefix = prefix + "/" if not prefix.endswith("/") else prefix
+
+ # Handle pagination for S3 list_objects_v2
+ files = []
+ continuation_token = None
+ page = 1
+
+ while True:
+ list_kwargs = {"Bucket": self.bucket_name, "Prefix": prefix, "MaxKeys": 1000}
+ if continuation_token:
+ list_kwargs["ContinuationToken"] = continuation_token
+ response = await self.s3_client.list_objects_v2(**list_kwargs)
+
+ if "Contents" in response:
+ page_files = []
+ for obj in response["Contents"]:
+ # Remove the prefix from the key to get just the filename
+ key = obj["Key"]
+ if key.startswith(prefix):
+ filename = key[len(prefix) :]
+ if filename: # Skip empty filenames (the directory itself)
+ page_files.append(filename)
+ files.extend(page_files)
+ else:
+ logger.warning(f"[S3DataManager.list_files] Page {page}: No files found in this page.")
+
+ # Check if there are more pages
+ if response.get("IsTruncated", False):
+ continuation_token = response.get("NextContinuationToken")
+ page += 1
+ else:
+ break
+ return files
+
+ @class_try_catch_async
+ async def presign_url(self, filename, abs_path=None):
+ filename = self.fmt_path(self.base_path, filename, abs_path)
+ if self.cdn_url:
+ return f"{self.cdn_url}/{filename}"
+
+ if self.presign_client:
+ expires = self.config.get("presign_expires", 24 * 60 * 60)
+ out = await asyncio.to_thread(self.presign_client.pre_signed_url, tos.HttpMethodType.Http_Method_Get, self.bucket_name, filename, expires)
+ return out.signed_url
+ else:
+ return None
+
+ @class_try_catch_async
+ async def create_podcast_temp_session_dir(self, session_id):
+ pass
+
+ @class_try_catch_async
+ async def clear_podcast_temp_session_dir(self, session_id):
+ session_dir = os.path.join(self.podcast_temp_session_dir, session_id)
+ fs = await self.list_files(base_dir=session_dir)
+ logger.info(f"clear podcast temp session dir {session_dir} with files: {fs}")
+ for f in fs:
+ await self.delete_bytes(f, abs_path=os.path.join(session_dir, f))
+
+
+async def test():
+ import torch
+ from PIL import Image
+
+ s3_config = {
+ "aws_access_key_id": "xxx",
+ "aws_secret_access_key": "xx",
+ "endpoint_url": "xxx",
+ "bucket_name": "xxx",
+ "base_path": "xxx",
+ "connect_timeout": 10,
+ "read_timeout": 10,
+ "write_timeout": 10,
+ }
+
+ m = S3DataManager(json.dumps(s3_config), None)
+ await m.init()
+
+ img = Image.open("../../../assets/img_lightx2v.png")
+ tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0")
+
+ await m.save_image(img, "test_img.png")
+ print(await m.load_image("test_img.png"))
+
+ await m.save_tensor(tensor, "test_tensor.pt")
+ print(await m.load_tensor("test_tensor.pt", "cuda:0"))
+
+ await m.save_object(
+ {
+ "images": [img, img],
+ "tensor": tensor,
+ "list": [
+ [2, 0, 5, 5],
+ {
+ "1": "hello world",
+ "2": "world",
+ "3": img,
+ "t": tensor,
+ },
+ "0609",
+ ],
+ },
+ "test_object.json",
+ )
+ print(await m.load_object("test_object.json", "cuda:0"))
+
+ print("all files:", await m.list_files())
+ await m.get_delete_func("OBJECT")("test_object.json")
+ await m.get_delete_func("TENSOR")("test_tensor.pt")
+ await m.get_delete_func("IMAGE")("test_img.png")
+ print("after delete all files", await m.list_files())
+ await m.close()
+
+
+if __name__ == "__main__":
+ asyncio.run(test())
diff --git a/lightx2v/deploy/queue_manager/__init__.py b/lightx2v/deploy/queue_manager/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee49b849bbf38c7584a797eb718f7bc0cdf6be1
--- /dev/null
+++ b/lightx2v/deploy/queue_manager/__init__.py
@@ -0,0 +1,25 @@
+class BaseQueueManager:
+ def __init__(self):
+ pass
+
+ async def init(self):
+ pass
+
+ async def close(self):
+ pass
+
+ async def put_subtask(self, subtask):
+ raise NotImplementedError
+
+ async def get_subtasks(self, queue, max_batch, timeout):
+ raise NotImplementedError
+
+ async def pending_num(self, queue):
+ raise NotImplementedError
+
+
+# Import queue manager implementations
+from .local_queue_manager import LocalQueueManager # noqa
+from .rabbitmq_queue_manager import RabbitMQQueueManager # noqa
+
+__all__ = ["BaseQueueManager", "LocalQueueManager", "RabbitMQQueueManager"]
diff --git a/lightx2v/deploy/queue_manager/local_queue_manager.py b/lightx2v/deploy/queue_manager/local_queue_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a8752953ecf2bd78dc6acc38e1027a15ee04b32
--- /dev/null
+++ b/lightx2v/deploy/queue_manager/local_queue_manager.py
@@ -0,0 +1,113 @@
+import asyncio
+import json
+import os
+import time
+import traceback
+
+from loguru import logger
+
+from lightx2v.deploy.common.utils import class_try_catch_async
+from lightx2v.deploy.queue_manager import BaseQueueManager
+
+
+class LocalQueueManager(BaseQueueManager):
+ def __init__(self, local_dir):
+ self.local_dir = local_dir
+ if not os.path.exists(self.local_dir):
+ os.makedirs(self.local_dir)
+
+ async def get_conn(self):
+ pass
+
+ async def del_conn(self):
+ pass
+
+ async def declare_queue(self, queue):
+ pass
+
+ @class_try_catch_async
+ async def put_subtask(self, subtask):
+ out_name = self.get_filename(subtask["queue"])
+ keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"]
+ msg = json.dumps({k: subtask[k] for k in keys}) + "\n"
+ logger.info(f"Local published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {subtask['queue']}")
+ with open(out_name, "a") as fout:
+ fout.write(msg)
+ return True
+
+ def read_first_line(self, queue):
+ out_name = self.get_filename(queue)
+ if not os.path.exists(out_name):
+ return None
+ lines = []
+ with open(out_name) as fin:
+ lines = fin.readlines()
+ if len(lines) <= 0:
+ return None
+ subtask = json.loads(lines[0])
+ msgs = "".join(lines[1:])
+ fout = open(out_name, "w")
+ fout.write(msgs)
+ fout.close()
+ return subtask
+
+ @class_try_catch_async
+ async def get_subtasks(self, queue, max_batch, timeout):
+ try:
+ t0 = time.time()
+ subtasks = []
+ while True:
+ subtask = self.read_first_line(queue)
+ if subtask:
+ subtasks.append(subtask)
+ if len(subtasks) >= max_batch:
+ return subtasks
+ else:
+ continue
+ else:
+ if len(subtasks) > 0:
+ return subtasks
+ if time.time() - t0 > timeout:
+ return None
+ await asyncio.sleep(1)
+ except asyncio.CancelledError:
+ logger.warning(f"local queue get_subtasks for {queue} cancelled")
+ return None
+ except: # noqa
+ logger.warning(f"local queue get_subtasks for {queue} failed: {traceback.format_exc()}")
+ return None
+
+ def get_filename(self, queue):
+ return os.path.join(self.local_dir, f"{queue}.jsonl")
+
+ @class_try_catch_async
+ async def pending_num(self, queue):
+ out_name = self.get_filename(queue)
+ if not os.path.exists(out_name):
+ return 0
+ lines = []
+ with open(out_name) as fin:
+ lines = fin.readlines()
+ return len(lines)
+
+
+async def test():
+ q = LocalQueueManager("/data/nvme1/liuliang1/lightx2v/local_queue")
+ await q.init()
+ subtask = {
+ "task_id": "test-subtask-id",
+ "queue": "test_queue",
+ "worker_name": "test_worker",
+ "inputs": {},
+ "outputs": {},
+ "params": {},
+ }
+ await q.put_subtask(subtask)
+ await asyncio.sleep(5)
+ for i in range(2):
+ subtask = await q.get_subtasks("test_queue", 3, 5)
+ print("get subtask:", subtask)
+
+
+if __name__ == "__main__":
+ asyncio.run(test())
diff --git a/lightx2v/deploy/queue_manager/rabbitmq_queue_manager.py b/lightx2v/deploy/queue_manager/rabbitmq_queue_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe8c3f180b4cebe9bfb58381055c0d8bacd42372
--- /dev/null
+++ b/lightx2v/deploy/queue_manager/rabbitmq_queue_manager.py
@@ -0,0 +1,124 @@
+import asyncio
+import json
+import traceback
+
+import aio_pika
+from loguru import logger
+
+from lightx2v.deploy.common.utils import class_try_catch_async
+from lightx2v.deploy.queue_manager import BaseQueueManager
+
+
+class RabbitMQQueueManager(BaseQueueManager):
+ def __init__(self, conn_url, max_retries=3):
+ self.conn_url = conn_url
+ self.max_retries = max_retries
+ self.conn = None
+ self.chan = None
+ self.queues = set()
+
+ async def init(self):
+ await self.get_conn()
+
+ async def close(self):
+ await self.del_conn()
+
+ async def get_conn(self):
+ if self.chan and self.conn:
+ return
+ for i in range(self.max_retries):
+ try:
+ logger.info(f"Connect to RabbitMQ (attempt {i + 1}/{self.max_retries}..)")
+ self.conn = await aio_pika.connect_robust(self.conn_url)
+ self.chan = await self.conn.channel()
+ self.queues = set()
+ await self.chan.set_qos(prefetch_count=10)
+ logger.info("Successfully connected to RabbitMQ")
+ return
+ except Exception as e:
+ logger.warning(f"Failed to connect to RabbitMQ: {e}")
+ if i < self.max_retries - 1:
+ await asyncio.sleep(1)
+ else:
+ raise
+
+ async def declare_queue(self, queue):
+ if queue not in self.queues:
+ await self.get_conn()
+ await self.chan.declare_queue(queue, durable=True)
+ self.queues.add(queue)
+ return await self.chan.get_queue(queue)
+
+ @class_try_catch_async
+ async def put_subtask(self, subtask):
+ queue = subtask["queue"]
+ await self.declare_queue(queue)
+ keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"]
+ msg = json.dumps({k: subtask[k] for k in keys}).encode("utf-8")
+ message = aio_pika.Message(body=msg, delivery_mode=aio_pika.DeliveryMode.PERSISTENT, content_type="application/json")
+ await self.chan.default_exchange.publish(message, routing_key=queue)
+ logger.info(f"Rabbitmq published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {queue}")
+ return True
+
+ async def get_subtasks(self, queue, max_batch, timeout):
+ try:
+ q = await self.declare_queue(queue)
+ subtasks = []
+ async with q.iterator() as qiter:
+ async for message in qiter:
+ await message.ack()
+ subtask = json.loads(message.body.decode("utf-8"))
+ subtasks.append(subtask)
+ if len(subtasks) >= max_batch:
+ return subtasks
+ while True:
+ message = await q.get(no_ack=False, fail=False)
+ if message:
+ await message.ack()
+ subtask = json.loads(message.body.decode("utf-8"))
+ subtasks.append(subtask)
+ if len(subtasks) >= max_batch:
+ return subtasks
+ else:
+ return subtasks
+ except asyncio.CancelledError:
+ logger.warning(f"rabbitmq get_subtasks for {queue} cancelled")
+ return None
+ except: # noqa
+ logger.warning(f"rabbitmq get_subtasks for {queue} failed: {traceback.format_exc()}")
+ return None
+
+ @class_try_catch_async
+ async def pending_num(self, queue):
+ q = await self.declare_queue(queue)
+ return q.declaration_result.message_count
+
+ async def del_conn(self):
+ if self.chan:
+ await self.chan.close()
+ if self.conn:
+ await self.conn.close()
+
+
+async def test():
+ conn_url = "amqp://username:password@127.0.0.1:5672"
+ q = RabbitMQQueueManager(conn_url)
+ await q.init()
+ subtask = {
+ "task_id": "test-subtask-id",
+ "queue": "test_queue",
+ "worker_name": "test_worker",
+ "inputs": {},
+ "outputs": {},
+ "params": {},
+ }
+ await q.put_subtask(subtask)
+ await asyncio.sleep(5)
+ for i in range(2):
+ subtask = await q.get_subtasks("test_queue", 3, 5)
+ print("get subtask:", subtask)
+ await q.close()
+
+
+if __name__ == "__main__":
+ asyncio.run(test())
diff --git a/lightx2v/deploy/server/__init__.py b/lightx2v/deploy/server/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lightx2v/deploy/server/__main__.py b/lightx2v/deploy/server/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..348fe074b58ff0454e8b2bc6e18e074ff9a82d87
--- /dev/null
+++ b/lightx2v/deploy/server/__main__.py
@@ -0,0 +1,1490 @@
+import argparse
+import asyncio
+import base64
+import copy
+import json
+import mimetypes
+import os
+import re
+import tempfile
+import traceback
+import uuid
+from contextlib import asynccontextmanager
+
+import uvicorn
+from fastapi import Depends, FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, Response
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from fastapi.staticfiles import StaticFiles
+from loguru import logger
+from pydantic import BaseModel
+
+from lightx2v.deploy.common.audio_separator import AudioSeparator
+from lightx2v.deploy.common.face_detector import FaceDetector
+from lightx2v.deploy.common.pipeline import Pipeline
+from lightx2v.deploy.common.podcasts import VolcEnginePodcastClient
+from lightx2v.deploy.common.utils import check_params, data_name, fetch_resource, format_image_data, load_inputs
+from lightx2v.deploy.common.volcengine_tts import VolcEngineTTSClient
+from lightx2v.deploy.data_manager import LocalDataManager, S3DataManager
+from lightx2v.deploy.queue_manager import LocalQueueManager, RabbitMQQueueManager
+from lightx2v.deploy.server.auth import AuthManager
+from lightx2v.deploy.server.metrics import MetricMonitor
+from lightx2v.deploy.server.monitor import ServerMonitor, WorkerStatus
+from lightx2v.deploy.server.redis_monitor import RedisServerMonitor
+from lightx2v.deploy.task_manager import FinishedStatus, LocalTaskManager, PostgresSQLTaskManager, TaskStatus
+from lightx2v.utils.service_utils import ProcessManager
+
+# =========================
+# Pydantic Models
+# =========================
+
+
+class TTSRequest(BaseModel):
+ text: str
+ voice_type: str
+ context_texts: str = ""
+ emotion: str = ""
+ emotion_scale: int = 3
+ speech_rate: int = 0
+ pitch: int = 0
+ loudness_rate: int = 0
+ resource_id: str = "seed-tts-1.0"
+
+
+class RefreshTokenRequest(BaseModel):
+ refresh_token: str
+
+
+# =========================
+# FastAPI Related Code
+# =========================
+
+model_pipelines = None
+task_manager = None
+data_manager = None
+queue_manager = None
+server_monitor = None
+auth_manager = None
+metrics_monitor = MetricMonitor()
+volcengine_tts_client = None
+volcengine_podcast_client = None
+face_detector = None
+audio_separator = None
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ await task_manager.init()
+ await task_manager.mark_server_restart()
+ await data_manager.init()
+ await queue_manager.init()
+ await server_monitor.init()
+ asyncio.create_task(server_monitor.loop())
+ yield
+ await server_monitor.close()
+ await queue_manager.close()
+ await data_manager.close()
+ await task_manager.close()
+
+
+app = FastAPI(lifespan=lifespan)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+
+@app.exception_handler(HTTPException)
+async def http_exception_handler(request: Request, exc: HTTPException):
+ logger.error(f"HTTP Exception: {exc.status_code} - {exc.detail} for {request.url}")
+ return JSONResponse(status_code=exc.status_code, content={"message": exc.detail})
+
+
+static_dir = os.path.join(os.path.dirname(__file__), "static")
+app.mount("/static", StaticFiles(directory=static_dir), name="static")
+
+# 添加assets目录的静态文件服务
+assets_dir = os.path.join(os.path.dirname(__file__), "static", "assets")
+app.mount("/assets", StaticFiles(directory=assets_dir), name="assets")
+security = HTTPBearer()
+
+
+async def verify_user_access(credentials: HTTPAuthorizationCredentials = Depends(security)):
+ token = credentials.credentials
+ payload = auth_manager.verify_jwt_token(token)
+ user_id = payload.get("user_id", None)
+ if not user_id:
+ raise HTTPException(status_code=401, detail="Invalid user")
+ user = await task_manager.query_user(user_id)
+ # logger.info(f"Verfiy user access: {payload}")
+ if user is None or user["user_id"] != user_id:
+ raise HTTPException(status_code=401, detail="Invalid user")
+ return user
+
+
+async def verify_user_access_from_query(request: Request):
+ """从查询参数中验证用户访问权限"""
+ # 首先尝试从 Authorization 头部获取 token
+ auth_header = request.headers.get("Authorization")
+ token = None
+
+ if auth_header and auth_header.startswith("Bearer "):
+ token = auth_header[7:] # 移除 "Bearer " 前缀
+ else:
+ # 如果没有 Authorization 头部,尝试从查询参数获取
+ token = request.query_params.get("token")
+
+ payload = auth_manager.verify_jwt_token(token)
+ user_id = payload.get("user_id", None)
+ if not user_id:
+ raise HTTPException(status_code=401, detail="Invalid user")
+ user = await task_manager.query_user(user_id)
+ if user is None or user["user_id"] != user_id:
+ raise HTTPException(status_code=401, detail="Invalid user")
+ return user
+
+
+async def verify_worker_access(credentials: HTTPAuthorizationCredentials = Depends(security)):
+ token = credentials.credentials
+ if not auth_manager.verify_worker_token(token):
+ raise HTTPException(status_code=403, detail="Invalid worker token")
+ return True
+
+
+def error_response(e, code):
+ return JSONResponse({"message": f"error: {e}!"}, status_code=code)
+
+
+def format_user_response(user):
+ return {
+ "user_id": user.get("user_id"),
+ "id": user.get("id"),
+ "source": user.get("source"),
+ "username": user.get("username") or "",
+ "email": user.get("email") or "",
+ "homepage": user.get("homepage") or "",
+ "avatar_url": user.get("avatar_url") or "",
+ }
+
+
+def guess_file_type(name, default_type):
+ content_type, _ = mimetypes.guess_type(name)
+ if content_type is None:
+ content_type = default_type
+ return content_type
+
+
+@app.get("/", response_class=HTMLResponse)
+async def root():
+ with open(os.path.join(static_dir, "index.html"), "r", encoding="utf-8") as f:
+ return HTMLResponse(content=f.read())
+
+
+@app.get("/sitemap.xml", response_class=HTMLResponse)
+async def sitemap():
+ with open(os.path.join(os.path.dirname(__file__), "frontend", "dist", "sitemap.xml"), "r", encoding="utf-8") as f:
+ return HTMLResponse(content=f.read())
+
+
+@app.get("/auth/login/github")
+async def github_auth(request: Request):
+ client_id = auth_manager.github_client_id
+ redirect_uri = f"{request.base_url}"
+ auth_url = f"https://github.com/login/oauth/authorize?client_id={client_id}&redirect_uri={redirect_uri}"
+ return {"auth_url": auth_url}
+
+
+@app.get("/auth/callback/github")
+async def github_callback(request: Request):
+ try:
+ code = request.query_params.get("code")
+ if not code:
+ return error_response("Missing authorization code", 400)
+ user_info = await auth_manager.auth_github(code)
+ user_id = await task_manager.create_user(user_info)
+ user_info["user_id"] = user_id
+ user_response = format_user_response(user_info)
+ access_token, refresh_token = auth_manager.create_tokens(user_response)
+ logger.info(f"GitHub callback: user_info: {user_response}, access token issued")
+ return {"access_token": access_token, "refresh_token": refresh_token, "user_info": user_response}
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/auth/login/google")
+async def google_auth(request: Request):
+ client_id = auth_manager.google_client_id
+ redirect_uri = auth_manager.google_redirect_uri
+ auth_url = f"https://accounts.google.com/o/oauth2/v2/auth?client_id={client_id}&redirect_uri={redirect_uri}&response_type=code&scope=openid%20email%20profile&access_type=offline"
+ logger.info(f"Google auth: auth_url: {auth_url}")
+ return {"auth_url": auth_url}
+
+
+@app.get("/auth/callback/google")
+async def google_callback(request: Request):
+ try:
+ code = request.query_params.get("code")
+ if not code:
+ return error_response("Missing authorization code", 400)
+ user_info = await auth_manager.auth_google(code)
+ user_id = await task_manager.create_user(user_info)
+ user_info["user_id"] = user_id
+ user_response = format_user_response(user_info)
+ access_token, refresh_token = auth_manager.create_tokens(user_response)
+ logger.info(f"Google callback: user_info: {user_response}, access token issued")
+ return {"access_token": access_token, "refresh_token": refresh_token, "user_info": user_response}
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/auth/login/sms")
+async def sms_auth(request: Request):
+ try:
+ phone_number = request.query_params.get("phone_number")
+ if not phone_number:
+ return error_response("Missing phone number", 400)
+ ok = await auth_manager.send_sms(phone_number)
+ if not ok:
+ return error_response("SMS send failed", 400)
+ return {"msg": "SMS send successfully"}
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/auth/callback/sms")
+async def sms_callback(request: Request):
+ try:
+ phone_number = request.query_params.get("phone_number")
+ verify_code = request.query_params.get("verify_code")
+ if not phone_number or not verify_code:
+ return error_response("Missing phone number or verify code", 400)
+ user_info = await auth_manager.check_sms(phone_number, verify_code)
+ if not user_info:
+ return error_response("SMS verify failed", 400)
+
+ user_id = await task_manager.create_user(user_info)
+ user_info["user_id"] = user_id
+ user_response = format_user_response(user_info)
+ access_token, refresh_token = auth_manager.create_tokens(user_response)
+ logger.info(f"SMS callback: user_info: {user_response}, access token issued")
+ return {"access_token": access_token, "refresh_token": refresh_token, "user_info": user_response}
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.post("/auth/refresh")
+async def refresh_access_token(request: RefreshTokenRequest):
+ try:
+ payload = auth_manager.verify_refresh_token(request.refresh_token)
+ user_id = payload.get("user_id")
+ if not user_id:
+ raise HTTPException(status_code=401, detail="Invalid refresh token")
+ user = await task_manager.query_user(user_id)
+ if user is None or user.get("user_id") != user_id:
+ raise HTTPException(status_code=401, detail="Invalid user")
+ user_info = format_user_response(user)
+ access_token, refresh_token = auth_manager.create_tokens(user_info)
+ return {"access_token": access_token, "refresh_token": refresh_token, "user_info": user_info}
+ except HTTPException as exc:
+ raise exc
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+async def prepare_subtasks(task_id):
+ # schedule next subtasks and pend, put to message queue
+ subtasks = await task_manager.next_subtasks(task_id)
+ for sub in subtasks:
+ logger.info(f"Prepare ready subtask: ({task_id}, {sub['worker_name']})")
+ r = await queue_manager.put_subtask(sub)
+ assert r, "put subtask to queue error"
+ await server_monitor.pending_subtasks_add(sub["queue"], sub["task_id"])
+
+
+def format_task(task):
+ task["status"] = task["status"].name
+ task["model_cls"] = model_pipelines.outer_model_name(task["model_cls"])
+
+
+@app.get("/api/v1/model/list")
+async def api_v1_model_list(user=Depends(verify_user_access)):
+ try:
+ return {"models": model_pipelines.get_model_lists()}
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.post("/api/v1/task/submit")
+async def api_v1_task_submit(request: Request, user=Depends(verify_user_access)):
+ task_id = None
+ try:
+ msg = await server_monitor.check_user_busy(user["user_id"], active_new_task=True)
+ if msg is not True:
+ return error_response(msg, 400)
+ params = await request.json()
+ keys = [params.pop("task"), params.pop("model_cls"), params.pop("stage")]
+ keys[1] = model_pipelines.inner_model_name(keys[1])
+ assert len(params["prompt"]) > 0, "valid prompt is required"
+
+ # get worker infos, model input names
+ workers = model_pipelines.get_workers(keys)
+ inputs = model_pipelines.get_inputs(keys)
+ outputs = model_pipelines.get_outputs(keys)
+ types = model_pipelines.get_types(keys)
+ check_params(params, inputs, outputs, types)
+
+ # check if task can be published to queues
+ queues = [v["queue"] for v in workers.values()]
+ wait_time = await server_monitor.check_queue_busy(keys, queues)
+ if wait_time is None:
+ return error_response(f"Queue busy, please try again later", 500)
+
+ # process multimodal inputs data
+ inputs_data = await load_inputs(params, inputs, types)
+
+ # init task (we need task_id before preprocessing to save processed files)
+ task_id = await task_manager.create_task(keys, workers, params, inputs, outputs, user["user_id"])
+ logger.info(f"Submit task: {task_id} {params}")
+
+ # save multimodal inputs data
+ for inp, data in inputs_data.items():
+ await data_manager.save_bytes(data, data_name(inp, task_id))
+
+ await prepare_subtasks(task_id)
+ return {"task_id": task_id, "workers": workers, "params": params, "wait_time": wait_time}
+
+ except Exception as e:
+ traceback.print_exc()
+ if task_id:
+ await task_manager.finish_subtasks(task_id, TaskStatus.FAILED, fail_msg=f"submit failed: {e}")
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/task/query")
+async def api_v1_task_query(request: Request, user=Depends(verify_user_access)):
+ try:
+ if "task_ids" in request.query_params:
+ task_ids = request.query_params["task_ids"].split(",")
+ tasks = []
+ for task_id in task_ids:
+ task_id = task_id.strip()
+ if task_id:
+ task, subtasks = await task_manager.query_task(task_id, user["user_id"], only_task=False)
+ if task is not None:
+ task["subtasks"] = await server_monitor.format_subtask(subtasks)
+ format_task(task)
+ tasks.append(task)
+ return {"tasks": tasks}
+
+ # 单个任务查询
+ task_id = request.query_params["task_id"]
+ task, subtasks = await task_manager.query_task(task_id, user["user_id"], only_task=False)
+ if task is None:
+ return error_response(f"Task {task_id} not found", 404)
+ task["subtasks"] = await server_monitor.format_subtask(subtasks)
+ format_task(task)
+ return task
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/task/list")
+async def api_v1_task_list(request: Request, user=Depends(verify_user_access)):
+ try:
+ user_id = user["user_id"]
+ page = int(request.query_params.get("page", 1))
+ page_size = int(request.query_params.get("page_size", 10))
+ assert page > 0 and page_size > 0, "page and page_size must be greater than 0"
+ status_filter = request.query_params.get("status", None)
+
+ query_params = {"user_id": user_id}
+ if status_filter and status_filter != "ALL":
+ query_params["status"] = TaskStatus[status_filter.upper()]
+
+ total_tasks = await task_manager.list_tasks(count=True, **query_params)
+ total_pages = (total_tasks + page_size - 1) // page_size
+ page_info = {"page": page, "page_size": page_size, "total": total_tasks, "total_pages": total_pages}
+ if page > total_pages:
+ return {"tasks": [], "pagination": page_info}
+
+ query_params["offset"] = (page - 1) * page_size
+ query_params["limit"] = page_size
+
+ tasks = await task_manager.list_tasks(**query_params)
+ for task in tasks:
+ format_task(task)
+
+ return {"tasks": tasks, "pagination": page_info}
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/task/result_url")
+async def api_v1_task_result_url(request: Request, user=Depends(verify_user_access)):
+ try:
+ name = request.query_params["name"]
+ task_id = request.query_params["task_id"]
+ task = await task_manager.query_task(task_id, user_id=user["user_id"])
+ assert task is not None, f"Task {task_id} not found"
+ assert task["status"] == TaskStatus.SUCCEED, f"Task {task_id} not succeed"
+ assert name in task["outputs"], f"Output {name} not found in task {task_id}"
+ assert name not in task["params"], f"Output {name} is a stream"
+
+ url = await data_manager.presign_url(task["outputs"][name])
+ if url is None:
+ url = f"./assets/task/result?task_id={task_id}&name={name}"
+ return {"url": url}
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/task/input_url")
+async def api_v1_task_input_url(request: Request, user=Depends(verify_user_access)):
+ try:
+ name = request.query_params["name"]
+ task_id = request.query_params["task_id"]
+ filename = request.query_params.get("filename", None)
+
+ task = await task_manager.query_task(task_id, user_id=user["user_id"])
+ assert task is not None, f"Task {task_id} not found"
+ assert name in task["inputs"], f"Input {name} not found in task {task_id}"
+ if name in task["params"]:
+ return error_response(f"Input {name} is a stream", 400)
+
+ # eg, multi person audio directory input
+ if filename is not None:
+ extra_inputs = task["params"]["extra_inputs"][name]
+ name = f"{name}/{filename}"
+ assert name in task["inputs"], f"Extra input {name} not found in task {task_id}"
+ assert name in extra_inputs, f"Filename {filename} not found in extra inputs"
+
+ url = await data_manager.presign_url(task["inputs"][name])
+ if url is None:
+ url = f"./assets/task/input?task_id={task_id}&name={name}"
+ if filename is not None:
+ url += f"&filename={filename}"
+ return {"url": url}
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/assets/task/result")
+async def assets_task_result(request: Request, user=Depends(verify_user_access_from_query)):
+ try:
+ name = request.query_params["name"]
+ task_id = request.query_params["task_id"]
+ task = await task_manager.query_task(task_id, user_id=user["user_id"])
+ assert task is not None, f"Task {task_id} not found"
+ assert task["status"] == TaskStatus.SUCCEED, f"Task {task_id} not succeed"
+ assert name in task["outputs"], f"Output {name} not found in task {task_id}"
+ assert name not in task["params"], f"Output {name} is a stream"
+ data = await data_manager.load_bytes(task["outputs"][name])
+
+ # set correct Content-Type
+ content_type = guess_file_type(name, "application/octet-stream")
+ headers = {"Content-Disposition": f'attachment; filename="{name}"'}
+ headers["Cache-Control"] = "public, max-age=3600"
+ return Response(content=data, media_type=content_type, headers=headers)
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/assets/task/input")
+async def assets_task_input(request: Request, user=Depends(verify_user_access_from_query)):
+ try:
+ name = request.query_params["name"]
+ task_id = request.query_params["task_id"]
+ filename = request.query_params.get("filename", None)
+
+ task = await task_manager.query_task(task_id, user_id=user["user_id"])
+ assert task is not None, f"Task {task_id} not found"
+ assert name in task["inputs"], f"Input {name} not found in task {task_id}"
+ if name in task["params"]:
+ return error_response(f"Input {name} is a stream", 400)
+
+ # eg, multi person audio directory input
+ if filename is not None:
+ extra_inputs = task["params"]["extra_inputs"][name]
+ name = f"{name}/{filename}"
+ assert name in task["inputs"], f"Extra input {name} not found in task {task_id}"
+ assert name in extra_inputs, f"Filename {filename} not found in extra inputs"
+ data = await data_manager.load_bytes(task["inputs"][name])
+
+ # set correct Content-Type
+ content_type = guess_file_type(name, "application/octet-stream")
+ headers = {"Content-Disposition": f'attachment; filename="{name}"'}
+ headers["Cache-Control"] = "public, max-age=3600"
+ return Response(content=data, media_type=content_type, headers=headers)
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/task/cancel")
+async def api_v1_task_cancel(request: Request, user=Depends(verify_user_access)):
+ try:
+ task_id = request.query_params["task_id"]
+ ret = await task_manager.cancel_task(task_id, user_id=user["user_id"])
+ logger.warning(f"Task {task_id} cancelled: {ret}")
+ if ret is True:
+ return {"msg": "Task cancelled successfully"}
+ else:
+ return error_response({"error": f"Task {task_id} cancel failed: {ret}"}, 400)
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/task/resume")
+async def api_v1_task_resume(request: Request, user=Depends(verify_user_access)):
+ try:
+ task_id = request.query_params["task_id"]
+
+ task = await task_manager.query_task(task_id, user_id=user["user_id"])
+ keys = [task["task_type"], task["model_cls"], task["stage"]]
+ if not model_pipelines.check_item_by_keys(keys):
+ return error_response(f"Model {keys} is not supported now, please submit a new task", 400)
+
+ ret = await task_manager.resume_task(task_id, user_id=user["user_id"], all_subtask=False)
+ if ret is True:
+ await prepare_subtasks(task_id)
+ return {"msg": "ok"}
+ else:
+ return error_response(f"Task {task_id} resume failed: {ret}", 400)
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.delete("/api/v1/task/delete")
+async def api_v1_task_delete(request: Request, user=Depends(verify_user_access)):
+ try:
+ task_id = request.query_params["task_id"]
+
+ task = await task_manager.query_task(task_id, user["user_id"], only_task=True)
+ if not task:
+ return error_response("Task not found", 404)
+
+ if task["status"] not in FinishedStatus:
+ return error_response("Only finished tasks can be deleted", 400)
+
+ # delete task record
+ success = await task_manager.delete_task(task_id, user["user_id"])
+ if success:
+ logger.info(f"Task {task_id} deleted by user {user['user_id']}")
+ return JSONResponse({"message": "Task deleted successfully"})
+ else:
+ return error_response("Failed to delete task", 400)
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.post("/api/v1/worker/fetch")
+async def api_v1_worker_fetch(request: Request, valid=Depends(verify_worker_access)):
+ try:
+ params = await request.json()
+ logger.info(f"Worker fetching: {params}")
+ keys = params.pop("worker_keys")
+ identity = params.pop("worker_identity")
+ max_batch = params.get("max_batch", 1)
+ timeout = params.get("timeout", 5)
+
+ # check client disconnected
+ async def check_client(request, fetch_task, identity, queue):
+ while True:
+ msg = await request.receive()
+ if msg["type"] == "http.disconnect":
+ logger.warning(f"Worker {identity} {queue} disconnected, req: {request.client}, {msg}")
+ fetch_task.cancel()
+ await server_monitor.worker_update(queue, identity, WorkerStatus.DISCONNECT)
+ return
+ await asyncio.sleep(1)
+
+ # get worker info
+ worker = model_pipelines.get_worker(keys)
+ await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.FETCHING)
+
+ fetch_task = asyncio.create_task(queue_manager.get_subtasks(worker["queue"], max_batch, timeout))
+ check_task = asyncio.create_task(check_client(request, fetch_task, identity, worker["queue"]))
+ try:
+ subtasks = await asyncio.wait_for(fetch_task, timeout=timeout)
+ except asyncio.TimeoutError:
+ subtasks = []
+ fetch_task.cancel()
+ check_task.cancel()
+
+ subtasks = [] if subtasks is None else subtasks
+ for sub in subtasks:
+ await server_monitor.pending_subtasks_sub(sub["queue"], sub["task_id"])
+ valid_subtasks = await task_manager.run_subtasks(subtasks, identity)
+ valids = [sub["task_id"] for sub in valid_subtasks]
+
+ if len(valid_subtasks) > 0:
+ await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.FETCHED)
+ logger.info(f"Worker {identity} {keys} {request.client} fetched {valids}")
+ else:
+ await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.DISCONNECT)
+ return {"subtasks": valid_subtasks}
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.post("/api/v1/worker/report")
+async def api_v1_worker_report(request: Request, valid=Depends(verify_worker_access)):
+ try:
+ params = await request.json()
+ logger.info(f"{params}")
+ task_id = params.pop("task_id")
+ worker_name = params.pop("worker_name")
+ status = TaskStatus[params.pop("status")]
+ identity = params.pop("worker_identity")
+ queue = params.pop("queue")
+ fail_msg = params.pop("fail_msg", None)
+ await server_monitor.worker_update(queue, identity, WorkerStatus.REPORT)
+
+ ret = await task_manager.finish_subtasks(task_id, status, worker_identity=identity, worker_name=worker_name, fail_msg=fail_msg, should_running=True)
+
+ # not all subtasks finished, prepare new ready subtasks
+ if ret not in [TaskStatus.SUCCEED, TaskStatus.FAILED]:
+ await prepare_subtasks(task_id)
+
+ # all subtasks succeed, delete temp data
+ elif ret == TaskStatus.SUCCEED:
+ logger.info(f"Task {task_id} succeed")
+ task = await task_manager.query_task(task_id)
+ keys = [task["task_type"], task["model_cls"], task["stage"]]
+ temps = model_pipelines.get_temps(keys)
+ for temp in temps:
+ type = model_pipelines.get_type(temp)
+ name = data_name(temp, task_id)
+ await data_manager.get_delete_func(type)(name)
+
+ elif ret == TaskStatus.FAILED:
+ logger.warning(f"Task {task_id} failed")
+
+ return {"msg": "ok"}
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.post("/api/v1/worker/ping/subtask")
+async def api_v1_worker_ping_subtask(request: Request, valid=Depends(verify_worker_access)):
+ try:
+ params = await request.json()
+ logger.info(f"{params}")
+ task_id = params.pop("task_id")
+ worker_name = params.pop("worker_name")
+ identity = params.pop("worker_identity")
+ queue = params.pop("queue")
+
+ task = await task_manager.query_task(task_id)
+ if task is None or task["status"] != TaskStatus.RUNNING:
+ return {"msg": "delete"}
+
+ assert await task_manager.ping_subtask(task_id, worker_name, identity)
+ await server_monitor.worker_update(queue, identity, WorkerStatus.PING)
+ return {"msg": "ok"}
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/metrics")
+async def api_v1_monitor_metrics():
+ try:
+ return Response(content=metrics_monitor.get_metrics(), media_type="text/plain")
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/template/asset_url/{template_type}/{filename}")
+async def api_v1_template_asset_url(template_type: str, filename: str):
+ """get template asset URL - no authentication required"""
+ try:
+ url = await data_manager.presign_template_url(template_type, filename)
+ if url is None:
+ url = f"./assets/template/{template_type}/{filename}"
+ headers = {"Cache-Control": "public, max-age=3600"}
+ return Response(content=json.dumps({"url": url}), media_type="application/json", headers=headers)
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+# Template API endpoints
+@app.get("/assets/template/{template_type}/{filename}")
+async def assets_template(template_type: str, filename: str):
+ """get template file - no authentication required"""
+ try:
+ if not await data_manager.template_file_exists(template_type, filename):
+ return error_response(f"template file {template_type} {filename} not found", 404)
+ data = await data_manager.load_template_file(template_type, filename)
+
+ # set media type according to file type
+ if template_type == "images":
+ if filename.lower().endswith(".png"):
+ media_type = "image/png"
+ elif filename.lower().endswith((".jpg", ".jpeg")):
+ media_type = "image/jpeg"
+ else:
+ media_type = "application/octet-stream"
+ elif template_type == "audios":
+ if filename.lower().endswith(".mp3"):
+ media_type = "audio/mpeg"
+ elif filename.lower().endswith(".wav"):
+ media_type = "audio/wav"
+ else:
+ media_type = "application/octet-stream"
+ elif template_type == "videos":
+ if filename.lower().endswith(".mp4"):
+ media_type = "video/mp4"
+ elif filename.lower().endswith(".webm"):
+ media_type = "video/webm"
+ elif filename.lower().endswith(".avi"):
+ media_type = "video/x-msvideo"
+ else:
+ media_type = "video/mp4" # default to mp4
+ else:
+ media_type = "application/octet-stream"
+
+ headers = {"Cache-Control": "public, max-age=3600"}
+ return Response(content=data, media_type=media_type, headers=headers)
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/template/list")
+async def api_v1_template_list(request: Request):
+ """get template file list (support pagination) - no authentication required"""
+ try:
+ # check page params
+ page = int(request.query_params.get("page", 1))
+ page_size = int(request.query_params.get("page_size", 12))
+ if page < 1 or page_size < 1:
+ return error_response("page and page_size must be greater than 0", 400)
+ # limit page size
+ page_size = min(page_size, 100)
+
+ all_images = await data_manager.list_template_files("images")
+ all_audios = await data_manager.list_template_files("audios")
+ all_videos = await data_manager.list_template_files("videos")
+ all_images = [] if all_images is None else all_images
+ all_audios = [] if all_audios is None else all_audios
+ all_videos = [] if all_videos is None else all_videos
+
+ # 创建图片文件名(不含扩展名)到图片信息的映射
+ all_images_sorted = sorted(all_images)
+ image_map = {} # 文件名(不含扩展名) -> {"filename": 完整文件名, "url": URL}
+ for img_name in all_images_sorted:
+ img_name_without_ext = img_name.rsplit(".", 1)[0] if "." in img_name else img_name
+ url = await data_manager.presign_template_url("images", img_name)
+ if url is None:
+ url = f"./assets/template/images/{img_name}"
+ image_map[img_name_without_ext] = {"filename": img_name, "url": url}
+
+ # 创建音频文件名(不含扩展名)到音频信息的映射
+ all_audios_sorted = sorted(all_audios)
+ audio_map = {} # 文件名(不含扩展名) -> {"filename": 完整文件名, "url": URL}
+ for audio_name in all_audios_sorted:
+ audio_name_without_ext = audio_name.rsplit(".", 1)[0] if "." in audio_name else audio_name
+ url = await data_manager.presign_template_url("audios", audio_name)
+ if url is None:
+ url = f"./assets/template/audios/{audio_name}"
+ audio_map[audio_name_without_ext] = {"filename": audio_name, "url": url}
+
+ # 合并音频和图片模板,基于文件名前缀匹配
+ # 获取所有唯一的基础文件名(不含扩展名)
+ all_base_names = set(list(image_map.keys()) + list(audio_map.keys()))
+ all_base_names_sorted = sorted(all_base_names)
+
+ # 构建合并后的模板列表
+ merged_templates = []
+ for base_name in all_base_names_sorted:
+ template_item = {
+ "id": base_name, # 使用基础文件名作为ID
+ "image": image_map.get(base_name),
+ "audio": audio_map.get(base_name),
+ }
+ merged_templates.append(template_item)
+
+ # 分页处理
+ total = len(merged_templates)
+ total_pages = (total + page_size - 1) // page_size if total > 0 else 1
+
+ paginated_templates = []
+ if page <= total_pages:
+ start_idx = (page - 1) * page_size
+ end_idx = start_idx + page_size
+ paginated_templates = merged_templates[start_idx:end_idx]
+
+ # 为了保持向后兼容,仍然返回images和audios字段(但可能为空)
+ # 同时添加新的merged字段
+ return {
+ "templates": {
+ "images": [], # 保持向后兼容,但设为空
+ "audios": [], # 保持向后兼容,但设为空
+ "videos": [], # 保持向后兼容
+ "merged": paginated_templates, # 新的合并列表
+ },
+ "pagination": {"page": page, "page_size": page_size, "total": total, "total_pages": total_pages},
+ }
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/template/tasks")
+async def api_v1_template_tasks(request: Request):
+ """get template task list (support pagination) - no authentication required"""
+ try:
+ # check page params
+ page = int(request.query_params.get("page", 1))
+ page_size = int(request.query_params.get("page_size", 12))
+ category = request.query_params.get("category", None)
+ search = request.query_params.get("search", None)
+ if page < 1 or page_size < 1:
+ return error_response("page and page_size must be greater than 0", 400)
+ # limit page size
+ page_size = min(page_size, 100)
+
+ all_templates = []
+ all_categories = set()
+ template_files = await data_manager.list_template_files("tasks")
+ template_files = [] if template_files is None else template_files
+
+ for template_file in template_files:
+ try:
+ bytes_data = await data_manager.load_template_file("tasks", template_file)
+ template_data = json.loads(bytes_data)
+ template_data["task"]["model_cls"] = model_pipelines.outer_model_name(template_data["task"]["model_cls"])
+ all_categories.update(template_data["task"]["tags"])
+ if category and category not in template_data["task"]["tags"]:
+ continue
+ if search and search not in template_data["task"]["params"]["prompt"] + template_data["task"]["params"]["negative_prompt"] + template_data["task"]["model_cls"] + template_data["task"][
+ "stage"
+ ] + template_data["task"]["task_type"] + ",".join(template_data["task"]["tags"]):
+ continue
+ all_templates.append(template_data["task"])
+ except Exception as e:
+ logger.warning(f"Failed to load template file {template_file}: {e}")
+
+ # page info
+ total_templates = len(all_templates)
+ total_pages = (total_templates + page_size - 1) // page_size
+ paginated_templates = []
+
+ if page <= total_pages:
+ start_idx = (page - 1) * page_size
+ end_idx = start_idx + page_size
+ paginated_templates = all_templates[start_idx:end_idx]
+
+ return {"templates": paginated_templates, "pagination": {"page": page, "page_size": page_size, "total": total_templates, "total_pages": total_pages}, "categories": list(all_categories)}
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/template/{template_id}")
+async def api_v1_template_get(template_id: str, user=None):
+ try:
+ template_files = await data_manager.list_template_files("tasks")
+ template_files = [] if template_files is None else template_files
+
+ for template_file in template_files:
+ try:
+ bytes_data = await data_manager.load_template_file("tasks", template_file)
+ template_data = json.loads(bytes_data)
+ if template_data["task"]["task_id"] == template_id:
+ return template_data["task"]
+ except Exception as e:
+ logger.warning(f"Failed to load template file {template_file}: {e}")
+ continue
+ return error_response("Template not found", 404)
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.post("/api/v1/share/create")
+async def api_v1_share_create(request: Request, user=Depends(verify_user_access)):
+ try:
+ params = await request.json()
+ task_id = params["task_id"]
+ valid_days = params.get("valid_days", 7)
+ auth_type = params.get("auth_type", "public")
+ auth_value = params.get("auth_value", "")
+ share_type = params.get("share_type", "task")
+ assert auth_type == "public", "Only public share is supported now"
+
+ if share_type == "template":
+ template = await api_v1_template_get(task_id, user)
+ assert isinstance(template, dict) and template["task_id"] == task_id, f"Template {task_id} not found"
+ else:
+ task = await task_manager.query_task(task_id, user["user_id"], only_task=True)
+ assert task, f"Task {task_id} not found"
+
+ if auth_type == "user_id":
+ assert auth_value != "", "Target user is required for auth_type = user_id"
+ target_user = await task_manager.query_user(auth_value)
+ assert target_user and target_user["user_id"] == auth_value, f"Target user {auth_value} not found"
+
+ share_id = await task_manager.create_share(task_id, user["user_id"], share_type, valid_days, auth_type, auth_value)
+ return {"share_id": share_id, "share_url": f"/share/{share_id}"}
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/share/{share_id}")
+async def api_v1_share_get(share_id: str):
+ try:
+ share_data = await task_manager.query_share(share_id)
+ assert share_data, f"Share {share_id} not found or expired or deleted"
+ task_id = share_data["task_id"]
+ share_type = share_data["share_type"]
+ assert share_data["auth_type"] == "public", "Only public share is supported now"
+
+ if share_type == "template":
+ task = await api_v1_template_get(task_id, None)
+ assert isinstance(task, dict) and task["task_id"] == task_id, f"Template {task_id} not found"
+ else:
+ task = await task_manager.query_task(task_id, only_task=True)
+ assert task, f"Task {task_id} not found"
+
+ user_info = await task_manager.query_user(share_data["user_id"])
+ username = user_info.get("username", "用户") if user_info else "用户"
+
+ share_info = {
+ "task_id": task_id,
+ "share_type": share_type,
+ "user_id": share_data["user_id"],
+ "username": username,
+ "task_type": task["task_type"],
+ "model_cls": task["model_cls"],
+ "stage": task["stage"],
+ "prompt": task["params"].get("prompt", ""),
+ "negative_prompt": task["params"].get("negative_prompt", ""),
+ "inputs": task["inputs"],
+ "outputs": task["outputs"],
+ "create_t": task["create_t"],
+ "valid_days": share_data["valid_days"],
+ "valid_t": share_data["valid_t"],
+ "auth_type": share_data["auth_type"],
+ "auth_value": share_data["auth_value"],
+ "output_video_url": None,
+ "input_urls": {},
+ }
+
+ for input_name, input_filename in task["inputs"].items():
+ if share_type == "template":
+ template_type = "images" if "image" in input_name else "audios"
+ input_url = await data_manager.presign_template_url(template_type, input_filename)
+ else:
+ input_url = await data_manager.presign_url(input_filename)
+ share_info["input_urls"][input_name] = input_url
+
+ for output_name, output_filename in task["outputs"].items():
+ if share_type == "template":
+ assert "video" in output_name, "Only video output is supported for template share"
+ output_url = await data_manager.presign_template_url("videos", output_filename)
+ else:
+ output_url = await data_manager.presign_url(output_filename)
+ share_info["output_video_url"] = output_url
+
+ return share_info
+
+ except Exception as e:
+ traceback.print_exc()
+ return error_response(str(e), 500)
+
+
+@app.get("/api/v1/voices/list")
+async def api_v1_voices_list(request: Request):
+ try:
+ version = request.query_params.get("version", "all")
+ if volcengine_tts_client is None:
+ return error_response("Volcengine TTS client not loaded", 500)
+ voices = volcengine_tts_client.get_voice_list()
+ if voices is None:
+ return error_response("No voice list found", 404)
+ if version != "all":
+ voices = copy.deepcopy(voices)
+ voices["voices"] = [v for v in voices["voices"] if v["version"] == version]
+ return voices
+ except Exception as e:
+ traceback.print_exc()
+ return error_response("Failed to get voice list", 500)
+
+
+@app.post("/api/v1/tts/generate")
+async def api_v1_tts_generate(request: TTSRequest):
+ """Generate TTS audio from text"""
+ try:
+ # Validate parameters
+ if not request.text.strip():
+ return JSONResponse({"error": "Text cannot be empty"}, status_code=400)
+
+ if not request.voice_type:
+ return JSONResponse({"error": "Voice type is required"}, status_code=400)
+
+ # Generate unique output filename
+ output_filename = f"tts_output_{uuid.uuid4().hex}.mp3"
+ output_path = os.path.join(tempfile.gettempdir(), output_filename)
+
+ # Generate TTS
+ success = await volcengine_tts_client.tts_request(
+ text=request.text,
+ voice_type=request.voice_type,
+ context_texts=request.context_texts,
+ emotion=request.emotion,
+ emotion_scale=request.emotion_scale,
+ speech_rate=request.speech_rate,
+ loudness_rate=request.loudness_rate,
+ pitch=request.pitch,
+ output=output_path,
+ resource_id=request.resource_id,
+ )
+
+ if success and os.path.exists(output_path):
+ # Return the audio file
+ return FileResponse(output_path, media_type="audio/mpeg", filename=output_filename)
+ else:
+ return JSONResponse({"error": "TTS generation failed"}, status_code=500)
+
+ except Exception as e:
+ logger.error(f"TTS generation error: {e}")
+ return JSONResponse({"error": f"TTS generation failed: {str(e)}"}, status_code=500)
+
+
+@app.websocket("/api/v1/podcast/generate")
+async def api_v1_podcast_generate_ws(websocket: WebSocket):
+ await websocket.accept()
+
+ def ws_get_user_id():
+ token = websocket.query_params.get("token")
+ if not token:
+ token = websocket.headers.get("authorization") or websocket.headers.get("Authorization")
+ if token and token.startswith("Bearer "):
+ token = token[7:]
+ payload = auth_manager.verify_jwt_token(token)
+ user_id = payload["user_id"]
+ return user_id
+
+ async def safe_send_json(payload):
+ try:
+ await websocket.send_json(payload)
+ except (WebSocketDisconnect, RuntimeError) as e:
+ logger.warning(f"WebSocket send skipped: {e}")
+
+ try:
+ user_id = ws_get_user_id()
+ data = await websocket.receive_text()
+ request_data = json.loads(data)
+
+ # stop request
+ if request_data.get("type") == "stop":
+ logger.info("Received stop signal from client")
+ await safe_send_json({"type": "stopped"})
+ return
+
+ # user input prompt
+ input_text = request_data.get("input", "")
+ is_url = input_text.startswith(("http://", "https://"))
+ if not input_text:
+ await safe_send_json({"error": "输入不能为空"})
+ return
+
+ session_id = "session_" + str(uuid.uuid4())
+ params = {
+ "session_id": session_id,
+ "data_manager": data_manager,
+ "text": "" if is_url else input_text,
+ "input_url": input_text if is_url else "",
+ "action": 0,
+ "use_head_music": False,
+ "use_tail_music": False,
+ "skip_round_audio_save": False,
+ }
+ logger.info(f"WebSocket generating podcast with params: {params}")
+
+ # 使用回调函数实时推送音频
+ async def on_round_complete(round_info):
+ await safe_send_json({"type": "audio_update", "data": round_info})
+
+ params["on_round_complete"] = on_round_complete
+
+ # 创建一个任务来处理停止信号
+ async def listen_for_stop(podcast_task):
+ while True:
+ try:
+ if podcast_task.done():
+ return
+ data = await asyncio.wait_for(websocket.receive_text(), timeout=0.1)
+ request = json.loads(data)
+ if request.get("type") == "stop":
+ logger.warning("Stop signal received during podcast generation")
+ podcast_task.cancel()
+ return
+ except asyncio.TimeoutError:
+ continue
+ except Exception as e:
+ logger.warning(f"Stop listener ended: {e}")
+ return
+
+ podcast_task = asyncio.create_task(volcengine_podcast_client.podcast_request(**params))
+ stop_listener_task = asyncio.create_task(listen_for_stop(podcast_task))
+ podcast_info = None
+
+ try:
+ podcast_info = await podcast_task
+ except asyncio.CancelledError:
+ logger.warning("Podcast generation cancelled by user")
+ await safe_send_json({"type": "stopped"})
+ return
+ finally:
+ stop_listener_task.cancel()
+ if podcast_info is None:
+ await safe_send_json({"error": "播客生成失败,请稍后重试"})
+ return
+
+ audio_path = podcast_info["audio_name"]
+ rounds = podcast_info["subtitles"]
+ await task_manager.create_podcast(session_id, user_id, input_text, audio_path, rounds)
+ audio_url = await data_manager.presign_podcast_output_url(audio_path)
+ if not audio_url:
+ audio_url = f"/api/v1/podcast/audio?session_id={session_id}&filename={audio_path}"
+ logger.info(f"completed podcast generation (session: {session_id})")
+
+ await safe_send_json(
+ {
+ "type": "complete",
+ "data": {
+ "audio_url": audio_url,
+ "subtitles": podcast_info["subtitles"],
+ "session_id": session_id,
+ "user_id": user_id,
+ },
+ }
+ )
+
+ except WebSocketDisconnect:
+ logger.info("WebSocket disconnected")
+
+ except Exception:
+ logger.error(f"Error in websocket: {traceback.format_exc()}")
+ await safe_send_json({"error": "websocket internal error, please try again later!"})
+
+
+@app.get("/api/v1/podcast/audio")
+async def api_v1_podcast_audio(request: Request, user=Depends(verify_user_access_from_query)):
+ try:
+ user_id = user["user_id"]
+ session_id = request.query_params.get("session_id")
+ filename = request.query_params.get("filename")
+ if not session_id or not filename:
+ return JSONResponse({"error": "session_id and filename are required"}, status_code=400)
+
+ ext = os.path.splitext(filename)[1].lower()
+ assert ext == ".mp3", f"Unsupported file extension: {ext}"
+
+ # 解析 Range 头,格式:bytes=start-end 或 bytes=start-
+ range_header = request.headers.get("Range")
+ start_byte, end_byte = None, None
+ if range_header:
+ match = re.match(r"bytes=(\d+)-(\d*)", range_header)
+ if match:
+ start_byte = int(match.group(1))
+ end_byte = int(match.group(2)) + 1 if match.group(2) else None
+
+ podcast_data = await task_manager.query_podcast(session_id, user_id)
+ if podcast_data:
+ # generate is finished and save info to database
+ func = data_manager.load_podcast_output_file
+ filename = podcast_data["audio_path"]
+ func_args = (filename,)
+ else:
+ func = data_manager.load_podcast_temp_session_file
+ func_args = (session_id, filename)
+
+ logger.debug(f"Serving audio file from {func.__name__} with args: {func_args}, start_byte: {start_byte}, end_byte: {end_byte}")
+ file_bytes = await func(*func_args)
+ file_size = len(file_bytes)
+ file_bytes = file_bytes[start_byte:end_byte]
+
+ content_length = len(file_bytes)
+ media_type = "audio/mpeg"
+ status_code = 200
+ headers = {"Content-Length": str(content_length), "Accept-Ranges": "bytes", "Content-Type": media_type, "Content-Disposition": f'attachment; filename="{filename}"'}
+
+ if start_byte is not None and start_byte > 0:
+ status_code = 206 # Partial Content
+ headers["Content-Range"] = f"bytes {start_byte}-{start_byte + content_length - 1}/{file_size}"
+ return Response(content=file_bytes, media_type=media_type, status_code=status_code, headers=headers)
+
+ except Exception as e:
+ logger.error(f"Error serving audio: {e}")
+ traceback.print_exc()
+ return JSONResponse({"error": str(e)}, status_code=500)
+
+
+@app.get("/api/v1/podcast/history")
+async def api_v1_podcast_history(request: Request, user=Depends(verify_user_access)):
+ try:
+ user_id = user["user_id"]
+ page = int(request.query_params.get("page", 1))
+ page_size = int(request.query_params.get("page_size", 10))
+ assert page > 0 and page_size > 0, "page and page_size must be greater than 0"
+ status = request.query_params.get("status", None) # has_audio, no_audio
+
+ query_params = {"user_id": user_id}
+ if status == "has_audio":
+ query_params["has_audio"] = True
+ elif status == "no_audio":
+ query_params["has_audio"] = False
+
+ total_tasks = await task_manager.list_podcasts(count=True, **query_params)
+ total_pages = (total_tasks + page_size - 1) // page_size
+ page_info = {"page": page, "page_size": page_size, "total": total_tasks, "total_pages": total_pages}
+ if page > total_pages:
+ return {"sessions": [], "pagination": page_info}
+
+ query_params["offset"] = (page - 1) * page_size
+ query_params["limit"] = page_size
+ sessions = await task_manager.list_podcasts(**query_params)
+ return {"sessions": sessions, "pagination": page_info}
+
+ except Exception as e:
+ logger.error(f"Error getting podcast history: {e}")
+ traceback.print_exc()
+ return {"sessions": []}
+
+
+@app.get("/api/v1/podcast/session/{session_id}/audio_url")
+async def api_v1_podcast_session_audio_url(session_id: str, user=Depends(verify_user_access)):
+ try:
+ user_id = user["user_id"]
+ podcast_data = await task_manager.query_podcast(session_id, user_id)
+ if not podcast_data:
+ return JSONResponse({"error": "Podcast session not found"}, status_code=404)
+
+ audio_path = podcast_data["audio_path"]
+ audio_url = await data_manager.presign_podcast_output_url(audio_path)
+ if not audio_url:
+ audio_url = f"/api/v1/podcast/audio?session_id={session_id}&filename={audio_path}"
+ return {"audio_url": audio_url}
+
+ except Exception as e:
+ logger.error(f"Error getting podcast session audio URL: {e}")
+ traceback.print_exc()
+ return JSONResponse({"error": str(e)}, status_code=500)
+
+
+class FaceDetectRequest(BaseModel):
+ image: str # Base64 encoded image
+
+
+class AudioSeparateRequest(BaseModel):
+ audio: str # Base64 encoded audio
+ num_speakers: int = None # Optional: number of speakers to separate
+
+
+@app.post("/api/v1/face/detect")
+async def api_v1_face_detect(request: FaceDetectRequest, user=Depends(verify_user_access)):
+ """Detect faces in image (only detection, no cropping - cropping is done on frontend)
+ Supports both base64 encoded images and URLs (blob URLs, http URLs, etc.)
+ """
+ try:
+ if not face_detector:
+ return error_response("Face detector not initialized", 500)
+
+ # 验证输入
+ if not request.image or not request.image.strip():
+ logger.error("Face detection request: image is empty")
+ return error_response("Image input is empty", 400)
+
+ image_bytes = None
+ try:
+ # Check if input is a URL (blob:, http:, https:, or data: URL)
+ if request.image.startswith(("http://", "https://")):
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
+ image_bytes = await fetch_resource(request.image, timeout=timeout)
+ logger.debug(f"Fetched image from URL for face detection: {request.image[:100]}... (size: {len(image_bytes)} bytes)")
+ else:
+ encoded = request.image
+ # Data URL format: "data:image/png;base64,..."
+ if encoded.startswith("data:image"):
+ _, encoded = encoded.split(",", 1)
+ image_bytes = base64.b64decode(encoded)
+ logger.debug(f"Decoded base64 image: {request.image[:100]}... (size: {len(image_bytes)} bytes)")
+
+ # Validate image format before passing to face detector
+ image_bytes = await asyncio.to_thread(format_image_data, image_bytes)
+
+ except Exception as e:
+ logger.error(f"Failed to decode base64 image: {e}, image length: {len(request.image) if request.image else 0}")
+ return error_response(f"Invalid image format: {str(e)}", 400)
+
+ # Detect faces only (no cropping)
+ result = await asyncio.to_thread(face_detector.detect_faces, image_bytes, return_image=False)
+ faces_data = []
+ for i, face in enumerate(result["faces"]):
+ faces_data.append(
+ {
+ "index": i,
+ "bbox": face["bbox"], # [x1, y1, x2, y2] - absolute pixel coordinates in original image
+ "confidence": face["confidence"],
+ "class_id": face["class_id"],
+ "class_name": face["class_name"],
+ # Note: face_image is not included - frontend will crop it based on bbox
+ }
+ )
+ return {"faces": faces_data, "total": len(faces_data)}
+
+ except Exception as e:
+ logger.error(f"Face detection error: {traceback.format_exc()}")
+ return error_response(f"Face detection failed: {str(e)}", 500)
+
+
+@app.post("/api/v1/audio/separate")
+async def api_v1_audio_separate(request: AudioSeparateRequest, user=Depends(verify_user_access)):
+ """Separate different speakers in audio"""
+ try:
+ if not audio_separator:
+ return error_response("Audio separator not initialized", 500)
+ audio_bytes = None
+ try:
+ encoded = request.audio
+ if encoded.startswith("data:"):
+ # Remove data URL prefix (e.g., "data:audio/mpeg;base64," or "data:application/octet-stream;base64,")
+ _, encoded = encoded.split(",", 1)
+ audio_bytes = await asyncio.to_thread(base64.b64decode, encoded, validate=True)
+ logger.debug(f"Successfully decoded base64 audio, size: {len(audio_bytes)} bytes")
+
+ except Exception as e:
+ logger.error(f"Failed to decode base64 audio {request.audio[:100]}..., error: {str(e)}")
+ return error_response(f"Invalid base64 audio data", 400)
+
+ # Separate speakers
+ result = await asyncio.to_thread(audio_separator.separate_speakers, audio_bytes, num_speakers=request.num_speakers)
+
+ # Convert audio tensors to base64 strings (without saving to file)
+ speakers_data = []
+ for speaker in result["speakers"]:
+ # Convert audio tensor directly to base64
+ audio_base64 = await asyncio.to_thread(audio_separator.speaker_audio_to_base64, speaker["audio"], speaker["sample_rate"], format="wav")
+ speakers_data.append(
+ {
+ "speaker_id": speaker["speaker_id"],
+ "audio": audio_base64, # Base64 encoded audio
+ "segments": speaker["segments"],
+ "sample_rate": speaker["sample_rate"],
+ }
+ )
+ return {"speakers": speakers_data, "total": len(speakers_data), "method": result.get("method", "pyannote")}
+
+ except Exception as e:
+ logger.error(f"Audio separation error: {traceback.format_exc()}")
+ return error_response(f"Audio separation failed: {str(e)}", 500)
+
+
+# 所有未知路由 fallback 到 index.html (必须在所有API路由之后)
+@app.get("/{full_path:path}", response_class=HTMLResponse)
+async def vue_fallback(full_path: str):
+ index_path = os.path.join(static_dir, "index.html")
+ if os.path.exists(index_path):
+ return FileResponse(index_path)
+ return HTMLResponse("Frontend not found
", status_code=404)
+
+
+# =========================
+# Main Entry
+# =========================
+
+if __name__ == "__main__":
+ ProcessManager.register_signal_handler()
+ parser = argparse.ArgumentParser()
+
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
+ base_dir = os.path.abspath(os.path.join(cur_dir, "../../.."))
+ dft_pipeline_json = os.path.join(base_dir, "configs/model_pipeline.json")
+ dft_task_url = os.path.join(base_dir, "local_task")
+ dft_data_url = os.path.join(base_dir, "local_data")
+ dft_queue_url = os.path.join(base_dir, "local_queue")
+ dft_volcengine_tts_list_json = os.path.join(base_dir, "configs/volcengine_voices_list.json")
+
+ parser.add_argument("--pipeline_json", type=str, default=dft_pipeline_json)
+ parser.add_argument("--task_url", type=str, default=dft_task_url)
+ parser.add_argument("--data_url", type=str, default=dft_data_url)
+ parser.add_argument("--queue_url", type=str, default=dft_queue_url)
+ parser.add_argument("--redis_url", type=str, default="")
+ parser.add_argument("--template_dir", type=str, default="")
+ parser.add_argument("--volcengine_tts_list_json", type=str, default=dft_volcengine_tts_list_json)
+ parser.add_argument("--ip", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int, default=8080)
+ parser.add_argument("--face_detector_model_path", type=str, default=None)
+ parser.add_argument("--audio_separator_model_path", type=str, default="")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ model_pipelines = Pipeline(args.pipeline_json)
+ volcengine_tts_client = VolcEngineTTSClient(args.volcengine_tts_list_json)
+ volcengine_podcast_client = VolcEnginePodcastClient()
+ face_detector = FaceDetector(model_path=args.face_detector_model_path)
+ audio_separator = AudioSeparator(model_path=args.audio_separator_model_path)
+ auth_manager = AuthManager()
+ if args.task_url.startswith("/"):
+ task_manager = LocalTaskManager(args.task_url, metrics_monitor)
+ elif args.task_url.startswith("postgresql://"):
+ task_manager = PostgresSQLTaskManager(args.task_url, metrics_monitor)
+ else:
+ raise NotImplementedError
+ if args.data_url.startswith("/"):
+ data_manager = LocalDataManager(args.data_url, args.template_dir)
+ elif args.data_url.startswith("{"):
+ data_manager = S3DataManager(args.data_url, args.template_dir)
+ else:
+ raise NotImplementedError
+ if args.queue_url.startswith("/"):
+ queue_manager = LocalQueueManager(args.queue_url)
+ elif args.queue_url.startswith("amqp://"):
+ queue_manager = RabbitMQQueueManager(args.queue_url)
+ else:
+ raise NotImplementedError
+ if args.redis_url:
+ server_monitor = RedisServerMonitor(model_pipelines, task_manager, queue_manager, args.redis_url)
+ else:
+ server_monitor = ServerMonitor(model_pipelines, task_manager, queue_manager)
+
+ uvicorn.run(app, host=args.ip, port=args.port, reload=False, workers=1)
diff --git a/lightx2v/deploy/server/auth.py b/lightx2v/deploy/server/auth.py
new file mode 100644
index 0000000000000000000000000000000000000000..11801257463992bc102bed480314943db147e3eb
--- /dev/null
+++ b/lightx2v/deploy/server/auth.py
@@ -0,0 +1,205 @@
+import os
+import time
+import uuid
+
+import aiohttp
+import jwt
+from fastapi import HTTPException
+from loguru import logger
+
+from lightx2v.deploy.common.aliyun import AlibabaCloudClient
+
+
+class AuthManager:
+ def __init__(self):
+ # Worker access token
+ self.worker_secret_key = os.getenv("WORKER_SECRET_KEY", "worker-secret-key-change-in-production")
+
+ # GitHub OAuth
+ self.github_client_id = os.getenv("GITHUB_CLIENT_ID", "")
+ self.github_client_secret = os.getenv("GITHUB_CLIENT_SECRET", "")
+
+ # Google OAuth
+ self.google_client_id = os.getenv("GOOGLE_CLIENT_ID", "")
+ self.google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", "")
+ self.google_redirect_uri = os.getenv("GOOGLE_REDIRECT_URI", "")
+
+ self.jwt_algorithm = os.getenv("JWT_ALGORITHM", "HS256")
+ self.jwt_secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production")
+ self.jwt_expiration_hours = int(os.getenv("JWT_EXPIRATION_HOURS", "168"))
+ self.refresh_token_expiration_days = int(os.getenv("REFRESH_TOKEN_EXPIRATION_DAYS", "30"))
+ self.refresh_jwt_secret_key = os.getenv("REFRESH_JWT_SECRET_KEY", self.jwt_secret_key)
+
+ # Aliyun SMS
+ self.aliyun_client = AlibabaCloudClient()
+
+ logger.info(f"AuthManager: GITHUB_CLIENT_ID: {self.github_client_id}")
+ logger.info(f"AuthManager: GITHUB_CLIENT_SECRET: {self.github_client_secret}")
+ logger.info(f"AuthManager: GOOGLE_CLIENT_ID: {self.google_client_id}")
+ logger.info(f"AuthManager: GOOGLE_CLIENT_SECRET: {self.google_client_secret}")
+ logger.info(f"AuthManager: GOOGLE_REDIRECT_URI: {self.google_redirect_uri}")
+ logger.info(f"AuthManager: JWT_SECRET_KEY: {self.jwt_secret_key}")
+ logger.info(f"AuthManager: WORKER_SECRET_KEY: {self.worker_secret_key}")
+
+ def _create_token(self, data, expires_in_seconds, token_type, secret_key):
+ now = int(time.time())
+ payload = {
+ "user_id": data["user_id"],
+ "username": data["username"],
+ "email": data["email"],
+ "homepage": data["homepage"],
+ "token_type": token_type,
+ "iat": now,
+ "exp": now + expires_in_seconds,
+ "jti": str(uuid.uuid4()),
+ }
+ return jwt.encode(payload, secret_key, algorithm=self.jwt_algorithm)
+
+ def create_access_token(self, data):
+ return self._create_token(data, self.jwt_expiration_hours * 3600, "access", self.jwt_secret_key)
+
+ def create_refresh_token(self, data):
+ return self._create_token(data, self.refresh_token_expiration_days * 24 * 3600, "refresh", self.refresh_jwt_secret_key)
+
+ def create_tokens(self, data):
+ return self.create_access_token(data), self.create_refresh_token(data)
+
+ def create_jwt_token(self, data):
+ # Backwards compatibility for callers that still expect this name
+ return self.create_access_token(data)
+
+ async def auth_github(self, code):
+ try:
+ logger.info(f"GitHub OAuth code: {code}")
+ token_url = "https://github.com/login/oauth/access_token"
+ token_data = {"client_id": self.github_client_id, "client_secret": self.github_client_secret, "code": code}
+ headers = {"Accept": "application/json"}
+
+ proxy = os.getenv("auth_https_proxy", None)
+ if proxy:
+ logger.info(f"auth_github use proxy: {proxy}")
+ async with aiohttp.ClientSession() as session:
+ async with session.post(token_url, data=token_data, headers=headers, proxy=proxy) as response:
+ response.raise_for_status()
+ token_info = await response.json()
+
+ if "error" in token_info:
+ raise HTTPException(status_code=400, detail=f"GitHub OAuth error: {token_info['error']}")
+
+ access_token = token_info.get("access_token")
+ if not access_token:
+ raise HTTPException(status_code=400, detail="Failed to get access token")
+
+ user_url = "https://api.github.com/user"
+ user_headers = {"Authorization": f"token {access_token}", "Accept": "application/vnd.github.v3+json"}
+ async with aiohttp.ClientSession() as session:
+ async with session.get(user_url, headers=user_headers, proxy=proxy) as response:
+ response.raise_for_status()
+ user_info = await response.json()
+
+ return {
+ "source": "github",
+ "id": str(user_info["id"]),
+ "username": user_info["login"],
+ "email": user_info.get("email", ""),
+ "homepage": user_info.get("html_url", ""),
+ "avatar_url": user_info.get("avatar_url", ""),
+ }
+
+ except aiohttp.ClientError as e:
+ logger.error(f"GitHub API request failed: {e}")
+ raise HTTPException(status_code=500, detail="Failed to authenticate with GitHub")
+
+ except Exception as e:
+ logger.error(f"Authentication error: {e}")
+ raise HTTPException(status_code=500, detail="Authentication failed")
+
+ async def auth_google(self, code):
+ try:
+ logger.info(f"Google OAuth code: {code}")
+ token_url = "https://oauth2.googleapis.com/token"
+ token_data = {
+ "client_id": self.google_client_id,
+ "client_secret": self.google_client_secret,
+ "code": code,
+ "redirect_uri": self.google_redirect_uri,
+ "grant_type": "authorization_code",
+ }
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+
+ proxy = os.getenv("auth_https_proxy", None)
+ if proxy:
+ logger.info(f"auth_google use proxy: {proxy}")
+ async with aiohttp.ClientSession() as session:
+ async with session.post(token_url, data=token_data, headers=headers, proxy=proxy) as response:
+ response.raise_for_status()
+ token_info = await response.json()
+
+ if "error" in token_info:
+ raise HTTPException(status_code=400, detail=f"Google OAuth error: {token_info['error']}")
+
+ access_token = token_info.get("access_token")
+ if not access_token:
+ raise HTTPException(status_code=400, detail="Failed to get access token")
+
+ # get user info
+ user_url = "https://www.googleapis.com/oauth2/v2/userinfo"
+ user_headers = {"Authorization": f"Bearer {access_token}"}
+ async with aiohttp.ClientSession() as session:
+ async with session.get(user_url, headers=user_headers, proxy=proxy) as response:
+ response.raise_for_status()
+ user_info = await response.json()
+ return {
+ "source": "google",
+ "id": str(user_info["id"]),
+ "username": user_info.get("name", user_info.get("email", "")),
+ "email": user_info.get("email", ""),
+ "homepage": user_info.get("link", ""),
+ "avatar_url": user_info.get("picture", ""),
+ }
+
+ except aiohttp.ClientError as e:
+ logger.error(f"Google API request failed: {e}")
+ raise HTTPException(status_code=500, detail="Failed to authenticate with Google")
+
+ except Exception as e:
+ logger.error(f"Google authentication error: {e}")
+ raise HTTPException(status_code=500, detail="Google authentication failed")
+
+ async def send_sms(self, phone_number):
+ return await self.aliyun_client.send_sms(phone_number)
+
+ async def check_sms(self, phone_number, verify_code):
+ ok = await self.aliyun_client.check_sms(phone_number, verify_code)
+ if not ok:
+ return None
+ return {
+ "source": "phone",
+ "id": phone_number,
+ "username": phone_number,
+ "email": "",
+ "homepage": "",
+ "avatar_url": "",
+ }
+
+ def _verify_token(self, token, expected_type, secret_key):
+ try:
+ payload = jwt.decode(token, secret_key, algorithms=[self.jwt_algorithm])
+ token_type = payload.get("token_type")
+ if token_type and token_type != expected_type:
+ raise HTTPException(status_code=401, detail="Token type mismatch")
+ return payload
+ except jwt.ExpiredSignatureError:
+ raise HTTPException(status_code=401, detail="Token has expired")
+ except Exception as e:
+ logger.error(f"verify_jwt_token error: {e}")
+ raise HTTPException(status_code=401, detail="Could not validate credentials")
+
+ def verify_jwt_token(self, token):
+ return self._verify_token(token, "access", self.jwt_secret_key)
+
+ def verify_refresh_token(self, token):
+ return self._verify_token(token, "refresh", self.refresh_jwt_secret_key)
+
+ def verify_worker_token(self, token):
+ return token == self.worker_secret_key
diff --git a/lightx2v/deploy/server/frontend/.gitignore b/lightx2v/deploy/server/frontend/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a547bf36d8d11a4f89c59c144f24795749086dd1
--- /dev/null
+++ b/lightx2v/deploy/server/frontend/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/lightx2v/deploy/server/frontend/README.md b/lightx2v/deploy/server/frontend/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1511959c22b74118c06679739cf9abe2ceeea48c
--- /dev/null
+++ b/lightx2v/deploy/server/frontend/README.md
@@ -0,0 +1,5 @@
+# Vue 3 + Vite
+
+This template should help get you started developing with Vue 3 in Vite. The template uses Vue 3 `
+
+
+
+
+
+
+
+
+ 功能亮点
+
+ - 电影级数字人视频生成
+ - 20倍生成提速
+ - 超低成本生成
+ - 精准口型对齐
+ - 分钟级视频时长
+ - 多场景应用
+ - 最新tts语音合成技术,支持多种语言,支持100+种音色,支持语音指令控制合成语音细节
+
+
+
+ 快速开始
+
+ - 上传图片及音频,输入视频生成提示词,点击开始生成
+ - 生成并下载视频
+ - 应用模版,一键生成同款数字人视频
+
+
+
+
+
+
+