diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..77b7ba25dcb92f592bf0df5c8f500332d13187ab --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +*.pth +*.pt +*.onnx +*.pk +*.model +*.zip +*.tar +*.pyc +*.log +*.o +*.so +*.a +*.exe +*.out +.idea +**.DS_Store** +**/__pycache__/** +**.swp +.vscode/ +.env +.log +*.pid +*.ipynb* +*.mp4 +build/ +dist/ +.cache/ +server_cache/ +app/.gradio/ +*.pkl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bdb356ab0608ea3a6145926c0d6906ad4217ee5c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,22 @@ +# Follow https://verdantfox.com/blog/how-to-use-git-pre-commit-hooks-the-hard-way-and-the-easy-way +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.0 + hooks: + - id: ruff + args: [--fix, --respect-gitignore, --config=pyproject.toml] + - id: ruff-format + args: [--config=pyproject.toml] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-added-large-files + args: ['--maxkb=3000'] # Allow files up to 3MB + - id: check-case-conflict + - id: check-merge-conflict + - id: debug-statements diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 6b94e2481927b013b3776ef7cc011e9094341871..35408da19ac0a419ed5692305855e51124358fa7 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,295 @@ -# LightX2V +
+

⚡️ LightX2V:
轻量级视频生成推理框架

+logo + +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/ModelTC/lightx2v) +[![Doc](https://img.shields.io/badge/docs-English-99cc2)](https://lightx2v-en.readthedocs.io/en/latest) +[![Doc](https://img.shields.io/badge/文档-中文-99cc2)](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest) +[![Papers](https://img.shields.io/badge/论文集-中文-99cc2)](https://lightx2v-papers-zhcn.readthedocs.io/zh-cn/latest) +[![Docker](https://img.shields.io/badge/Docker-2496ED?style=flat&logo=docker&logoColor=white)](https://hub.docker.com/r/lightx2v/lightx2v/tags) + +**\[ [English](README.md) | 中文 \]** + +
+ +-------------------------------------------------------------------------------- + +**LightX2V** 是一个先进的轻量级视频生成推理框架,专为提供高效、高性能的视频合成解决方案而设计。该统一平台集成了多种前沿的视频生成技术,支持文本生成视频(T2V)和图像生成视频(I2V)等多样化生成任务。**X2V 表示将不同的输入模态(X,如文本或图像)转换为视频输出(V)**。 + +> 🌐 **立即在线体验!** 无需安装即可体验 LightX2V:**[LightX2V 在线服务](https://x2v.light-ai.top/login)** - 免费、轻量、快速的AI数字人视频生成平台。 + +## :fire: 最新动态 + +- **2025年12月4日:** 🚀 支持 GGUF 格式模型推理,以及在寒武纪 MLU590、MetaX C500 硬件上的部署。 + +- **2025年11月24日:** 🚀 我们发布了HunyuanVideo-1.5的4步蒸馏模型!这些模型支持**超快速4步推理**,无需CFG配置,相比标准50步推理可实现约**25倍加速**。现已提供基础版本和FP8量化版本:[Hy1.5-Distill-Models](https://huggingface.co/lightx2v/Hy1.5-Distill-Models)。 + +- **2025年11月21日:** 🚀 我们Day0支持了[HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5)的视频生成模型,同样GPU数量,LightX2V可带来约2倍以上的速度提升,并支持更低显存GPU部署(如24G RTX4090)。支持CFG并行/Ulysses并行,高效Offload,TeaCache/MagCache等技术。同时支持沐曦,寒武纪等国产芯片部署。我们很快将在我们的[HuggingFace主页](https://huggingface.co/lightx2v)更新更多模型,包括步数蒸馏,VAE蒸馏等相关模型。量化模型和轻量VAE模型现已可用:[Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models)用于量化推理,[HunyuanVideo-1.5轻量TAE](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors)用于快速VAE解码。使用教程参考[这里](https://github.com/ModelTC/LightX2V/tree/main/scripts/hunyuan_video_15),或查看[示例目录](https://github.com/ModelTC/LightX2V/tree/main/examples)获取代码示例。 + + +## 🏆 性能测试数据 (更新于 2025.12.01) + +### 📊 推理框架之间性能对比 (H100) + +| Framework | GPUs | Step Time | Speedup | +|-----------|---------|---------|---------| +| Diffusers | 1 | 9.77s/it | 1x | +| xDiT | 1 | 8.93s/it | 1.1x | +| FastVideo | 1 | 7.35s/it | 1.3x | +| SGL-Diffusion | 1 | 6.13s/it | 1.6x | +| **LightX2V** | 1 | **5.18s/it** | **1.9x** 🚀 | +| FastVideo | 8 | 2.94s/it | 1x | +| xDiT | 8 | 2.70s/it | 1.1x | +| SGL-Diffusion | 8 | 1.19s/it | 2.5x | +| **LightX2V** | 8 | **0.75s/it** | **3.9x** 🚀 | + +### 📊 推理框架之间性能对比 (RTX 4090D) + +| Framework | GPUs | Step Time | Speedup | +|-----------|---------|---------|---------| +| Diffusers | 1 | 30.50s/it | 1x | +| FastVideo | 1 | 22.66s/it | 1.3x | +| xDiT | 1 | OOM | OOM | +| SGL-Diffusion | 1 | OOM | OOM | +| **LightX2V** | 1 | **20.26s/it** | **1.5x** 🚀 | +| FastVideo | 8 | 15.48s/it | 1x | +| xDiT | 8 | OOM | OOM | +| SGL-Diffusion | 8 | OOM | OOM | +| **LightX2V** | 8 | **4.75s/it** | **3.3x** 🚀 | + +### 📊 LightX2V不同配置之间性能对比 + +| Framework | GPU | Configuration | Step Time | Speedup | +|-----------|-----|---------------|-----------|---------------| +| **LightX2V** | H100 | 8 GPUs + cfg | 0.75s/it | 1x | +| **LightX2V** | H100 | 8 GPUs + no cfg | 0.39s/it | 1.9x | +| **LightX2V** | H100 | **8 GPUs + no cfg + fp8** | **0.35s/it** | **2.1x** 🚀 | +| **LightX2V** | 4090D | 8 GPUs + cfg | 4.75s/it | 1x | +| **LightX2V** | 4090D | 8 GPUs + no cfg | 3.13s/it | 1.5x | +| **LightX2V** | 4090D | **8 GPUs + no cfg + fp8** | **2.35s/it** | **2.0x** 🚀 | + +**注意**: 所有以上性能数据均在 Wan2.1-I2V-14B-480P(40 steps, 81 frames) 上测试。此外,我们[HuggingFace 主页](https://huggingface.co/lightx2v)还提供了4步蒸馏模型。 + + +## 💡 快速开始 + + +详细使用说明请参考我们的文档:**[英文文档](https://lightx2v-en.readthedocs.io/en/latest/) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/)** + +**我们强烈推荐使用 Docker 环境,这是最简单快捷的环境安装方式。具体参考:文档中的快速入门章节。** + +### 从 Git 安装 +```bash +pip install -v git+https://github.com/ModelTC/LightX2V.git +``` + +### 从源码构建 +```bash +git clone https://github.com/ModelTC/LightX2V.git +cd LightX2V +uv pip install -v . # pip install -v . +``` + +### (可选)安装注意力/量化算子 +注意力算子安装说明请参考我们的文档:**[英文文档](https://lightx2v-en.readthedocs.io/en/latest/getting_started/quickstart.html#step-4-install-attention-operators) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/quickstart.html#id9)** + +### 使用示例 +```python +# examples/wan/wan_i2v.py +""" +Wan2.2 image-to-video generation example. +This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for Wan2.2 I2V task +# For wan2.1, use model_cls="wan2.1" +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.2-I2V-A14B", + model_cls="wan2.2_moe", + task="i2v", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator( +# config_json="configs/wan22/wan_moe_i2v.json" +# ) + +# Enable offloading to significantly reduce VRAM usage with minimal speed impact +# Suitable for RTX 30/40/50 consumer GPUs +pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", # For Wan models, supports both "block" and "phase" + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, +) + +# Create generator manually with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=40, + height=480, # Can be set to 720 for higher resolution + width=832, # Can be set to 1280 for higher resolution + num_frames=81, + guidance_scale=[3.5, 3.5], # For wan2.1, guidance_scale is a scalar (e.g., 5.0) + sample_shift=5.0, +) + +# Generation parameters +seed = 42 +prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +image_path="/path/to/img_0.jpg" +save_result_path = "/path/to/save_results/output.mp4" + +# Generate video +pipe.generate( + seed=seed, + image_path=image_path, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) + +``` + +> 💡 **更多示例**: 更多使用案例,包括量化、卸载、缓存等进阶配置,请参考 [examples 目录](https://github.com/ModelTC/LightX2V/tree/main/examples)。 + +## 🤖 支持的模型生态 + +### 官方开源模型 +- ✅ [HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5) +- ✅ [Wan2.1 & Wan2.2](https://huggingface.co/Wan-AI/) +- ✅ [Qwen-Image](https://huggingface.co/Qwen/Qwen-Image) +- ✅ [Qwen-Image-Edit](https://huggingface.co/spaces/Qwen/Qwen-Image-Edit) +- ✅ [Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) + +### 量化模型和蒸馏模型/Lora (**🚀 推荐:4步推理**) +- ✅ [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- ✅ [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) +- ✅ [Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras) +- ✅ [Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras) + +### 轻量级自编码器模型(**🚀 推荐:推理快速 + 内存占用低**) +- ✅ [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) + +### 自回归模型 +- ✅ [Wan2.1-T2V-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid) +- ✅ [Self-Forcing](https://github.com/guandeh17/Self-Forcing) +- ✅ [Matrix-Game-2.0](https://huggingface.co/Skywork/Matrix-Game-2.0) + +🔔 可以关注我们的[HuggingFace主页](https://huggingface.co/lightx2v),及时获取我们团队的模型。 + +💡 参考[模型结构文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/model_structure.html)快速上手 LightX2V + +## 🚀 前端展示 + +我们提供了多种前端界面部署方式: + +- **🎨 Gradio界面**: 简洁易用的Web界面,适合快速体验和原型开发 + - 📖 [Gradio部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html) +- **🎯 ComfyUI界面**: 强大的节点式工作流界面,支持复杂的视频生成任务 + - 📖 [ComfyUI部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_comfyui.html) +- **🚀 Windows一键部署**: 专为Windows用户设计的便捷部署方案,支持自动环境配置和智能参数优化 + - 📖 [Windows一键部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_local_windows.html) + +**💡 推荐方案**: +- **首次使用**: 建议选择Windows一键部署方案 +- **高级用户**: 推荐使用ComfyUI界面获得更多自定义选项 +- **快速体验**: Gradio界面提供最直观的操作体验 + +## 🚀 核心特性 + +### 🎯 **极致性能优化** +- **🔥 SOTA推理速度**: 通过步数蒸馏和系统优化实现**20倍**极速加速(单GPU) +- **⚡️ 革命性4步蒸馏**: 将原始40-50步推理压缩至仅需4步,且无需CFG配置 +- **🛠️ 先进算子支持**: 集成顶尖算子,包括[Sage Attention](https://github.com/thu-ml/SageAttention)、[Flash Attention](https://github.com/Dao-AILab/flash-attention)、[Radial Attention](https://github.com/mit-han-lab/radial-attention)、[q8-kernel](https://github.com/KONAKONA666/q8_kernels)、[sgl-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)、[vllm](https://github.com/vllm-project/vllm) + +### 💾 **资源高效部署** +- **💡 突破硬件限制**: **仅需8GB显存 + 16GB内存**即可运行14B模型生成480P/720P视频 +- **🔧 智能参数卸载**: 先进的磁盘-CPU-GPU三级卸载架构,支持阶段/块级别的精细化管理 +- **⚙️ 全面量化支持**: 支持`w8a8-int8`、`w8a8-fp8`、`w4a4-nvfp4`等多种量化策略 + +### 🎨 **丰富功能生态** +- **📈 智能特征缓存**: 智能缓存机制,消除冗余计算,提升效率 +- **🔄 并行推理加速**: 多GPU并行处理,显著提升性能表现 +- **📱 灵活部署选择**: 支持Gradio、服务化部署、ComfyUI等多种部署方式 +- **🎛️ 动态分辨率推理**: 自适应分辨率调整,优化生成质量 +- **🎞️ 视频帧插值**: 基于RIFE的帧插值技术,实现流畅的帧率提升 + + +## 📚 技术文档 + +### 📖 **方法教程** +- [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) - 量化策略全面指南 +- [特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html) - 智能缓存机制详解 +- [注意力机制](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html) - 前沿注意力算子 +- [参数卸载](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/offload.html) - 三级存储架构 +- [并行推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/parallel.html) - 多GPU加速策略 +- [变分辨率推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/changing_resolution.html) - U型分辨率策略 +- [步数蒸馏](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/step_distill.html) - 4步推理技术 +- [视频帧插值](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/video_frame_interpolation.html) - 基于RIFE的帧插值技术 + +### 🛠️ **部署指南** +- [低资源场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_resource.html) - 优化的8GB显存解决方案 +- [低延迟场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_latency.html) - 极速推理优化 +- [Gradio部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html) - Web界面搭建 +- [服务化部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_service.html) - 生产级API服务部署 +- [Lora模型部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/lora_deploy.html) - Lora灵活部署 + +## 🧾 代码贡献指南 + +我们通过自动化的预提交钩子来保证代码质量,确保项目代码格式的一致性。 + +> [!TIP] +> **安装说明:** +> +> 1. 安装必要的依赖: +> ```shell +> pip install ruff pre-commit +> ``` +> +> 2. 提交前运行: +> ```shell +> pre-commit run --all-files +> ``` + +感谢您为LightX2V的改进做出贡献! + +## 🤝 致谢 + +我们向所有启发和促进LightX2V开发的模型仓库和研究社区表示诚挚的感谢。此框架基于开源社区的集体努力而构建。 + +## 🌟 Star 历史 + +[![Star History Chart](https://api.star-history.com/svg?repos=ModelTC/lightx2v&type=Timeline)](https://star-history.com/#ModelTC/lightx2v&Timeline) + +## ✏️ 引用 + +如果您发现LightX2V对您的研究有用,请考虑引用我们的工作: + +```bibtex +@misc{lightx2v, + author = {LightX2V Contributors}, + title = {LightX2V: Light Video Generation Inference Framework}, + year = {2025}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/ModelTC/lightx2v}}, +} +``` + +## 📞 联系与支持 + +如有任何问题、建议或需要支持,欢迎通过以下方式联系我们: +- 🐛 [GitHub Issues](https://github.com/ModelTC/lightx2v/issues) - 错误报告和功能请求 + +--- + +
+由 LightX2V 团队用 ❤️ 构建 +
diff --git a/app/README.md b/app/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2be884b7e04244e2fb210d27ea7b7978abc072b7 --- /dev/null +++ b/app/README.md @@ -0,0 +1,13 @@ +# Gradio Demo + +Please refer our gradio deployment doc: + +[English doc: Gradio Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_gradio.html) + +[中文文档: Gradio 部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html) + +## 🚀 Quick Start (快速开始) + +For Windows users, we provide a convenient one-click deployment solution with automatic environment configuration and intelligent parameter optimization. Please refer to the [One-Click Gradio Launch](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_local_windows.html) section for detailed instructions. + +对于Windows用户,我们提供了便捷的一键部署方式,支持自动环境配置和智能参数优化。详细操作请参考[一键启动Gradio](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_local_windows.html)章节。 diff --git a/app/gradio_demo.py b/app/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..448ec3363692bae777afeccf3ef82c9450ffc53b --- /dev/null +++ b/app/gradio_demo.py @@ -0,0 +1,1442 @@ +import argparse +import gc +import glob +import importlib.util +import json +import os + +os.environ["PROFILING_DEBUG_LEVEL"] = "2" +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["DTYPE"] = "BF16" +import random +from datetime import datetime + +import gradio as gr +import psutil +import torch +from loguru import logger + +from lightx2v.utils.input_info import set_input_info +from lightx2v.utils.set_config import get_default_config + +try: + from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace +except ImportError: + apply_rope_with_cos_sin_cache_inplace = None + + +logger.add( + "inference_logs.log", + rotation="100 MB", + encoding="utf-8", + enqueue=True, + backtrace=True, + diagnose=True, +) + +MAX_NUMPY_SEED = 2**32 - 1 + + +def scan_model_path_contents(model_path): + """Scan model_path directory and return available files and subdirectories""" + if not model_path or not os.path.exists(model_path): + return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []} + + dirs = [] + files = [] + safetensors_dirs = [] + pth_files = [] + + try: + for item in os.listdir(model_path): + item_path = os.path.join(model_path, item) + if os.path.isdir(item_path): + dirs.append(item) + # Check if directory contains safetensors files + if glob.glob(os.path.join(item_path, "*.safetensors")): + safetensors_dirs.append(item) + elif os.path.isfile(item_path): + files.append(item) + if item.endswith(".pth"): + pth_files.append(item) + except Exception as e: + logger.warning(f"Failed to scan directory: {e}") + + return { + "dirs": sorted(dirs), + "files": sorted(files), + "safetensors_dirs": sorted(safetensors_dirs), + "pth_files": sorted(pth_files), + } + + +def get_dit_choices(model_path, model_type="wan2.1"): + """Get Diffusion model options (filtered by model type)""" + contents = scan_model_path_contents(model_path) + excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"] + fp8_supported = is_fp8_supported_gpu() + + if model_type == "wan2.1": + # wan2.1: filter files/dirs containing wan2.1 or Wan2.1 + def is_valid(name): + name_lower = name.lower() + if "wan2.1" not in name_lower: + return False + if not fp8_supported and "fp8" in name_lower: + return False + return not any(kw in name_lower for kw in excluded_keywords) + else: + # wan2.2: filter files/dirs containing wan2.2 or Wan2.2 + def is_valid(name): + name_lower = name.lower() + if "wan2.2" not in name_lower: + return False + if not fp8_supported and "fp8" in name_lower: + return False + return not any(kw in name_lower for kw in excluded_keywords) + + # Filter matching directories and files + dir_choices = [d for d in contents["dirs"] if is_valid(d)] + file_choices = [f for f in contents["files"] if is_valid(f)] + choices = dir_choices + file_choices + return choices if choices else [""] + + +def get_high_noise_choices(model_path): + """Get high noise model options (files/dirs containing high_noise)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + def is_valid(name): + name_lower = name.lower() + if not fp8_supported and "fp8" in name_lower: + return False + return "high_noise" in name_lower or "high-noise" in name_lower + + dir_choices = [d for d in contents["dirs"] if is_valid(d)] + file_choices = [f for f in contents["files"] if is_valid(f)] + choices = dir_choices + file_choices + return choices if choices else [""] + + +def get_low_noise_choices(model_path): + """Get low noise model options (files/dirs containing low_noise)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + def is_valid(name): + name_lower = name.lower() + if not fp8_supported and "fp8" in name_lower: + return False + return "low_noise" in name_lower or "low-noise" in name_lower + + dir_choices = [d for d in contents["dirs"] if is_valid(d)] + file_choices = [f for f in contents["files"] if is_valid(f)] + choices = dir_choices + file_choices + return choices if choices else [""] + + +def get_t5_choices(model_path): + """Get T5 model options (.pth or .safetensors files containing t5 keyword)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + # Filter from .pth files + pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())] + + # Filter from .safetensors files + safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())] + + # Filter from directories containing safetensors + safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())] + + choices = pth_choices + safetensors_choices + safetensors_dir_choices + return choices if choices else [""] + + +def get_clip_choices(model_path): + """Get CLIP model options (.pth or .safetensors files containing clip keyword)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + # Filter from .pth files + pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())] + + # Filter from .safetensors files + safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())] + + # Filter from directories containing safetensors + safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())] + + choices = pth_choices + safetensors_choices + safetensors_dir_choices + return choices if choices else [""] + + +def get_vae_choices(model_path): + """Get VAE model options (.pth or .safetensors files containing vae/VAE/tae keyword)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + # Filter from .pth files + pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())] + + # Filter from .safetensors files + safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())] + + # Filter from directories containing safetensors + safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())] + + choices = pth_choices + safetensors_choices + safetensors_dir_choices + return choices if choices else [""] + + +def detect_quant_scheme(model_name): + """Automatically detect quantization scheme from model name + - If model name contains "int8" → "int8" + - If model name contains "fp8" and device supports → "fp8" + - Otherwise return None (no quantization) + """ + if not model_name: + return None + name_lower = model_name.lower() + if "int8" in name_lower: + return "int8" + elif "fp8" in name_lower: + if is_fp8_supported_gpu(): + return "fp8" + else: + # Device doesn't support fp8, return None (use default precision) + return None + return None + + +def update_model_path_options(model_path, model_type="wan2.1"): + """Update all model path selectors when model_path or model_type changes""" + dit_choices = get_dit_choices(model_path, model_type) + high_noise_choices = get_high_noise_choices(model_path) + low_noise_choices = get_low_noise_choices(model_path) + t5_choices = get_t5_choices(model_path) + clip_choices = get_clip_choices(model_path) + vae_choices = get_vae_choices(model_path) + + return ( + gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""), + gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""), + gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""), + gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""), + gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""), + gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""), + ) + + +def generate_random_seed(): + return random.randint(0, MAX_NUMPY_SEED) + + +def is_module_installed(module_name): + try: + spec = importlib.util.find_spec(module_name) + return spec is not None + except ModuleNotFoundError: + return False + + +def get_available_quant_ops(): + available_ops = [] + + vllm_installed = is_module_installed("vllm") + if vllm_installed: + available_ops.append(("vllm", True)) + else: + available_ops.append(("vllm", False)) + + sgl_installed = is_module_installed("sgl_kernel") + if sgl_installed: + available_ops.append(("sgl", True)) + else: + available_ops.append(("sgl", False)) + + q8f_installed = is_module_installed("q8_kernels") + if q8f_installed: + available_ops.append(("q8f", True)) + else: + available_ops.append(("q8f", False)) + + return available_ops + + +def get_available_attn_ops(): + available_ops = [] + + vllm_installed = is_module_installed("flash_attn") + if vllm_installed: + available_ops.append(("flash_attn2", True)) + else: + available_ops.append(("flash_attn2", False)) + + sgl_installed = is_module_installed("flash_attn_interface") + if sgl_installed: + available_ops.append(("flash_attn3", True)) + else: + available_ops.append(("flash_attn3", False)) + + sage_installed = is_module_installed("sageattention") + if sage_installed: + available_ops.append(("sage_attn2", True)) + else: + available_ops.append(("sage_attn2", False)) + + sage3_installed = is_module_installed("sageattn3") + if sage3_installed: + available_ops.append(("sage_attn3", True)) + else: + available_ops.append(("sage_attn3", False)) + + torch_installed = is_module_installed("torch") + if torch_installed: + available_ops.append(("torch_sdpa", True)) + else: + available_ops.append(("torch_sdpa", False)) + + return available_ops + + +def get_gpu_memory(gpu_idx=0): + if not torch.cuda.is_available(): + return 0 + try: + with torch.cuda.device(gpu_idx): + memory_info = torch.cuda.mem_get_info() + total_memory = memory_info[1] / (1024**3) # Convert bytes to GB + return total_memory + except Exception as e: + logger.warning(f"Failed to get GPU memory: {e}") + return 0 + + +def get_cpu_memory(): + available_bytes = psutil.virtual_memory().available + return available_bytes / 1024**3 + + +def cleanup_memory(): + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + try: + import psutil + + if hasattr(psutil, "virtual_memory"): + if os.name == "posix": + try: + os.system("sync") + except: # noqa + pass + except: # noqa + pass + + +def generate_unique_filename(output_dir): + os.makedirs(output_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return os.path.join(output_dir, f"{timestamp}.mp4") + + +def is_fp8_supported_gpu(): + if not torch.cuda.is_available(): + return False + compute_capability = torch.cuda.get_device_capability(0) + major, minor = compute_capability + return (major == 8 and minor == 9) or (major >= 9) + + +def is_ada_architecture_gpu(): + if not torch.cuda.is_available(): + return False + try: + gpu_name = torch.cuda.get_device_name(0).upper() + ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"] + return any(keyword in gpu_name for keyword in ada_keywords) + except Exception as e: + logger.warning(f"Failed to get GPU name: {e}") + return False + + +def get_quantization_options(model_path): + """Get quantization options dynamically based on model_path""" + import os + + # Check subdirectories + subdirs = ["original", "fp8", "int8"] + has_subdirs = {subdir: os.path.exists(os.path.join(model_path, subdir)) for subdir in subdirs} + + # Check original files in root directory + t5_bf16_exists = os.path.exists(os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")) + clip_fp16_exists = os.path.exists(os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")) + + # Generate options + def get_choices(has_subdirs, original_type, fp8_type, int8_type, fallback_type, has_original_file=False): + choices = [] + if has_subdirs["original"]: + choices.append(original_type) + if has_subdirs["fp8"]: + choices.append(fp8_type) + if has_subdirs["int8"]: + choices.append(int8_type) + + # If no subdirectories but original file exists, add original type + if has_original_file: + if not choices or "original" not in choices: + choices.append(original_type) + + # If no options at all, use default value + if not choices: + choices = [fallback_type] + + return choices, choices[0] + + # DIT options + dit_choices, dit_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16") + + # T5 options - check if original file exists + t5_choices, t5_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16", t5_bf16_exists) + + # CLIP options - check if original file exists + clip_choices, clip_default = get_choices(has_subdirs, "fp16", "fp8", "int8", "fp16", clip_fp16_exists) + + return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default} + + +def determine_model_cls(model_type, dit_name, high_noise_name): + """Determine model_cls based on model type and file name""" + # Determine file name to check + if model_type == "wan2.1": + check_name = dit_name.lower() if dit_name else "" + is_distill = "4step" in check_name + return "wan2.1_distill" if is_distill else "wan2.1" + else: + # wan2.2 + check_name = high_noise_name.lower() if high_noise_name else "" + is_distill = "4step" in check_name + return "wan2.2_moe_distill" if is_distill else "wan2.2_moe" + + +global_runner = None +current_config = None +cur_dit_path = None +cur_t5_path = None +cur_clip_path = None + +available_quant_ops = get_available_quant_ops() +quant_op_choices = [] +for op_name, is_installed in available_quant_ops: + status_text = "✅ Installed" if is_installed else "❌ Not Installed" + display_text = f"{op_name} ({status_text})" + quant_op_choices.append((op_name, display_text)) + +available_attn_ops = get_available_attn_ops() +# Priority order +attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"] +# Sort by priority, installed ones first, uninstalled ones last +attn_op_choices = [] +attn_op_dict = dict(available_attn_ops) + +# Add installed ones first (by priority) +for op_name in attn_priority: + if op_name in attn_op_dict and attn_op_dict[op_name]: + status_text = "✅ Installed" + display_text = f"{op_name} ({status_text})" + attn_op_choices.append((op_name, display_text)) + +# Add uninstalled ones (by priority) +for op_name in attn_priority: + if op_name in attn_op_dict and not attn_op_dict[op_name]: + status_text = "❌ Not Installed" + display_text = f"{op_name} ({status_text})" + attn_op_choices.append((op_name, display_text)) + +# Add other operators not in priority list (installed ones first) +other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority] +for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # Installed ones first + status_text = "✅ Installed" if is_installed else "❌ Not Installed" + display_text = f"{op_name} ({status_text})" + attn_op_choices.append((op_name, display_text)) + + +def run_inference( + prompt, + negative_prompt, + save_result_path, + infer_steps, + num_frames, + resolution, + seed, + sample_shift, + enable_cfg, + cfg_scale, + fps, + use_tiling_vae, + lazy_load, + cpu_offload, + offload_granularity, + t5_cpu_offload, + clip_cpu_offload, + vae_cpu_offload, + unload_modules, + attention_type, + quant_op, + rope_chunk, + rope_chunk_size, + clean_cuda_cache, + model_path_input, + model_type_input, + task_type_input, + dit_path_input, + high_noise_path_input, + low_noise_path_input, + t5_path_input, + clip_path_input, + vae_path_input, + image_path=None, +): + cleanup_memory() + + quant_op = quant_op.split("(")[0].strip() + attention_type = attention_type.split("(")[0].strip() + + global global_runner, current_config, model_path, model_cls + global cur_dit_path, cur_t5_path, cur_clip_path + + task = task_type_input + model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input) + logger.info(f"Auto-determined model_cls: {model_cls} (Model type: {model_type_input})") + + if model_type_input == "wan2.1": + dit_quant_detected = detect_quant_scheme(dit_path_input) + else: + dit_quant_detected = detect_quant_scheme(high_noise_path_input) + t5_quant_detected = detect_quant_scheme(t5_path_input) + clip_quant_detected = detect_quant_scheme(clip_path_input) + logger.info(f"Auto-detected quantization scheme - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}") + + if model_path_input and model_path_input.strip(): + model_path = model_path_input.strip() + + if os.path.exists(os.path.join(model_path, "config.json")): + with open(os.path.join(model_path, "config.json"), "r") as f: + model_config = json.load(f) + else: + model_config = {} + + save_result_path = generate_unique_filename(output_dir) + + is_dit_quant = dit_quant_detected != "bf16" + is_t5_quant = t5_quant_detected != "bf16" + is_clip_quant = clip_quant_detected != "fp16" + + dit_quantized_ckpt = None + dit_original_ckpt = None + high_noise_quantized_ckpt = None + low_noise_quantized_ckpt = None + high_noise_original_ckpt = None + low_noise_original_ckpt = None + + if is_dit_quant: + dit_quant_scheme = f"{dit_quant_detected}-{quant_op}" + if "wan2.1" in model_cls: + dit_quantized_ckpt = os.path.join(model_path, dit_path_input) + else: + high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input) + low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input) + else: + dit_quantized_ckpt = "Default" + if "wan2.1" in model_cls: + dit_original_ckpt = os.path.join(model_path, dit_path_input) + else: + high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input) + low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input) + + # Use frontend-selected T5 path + if is_t5_quant: + t5_quantized_ckpt = os.path.join(model_path, t5_path_input) + t5_quant_scheme = f"{t5_quant_detected}-{quant_op}" + t5_original_ckpt = None + else: + t5_quantized_ckpt = None + t5_quant_scheme = None + t5_original_ckpt = os.path.join(model_path, t5_path_input) + + # Use frontend-selected CLIP path + if is_clip_quant: + clip_quantized_ckpt = os.path.join(model_path, clip_path_input) + clip_quant_scheme = f"{clip_quant_detected}-{quant_op}" + clip_original_ckpt = None + else: + clip_quantized_ckpt = None + clip_quant_scheme = None + clip_original_ckpt = os.path.join(model_path, clip_path_input) + + if model_type_input == "wan2.1": + current_dit_path = dit_path_input + else: + current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None + + current_t5_path = t5_path_input + current_clip_path = clip_path_input + + needs_reinit = ( + lazy_load + or unload_modules + or global_runner is None + or current_config is None + or cur_dit_path is None + or cur_dit_path != current_dit_path + or cur_t5_path is None + or cur_t5_path != current_t5_path + or cur_clip_path is None + or cur_clip_path != current_clip_path + ) + + if cfg_scale == 1: + enable_cfg = False + else: + enable_cfg = True + + vae_name_lower = vae_path_input.lower() if vae_path_input else "" + use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower + use_lightvae = "lightvae" in vae_name_lower + need_scaled = "lighttae" in vae_name_lower + + logger.info(f"VAE configuration - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})") + + config_graio = { + "infer_steps": infer_steps, + "target_video_length": num_frames, + "target_width": int(resolution.split("x")[0]), + "target_height": int(resolution.split("x")[1]), + "self_attn_1_type": attention_type, + "cross_attn_1_type": attention_type, + "cross_attn_2_type": attention_type, + "enable_cfg": enable_cfg, + "sample_guide_scale": cfg_scale, + "sample_shift": sample_shift, + "fps": fps, + "feature_caching": "NoCaching", + "do_mm_calib": False, + "parallel_attn_type": None, + "parallel_vae": False, + "max_area": False, + "vae_stride": (4, 8, 8), + "patch_size": (1, 2, 2), + "lora_path": None, + "strength_model": 1.0, + "use_prompt_enhancer": False, + "text_len": 512, + "denoising_step_list": [1000, 750, 500, 250], + "cpu_offload": True if "wan2.2" in model_cls else cpu_offload, + "offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity, + "t5_cpu_offload": t5_cpu_offload, + "clip_cpu_offload": clip_cpu_offload, + "vae_cpu_offload": vae_cpu_offload, + "dit_quantized": is_dit_quant, + "dit_quant_scheme": dit_quant_scheme, + "dit_quantized_ckpt": dit_quantized_ckpt, + "dit_original_ckpt": dit_original_ckpt, + "high_noise_quantized_ckpt": high_noise_quantized_ckpt, + "low_noise_quantized_ckpt": low_noise_quantized_ckpt, + "high_noise_original_ckpt": high_noise_original_ckpt, + "low_noise_original_ckpt": low_noise_original_ckpt, + "t5_original_ckpt": t5_original_ckpt, + "t5_quantized": is_t5_quant, + "t5_quantized_ckpt": t5_quantized_ckpt, + "t5_quant_scheme": t5_quant_scheme, + "clip_original_ckpt": clip_original_ckpt, + "clip_quantized": is_clip_quant, + "clip_quantized_ckpt": clip_quantized_ckpt, + "clip_quant_scheme": clip_quant_scheme, + "vae_path": os.path.join(model_path, vae_path_input), + "use_tiling_vae": use_tiling_vae, + "use_tae": use_tae, + "use_lightvae": use_lightvae, + "need_scaled": need_scaled, + "lazy_load": lazy_load, + "rope_chunk": rope_chunk, + "rope_chunk_size": rope_chunk_size, + "clean_cuda_cache": clean_cuda_cache, + "unload_modules": unload_modules, + "seq_parallel": False, + "warm_up_cpu_buffers": False, + "boundary_step_index": 2, + "boundary": 0.900, + "use_image_encoder": False if "wan2.2" in model_cls else True, + "rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch", + } + + args = argparse.Namespace( + model_cls=model_cls, + seed=seed, + task=task, + model_path=model_path, + prompt_enhancer=None, + prompt=prompt, + negative_prompt=negative_prompt, + image_path=image_path, + save_result_path=save_result_path, + return_result_tensor=False, + ) + + config = get_default_config() + config.update({k: v for k, v in vars(args).items()}) + config.update(model_config) + config.update(config_graio) + + logger.info(f"Using model: {model_path}") + logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}") + + # Initialize or reuse the runner + runner = global_runner + if needs_reinit: + if runner is not None: + del runner + torch.cuda.empty_cache() + gc.collect() + + from lightx2v.infer import init_runner # noqa + + runner = init_runner(config) + input_info = set_input_info(args) + + current_config = config + cur_dit_path = current_dit_path + cur_t5_path = current_t5_path + cur_clip_path = current_clip_path + + if not lazy_load: + global_runner = runner + else: + runner.config = config + + runner.run_pipeline(input_info) + cleanup_memory() + + return save_result_path + + +def handle_lazy_load_change(lazy_load_enabled): + """Handle lazy_load checkbox change to automatically enable unload_modules""" + return gr.update(value=lazy_load_enabled) + + +def auto_configure(resolution): + """Auto-configure inference options based on machine configuration and resolution""" + default_config = { + "lazy_load_val": False, + "rope_chunk_val": False, + "rope_chunk_size_val": 100, + "clean_cuda_cache_val": False, + "cpu_offload_val": False, + "offload_granularity_val": "block", + "t5_cpu_offload_val": False, + "clip_cpu_offload_val": False, + "vae_cpu_offload_val": False, + "unload_modules_val": False, + "attention_type_val": attn_op_choices[0][1], + "quant_op_val": quant_op_choices[0][1], + "use_tiling_vae_val": False, + } + + gpu_memory = round(get_gpu_memory()) + cpu_memory = round(get_cpu_memory()) + + attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"] + + if is_ada_architecture_gpu(): + quant_op_priority = ["q8f", "vllm", "sgl"] + else: + quant_op_priority = ["vllm", "sgl", "q8f"] + + for op in attn_priority: + if dict(available_attn_ops).get(op): + default_config["attention_type_val"] = dict(attn_op_choices)[op] + break + + for op in quant_op_priority: + if dict(available_quant_ops).get(op): + default_config["quant_op_val"] = dict(quant_op_choices)[op] + break + + if resolution in [ + "1280x720", + "720x1280", + "1280x544", + "544x1280", + "1104x832", + "832x1104", + "960x960", + ]: + res = "720p" + elif resolution in [ + "960x544", + "544x960", + ]: + res = "540p" + else: + res = "480p" + + if res == "720p": + gpu_rules = [ + (80, {}), + (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}), + (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}), + ( + 24, + { + "cpu_offload_val": True, + "use_tiling_vae_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + }, + ), + ( + 16, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + "offload_granularity_val": "phase", + "rope_chunk_val": True, + "rope_chunk_size_val": 100, + }, + ), + ( + 8, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + "offload_granularity_val": "phase", + "rope_chunk_val": True, + "rope_chunk_size_val": 100, + "clean_cuda_cache_val": True, + }, + ), + ] + + else: + gpu_rules = [ + (80, {}), + (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}), + (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}), + ( + 24, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + }, + ), + ( + 16, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + "offload_granularity_val": "phase", + }, + ), + ( + 8, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + "offload_granularity_val": "phase", + }, + ), + ] + + cpu_rules = [ + (128, {}), + (64, {}), + (32, {"unload_modules_val": True}), + ( + 16, + { + "lazy_load_val": True, + "unload_modules_val": True, + }, + ), + ] + + for threshold, updates in gpu_rules: + if gpu_memory >= threshold: + default_config.update(updates) + break + + for threshold, updates in cpu_rules: + if cpu_memory >= threshold: + default_config.update(updates) + break + + return ( + gr.update(value=default_config["lazy_load_val"]), + gr.update(value=default_config["rope_chunk_val"]), + gr.update(value=default_config["rope_chunk_size_val"]), + gr.update(value=default_config["clean_cuda_cache_val"]), + gr.update(value=default_config["cpu_offload_val"]), + gr.update(value=default_config["offload_granularity_val"]), + gr.update(value=default_config["t5_cpu_offload_val"]), + gr.update(value=default_config["clip_cpu_offload_val"]), + gr.update(value=default_config["vae_cpu_offload_val"]), + gr.update(value=default_config["unload_modules_val"]), + gr.update(value=default_config["attention_type_val"]), + gr.update(value=default_config["quant_op_val"]), + gr.update(value=default_config["use_tiling_vae_val"]), + ) + + +css = """ + .main-content { max-width: 1600px; margin: auto; padding: 20px; } + .warning { color: #ff6b6b; font-weight: bold; } + + /* Model configuration area styles */ + .model-config { + margin-bottom: 20px !important; + border: 1px solid #e0e0e0; + border-radius: 12px; + padding: 15px; + background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); + } + + /* Input parameters area styles */ + .input-params { + margin-bottom: 20px !important; + border: 1px solid #e0e0e0; + border-radius: 12px; + padding: 15px; + background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%); + } + + /* Output video area styles */ + .output-video { + border: 1px solid #e0e0e0; + border-radius: 12px; + padding: 20px; + background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%); + min-height: 400px; + } + + /* Generate button styles */ + .generate-btn { + width: 100%; + margin-top: 20px; + padding: 15px 30px !important; + font-size: 18px !important; + font-weight: bold !important; + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; + border: none !important; + border-radius: 10px !important; + box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; + transition: all 0.3s ease !important; + } + .generate-btn:hover { + transform: translateY(-2px); + box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; + } + + /* Accordion header styles */ + .model-config .gr-accordion-header, + .input-params .gr-accordion-header, + .output-video .gr-accordion-header { + font-size: 20px !important; + font-weight: bold !important; + padding: 15px !important; + } + + /* Optimize spacing */ + .gr-row { + margin-bottom: 15px; + } + + /* Video player styles */ + .output-video video { + border-radius: 10px; + box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); + } + """ + + +def main(): + with gr.Blocks(title="Lightx2v (Lightweight Video Inference and Generation Engine)") as demo: + gr.Markdown(f"# 🎬 LightX2V Video Generator") + gr.HTML(f"") + # Main layout: left and right columns + with gr.Row(): + # Left: configuration and input area + with gr.Column(scale=5): + # Model configuration area + with gr.Accordion("🗂️ Model Configuration", open=True, elem_classes=["model-config"]): + # FP8 support notice + if not is_fp8_supported_gpu(): + gr.Markdown("⚠️ **Your device does not support FP8 inference**. Models containing FP8 have been automatically hidden.") + + # Hidden state components + model_path_input = gr.Textbox(value=model_path, visible=False) + + # Model type + Task type + with gr.Row(): + model_type_input = gr.Radio( + label="Model Type", + choices=["wan2.1", "wan2.2"], + value="wan2.1", + info="wan2.2 requires separate high noise and low noise models", + ) + task_type_input = gr.Radio( + label="Task Type", + choices=["i2v", "t2v"], + value="i2v", + info="i2v: Image-to-video, t2v: Text-to-video", + ) + + # wan2.1: Diffusion model (single row) + with gr.Row() as wan21_row: + dit_path_input = gr.Dropdown( + label="🎨 Diffusion Model", + choices=get_dit_choices(model_path, "wan2.1"), + value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "", + allow_custom_value=True, + visible=True, + ) + + # wan2.2 specific: high noise model + low noise model (hidden by default) + with gr.Row(visible=False) as wan22_row: + high_noise_path_input = gr.Dropdown( + label="🔊 High Noise Model", + choices=get_high_noise_choices(model_path), + value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "", + allow_custom_value=True, + ) + low_noise_path_input = gr.Dropdown( + label="🔇 Low Noise Model", + choices=get_low_noise_choices(model_path), + value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "", + allow_custom_value=True, + ) + + # Text encoder (single row) + with gr.Row(): + t5_path_input = gr.Dropdown( + label="📝 Text Encoder", + choices=get_t5_choices(model_path), + value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "", + allow_custom_value=True, + ) + + # Image encoder + VAE decoder + with gr.Row(): + clip_path_input = gr.Dropdown( + label="🖼️ Image Encoder", + choices=get_clip_choices(model_path), + value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "", + allow_custom_value=True, + ) + vae_path_input = gr.Dropdown( + label="🎞️ VAE Decoder", + choices=get_vae_choices(model_path), + value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "", + allow_custom_value=True, + ) + + # Attention operator and quantization matrix multiplication operator + with gr.Row(): + attention_type = gr.Dropdown( + label="⚡ Attention Operator", + choices=[op[1] for op in attn_op_choices], + value=attn_op_choices[0][1] if attn_op_choices else "", + info="Use appropriate attention operators to accelerate inference", + ) + quant_op = gr.Dropdown( + label="Quantization Matmul Operator", + choices=[op[1] for op in quant_op_choices], + value=quant_op_choices[0][1], + info="Select quantization matrix multiplication operator to accelerate inference", + interactive=True, + ) + + # Determine if model is distill version + def is_distill_model(model_type, dit_path, high_noise_path): + """Determine if model is distill version based on model type and path""" + if model_type == "wan2.1": + check_name = dit_path.lower() if dit_path else "" + else: + check_name = high_noise_path.lower() if high_noise_path else "" + return "4step" in check_name + + # Model type change event + def on_model_type_change(model_type, model_path_val): + if model_type == "wan2.2": + return gr.update(visible=False), gr.update(visible=True), gr.update() + else: + # Update wan2.1 Diffusion model options + dit_choices = get_dit_choices(model_path_val, "wan2.1") + return ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""), + ) + + model_type_input.change( + fn=on_model_type_change, + inputs=[model_type_input, model_path_input], + outputs=[wan21_row, wan22_row, dit_path_input], + ) + + # Input parameters area + with gr.Accordion("📥 Input Parameters", open=True, elem_classes=["input-params"]): + # Image input (shown for i2v) + with gr.Row(visible=True) as image_input_row: + image_path = gr.Image( + label="Input Image", + type="filepath", + height=300, + interactive=True, + ) + + # Task type change event + def on_task_type_change(task_type): + return gr.update(visible=(task_type == "i2v")) + + task_type_input.change( + fn=on_task_type_change, + inputs=[task_type_input], + outputs=[image_input_row], + ) + + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox( + label="Prompt", + lines=3, + placeholder="Describe the video content...", + max_lines=5, + ) + with gr.Column(): + negative_prompt = gr.Textbox( + label="Negative Prompt", + lines=3, + placeholder="What you don't want to appear in the video...", + max_lines=5, + value="Camera shake, bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + ) + with gr.Column(): + resolution = gr.Dropdown( + choices=[ + # 720p + ("1280x720 (16:9, 720p)", "1280x720"), + ("720x1280 (9:16, 720p)", "720x1280"), + ("1280x544 (21:9, 720p)", "1280x544"), + ("544x1280 (9:21, 720p)", "544x1280"), + ("1104x832 (4:3, 720p)", "1104x832"), + ("832x1104 (3:4, 720p)", "832x1104"), + ("960x960 (1:1, 720p)", "960x960"), + # 480p + ("960x544 (16:9, 540p)", "960x544"), + ("544x960 (9:16, 540p)", "544x960"), + ("832x480 (16:9, 480p)", "832x480"), + ("480x832 (9:16, 480p)", "480x832"), + ("832x624 (4:3, 480p)", "832x624"), + ("624x832 (3:4, 480p)", "624x832"), + ("720x720 (1:1, 480p)", "720x720"), + ("512x512 (1:1, 480p)", "512x512"), + ], + value="832x480", + label="Maximum Resolution", + ) + + with gr.Column(scale=9): + seed = gr.Slider( + label="Random Seed", + minimum=0, + maximum=MAX_NUMPY_SEED, + step=1, + value=generate_random_seed(), + ) + with gr.Column(): + default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "" + default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "" + default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise) + + if default_is_distill: + infer_steps = gr.Slider( + label="Inference Steps", + minimum=1, + maximum=100, + step=1, + value=4, + info="Distill model inference steps default to 4.", + ) + else: + infer_steps = gr.Slider( + label="Inference Steps", + minimum=1, + maximum=100, + step=1, + value=40, + info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.", + ) + + # Dynamically update inference steps when model path changes + def update_infer_steps(model_type, dit_path, high_noise_path): + is_distill = is_distill_model(model_type, dit_path, high_noise_path) + if is_distill: + return gr.update(minimum=1, maximum=100, value=4, interactive=True) + else: + return gr.update(minimum=1, maximum=100, value=40, interactive=True) + + # Listen to model path changes + dit_path_input.change( + fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[infer_steps], + ) + high_noise_path_input.change( + fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[infer_steps], + ) + model_type_input.change( + fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[infer_steps], + ) + + # Set default CFG based on model class + # CFG scale factor: default to 1 for distill, otherwise 5 + default_cfg_scale = 1 if default_is_distill else 5 + # enable_cfg is not exposed to frontend, automatically set based on cfg_scale + # If cfg_scale == 1, then enable_cfg = False, otherwise enable_cfg = True + default_enable_cfg = False if default_cfg_scale == 1 else True + enable_cfg = gr.Checkbox( + label="Enable Classifier-Free Guidance", + value=default_enable_cfg, + visible=False, # Hidden, not exposed to frontend + ) + + with gr.Row(): + sample_shift = gr.Slider( + label="Distribution Shift", + value=5, + minimum=0, + maximum=10, + step=1, + info="Controls the degree of distribution shift for samples. Larger values indicate more significant shifts.", + ) + cfg_scale = gr.Slider( + label="CFG Scale Factor", + minimum=1, + maximum=10, + step=1, + value=default_cfg_scale, + info="Controls the influence strength of the prompt. Higher values give more influence to the prompt. When value is 1, CFG is automatically disabled.", + ) + + # Update enable_cfg based on cfg_scale + def update_enable_cfg(cfg_scale_val): + """Automatically set enable_cfg based on cfg_scale value""" + if cfg_scale_val == 1: + return gr.update(value=False) + else: + return gr.update(value=True) + + # Dynamically update CFG scale factor and enable_cfg when model path changes + def update_cfg_scale(model_type, dit_path, high_noise_path): + is_distill = is_distill_model(model_type, dit_path, high_noise_path) + if is_distill: + new_cfg_scale = 1 + else: + new_cfg_scale = 5 + new_enable_cfg = False if new_cfg_scale == 1 else True + return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg) + + dit_path_input.change( + fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[cfg_scale, enable_cfg], + ) + high_noise_path_input.change( + fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[cfg_scale, enable_cfg], + ) + model_type_input.change( + fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[cfg_scale, enable_cfg], + ) + + cfg_scale.change( + fn=update_enable_cfg, + inputs=[cfg_scale], + outputs=[enable_cfg], + ) + + with gr.Row(): + fps = gr.Slider( + label="Frames Per Second (FPS)", + minimum=8, + maximum=30, + step=1, + value=16, + info="Frames per second of the video. Higher FPS results in smoother videos.", + ) + num_frames = gr.Slider( + label="Total Frames", + minimum=16, + maximum=120, + step=1, + value=81, + info="Total number of frames in the video. More frames result in longer videos.", + ) + + save_result_path = gr.Textbox( + label="Output Video Path", + value=generate_unique_filename(output_dir), + info="Must include .mp4 extension. If left blank or using the default value, a unique filename will be automatically generated.", + visible=False, # Hide output path, auto-generated + ) + + with gr.Column(scale=4): + with gr.Accordion("📤 Generated Video", open=True, elem_classes=["output-video"]): + output_video = gr.Video( + label="", + height=600, + autoplay=True, + show_label=False, + ) + + infer_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg", elem_classes=["generate-btn"]) + + rope_chunk = gr.Checkbox(label="Chunked Rotary Position Embedding", value=False, visible=False) + rope_chunk_size = gr.Slider(label="Rotary Embedding Chunk Size", value=100, minimum=100, maximum=10000, step=100, visible=False) + unload_modules = gr.Checkbox(label="Unload Modules", value=False, visible=False) + clean_cuda_cache = gr.Checkbox(label="Clean CUDA Memory Cache", value=False, visible=False) + cpu_offload = gr.Checkbox(label="CPU Offloading", value=False, visible=False) + lazy_load = gr.Checkbox(label="Enable Lazy Loading", value=False, visible=False) + offload_granularity = gr.Dropdown(label="Dit Offload Granularity", choices=["block", "phase"], value="phase", visible=False) + t5_cpu_offload = gr.Checkbox(label="T5 CPU Offloading", value=False, visible=False) + clip_cpu_offload = gr.Checkbox(label="CLIP CPU Offloading", value=False, visible=False) + vae_cpu_offload = gr.Checkbox(label="VAE CPU Offloading", value=False, visible=False) + use_tiling_vae = gr.Checkbox(label="VAE Tiling Inference", value=False, visible=False) + + resolution.change( + fn=auto_configure, + inputs=[resolution], + outputs=[ + lazy_load, + rope_chunk, + rope_chunk_size, + clean_cuda_cache, + cpu_offload, + offload_granularity, + t5_cpu_offload, + clip_cpu_offload, + vae_cpu_offload, + unload_modules, + attention_type, + quant_op, + use_tiling_vae, + ], + ) + + demo.load( + fn=lambda res: auto_configure(res), + inputs=[resolution], + outputs=[ + lazy_load, + rope_chunk, + rope_chunk_size, + clean_cuda_cache, + cpu_offload, + offload_granularity, + t5_cpu_offload, + clip_cpu_offload, + vae_cpu_offload, + unload_modules, + attention_type, + quant_op, + use_tiling_vae, + ], + ) + + infer_btn.click( + fn=run_inference, + inputs=[ + prompt, + negative_prompt, + save_result_path, + infer_steps, + num_frames, + resolution, + seed, + sample_shift, + enable_cfg, + cfg_scale, + fps, + use_tiling_vae, + lazy_load, + cpu_offload, + offload_granularity, + t5_cpu_offload, + clip_cpu_offload, + vae_cpu_offload, + unload_modules, + attention_type, + quant_op, + rope_chunk, + rope_chunk_size, + clean_cuda_cache, + model_path_input, + model_type_input, + task_type_input, + dit_path_input, + high_noise_path_input, + low_noise_path_input, + t5_path_input, + clip_path_input, + vae_path_input, + image_path, + ], + outputs=output_video, + ) + + demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Lightweight Video Generation") + parser.add_argument("--model_path", type=str, required=True, help="Model folder path") + parser.add_argument("--server_port", type=int, default=7862, help="Server port") + parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server IP") + parser.add_argument("--output_dir", type=str, default="./outputs", help="Output video save directory") + args = parser.parse_args() + + global model_path, model_cls, output_dir + model_path = args.model_path + model_cls = "wan2.1" + output_dir = args.output_dir + + main() diff --git a/app/gradio_demo_zh.py b/app/gradio_demo_zh.py new file mode 100644 index 0000000000000000000000000000000000000000..229409f7bf8df4cce3c056fbe334e680abd8fc7f --- /dev/null +++ b/app/gradio_demo_zh.py @@ -0,0 +1,1442 @@ +import argparse +import gc +import glob +import importlib.util +import json +import os + +os.environ["PROFILING_DEBUG_LEVEL"] = "2" +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["DTYPE"] = "BF16" +import random +from datetime import datetime + +import gradio as gr +import psutil +import torch +from loguru import logger + +from lightx2v.utils.input_info import set_input_info +from lightx2v.utils.set_config import get_default_config + +try: + from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace +except ImportError: + apply_rope_with_cos_sin_cache_inplace = None + + +logger.add( + "inference_logs.log", + rotation="100 MB", + encoding="utf-8", + enqueue=True, + backtrace=True, + diagnose=True, +) + +MAX_NUMPY_SEED = 2**32 - 1 + + +def scan_model_path_contents(model_path): + """扫描 model_path 目录,返回可用的文件和子目录""" + if not model_path or not os.path.exists(model_path): + return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []} + + dirs = [] + files = [] + safetensors_dirs = [] + pth_files = [] + + try: + for item in os.listdir(model_path): + item_path = os.path.join(model_path, item) + if os.path.isdir(item_path): + dirs.append(item) + # 检查目录是否包含 safetensors 文件 + if glob.glob(os.path.join(item_path, "*.safetensors")): + safetensors_dirs.append(item) + elif os.path.isfile(item_path): + files.append(item) + if item.endswith(".pth"): + pth_files.append(item) + except Exception as e: + logger.warning(f"扫描目录失败: {e}") + + return { + "dirs": sorted(dirs), + "files": sorted(files), + "safetensors_dirs": sorted(safetensors_dirs), + "pth_files": sorted(pth_files), + } + + +def get_dit_choices(model_path, model_type="wan2.1"): + """获取 Diffusion 模型可选项(根据模型类型筛选)""" + contents = scan_model_path_contents(model_path) + excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"] + fp8_supported = is_fp8_supported_gpu() + + if model_type == "wan2.1": + # wan2.1: 筛选包含 wan2.1 或 Wan2.1 的文件/目录 + def is_valid(name): + name_lower = name.lower() + if "wan2.1" not in name_lower: + return False + if not fp8_supported and "fp8" in name_lower: + return False + return not any(kw in name_lower for kw in excluded_keywords) + else: + # wan2.2: 筛选包含 wan2.2 或 Wan2.2 的文件/目录 + def is_valid(name): + name_lower = name.lower() + if "wan2.2" not in name_lower: + return False + if not fp8_supported and "fp8" in name_lower: + return False + return not any(kw in name_lower for kw in excluded_keywords) + + # 筛选符合条件的目录和文件 + dir_choices = [d for d in contents["dirs"] if is_valid(d)] + file_choices = [f for f in contents["files"] if is_valid(f)] + choices = dir_choices + file_choices + return choices if choices else [""] + + +def get_high_noise_choices(model_path): + """获取高噪模型可选项(包含 high_noise 的文件/目录)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + def is_valid(name): + name_lower = name.lower() + if not fp8_supported and "fp8" in name_lower: + return False + return "high_noise" in name_lower or "high-noise" in name_lower + + dir_choices = [d for d in contents["dirs"] if is_valid(d)] + file_choices = [f for f in contents["files"] if is_valid(f)] + choices = dir_choices + file_choices + return choices if choices else [""] + + +def get_low_noise_choices(model_path): + """获取低噪模型可选项(包含 low_noise 的文件/目录)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + def is_valid(name): + name_lower = name.lower() + if not fp8_supported and "fp8" in name_lower: + return False + return "low_noise" in name_lower or "low-noise" in name_lower + + dir_choices = [d for d in contents["dirs"] if is_valid(d)] + file_choices = [f for f in contents["files"] if is_valid(f)] + choices = dir_choices + file_choices + return choices if choices else [""] + + +def get_t5_choices(model_path): + """获取 T5 模型可选项(.pth 或 .safetensors 文件,包含 t5 关键字)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + # 从 .pth 文件中筛选 + pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())] + + # 从 .safetensors 文件中筛选 + safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())] + + # 从包含 safetensors 的目录中筛选 + safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())] + + choices = pth_choices + safetensors_choices + safetensors_dir_choices + return choices if choices else [""] + + +def get_clip_choices(model_path): + """获取 CLIP 模型可选项(.pth 或 .safetensors 文件,包含 clip 关键字)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + # 从 .pth 文件中筛选 + pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())] + + # 从 .safetensors 文件中筛选 + safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())] + + # 从包含 safetensors 的目录中筛选 + safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())] + + choices = pth_choices + safetensors_choices + safetensors_dir_choices + return choices if choices else [""] + + +def get_vae_choices(model_path): + """获取 VAE 模型可选项(.pth 或 .safetensors 文件,包含 vae/VAE/tae 关键字)""" + contents = scan_model_path_contents(model_path) + fp8_supported = is_fp8_supported_gpu() + + # 从 .pth 文件中筛选 + pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())] + + # 从 .safetensors 文件中筛选 + safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())] + + # 从包含 safetensors 的目录中筛选 + safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())] + + choices = pth_choices + safetensors_choices + safetensors_dir_choices + return choices if choices else [""] + + +def detect_quant_scheme(model_name): + """根据模型名字自动检测量化精度 + - 如果模型名字包含 "int8" → "int8" + - 如果模型名字包含 "fp8" 且设备支持 → "fp8" + - 否则返回 None(表示不使用量化) + """ + if not model_name: + return None + name_lower = model_name.lower() + if "int8" in name_lower: + return "int8" + elif "fp8" in name_lower: + if is_fp8_supported_gpu(): + return "fp8" + else: + # 设备不支持fp8,返回None(使用默认精度) + return None + return None + + +def update_model_path_options(model_path, model_type="wan2.1"): + """当 model_path 或 model_type 改变时,更新所有模型路径选择器""" + dit_choices = get_dit_choices(model_path, model_type) + high_noise_choices = get_high_noise_choices(model_path) + low_noise_choices = get_low_noise_choices(model_path) + t5_choices = get_t5_choices(model_path) + clip_choices = get_clip_choices(model_path) + vae_choices = get_vae_choices(model_path) + + return ( + gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""), + gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""), + gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""), + gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""), + gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""), + gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""), + ) + + +def generate_random_seed(): + return random.randint(0, MAX_NUMPY_SEED) + + +def is_module_installed(module_name): + try: + spec = importlib.util.find_spec(module_name) + return spec is not None + except ModuleNotFoundError: + return False + + +def get_available_quant_ops(): + available_ops = [] + + vllm_installed = is_module_installed("vllm") + if vllm_installed: + available_ops.append(("vllm", True)) + else: + available_ops.append(("vllm", False)) + + sgl_installed = is_module_installed("sgl_kernel") + if sgl_installed: + available_ops.append(("sgl", True)) + else: + available_ops.append(("sgl", False)) + + q8f_installed = is_module_installed("q8_kernels") + if q8f_installed: + available_ops.append(("q8f", True)) + else: + available_ops.append(("q8f", False)) + + return available_ops + + +def get_available_attn_ops(): + available_ops = [] + + vllm_installed = is_module_installed("flash_attn") + if vllm_installed: + available_ops.append(("flash_attn2", True)) + else: + available_ops.append(("flash_attn2", False)) + + sgl_installed = is_module_installed("flash_attn_interface") + if sgl_installed: + available_ops.append(("flash_attn3", True)) + else: + available_ops.append(("flash_attn3", False)) + + sage_installed = is_module_installed("sageattention") + if sage_installed: + available_ops.append(("sage_attn2", True)) + else: + available_ops.append(("sage_attn2", False)) + + sage3_installed = is_module_installed("sageattn3") + if sage3_installed: + available_ops.append(("sage_attn3", True)) + else: + available_ops.append(("sage_attn3", False)) + + torch_installed = is_module_installed("torch") + if torch_installed: + available_ops.append(("torch_sdpa", True)) + else: + available_ops.append(("torch_sdpa", False)) + + return available_ops + + +def get_gpu_memory(gpu_idx=0): + if not torch.cuda.is_available(): + return 0 + try: + with torch.cuda.device(gpu_idx): + memory_info = torch.cuda.mem_get_info() + total_memory = memory_info[1] / (1024**3) # Convert bytes to GB + return total_memory + except Exception as e: + logger.warning(f"获取GPU内存失败: {e}") + return 0 + + +def get_cpu_memory(): + available_bytes = psutil.virtual_memory().available + return available_bytes / 1024**3 + + +def cleanup_memory(): + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + try: + import psutil + + if hasattr(psutil, "virtual_memory"): + if os.name == "posix": + try: + os.system("sync") + except: # noqa + pass + except: # noqa + pass + + +def generate_unique_filename(output_dir): + os.makedirs(output_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return os.path.join(output_dir, f"{timestamp}.mp4") + + +def is_fp8_supported_gpu(): + if not torch.cuda.is_available(): + return False + compute_capability = torch.cuda.get_device_capability(0) + major, minor = compute_capability + return (major == 8 and minor == 9) or (major >= 9) + + +def is_ada_architecture_gpu(): + if not torch.cuda.is_available(): + return False + try: + gpu_name = torch.cuda.get_device_name(0).upper() + ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"] + return any(keyword in gpu_name for keyword in ada_keywords) + except Exception as e: + logger.warning(f"Failed to get GPU name: {e}") + return False + + +def get_quantization_options(model_path): + """根据model_path动态获取量化选项""" + import os + + # 检查子目录 + subdirs = ["original", "fp8", "int8"] + has_subdirs = {subdir: os.path.exists(os.path.join(model_path, subdir)) for subdir in subdirs} + + # 检查根目录下的原始文件 + t5_bf16_exists = os.path.exists(os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")) + clip_fp16_exists = os.path.exists(os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")) + + # 生成选项 + def get_choices(has_subdirs, original_type, fp8_type, int8_type, fallback_type, has_original_file=False): + choices = [] + if has_subdirs["original"]: + choices.append(original_type) + if has_subdirs["fp8"]: + choices.append(fp8_type) + if has_subdirs["int8"]: + choices.append(int8_type) + + # 如果没有子目录但有原始文件,添加原始类型 + if has_original_file: + if not choices or "original" not in choices: + choices.append(original_type) + + # 如果没有任何选项,使用默认值 + if not choices: + choices = [fallback_type] + + return choices, choices[0] + + # DIT选项 + dit_choices, dit_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16") + + # T5选项 - 检查是否有原始文件 + t5_choices, t5_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16", t5_bf16_exists) + + # CLIP选项 - 检查是否有原始文件 + clip_choices, clip_default = get_choices(has_subdirs, "fp16", "fp8", "int8", "fp16", clip_fp16_exists) + + return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default} + + +def determine_model_cls(model_type, dit_name, high_noise_name): + """根据模型类型和文件名确定 model_cls""" + # 确定要检查的文件名 + if model_type == "wan2.1": + check_name = dit_name.lower() if dit_name else "" + is_distill = "4step" in check_name + return "wan2.1_distill" if is_distill else "wan2.1" + else: + # wan2.2 + check_name = high_noise_name.lower() if high_noise_name else "" + is_distill = "4step" in check_name + return "wan2.2_moe_distill" if is_distill else "wan2.2_moe" + + +global_runner = None +current_config = None +cur_dit_path = None +cur_t5_path = None +cur_clip_path = None + +available_quant_ops = get_available_quant_ops() +quant_op_choices = [] +for op_name, is_installed in available_quant_ops: + status_text = "✅ 已安装" if is_installed else "❌ 未安装" + display_text = f"{op_name} ({status_text})" + quant_op_choices.append((op_name, display_text)) + +available_attn_ops = get_available_attn_ops() +# 优先级顺序 +attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"] +# 按优先级排序,已安装的在前,未安装的在后 +attn_op_choices = [] +attn_op_dict = dict(available_attn_ops) + +# 先添加已安装的(按优先级) +for op_name in attn_priority: + if op_name in attn_op_dict and attn_op_dict[op_name]: + status_text = "✅ 已安装" + display_text = f"{op_name} ({status_text})" + attn_op_choices.append((op_name, display_text)) + +# 再添加未安装的(按优先级) +for op_name in attn_priority: + if op_name in attn_op_dict and not attn_op_dict[op_name]: + status_text = "❌ 未安装" + display_text = f"{op_name} ({status_text})" + attn_op_choices.append((op_name, display_text)) + +# 添加其他不在优先级列表中的算子(已安装的在前) +other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority] +for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # 已安装的在前 + status_text = "✅ 已安装" if is_installed else "❌ 未安装" + display_text = f"{op_name} ({status_text})" + attn_op_choices.append((op_name, display_text)) + + +def run_inference( + prompt, + negative_prompt, + save_result_path, + infer_steps, + num_frames, + resolution, + seed, + sample_shift, + enable_cfg, + cfg_scale, + fps, + use_tiling_vae, + lazy_load, + cpu_offload, + offload_granularity, + t5_cpu_offload, + clip_cpu_offload, + vae_cpu_offload, + unload_modules, + attention_type, + quant_op, + rope_chunk, + rope_chunk_size, + clean_cuda_cache, + model_path_input, + model_type_input, + task_type_input, + dit_path_input, + high_noise_path_input, + low_noise_path_input, + t5_path_input, + clip_path_input, + vae_path_input, + image_path=None, +): + cleanup_memory() + + quant_op = quant_op.split("(")[0].strip() + attention_type = attention_type.split("(")[0].strip() + + global global_runner, current_config, model_path, model_cls + global cur_dit_path, cur_t5_path, cur_clip_path + + task = task_type_input + model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input) + logger.info(f"自动确定 model_cls: {model_cls} (模型类型: {model_type_input})") + + if model_type_input == "wan2.1": + dit_quant_detected = detect_quant_scheme(dit_path_input) + else: + dit_quant_detected = detect_quant_scheme(high_noise_path_input) + t5_quant_detected = detect_quant_scheme(t5_path_input) + clip_quant_detected = detect_quant_scheme(clip_path_input) + logger.info(f"自动检测量化精度 - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}") + + if model_path_input and model_path_input.strip(): + model_path = model_path_input.strip() + + if os.path.exists(os.path.join(model_path, "config.json")): + with open(os.path.join(model_path, "config.json"), "r") as f: + model_config = json.load(f) + else: + model_config = {} + + save_result_path = generate_unique_filename(output_dir) + + is_dit_quant = dit_quant_detected != "bf16" + is_t5_quant = t5_quant_detected != "bf16" + is_clip_quant = clip_quant_detected != "fp16" + + dit_quantized_ckpt = None + dit_original_ckpt = None + high_noise_quantized_ckpt = None + low_noise_quantized_ckpt = None + high_noise_original_ckpt = None + low_noise_original_ckpt = None + + if is_dit_quant: + dit_quant_scheme = f"{dit_quant_detected}-{quant_op}" + if "wan2.1" in model_cls: + dit_quantized_ckpt = os.path.join(model_path, dit_path_input) + else: + high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input) + low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input) + else: + dit_quantized_ckpt = "Default" + if "wan2.1" in model_cls: + dit_original_ckpt = os.path.join(model_path, dit_path_input) + else: + high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input) + low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input) + + # 使用前端选择的 T5 路径 + if is_t5_quant: + t5_quantized_ckpt = os.path.join(model_path, t5_path_input) + t5_quant_scheme = f"{t5_quant_detected}-{quant_op}" + t5_original_ckpt = None + else: + t5_quantized_ckpt = None + t5_quant_scheme = None + t5_original_ckpt = os.path.join(model_path, t5_path_input) + + # 使用前端选择的 CLIP 路径 + if is_clip_quant: + clip_quantized_ckpt = os.path.join(model_path, clip_path_input) + clip_quant_scheme = f"{clip_quant_detected}-{quant_op}" + clip_original_ckpt = None + else: + clip_quantized_ckpt = None + clip_quant_scheme = None + clip_original_ckpt = os.path.join(model_path, clip_path_input) + + if model_type_input == "wan2.1": + current_dit_path = dit_path_input + else: + current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None + + current_t5_path = t5_path_input + current_clip_path = clip_path_input + + needs_reinit = ( + lazy_load + or unload_modules + or global_runner is None + or current_config is None + or cur_dit_path is None + or cur_dit_path != current_dit_path + or cur_t5_path is None + or cur_t5_path != current_t5_path + or cur_clip_path is None + or cur_clip_path != current_clip_path + ) + + if cfg_scale == 1: + enable_cfg = False + else: + enable_cfg = True + + vae_name_lower = vae_path_input.lower() if vae_path_input else "" + use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower + use_lightvae = "lightvae" in vae_name_lower + need_scaled = "lighttae" in vae_name_lower + + logger.info(f"VAE 配置 - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})") + + config_graio = { + "infer_steps": infer_steps, + "target_video_length": num_frames, + "target_width": int(resolution.split("x")[0]), + "target_height": int(resolution.split("x")[1]), + "self_attn_1_type": attention_type, + "cross_attn_1_type": attention_type, + "cross_attn_2_type": attention_type, + "enable_cfg": enable_cfg, + "sample_guide_scale": cfg_scale, + "sample_shift": sample_shift, + "fps": fps, + "feature_caching": "NoCaching", + "do_mm_calib": False, + "parallel_attn_type": None, + "parallel_vae": False, + "max_area": False, + "vae_stride": (4, 8, 8), + "patch_size": (1, 2, 2), + "lora_path": None, + "strength_model": 1.0, + "use_prompt_enhancer": False, + "text_len": 512, + "denoising_step_list": [1000, 750, 500, 250], + "cpu_offload": True if "wan2.2" in model_cls else cpu_offload, + "offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity, + "t5_cpu_offload": t5_cpu_offload, + "clip_cpu_offload": clip_cpu_offload, + "vae_cpu_offload": vae_cpu_offload, + "dit_quantized": is_dit_quant, + "dit_quant_scheme": dit_quant_scheme, + "dit_quantized_ckpt": dit_quantized_ckpt, + "dit_original_ckpt": dit_original_ckpt, + "high_noise_quantized_ckpt": high_noise_quantized_ckpt, + "low_noise_quantized_ckpt": low_noise_quantized_ckpt, + "high_noise_original_ckpt": high_noise_original_ckpt, + "low_noise_original_ckpt": low_noise_original_ckpt, + "t5_original_ckpt": t5_original_ckpt, + "t5_quantized": is_t5_quant, + "t5_quantized_ckpt": t5_quantized_ckpt, + "t5_quant_scheme": t5_quant_scheme, + "clip_original_ckpt": clip_original_ckpt, + "clip_quantized": is_clip_quant, + "clip_quantized_ckpt": clip_quantized_ckpt, + "clip_quant_scheme": clip_quant_scheme, + "vae_path": os.path.join(model_path, vae_path_input), + "use_tiling_vae": use_tiling_vae, + "use_tae": use_tae, + "use_lightvae": use_lightvae, + "need_scaled": need_scaled, + "lazy_load": lazy_load, + "rope_chunk": rope_chunk, + "rope_chunk_size": rope_chunk_size, + "clean_cuda_cache": clean_cuda_cache, + "unload_modules": unload_modules, + "seq_parallel": False, + "warm_up_cpu_buffers": False, + "boundary_step_index": 2, + "boundary": 0.900, + "use_image_encoder": False if "wan2.2" in model_cls else True, + "rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch", + } + + args = argparse.Namespace( + model_cls=model_cls, + seed=seed, + task=task, + model_path=model_path, + prompt_enhancer=None, + prompt=prompt, + negative_prompt=negative_prompt, + image_path=image_path, + save_result_path=save_result_path, + return_result_tensor=False, + ) + + config = get_default_config() + config.update({k: v for k, v in vars(args).items()}) + config.update(model_config) + config.update(config_graio) + + logger.info(f"使用模型: {model_path}") + logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}") + + # Initialize or reuse the runner + runner = global_runner + if needs_reinit: + if runner is not None: + del runner + torch.cuda.empty_cache() + gc.collect() + + from lightx2v.infer import init_runner # noqa + + runner = init_runner(config) + input_info = set_input_info(args) + + current_config = config + cur_dit_path = current_dit_path + cur_t5_path = current_t5_path + cur_clip_path = current_clip_path + + if not lazy_load: + global_runner = runner + else: + runner.config = config + + runner.run_pipeline(input_info) + cleanup_memory() + + return save_result_path + + +def handle_lazy_load_change(lazy_load_enabled): + """Handle lazy_load checkbox change to automatically enable unload_modules""" + return gr.update(value=lazy_load_enabled) + + +def auto_configure(resolution): + """根据机器配置和分辨率自动设置推理选项""" + default_config = { + "lazy_load_val": False, + "rope_chunk_val": False, + "rope_chunk_size_val": 100, + "clean_cuda_cache_val": False, + "cpu_offload_val": False, + "offload_granularity_val": "block", + "t5_cpu_offload_val": False, + "clip_cpu_offload_val": False, + "vae_cpu_offload_val": False, + "unload_modules_val": False, + "attention_type_val": attn_op_choices[0][1], + "quant_op_val": quant_op_choices[0][1], + "use_tiling_vae_val": False, + } + + gpu_memory = round(get_gpu_memory()) + cpu_memory = round(get_cpu_memory()) + + attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"] + + if is_ada_architecture_gpu(): + quant_op_priority = ["q8f", "vllm", "sgl"] + else: + quant_op_priority = ["vllm", "sgl", "q8f"] + + for op in attn_priority: + if dict(available_attn_ops).get(op): + default_config["attention_type_val"] = dict(attn_op_choices)[op] + break + + for op in quant_op_priority: + if dict(available_quant_ops).get(op): + default_config["quant_op_val"] = dict(quant_op_choices)[op] + break + + if resolution in [ + "1280x720", + "720x1280", + "1280x544", + "544x1280", + "1104x832", + "832x1104", + "960x960", + ]: + res = "720p" + elif resolution in [ + "960x544", + "544x960", + ]: + res = "540p" + else: + res = "480p" + + if res == "720p": + gpu_rules = [ + (80, {}), + (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}), + (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}), + ( + 24, + { + "cpu_offload_val": True, + "use_tiling_vae_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + }, + ), + ( + 16, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + "offload_granularity_val": "phase", + "rope_chunk_val": True, + "rope_chunk_size_val": 100, + }, + ), + ( + 8, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + "offload_granularity_val": "phase", + "rope_chunk_val": True, + "rope_chunk_size_val": 100, + "clean_cuda_cache_val": True, + }, + ), + ] + + else: + gpu_rules = [ + (80, {}), + (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}), + (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}), + ( + 24, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + }, + ), + ( + 16, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + "offload_granularity_val": "phase", + }, + ), + ( + 8, + { + "cpu_offload_val": True, + "t5_cpu_offload_val": True, + "vae_cpu_offload_val": True, + "clip_cpu_offload_val": True, + "use_tiling_vae_val": True, + "offload_granularity_val": "phase", + }, + ), + ] + + cpu_rules = [ + (128, {}), + (64, {}), + (32, {"unload_modules_val": True}), + ( + 16, + { + "lazy_load_val": True, + "unload_modules_val": True, + }, + ), + ] + + for threshold, updates in gpu_rules: + if gpu_memory >= threshold: + default_config.update(updates) + break + + for threshold, updates in cpu_rules: + if cpu_memory >= threshold: + default_config.update(updates) + break + + return ( + gr.update(value=default_config["lazy_load_val"]), + gr.update(value=default_config["rope_chunk_val"]), + gr.update(value=default_config["rope_chunk_size_val"]), + gr.update(value=default_config["clean_cuda_cache_val"]), + gr.update(value=default_config["cpu_offload_val"]), + gr.update(value=default_config["offload_granularity_val"]), + gr.update(value=default_config["t5_cpu_offload_val"]), + gr.update(value=default_config["clip_cpu_offload_val"]), + gr.update(value=default_config["vae_cpu_offload_val"]), + gr.update(value=default_config["unload_modules_val"]), + gr.update(value=default_config["attention_type_val"]), + gr.update(value=default_config["quant_op_val"]), + gr.update(value=default_config["use_tiling_vae_val"]), + ) + + +css = """ + .main-content { max-width: 1600px; margin: auto; padding: 20px; } + .warning { color: #ff6b6b; font-weight: bold; } + + /* 模型配置区域样式 */ + .model-config { + margin-bottom: 20px !important; + border: 1px solid #e0e0e0; + border-radius: 12px; + padding: 15px; + background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); + } + + /* 输入参数区域样式 */ + .input-params { + margin-bottom: 20px !important; + border: 1px solid #e0e0e0; + border-radius: 12px; + padding: 15px; + background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%); + } + + /* 输出视频区域样式 */ + .output-video { + border: 1px solid #e0e0e0; + border-radius: 12px; + padding: 20px; + background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%); + min-height: 400px; + } + + /* 生成按钮样式 */ + .generate-btn { + width: 100%; + margin-top: 20px; + padding: 15px 30px !important; + font-size: 18px !important; + font-weight: bold !important; + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; + border: none !important; + border-radius: 10px !important; + box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; + transition: all 0.3s ease !important; + } + .generate-btn:hover { + transform: translateY(-2px); + box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; + } + + /* Accordion 标题样式 */ + .model-config .gr-accordion-header, + .input-params .gr-accordion-header, + .output-video .gr-accordion-header { + font-size: 20px !important; + font-weight: bold !important; + padding: 15px !important; + } + + /* 优化间距 */ + .gr-row { + margin-bottom: 15px; + } + + /* 视频播放器样式 */ + .output-video video { + border-radius: 10px; + box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); + } + """ + + +def main(): + with gr.Blocks(title="Lightx2v (轻量级视频推理和生成引擎)") as demo: + gr.Markdown(f"# 🎬 LightX2V 视频生成器") + gr.HTML(f"") + # 主布局:左右分栏 + with gr.Row(): + # 左侧:配置和输入区域 + with gr.Column(scale=5): + # 模型配置区域 + with gr.Accordion("🗂️ 模型配置", open=True, elem_classes=["model-config"]): + # FP8 支持提示 + if not is_fp8_supported_gpu(): + gr.Markdown("⚠️ **您的设备不支持fp8推理**,已自动隐藏包含fp8的模型选项。") + + # 隐藏的状态组件 + model_path_input = gr.Textbox(value=model_path, visible=False) + + # 模型类型 + 任务类型 + with gr.Row(): + model_type_input = gr.Radio( + label="模型类型", + choices=["wan2.1", "wan2.2"], + value="wan2.1", + info="wan2.2 需要分别指定高噪模型和低噪模型", + ) + task_type_input = gr.Radio( + label="任务类型", + choices=["i2v", "t2v"], + value="i2v", + info="i2v: 图生视频, t2v: 文生视频", + ) + + # wan2.1:Diffusion模型(单独一行) + with gr.Row() as wan21_row: + dit_path_input = gr.Dropdown( + label="🎨 Diffusion模型", + choices=get_dit_choices(model_path, "wan2.1"), + value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "", + allow_custom_value=True, + visible=True, + ) + + # wan2.2 专用:高噪模型 + 低噪模型(默认隐藏) + with gr.Row(visible=False) as wan22_row: + high_noise_path_input = gr.Dropdown( + label="🔊 高噪模型", + choices=get_high_noise_choices(model_path), + value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "", + allow_custom_value=True, + ) + low_noise_path_input = gr.Dropdown( + label="🔇 低噪模型", + choices=get_low_noise_choices(model_path), + value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "", + allow_custom_value=True, + ) + + # 文本编码器(单独一行) + with gr.Row(): + t5_path_input = gr.Dropdown( + label="📝 文本编码器", + choices=get_t5_choices(model_path), + value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "", + allow_custom_value=True, + ) + + # 图像编码器 + VAE解码器 + with gr.Row(): + clip_path_input = gr.Dropdown( + label="🖼️ 图像编码器", + choices=get_clip_choices(model_path), + value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "", + allow_custom_value=True, + ) + vae_path_input = gr.Dropdown( + label="🎞️ VAE解码器", + choices=get_vae_choices(model_path), + value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "", + allow_custom_value=True, + ) + + # 注意力算子和量化矩阵乘法算子 + with gr.Row(): + attention_type = gr.Dropdown( + label="⚡ 注意力算子", + choices=[op[1] for op in attn_op_choices], + value=attn_op_choices[0][1] if attn_op_choices else "", + info="使用适当的注意力算子加速推理", + ) + quant_op = gr.Dropdown( + label="量化矩阵乘法算子", + choices=[op[1] for op in quant_op_choices], + value=quant_op_choices[0][1], + info="选择量化矩阵乘法算子以加速推理", + interactive=True, + ) + + # 判断模型是否是 distill 版本 + def is_distill_model(model_type, dit_path, high_noise_path): + """根据模型类型和路径判断是否是 distill 版本""" + if model_type == "wan2.1": + check_name = dit_path.lower() if dit_path else "" + else: + check_name = high_noise_path.lower() if high_noise_path else "" + return "4step" in check_name + + # 模型类型切换事件 + def on_model_type_change(model_type, model_path_val): + if model_type == "wan2.2": + return gr.update(visible=False), gr.update(visible=True), gr.update() + else: + # 更新 wan2.1 的 Diffusion 模型选项 + dit_choices = get_dit_choices(model_path_val, "wan2.1") + return ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""), + ) + + model_type_input.change( + fn=on_model_type_change, + inputs=[model_type_input, model_path_input], + outputs=[wan21_row, wan22_row, dit_path_input], + ) + + # 输入参数区域 + with gr.Accordion("📥 输入参数", open=True, elem_classes=["input-params"]): + # 图片输入(i2v 时显示) + with gr.Row(visible=True) as image_input_row: + image_path = gr.Image( + label="输入图像", + type="filepath", + height=300, + interactive=True, + ) + + # 任务类型切换事件 + def on_task_type_change(task_type): + return gr.update(visible=(task_type == "i2v")) + + task_type_input.change( + fn=on_task_type_change, + inputs=[task_type_input], + outputs=[image_input_row], + ) + + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox( + label="提示词", + lines=3, + placeholder="描述视频内容...", + max_lines=5, + ) + with gr.Column(): + negative_prompt = gr.Textbox( + label="负向提示词", + lines=3, + placeholder="不希望出现在视频中的内容...", + max_lines=5, + value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + ) + with gr.Column(): + resolution = gr.Dropdown( + choices=[ + # 720p + ("1280x720 (16:9, 720p)", "1280x720"), + ("720x1280 (9:16, 720p)", "720x1280"), + ("1280x544 (21:9, 720p)", "1280x544"), + ("544x1280 (9:21, 720p)", "544x1280"), + ("1104x832 (4:3, 720p)", "1104x832"), + ("832x1104 (3:4, 720p)", "832x1104"), + ("960x960 (1:1, 720p)", "960x960"), + # 480p + ("960x544 (16:9, 540p)", "960x544"), + ("544x960 (9:16, 540p)", "544x960"), + ("832x480 (16:9, 480p)", "832x480"), + ("480x832 (9:16, 480p)", "480x832"), + ("832x624 (4:3, 480p)", "832x624"), + ("624x832 (3:4, 480p)", "624x832"), + ("720x720 (1:1, 480p)", "720x720"), + ("512x512 (1:1, 480p)", "512x512"), + ], + value="832x480", + label="最大分辨率", + ) + + with gr.Column(scale=9): + seed = gr.Slider( + label="随机种子", + minimum=0, + maximum=MAX_NUMPY_SEED, + step=1, + value=generate_random_seed(), + ) + with gr.Column(): + default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "" + default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "" + default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise) + + if default_is_distill: + infer_steps = gr.Slider( + label="推理步数", + minimum=1, + maximum=100, + step=1, + value=4, + info="蒸馏模型推理步数默认为4。", + ) + else: + infer_steps = gr.Slider( + label="推理步数", + minimum=1, + maximum=100, + step=1, + value=40, + info="视频生成的推理步数。增加步数可能提高质量但降低速度。", + ) + + # 当模型路径改变时,动态更新推理步数 + def update_infer_steps(model_type, dit_path, high_noise_path): + is_distill = is_distill_model(model_type, dit_path, high_noise_path) + if is_distill: + return gr.update(minimum=1, maximum=100, value=4, interactive=True) + else: + return gr.update(minimum=1, maximum=100, value=40, interactive=True) + + # 监听模型路径变化 + dit_path_input.change( + fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[infer_steps], + ) + high_noise_path_input.change( + fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[infer_steps], + ) + model_type_input.change( + fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[infer_steps], + ) + + # 根据模型类别设置默认CFG + # CFG缩放因子:distill 时默认为 1,否则默认为 5 + default_cfg_scale = 1 if default_is_distill else 5 + # enable_cfg 不暴露到前端,根据 cfg_scale 自动设置 + # 如果 cfg_scale == 1,则 enable_cfg = False,否则 enable_cfg = True + default_enable_cfg = False if default_cfg_scale == 1 else True + enable_cfg = gr.Checkbox( + label="启用无分类器引导", + value=default_enable_cfg, + visible=False, # 隐藏,不暴露到前端 + ) + + with gr.Row(): + sample_shift = gr.Slider( + label="分布偏移", + value=5, + minimum=0, + maximum=10, + step=1, + info="控制样本分布偏移的程度。值越大表示偏移越明显。", + ) + cfg_scale = gr.Slider( + label="CFG缩放因子", + minimum=1, + maximum=10, + step=1, + value=default_cfg_scale, + info="控制提示词的影响强度。值越高,提示词的影响越大。当值为1时,自动禁用CFG。", + ) + + # 根据 cfg_scale 更新 enable_cfg + def update_enable_cfg(cfg_scale_val): + """根据 cfg_scale 的值自动设置 enable_cfg""" + if cfg_scale_val == 1: + return gr.update(value=False) + else: + return gr.update(value=True) + + # 当模型路径改变时,动态更新 CFG 缩放因子和 enable_cfg + def update_cfg_scale(model_type, dit_path, high_noise_path): + is_distill = is_distill_model(model_type, dit_path, high_noise_path) + if is_distill: + new_cfg_scale = 1 + else: + new_cfg_scale = 5 + new_enable_cfg = False if new_cfg_scale == 1 else True + return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg) + + dit_path_input.change( + fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[cfg_scale, enable_cfg], + ) + high_noise_path_input.change( + fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[cfg_scale, enable_cfg], + ) + model_type_input.change( + fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp), + inputs=[model_type_input, dit_path_input, high_noise_path_input], + outputs=[cfg_scale, enable_cfg], + ) + + cfg_scale.change( + fn=update_enable_cfg, + inputs=[cfg_scale], + outputs=[enable_cfg], + ) + + with gr.Row(): + fps = gr.Slider( + label="每秒帧数(FPS)", + minimum=8, + maximum=30, + step=1, + value=16, + info="视频的每秒帧数。较高的FPS会产生更流畅的视频。", + ) + num_frames = gr.Slider( + label="总帧数", + minimum=16, + maximum=120, + step=1, + value=81, + info="视频中的总帧数。更多帧数会产生更长的视频。", + ) + + save_result_path = gr.Textbox( + label="输出视频路径", + value=generate_unique_filename(output_dir), + info="必须包含.mp4扩展名。如果留空或使用默认值,将自动生成唯一文件名。", + visible=False, # 隐藏输出路径,自动生成 + ) + + with gr.Column(scale=4): + with gr.Accordion("📤 生成的视频", open=True, elem_classes=["output-video"]): + output_video = gr.Video( + label="", + height=600, + autoplay=True, + show_label=False, + ) + + infer_btn = gr.Button("🎬 生成视频", variant="primary", size="lg", elem_classes=["generate-btn"]) + + rope_chunk = gr.Checkbox(label="分块旋转位置编码", value=False, visible=False) + rope_chunk_size = gr.Slider(label="旋转编码块大小", value=100, minimum=100, maximum=10000, step=100, visible=False) + unload_modules = gr.Checkbox(label="卸载模块", value=False, visible=False) + clean_cuda_cache = gr.Checkbox(label="清理CUDA内存缓存", value=False, visible=False) + cpu_offload = gr.Checkbox(label="CPU卸载", value=False, visible=False) + lazy_load = gr.Checkbox(label="启用延迟加载", value=False, visible=False) + offload_granularity = gr.Dropdown(label="Dit卸载粒度", choices=["block", "phase"], value="phase", visible=False) + t5_cpu_offload = gr.Checkbox(label="T5 CPU卸载", value=False, visible=False) + clip_cpu_offload = gr.Checkbox(label="CLIP CPU卸载", value=False, visible=False) + vae_cpu_offload = gr.Checkbox(label="VAE CPU卸载", value=False, visible=False) + use_tiling_vae = gr.Checkbox(label="VAE分块推理", value=False, visible=False) + + resolution.change( + fn=auto_configure, + inputs=[resolution], + outputs=[ + lazy_load, + rope_chunk, + rope_chunk_size, + clean_cuda_cache, + cpu_offload, + offload_granularity, + t5_cpu_offload, + clip_cpu_offload, + vae_cpu_offload, + unload_modules, + attention_type, + quant_op, + use_tiling_vae, + ], + ) + + demo.load( + fn=lambda res: auto_configure(res), + inputs=[resolution], + outputs=[ + lazy_load, + rope_chunk, + rope_chunk_size, + clean_cuda_cache, + cpu_offload, + offload_granularity, + t5_cpu_offload, + clip_cpu_offload, + vae_cpu_offload, + unload_modules, + attention_type, + quant_op, + use_tiling_vae, + ], + ) + + infer_btn.click( + fn=run_inference, + inputs=[ + prompt, + negative_prompt, + save_result_path, + infer_steps, + num_frames, + resolution, + seed, + sample_shift, + enable_cfg, + cfg_scale, + fps, + use_tiling_vae, + lazy_load, + cpu_offload, + offload_granularity, + t5_cpu_offload, + clip_cpu_offload, + vae_cpu_offload, + unload_modules, + attention_type, + quant_op, + rope_chunk, + rope_chunk_size, + clean_cuda_cache, + model_path_input, + model_type_input, + task_type_input, + dit_path_input, + high_noise_path_input, + low_noise_path_input, + t5_path_input, + clip_path_input, + vae_path_input, + image_path, + ], + outputs=output_video, + ) + + demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="轻量级视频生成") + parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径") + parser.add_argument("--server_port", type=int, default=7862, help="服务器端口") + parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP") + parser.add_argument("--output_dir", type=str, default="./outputs", help="输出视频保存目录") + args = parser.parse_args() + + global model_path, model_cls, output_dir + model_path = args.model_path + model_cls = "wan2.1" + output_dir = args.output_dir + + main() diff --git a/app/run_gradio.sh b/app/run_gradio.sh new file mode 100644 index 0000000000000000000000000000000000000000..f5dbcc55605f33c7d6504a7e3f7e1da372236eb2 --- /dev/null +++ b/app/run_gradio.sh @@ -0,0 +1,181 @@ +#!/bin/bash + +# Lightx2v Gradio Demo Startup Script +# Supports both Image-to-Video (i2v) and Text-to-Video (t2v) modes + +# ==================== Configuration Area ==================== +# ⚠️ Important: Please modify the following paths according to your actual environment + +# 🚨 Storage Performance Tips 🚨 +# 💾 Strongly recommend storing model files on SSD solid-state drives! +# 📈 SSD can significantly improve model loading speed and inference performance +# 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience + + +# Lightx2v project root directory path +# Example: /home/user/lightx2v or /data/video_gen/lightx2v +lightx2v_path=/data/video_gen/lightx2v_debug/LightX2V + +# Model path configuration +# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v +model_path=/models/ + +# Server configuration +server_name="0.0.0.0" +server_port=8033 + +# Output directory configuration +output_dir="./outputs" + +# GPU configuration +gpu_id=0 + +# ==================== Environment Variables Setup ==================== +export CUDA_VISIBLE_DEVICES=$gpu_id +export CUDA_LAUNCH_BLOCKING=1 +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH +export PROFILING_DEBUG_LEVEL=2 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# ==================== Parameter Parsing ==================== +# Default interface language +lang="zh" + +# 解析命令行参数 +while [[ $# -gt 0 ]]; do + case $1 in + --lang) + lang="$2" + shift 2 + ;; + --port) + server_port="$2" + shift 2 + ;; + --gpu) + gpu_id="$2" + export CUDA_VISIBLE_DEVICES=$gpu_id + shift 2 + ;; + --output_dir) + output_dir="$2" + shift 2 + ;; + --model_path) + model_path="$2" + shift 2 + ;; + --help) + echo "🎬 Lightx2v Gradio Demo Startup Script" + echo "==========================================" + echo "Usage: $0 [options]" + echo "" + echo "📋 Available options:" + echo " --lang zh|en Interface language (default: zh)" + echo " zh: Chinese interface" + echo " en: English interface" + echo " --port PORT Server port (default: 8032)" + echo " --gpu GPU_ID GPU device ID (default: 0)" + echo " --model_path PATH Model path (default: configured in script)" + echo " --output_dir DIR Output video save directory (default: ./outputs)" + echo " --help Show this help message" + echo "" + echo "📝 Notes:" + echo " - Task type (i2v/t2v) and model type are selected in the web UI" + echo " - Model class is auto-detected based on selected diffusion model" + echo " - Edit script to configure model paths before first use" + echo " - Ensure required Python dependencies are installed" + echo " - Recommended to use GPU with 8GB+ VRAM" + echo " - 🚨 Strongly recommend storing models on SSD for better performance" + exit 0 + ;; + *) + echo "Unknown parameter: $1" + echo "Use --help to see help information" + exit 1 + ;; + esac +done + +# ==================== Parameter Validation ==================== +if [[ "$lang" != "zh" && "$lang" != "en" ]]; then + echo "Error: Language must be 'zh' or 'en'" + exit 1 +fi + +# Check if model path exists +if [[ ! -d "$model_path" ]]; then + echo "❌ Error: Model path does not exist" + echo "📁 Path: $model_path" + echo "🔧 Solutions:" + echo " 1. Check model path configuration in script" + echo " 2. Ensure model files are properly downloaded" + echo " 3. Verify path permissions are correct" + echo " 4. 💾 Recommend storing models on SSD for faster loading" + exit 1 +fi + +# Select demo file based on language +if [[ "$lang" == "zh" ]]; then + demo_file="gradio_demo_zh.py" + echo "🌏 Using Chinese interface" +else + demo_file="gradio_demo.py" + echo "🌏 Using English interface" +fi + +# Check if demo file exists +if [[ ! -f "$demo_file" ]]; then + echo "❌ Error: Demo file does not exist" + echo "📄 File: $demo_file" + echo "🔧 Solutions:" + echo " 1. Ensure script is run in the correct directory" + echo " 2. Check if file has been renamed or moved" + echo " 3. Re-clone or download project files" + exit 1 +fi + +# ==================== System Information Display ==================== +echo "==========================================" +echo "🚀 Lightx2v Gradio Demo Starting..." +echo "==========================================" +echo "📁 Project path: $lightx2v_path" +echo "🤖 Model path: $model_path" +echo "🌏 Interface language: $lang" +echo "🖥️ GPU device: $gpu_id" +echo "🌐 Server address: $server_name:$server_port" +echo "📁 Output directory: $output_dir" +echo "📝 Note: Task type and model class are selected in web UI" +echo "==========================================" + +# Display system resource information +echo "💻 System resource information:" +free -h | grep -E "Mem|Swap" +echo "" + +# Display GPU information +if command -v nvidia-smi &> /dev/null; then + echo "🎮 GPU information:" + nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits | head -1 + echo "" +fi + +# ==================== Start Demo ==================== +echo "🎬 Starting Gradio demo..." +echo "📱 Please access in browser: http://$server_name:$server_port" +echo "⏹️ Press Ctrl+C to stop service" +echo "🔄 First startup may take several minutes to load resources..." +echo "==========================================" + +# Start Python demo +python $demo_file \ + --model_path "$model_path" \ + --server_name "$server_name" \ + --server_port "$server_port" \ + --output_dir "$output_dir" + +# Display final system resource usage +echo "" +echo "==========================================" +echo "📊 Final system resource usage:" +free -h | grep -E "Mem|Swap" diff --git a/app/run_gradio_win.bat b/app/run_gradio_win.bat new file mode 100644 index 0000000000000000000000000000000000000000..7ea43cd8ba0afa3e243731dd2c2d0acfc2c9e770 --- /dev/null +++ b/app/run_gradio_win.bat @@ -0,0 +1,196 @@ +@echo off +chcp 65001 >nul +echo 🎬 LightX2V Gradio Windows Startup Script +echo ========================================== + +REM ==================== Configuration Area ==================== +REM ⚠️ Important: Please modify the following paths according to your actual environment + +REM 🚨 Storage Performance Tips 🚨 +REM 💾 Strongly recommend storing model files on SSD solid-state drives! +REM 📈 SSD can significantly improve model loading speed and inference performance +REM 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience + +REM LightX2V project root directory path +REM Example: D:\LightX2V +set lightx2v_path=/path/to/LightX2V + +REM Model path configuration +REM Model root directory path +REM Example: D:\models\LightX2V +set model_path=/path/to/LightX2V + +REM Server configuration +set server_name=127.0.0.1 +set server_port=8032 + +REM Output directory configuration +set output_dir=./outputs + +REM GPU configuration +set gpu_id=0 + +REM ==================== Environment Variables Setup ==================== +set CUDA_VISIBLE_DEVICES=%gpu_id% +set PYTHONPATH=%lightx2v_path%;%PYTHONPATH% +set PROFILING_DEBUG_LEVEL=2 +set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +REM ==================== Parameter Parsing ==================== +REM Default interface language +set lang=zh + +REM Parse command line arguments +:parse_args +if "%1"=="" goto :end_parse +if "%1"=="--lang" ( + set lang=%2 + shift + shift + goto :parse_args +) +if "%1"=="--port" ( + set server_port=%2 + shift + shift + goto :parse_args +) +if "%1"=="--gpu" ( + set gpu_id=%2 + set CUDA_VISIBLE_DEVICES=%gpu_id% + shift + shift + goto :parse_args +) +if "%1"=="--output_dir" ( + set output_dir=%2 + shift + shift + goto :parse_args +) +if "%1"=="--help" ( + echo 🎬 LightX2V Gradio Windows Startup Script + echo ========================================== + echo Usage: %0 [options] + echo. + echo 📋 Available options: + echo --lang zh^|en Interface language (default: zh) + echo zh: Chinese interface + echo en: English interface + echo --port PORT Server port (default: 8032) + echo --gpu GPU_ID GPU device ID (default: 0) + echo --output_dir OUTPUT_DIR + echo Output video save directory (default: ./outputs) + echo --help Show this help message + echo. + echo 🚀 Usage examples: + echo %0 # Default startup + echo %0 --lang zh --port 8032 # Start with specified parameters + echo %0 --lang en --port 7860 # English interface + echo %0 --gpu 1 --port 8032 # Use GPU 1 + echo %0 --output_dir ./custom_output # Use custom output directory + echo. + echo 📝 Notes: + echo - Edit script to configure model path before first use + echo - Ensure required Python dependencies are installed + echo - Recommended to use GPU with 8GB+ VRAM + echo - 🚨 Strongly recommend storing models on SSD for better performance + pause + exit /b 0 +) +echo Unknown parameter: %1 +echo Use --help to see help information +pause +exit /b 1 + +:end_parse + +REM ==================== Parameter Validation ==================== +if "%lang%"=="zh" goto :valid_lang +if "%lang%"=="en" goto :valid_lang +echo Error: Language must be 'zh' or 'en' +pause +exit /b 1 + +:valid_lang + +REM Check if model path exists +if not exist "%model_path%" ( + echo ❌ Error: Model path does not exist + echo 📁 Path: %model_path% + echo 🔧 Solutions: + echo 1. Check model path configuration in script + echo 2. Ensure model files are properly downloaded + echo 3. Verify path permissions are correct + echo 4. 💾 Recommend storing models on SSD for faster loading + pause + exit /b 1 +) + +REM Select demo file based on language +if "%lang%"=="zh" ( + set demo_file=gradio_demo_zh.py + echo 🌏 Using Chinese interface +) else ( + set demo_file=gradio_demo.py + echo 🌏 Using English interface +) + +REM Check if demo file exists +if not exist "%demo_file%" ( + echo ❌ Error: Demo file does not exist + echo 📄 File: %demo_file% + echo 🔧 Solutions: + echo 1. Ensure script is run in the correct directory + echo 2. Check if file has been renamed or moved + echo 3. Re-clone or download project files + pause + exit /b 1 +) + +REM ==================== System Information Display ==================== +echo ========================================== +echo 🚀 LightX2V Gradio Starting... +echo ========================================== +echo 📁 Project path: %lightx2v_path% +echo 🤖 Model path: %model_path% +echo 🌏 Interface language: %lang% +echo 🖥️ GPU device: %gpu_id% +echo 🌐 Server address: %server_name%:%server_port% +echo 📁 Output directory: %output_dir% +echo ========================================== + +REM Display system resource information +echo 💻 System resource information: +wmic OS get TotalVisibleMemorySize,FreePhysicalMemory /format:table + +REM Display GPU information +nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits 2>nul +if errorlevel 1 ( + echo 🎮 GPU information: Unable to get GPU info +) else ( + echo 🎮 GPU information: + nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits +) + +REM ==================== Start Demo ==================== +echo 🎬 Starting Gradio demo... +echo 📱 Please access in browser: http://%server_name%:%server_port% +echo ⏹️ Press Ctrl+C to stop service +echo 🔄 First startup may take several minutes to load resources... +echo ========================================== + +REM Start Python demo +python %demo_file% ^ + --model_path "%model_path%" ^ + --server_name %server_name% ^ + --server_port %server_port% ^ + --output_dir "%output_dir%" + +REM Display final system resource usage +echo. +echo ========================================== +echo 📊 Final system resource usage: +wmic OS get TotalVisibleMemorySize,FreePhysicalMemory /format:table + +pause diff --git a/assets/figs/offload/fig1_en.png b/assets/figs/offload/fig1_en.png new file mode 100644 index 0000000000000000000000000000000000000000..acee7c1fa89bf9f8564dd47a844e8553480be9fe Binary files /dev/null and b/assets/figs/offload/fig1_en.png differ diff --git a/assets/figs/offload/fig1_zh.png b/assets/figs/offload/fig1_zh.png new file mode 100644 index 0000000000000000000000000000000000000000..1e1bbe2d92787d27c34c8fe742e5076e70b19161 Binary files /dev/null and b/assets/figs/offload/fig1_zh.png differ diff --git a/assets/figs/offload/fig2_en.png b/assets/figs/offload/fig2_en.png new file mode 100644 index 0000000000000000000000000000000000000000..606523e857d192c9c21245b149eae4b63062e4c0 Binary files /dev/null and b/assets/figs/offload/fig2_en.png differ diff --git a/assets/figs/offload/fig2_zh.png b/assets/figs/offload/fig2_zh.png new file mode 100644 index 0000000000000000000000000000000000000000..bdabf44507c69b6ff099734e4d8eadf70edcf1f2 Binary files /dev/null and b/assets/figs/offload/fig2_zh.png differ diff --git a/assets/figs/offload/fig3_en.png b/assets/figs/offload/fig3_en.png new file mode 100644 index 0000000000000000000000000000000000000000..585c27930d499615d1809d5586c1f0f71554fc35 Binary files /dev/null and b/assets/figs/offload/fig3_en.png differ diff --git a/assets/figs/offload/fig3_zh.png b/assets/figs/offload/fig3_zh.png new file mode 100644 index 0000000000000000000000000000000000000000..b6147797b57b3a4c9fedd69dc1713a1a1e12e54e Binary files /dev/null and b/assets/figs/offload/fig3_zh.png differ diff --git a/assets/figs/offload/fig4_en.png b/assets/figs/offload/fig4_en.png new file mode 100644 index 0000000000000000000000000000000000000000..9cc2a32c3a132c7ae2c174dd9de5f1852c0d933e Binary files /dev/null and b/assets/figs/offload/fig4_en.png differ diff --git a/assets/figs/offload/fig4_zh.png b/assets/figs/offload/fig4_zh.png new file mode 100644 index 0000000000000000000000000000000000000000..7e3442295196f81522ab030e559f444e3288e1be Binary files /dev/null and b/assets/figs/offload/fig4_zh.png differ diff --git a/assets/figs/offload/fig5_en.png b/assets/figs/offload/fig5_en.png new file mode 100644 index 0000000000000000000000000000000000000000..489e20f5b6a9574689a125cf7fecf7dbd82336de Binary files /dev/null and b/assets/figs/offload/fig5_en.png differ diff --git a/assets/figs/offload/fig5_zh.png b/assets/figs/offload/fig5_zh.png new file mode 100644 index 0000000000000000000000000000000000000000..3cd30e007d973774b799c47ed83c1bcfce28dc27 Binary files /dev/null and b/assets/figs/offload/fig5_zh.png differ diff --git a/assets/figs/portabl_windows/pic1.png b/assets/figs/portabl_windows/pic1.png new file mode 100644 index 0000000000000000000000000000000000000000..04721dafb8ccd1f5ce5041dbbcc5b2c22c83ffc4 Binary files /dev/null and b/assets/figs/portabl_windows/pic1.png differ diff --git a/assets/figs/portabl_windows/pic_gradio_en.png b/assets/figs/portabl_windows/pic_gradio_en.png new file mode 100644 index 0000000000000000000000000000000000000000..f08fa4cf1b0797045ad24ae9e52a70577396f6f2 Binary files /dev/null and b/assets/figs/portabl_windows/pic_gradio_en.png differ diff --git a/assets/figs/portabl_windows/pic_gradio_zh.png b/assets/figs/portabl_windows/pic_gradio_zh.png new file mode 100644 index 0000000000000000000000000000000000000000..86441fa54ad0d12140a083439c608eb41d2b8a0b Binary files /dev/null and b/assets/figs/portabl_windows/pic_gradio_zh.png differ diff --git a/assets/figs/step_distill/fig_01.png b/assets/figs/step_distill/fig_01.png new file mode 100644 index 0000000000000000000000000000000000000000..cdbca0a24540f1829376f132dbb66f4332d948a0 Binary files /dev/null and b/assets/figs/step_distill/fig_01.png differ diff --git a/assets/img_lightx2v.png b/assets/img_lightx2v.png new file mode 100644 index 0000000000000000000000000000000000000000..1067b67b49899403b4cb44b74de4c00683f63e98 Binary files /dev/null and b/assets/img_lightx2v.png differ diff --git a/assets/inputs/audio/multi_person/config.json b/assets/inputs/audio/multi_person/config.json new file mode 100644 index 0000000000000000000000000000000000000000..3e0cf1e4fbbccacfa099e4559fa6a31288d3108c --- /dev/null +++ b/assets/inputs/audio/multi_person/config.json @@ -0,0 +1,12 @@ +{ + "talk_objects": [ + { + "audio": "p1.mp3", + "mask": "p1_mask.png" + }, + { + "audio": "p2.mp3", + "mask": "p2_mask.png" + } + ] +} diff --git a/assets/inputs/audio/multi_person/config_multi_template.json b/assets/inputs/audio/multi_person/config_multi_template.json new file mode 100644 index 0000000000000000000000000000000000000000..3e0cf1e4fbbccacfa099e4559fa6a31288d3108c --- /dev/null +++ b/assets/inputs/audio/multi_person/config_multi_template.json @@ -0,0 +1,12 @@ +{ + "talk_objects": [ + { + "audio": "p1.mp3", + "mask": "p1_mask.png" + }, + { + "audio": "p2.mp3", + "mask": "p2_mask.png" + } + ] +} diff --git a/assets/inputs/audio/multi_person/config_single_template.json b/assets/inputs/audio/multi_person/config_single_template.json new file mode 100644 index 0000000000000000000000000000000000000000..306a61b465058f8265a3596acab24306622d0f30 --- /dev/null +++ b/assets/inputs/audio/multi_person/config_single_template.json @@ -0,0 +1,8 @@ +{ + "talk_objects": [ + { + "audio": "p1.mp3", + "mask": "p1_mask.png" + } + ] +} diff --git a/assets/inputs/audio/multi_person/p1.mp3 b/assets/inputs/audio/multi_person/p1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..41c1042f79cdecc04d2e87493067ea43142b2cad Binary files /dev/null and b/assets/inputs/audio/multi_person/p1.mp3 differ diff --git a/assets/inputs/audio/multi_person/p1_mask.png b/assets/inputs/audio/multi_person/p1_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..7cb8d84541a08927a71556f004b2b077a3d5a27a Binary files /dev/null and b/assets/inputs/audio/multi_person/p1_mask.png differ diff --git a/assets/inputs/audio/multi_person/p2.mp3 b/assets/inputs/audio/multi_person/p2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..a2ce0dd648bf43852ab9baaa1bb9dc45f095df98 Binary files /dev/null and b/assets/inputs/audio/multi_person/p2.mp3 differ diff --git a/assets/inputs/audio/multi_person/p2_mask.png b/assets/inputs/audio/multi_person/p2_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..05c1265091bce025deb000e066263e7fe8e733f6 Binary files /dev/null and b/assets/inputs/audio/multi_person/p2_mask.png differ diff --git a/assets/inputs/audio/multi_person/seko_input.png b/assets/inputs/audio/multi_person/seko_input.png new file mode 100644 index 0000000000000000000000000000000000000000..db8bf104725b83ef60894f50b3211aa85847f373 Binary files /dev/null and b/assets/inputs/audio/multi_person/seko_input.png differ diff --git a/assets/inputs/audio/seko_input.mp3 b/assets/inputs/audio/seko_input.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..a09e2c90448a64dc8501b3eca6a9ceda376ef2e1 Binary files /dev/null and b/assets/inputs/audio/seko_input.mp3 differ diff --git a/assets/inputs/audio/seko_input.png b/assets/inputs/audio/seko_input.png new file mode 100644 index 0000000000000000000000000000000000000000..cefda29965efd3e3a3c68e142b1ce8f82dfc93da Binary files /dev/null and b/assets/inputs/audio/seko_input.png differ diff --git a/assets/inputs/imgs/flf2v_input_first_frame-fs8.png b/assets/inputs/imgs/flf2v_input_first_frame-fs8.png new file mode 100644 index 0000000000000000000000000000000000000000..4c67ec443ea2c609ca9576be4168fd879c2d215b Binary files /dev/null and b/assets/inputs/imgs/flf2v_input_first_frame-fs8.png differ diff --git a/assets/inputs/imgs/flf2v_input_last_frame-fs8.png b/assets/inputs/imgs/flf2v_input_last_frame-fs8.png new file mode 100644 index 0000000000000000000000000000000000000000..51f3fbc5233e4d4215cb4c536c6374a5058cd40a Binary files /dev/null and b/assets/inputs/imgs/flf2v_input_last_frame-fs8.png differ diff --git a/assets/inputs/imgs/girl.png b/assets/inputs/imgs/girl.png new file mode 100644 index 0000000000000000000000000000000000000000..8174bf2122301340e7aa9ceb005e1964def965d2 Binary files /dev/null and b/assets/inputs/imgs/girl.png differ diff --git a/assets/inputs/imgs/img_0.jpg b/assets/inputs/imgs/img_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f6dcced590cbd18f5eda26bc2362b76ed714847 Binary files /dev/null and b/assets/inputs/imgs/img_0.jpg differ diff --git a/assets/inputs/imgs/img_1.jpg b/assets/inputs/imgs/img_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2dee840820801bab1ebc05300952bbb9a4945d6a Binary files /dev/null and b/assets/inputs/imgs/img_1.jpg differ diff --git a/assets/inputs/imgs/img_2.jpg b/assets/inputs/imgs/img_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..944f15ab131c77613c11c668a0ced4d531c5f9f5 Binary files /dev/null and b/assets/inputs/imgs/img_2.jpg differ diff --git a/assets/inputs/imgs/snake.png b/assets/inputs/imgs/snake.png new file mode 100644 index 0000000000000000000000000000000000000000..19085802f5b04656357bf8de812181f3ed161152 Binary files /dev/null and b/assets/inputs/imgs/snake.png differ diff --git a/configs/attentions/wan_i2v_flash.json b/configs/attentions/wan_i2v_flash.json new file mode 100644 index 0000000000000000000000000000000000000000..24c7a57e9fb24deebb0b54632e99ab4b585321af --- /dev/null +++ b/configs/attentions/wan_i2v_flash.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/attentions/wan_i2v_nbhd_480p.json b/configs/attentions/wan_i2v_nbhd_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..2ca3bca552ae6df8085d78ced97b95c5c7540d9f --- /dev/null +++ b/configs/attentions/wan_i2v_nbhd_480p.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "nbhd_attn", + "nbhd_attn_setting": { + "coefficient": [1.0, 0.5, 0.25, 0.25], + "min_width": 2.0 + }, + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/attentions/wan_i2v_nbhd_720p.json b/configs/attentions/wan_i2v_nbhd_720p.json new file mode 100644 index 0000000000000000000000000000000000000000..10a99cf101849807e53939898936d33d099cbf4f --- /dev/null +++ b/configs/attentions/wan_i2v_nbhd_720p.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "nbhd_attn", + "nbhd_attn_setting": { + "coefficient": [1.0, 0.5, 0.25, 0.25], + "min_width": 2.0 + }, + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/attentions/wan_i2v_radial.json b/configs/attentions/wan_i2v_radial.json new file mode 100644 index 0000000000000000000000000000000000000000..428517a9a0a9aa7e107f81814a7dd86063d7d403 --- /dev/null +++ b/configs/attentions/wan_i2v_radial.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "radial_attn", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/attentions/wan_i2v_sage.json b/configs/attentions/wan_i2v_sage.json new file mode 100644 index 0000000000000000000000000000000000000000..9a278a507854d5bf33c79a8d601c5eaff9b848b7 --- /dev/null +++ b/configs/attentions/wan_i2v_sage.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/attentions/wan_i2v_svg.json b/configs/attentions/wan_i2v_svg.json new file mode 100644 index 0000000000000000000000000000000000000000..e3d677d3a50c9c373ef13bb1b1d8aeac99019973 --- /dev/null +++ b/configs/attentions/wan_i2v_svg.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "svg_attn", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/attentions/wan_i2v_svg2.json b/configs/attentions/wan_i2v_svg2.json new file mode 100644 index 0000000000000000000000000000000000000000..4f707076b7a9fa88872a5a8be7c1c3a33d4d988b --- /dev/null +++ b/configs/attentions/wan_i2v_svg2.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "svg2_attn", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/attentions/wan_t2v_sparge.json b/configs/attentions/wan_t2v_sparge.json new file mode 100644 index 0000000000000000000000000000000000000000..64cce21b769c62bbb77dc509c3bc1abdcfaf948d --- /dev/null +++ b/configs/attentions/wan_t2v_sparge.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "sparge": true, + "sparge_ckpt": "/path/to/sparge_wan2.1_t2v_1.3B.pt" +} diff --git a/configs/bench/lightx2v_1.json b/configs/bench/lightx2v_1.json new file mode 100644 index 0000000000000000000000000000000000000000..9a278a507854d5bf33c79a8d601c5eaff9b848b7 --- /dev/null +++ b/configs/bench/lightx2v_1.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/bench/lightx2v_2.json b/configs/bench/lightx2v_2.json new file mode 100644 index 0000000000000000000000000000000000000000..9a278a507854d5bf33c79a8d601c5eaff9b848b7 --- /dev/null +++ b/configs/bench/lightx2v_2.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/bench/lightx2v_3.json b/configs/bench/lightx2v_3.json new file mode 100644 index 0000000000000000000000000000000000000000..2ef47231ff608ec6214e753accf7b9b04a980cfa --- /dev/null +++ b/configs/bench/lightx2v_3.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "use_tiling_vae": true +} diff --git a/configs/bench/lightx2v_3_distill.json b/configs/bench/lightx2v_3_distill.json new file mode 100644 index 0000000000000000000000000000000000000000..7cc8d37b9d99646740ed67a872feac8ad67a720b --- /dev/null +++ b/configs/bench/lightx2v_3_distill.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "use_tiling_vae": true +} diff --git a/configs/bench/lightx2v_4.json b/configs/bench/lightx2v_4.json new file mode 100644 index 0000000000000000000000000000000000000000..7d83a749111684490a6ea14210ef93f5bc9fd226 --- /dev/null +++ b/configs/bench/lightx2v_4.json @@ -0,0 +1,35 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "feature_caching": "Tea", + "coefficients": [ + [ + 2.57151496e05, + -3.54229917e04, + 1.40286849e03, + -1.35890334e01, + 1.32517977e-01 + ], + [ + -3.02331670e02, + 2.23948934e02, + -5.25463970e01, + 5.87348440e00, + -2.01973289e-01 + ] + ], + "use_ret_steps": false, + "teacache_thresh": 0.2, + "use_tiling_vae": true +} diff --git a/configs/bench/lightx2v_5.json b/configs/bench/lightx2v_5.json new file mode 100644 index 0000000000000000000000000000000000000000..a9a72814b490f761e7045f27d7fa062fbf454b8a --- /dev/null +++ b/configs/bench/lightx2v_5.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 0.8, + "t5_cpu_offload": true, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "use_tiling_vae": true +} diff --git a/configs/bench/lightx2v_5_distill.json b/configs/bench/lightx2v_5_distill.json new file mode 100644 index 0000000000000000000000000000000000000000..8d215c9a9b2ca1a7b28f9d596ad041147afe7956 --- /dev/null +++ b/configs/bench/lightx2v_5_distill.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 0.8, + "t5_cpu_offload": true, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "use_tiling_vae": true +} diff --git a/configs/bench/lightx2v_6.json b/configs/bench/lightx2v_6.json new file mode 100644 index 0000000000000000000000000000000000000000..a9a72814b490f761e7045f27d7fa062fbf454b8a --- /dev/null +++ b/configs/bench/lightx2v_6.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 0.8, + "t5_cpu_offload": true, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "use_tiling_vae": true +} diff --git a/configs/bench/lightx2v_6_distill.json b/configs/bench/lightx2v_6_distill.json new file mode 100644 index 0000000000000000000000000000000000000000..8d215c9a9b2ca1a7b28f9d596ad041147afe7956 --- /dev/null +++ b/configs/bench/lightx2v_6_distill.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 0.8, + "t5_cpu_offload": true, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "use_tiling_vae": true +} diff --git a/configs/caching/adacache/wan_i2v_ada.json b/configs/caching/adacache/wan_i2v_ada.json new file mode 100644 index 0000000000000000000000000000000000000000..90e635c1eddf00aa0d1d922db8062222e04baa22 --- /dev/null +++ b/configs/caching/adacache/wan_i2v_ada.json @@ -0,0 +1,15 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "seed": 442, + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Ada" +} diff --git a/configs/caching/adacache/wan_t2v_ada.json b/configs/caching/adacache/wan_t2v_ada.json new file mode 100644 index 0000000000000000000000000000000000000000..88ea1781cd26671dd3da2192a067dbd475fd99f0 --- /dev/null +++ b/configs/caching/adacache/wan_t2v_ada.json @@ -0,0 +1,15 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Ada" +} diff --git a/configs/caching/custom/wan_i2v_custom_480p.json b/configs/caching/custom/wan_i2v_custom_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..4980e8fdf84f4792948d647234d34448429e2660 --- /dev/null +++ b/configs/caching/custom/wan_i2v_custom_480p.json @@ -0,0 +1,21 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "seed": 442, + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Custom", + "coefficients": [ + [2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01], + [-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01] + ], + "use_ret_steps": false, + "teacache_thresh": 0.26 +} diff --git a/configs/caching/custom/wan_i2v_custom_720p.json b/configs/caching/custom/wan_i2v_custom_720p.json new file mode 100644 index 0000000000000000000000000000000000000000..7a72be1a3eb0793a099dbedeb338ef0cefbc5594 --- /dev/null +++ b/configs/caching/custom/wan_i2v_custom_720p.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 1280, + "target_width": 720, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Custom", + "coefficients": [ + [ + 8.10705460e03, + 2.13393892e03, + -3.72934672e02, + 1.66203073e01, + -4.17769401e-02 + ], + [ + -114.36346466, + 65.26524496, + -18.82220707, + 4.91518089, + -0.23412683 + ] + ], + "use_ret_steps": false, + "teacache_thresh": 0.26 +} diff --git a/configs/caching/custom/wan_t2v_custom_14b.json b/configs/caching/custom/wan_t2v_custom_14b.json new file mode 100644 index 0000000000000000000000000000000000000000..cdd24283810ac1b9cf2b55bc3d1a50cd9924bea0 --- /dev/null +++ b/configs/caching/custom/wan_t2v_custom_14b.json @@ -0,0 +1,33 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Custom", + "coefficients": [ + [ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01 + ], + [ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02252404 + ] + ], + "use_ret_steps": false, + "teacache_thresh": 0.26 +} diff --git a/configs/caching/custom/wan_t2v_custom_1_3b.json b/configs/caching/custom/wan_t2v_custom_1_3b.json new file mode 100644 index 0000000000000000000000000000000000000000..1412ba2544c25fb0659bc3bae7ebf6fa7b2674f5 --- /dev/null +++ b/configs/caching/custom/wan_t2v_custom_1_3b.json @@ -0,0 +1,33 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Custom", + "coefficients": [ + [ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02 + ], + [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01 + ] + ], + "use_ret_steps": false, + "teacache_thresh": 0.26 +} diff --git a/configs/caching/dualblock/wan_t2v_1_3b.json b/configs/caching/dualblock/wan_t2v_1_3b.json new file mode 100644 index 0000000000000000000000000000000000000000..f81f56ce85ef7fa3ac425e89abccce1ff9d52896 --- /dev/null +++ b/configs/caching/dualblock/wan_t2v_1_3b.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "DualBlock", + "residual_diff_threshold": 0.03, + "downsample_factor": 2 +} diff --git a/configs/caching/dynamicblock/wan_t2v_1_3b.json b/configs/caching/dynamicblock/wan_t2v_1_3b.json new file mode 100644 index 0000000000000000000000000000000000000000..fa24ba9e8a66b303fff19ae49f0e1d93cd444ce2 --- /dev/null +++ b/configs/caching/dynamicblock/wan_t2v_1_3b.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "DynamicBlock", + "residual_diff_threshold": 0.003, + "downsample_factor": 2 +} diff --git a/configs/caching/firstblock/wan_t2v_1_3b.json b/configs/caching/firstblock/wan_t2v_1_3b.json new file mode 100644 index 0000000000000000000000000000000000000000..7dd393d735831fcab6b0d96907a8f50c74739c10 --- /dev/null +++ b/configs/caching/firstblock/wan_t2v_1_3b.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "FirstBlock", + "residual_diff_threshold": 0.02, + "downsample_factor": 2 +} diff --git a/configs/caching/magcache/wan_i2v_dist_cfg_ulysses_mag_480p.json b/configs/caching/magcache/wan_i2v_dist_cfg_ulysses_mag_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..0c7329e19b9f2d3da1e3c14777fc2057f6726cea --- /dev/null +++ b/configs/caching/magcache/wan_i2v_dist_cfg_ulysses_mag_480p.json @@ -0,0 +1,109 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + }, + "feature_caching": "Mag", + "magcache_calibration": false, + "magcache_K": 6, + "magcache_thresh": 0.24, + "magcache_retention_ratio": 0.2, + "magcache_ratios": [ + [ + 1.0, + 0.98783, + 0.97559, + 0.98311, + 0.98202, + 0.9888, + 0.98762, + 0.98957, + 0.99052, + 0.99383, + 0.98857, + 0.99065, + 0.98845, + 0.99057, + 0.98957, + 0.98601, + 0.98823, + 0.98756, + 0.98808, + 0.98721, + 0.98571, + 0.98543, + 0.98157, + 0.98411, + 0.97952, + 0.98149, + 0.9774, + 0.97825, + 0.97355, + 0.97085, + 0.97056, + 0.96588, + 0.96113, + 0.9567, + 0.94961, + 0.93973, + 0.93217, + 0.91878, + 0.90955, + 0.92617 + ], + [ + 1.0, + 0.98993, + 0.97593, + 0.98319, + 0.98225, + 0.98878, + 0.98759, + 0.98971, + 0.99043, + 0.99384, + 0.9886, + 0.99068, + 0.98847, + 0.99057, + 0.98961, + 0.9861, + 0.98823, + 0.98759, + 0.98814, + 0.98724, + 0.98572, + 0.98544, + 0.98165, + 0.98413, + 0.97953, + 0.9815, + 0.97742, + 0.97826, + 0.97361, + 0.97087, + 0.97055, + 0.96587, + 0.96124, + 0.95681, + 0.94969, + 0.93988, + 0.93224, + 0.91896, + 0.90954, + 0.92616 + ] + ] +} diff --git a/configs/caching/magcache/wan_i2v_mag_480p.json b/configs/caching/magcache/wan_i2v_mag_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..c8588ac2a8124f57691cf7658740bf06fe4416bf --- /dev/null +++ b/configs/caching/magcache/wan_i2v_mag_480p.json @@ -0,0 +1,104 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Mag", + "magcache_calibration": false, + "magcache_K": 6, + "magcache_thresh": 0.24, + "magcache_retention_ratio": 0.2, + "magcache_ratios": [ + [ + 1.0, + 0.98783, + 0.97559, + 0.98311, + 0.98202, + 0.9888, + 0.98762, + 0.98957, + 0.99052, + 0.99383, + 0.98857, + 0.99065, + 0.98845, + 0.99057, + 0.98957, + 0.98601, + 0.98823, + 0.98756, + 0.98808, + 0.98721, + 0.98571, + 0.98543, + 0.98157, + 0.98411, + 0.97952, + 0.98149, + 0.9774, + 0.97825, + 0.97355, + 0.97085, + 0.97056, + 0.96588, + 0.96113, + 0.9567, + 0.94961, + 0.93973, + 0.93217, + 0.91878, + 0.90955, + 0.92617 + ], + [ + 1.0, + 0.98993, + 0.97593, + 0.98319, + 0.98225, + 0.98878, + 0.98759, + 0.98971, + 0.99043, + 0.99384, + 0.9886, + 0.99068, + 0.98847, + 0.99057, + 0.98961, + 0.9861, + 0.98823, + 0.98759, + 0.98814, + 0.98724, + 0.98572, + 0.98544, + 0.98165, + 0.98413, + 0.97953, + 0.9815, + 0.97742, + 0.97826, + 0.97361, + 0.97087, + 0.97055, + 0.96587, + 0.96124, + 0.95681, + 0.94969, + 0.93988, + 0.93224, + 0.91896, + 0.90954, + 0.92616 + ] + ] +} diff --git a/configs/caching/magcache/wan_i2v_mag_calibration_480p.json b/configs/caching/magcache/wan_i2v_mag_calibration_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..d331d33abddb417661b0ef68112b73e3377ca5ae --- /dev/null +++ b/configs/caching/magcache/wan_i2v_mag_calibration_480p.json @@ -0,0 +1,104 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Mag", + "magcache_calibration": true, + "magcache_K": 6, + "magcache_thresh": 0.24, + "magcache_retention_ratio": 0.2, + "magcache_ratios": [ + [ + 1.0, + 0.98783, + 0.97559, + 0.98311, + 0.98202, + 0.9888, + 0.98762, + 0.98957, + 0.99052, + 0.99383, + 0.98857, + 0.99065, + 0.98845, + 0.99057, + 0.98957, + 0.98601, + 0.98823, + 0.98756, + 0.98808, + 0.98721, + 0.98571, + 0.98543, + 0.98157, + 0.98411, + 0.97952, + 0.98149, + 0.9774, + 0.97825, + 0.97355, + 0.97085, + 0.97056, + 0.96588, + 0.96113, + 0.9567, + 0.94961, + 0.93973, + 0.93217, + 0.91878, + 0.90955, + 0.92617 + ], + [ + 1.0, + 0.98993, + 0.97593, + 0.98319, + 0.98225, + 0.98878, + 0.98759, + 0.98971, + 0.99043, + 0.99384, + 0.9886, + 0.99068, + 0.98847, + 0.99057, + 0.98961, + 0.9861, + 0.98823, + 0.98759, + 0.98814, + 0.98724, + 0.98572, + 0.98544, + 0.98165, + 0.98413, + 0.97953, + 0.9815, + 0.97742, + 0.97826, + 0.97361, + 0.97087, + 0.97055, + 0.96587, + 0.96124, + 0.95681, + 0.94969, + 0.93988, + 0.93224, + 0.91896, + 0.90954, + 0.92616 + ] + ] +} diff --git a/configs/caching/magcache/wan_t2v_dist_cfg_ulysses_mag_1_3b.json b/configs/caching/magcache/wan_t2v_dist_cfg_ulysses_mag_1_3b.json new file mode 100644 index 0000000000000000000000000000000000000000..3c5c5f8710b3f1c62cc7712b0357f08d4b4f3237 --- /dev/null +++ b/configs/caching/magcache/wan_t2v_dist_cfg_ulysses_mag_1_3b.json @@ -0,0 +1,130 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + }, + "feature_caching": "Mag", + "magcache_calibration": false, + "magcache_K": 4, + "magcache_thresh": 0.12, + "magcache_retention_ratio": 0.2, + "magcache_ratios": [ + [ + 1.0, + 1.0124, + 1.00166, + 0.99791, + 0.99682, + 0.99634, + 0.99567, + 0.99416, + 0.99578, + 0.9957, + 0.99511, + 0.99535, + 0.99552, + 0.99541, + 0.9954, + 0.99489, + 0.99518, + 0.99484, + 0.99481, + 0.99415, + 0.99419, + 0.99396, + 0.99388, + 0.99349, + 0.99309, + 0.9927, + 0.99228, + 0.99171, + 0.99137, + 0.99068, + 0.99005, + 0.98944, + 0.98849, + 0.98758, + 0.98644, + 0.98504, + 0.9836, + 0.98202, + 0.97977, + 0.97717, + 0.9741, + 0.97003, + 0.96538, + 0.9593, + 0.95086, + 0.94013, + 0.92402, + 0.90241, + 0.86821, + 0.81838 + ], + [ + 1.0, + 1.02213, + 1.0041, + 1.00061, + 0.99762, + 0.99685, + 0.99586, + 0.99422, + 0.99575, + 0.99563, + 0.99506, + 0.99531, + 0.99549, + 0.99539, + 0.99536, + 0.99485, + 0.99514, + 0.99478, + 0.99479, + 0.99413, + 0.99416, + 0.99393, + 0.99386, + 0.99349, + 0.99304, + 0.9927, + 0.99226, + 0.9917, + 0.99135, + 0.99063, + 0.99003, + 0.98942, + 0.98849, + 0.98757, + 0.98643, + 0.98503, + 0.98359, + 0.98201, + 0.97978, + 0.97718, + 0.97411, + 0.97002, + 0.96541, + 0.95933, + 0.95089, + 0.94019, + 0.92414, + 0.9026, + 0.86868, + 0.81939 + ] + ] +} diff --git a/configs/caching/magcache/wan_t2v_mag_1_3b.json b/configs/caching/magcache/wan_t2v_mag_1_3b.json new file mode 100644 index 0000000000000000000000000000000000000000..aaec17e20fb7d59ce854bbc611efcbf72413cf9f --- /dev/null +++ b/configs/caching/magcache/wan_t2v_mag_1_3b.json @@ -0,0 +1,125 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Mag", + "magcache_calibration": false, + "magcache_K": 4, + "magcache_thresh": 0.12, + "magcache_retention_ratio": 0.2, + "magcache_ratios": [ + [ + 1.0, + 1.0124, + 1.00166, + 0.99791, + 0.99682, + 0.99634, + 0.99567, + 0.99416, + 0.99578, + 0.9957, + 0.99511, + 0.99535, + 0.99552, + 0.99541, + 0.9954, + 0.99489, + 0.99518, + 0.99484, + 0.99481, + 0.99415, + 0.99419, + 0.99396, + 0.99388, + 0.99349, + 0.99309, + 0.9927, + 0.99228, + 0.99171, + 0.99137, + 0.99068, + 0.99005, + 0.98944, + 0.98849, + 0.98758, + 0.98644, + 0.98504, + 0.9836, + 0.98202, + 0.97977, + 0.97717, + 0.9741, + 0.97003, + 0.96538, + 0.9593, + 0.95086, + 0.94013, + 0.92402, + 0.90241, + 0.86821, + 0.81838 + ], + [ + 1.0, + 1.02213, + 1.0041, + 1.00061, + 0.99762, + 0.99685, + 0.99586, + 0.99422, + 0.99575, + 0.99563, + 0.99506, + 0.99531, + 0.99549, + 0.99539, + 0.99536, + 0.99485, + 0.99514, + 0.99478, + 0.99479, + 0.99413, + 0.99416, + 0.99393, + 0.99386, + 0.99349, + 0.99304, + 0.9927, + 0.99226, + 0.9917, + 0.99135, + 0.99063, + 0.99003, + 0.98942, + 0.98849, + 0.98757, + 0.98643, + 0.98503, + 0.98359, + 0.98201, + 0.97978, + 0.97718, + 0.97411, + 0.97002, + 0.96541, + 0.95933, + 0.95089, + 0.94019, + 0.92414, + 0.9026, + 0.86868, + 0.81939 + ] + ] +} diff --git a/configs/caching/magcache/wan_t2v_mag_calibration_1_3b.json b/configs/caching/magcache/wan_t2v_mag_calibration_1_3b.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5b5fc2b0fe73751ac4fb617077dde43c4cf95c --- /dev/null +++ b/configs/caching/magcache/wan_t2v_mag_calibration_1_3b.json @@ -0,0 +1,125 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Mag", + "magcache_calibration": true, + "magcache_K": 4, + "magcache_thresh": 0.12, + "magcache_retention_ratio": 0.2, + "magcache_ratios": [ + [ + 1.0, + 1.0124, + 1.00166, + 0.99791, + 0.99682, + 0.99634, + 0.99567, + 0.99416, + 0.99578, + 0.9957, + 0.99511, + 0.99535, + 0.99552, + 0.99541, + 0.9954, + 0.99489, + 0.99518, + 0.99484, + 0.99481, + 0.99415, + 0.99419, + 0.99396, + 0.99388, + 0.99349, + 0.99309, + 0.9927, + 0.99228, + 0.99171, + 0.99137, + 0.99068, + 0.99005, + 0.98944, + 0.98849, + 0.98758, + 0.98644, + 0.98504, + 0.9836, + 0.98202, + 0.97977, + 0.97717, + 0.9741, + 0.97003, + 0.96538, + 0.9593, + 0.95086, + 0.94013, + 0.92402, + 0.90241, + 0.86821, + 0.81838 + ], + [ + 1.0, + 1.02213, + 1.0041, + 1.00061, + 0.99762, + 0.99685, + 0.99586, + 0.99422, + 0.99575, + 0.99563, + 0.99506, + 0.99531, + 0.99549, + 0.99539, + 0.99536, + 0.99485, + 0.99514, + 0.99478, + 0.99479, + 0.99413, + 0.99416, + 0.99393, + 0.99386, + 0.99349, + 0.99304, + 0.9927, + 0.99226, + 0.9917, + 0.99135, + 0.99063, + 0.99003, + 0.98942, + 0.98849, + 0.98757, + 0.98643, + 0.98503, + 0.98359, + 0.98201, + 0.97978, + 0.97718, + 0.97411, + 0.97002, + 0.96541, + 0.95933, + 0.95089, + 0.94019, + 0.92414, + 0.9026, + 0.86868, + 0.81939 + ] + ] +} diff --git a/configs/caching/taylorseer/wan_t2v_taylorseer.json b/configs/caching/taylorseer/wan_t2v_taylorseer.json new file mode 100644 index 0000000000000000000000000000000000000000..bb4b4414eaab046ce4452b4df7bfb79d6a6e60df --- /dev/null +++ b/configs/caching/taylorseer/wan_t2v_taylorseer.json @@ -0,0 +1,15 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "TaylorSeer" +} diff --git a/configs/caching/teacache/wan_i2v_tea_480p.json b/configs/caching/teacache/wan_i2v_tea_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..cde8fc6b465ac97743e6a0fd32c10b4fa6545f2a --- /dev/null +++ b/configs/caching/teacache/wan_i2v_tea_480p.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Tea", + "coefficients": [ + [ + 2.57151496e05, + -3.54229917e04, + 1.40286849e03, + -1.35890334e01, + 1.32517977e-01 + ], + [ + -3.02331670e02, + 2.23948934e02, + -5.25463970e01, + 5.87348440e00, + -2.01973289e-01 + ] + ], + "use_ret_steps": true, + "teacache_thresh": 0.26 +} diff --git a/configs/caching/teacache/wan_i2v_tea_720p.json b/configs/caching/teacache/wan_i2v_tea_720p.json new file mode 100644 index 0000000000000000000000000000000000000000..2ff56d0135e7412c4cf0ec6372057e30bc5c6a3e --- /dev/null +++ b/configs/caching/teacache/wan_i2v_tea_720p.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 1280, + "target_width": 720, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Tea", + "coefficients": [ + [ + 8.10705460e03, + 2.13393892e03, + -3.72934672e02, + 1.66203073e01, + -4.17769401e-02 + ], + [ + -114.36346466, + 65.26524496, + -18.82220707, + 4.91518089, + -0.23412683 + ] + ], + "use_ret_steps": true, + "teacache_thresh": 0.26 +} diff --git a/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json b/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..cdb266ecd46a79f90519e690c8e37fa4511427bd --- /dev/null +++ b/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json @@ -0,0 +1,33 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "feature_caching": "Tea", + "coefficients": [ + [ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02 + ], + [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01 + ] + ], + "use_ret_steps": true, + "teacache_thresh": 0.26 +} diff --git a/configs/caching/teacache/wan_ti2v_tea_720p.json b/configs/caching/teacache/wan_ti2v_tea_720p.json new file mode 100644 index 0000000000000000000000000000000000000000..1ea52413d71cb0f78be6a873690b95ee7c1a5659 --- /dev/null +++ b/configs/caching/teacache/wan_ti2v_tea_720p.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [4, 16, 16], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "fps": 24, + "feature_caching": "Tea", + "coefficients": [ + [], + [ 1.57472669e+05, -1.15702395e+05, 3.10761669e+04, -3.83116651e+03, 2.21608777e+02, -4.81179567e+00] + ], + "use_ret_steps": false, + "teacache_thresh": 0.26, + "use_image_encoder": false +} diff --git a/configs/causvid/wan_i2v_causvid.json b/configs/causvid/wan_i2v_causvid.json new file mode 100644 index 0000000000000000000000000000000000000000..3a1acf127754861573e99d9b8e6d074abc56cd9f --- /dev/null +++ b/configs/causvid/wan_i2v_causvid.json @@ -0,0 +1,29 @@ +{ + "infer_steps": 20, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "num_fragments": 3, + "num_frames": 21, + "num_frame_per_block": 7, + "num_blocks": 3, + "frame_seq_length": 1560, + "denoising_step_list": [ + 999, + 934, + 862, + 756, + 603, + 410, + 250, + 140, + 74 + ] +} diff --git a/configs/causvid/wan_t2v_causvid.json b/configs/causvid/wan_t2v_causvid.json new file mode 100644 index 0000000000000000000000000000000000000000..253f453b44c1046aa496c109c49204d71e013ca7 --- /dev/null +++ b/configs/causvid/wan_t2v_causvid.json @@ -0,0 +1,29 @@ +{ + "infer_steps": 9, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "num_fragments": 3, + "num_frames": 21, + "num_frame_per_block": 3, + "num_blocks": 7, + "frame_seq_length": 1560, + "denoising_step_list": [ + 999, + 934, + 862, + 756, + 603, + 410, + 250, + 140, + 74 + ] +} diff --git a/configs/changing_resolution/wan_i2v.json b/configs/changing_resolution/wan_i2v.json new file mode 100644 index 0000000000000000000000000000000000000000..7b828d461e7f8e13b6883dc1bf3fd0a83ca847e4 --- /dev/null +++ b/configs/changing_resolution/wan_i2v.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false, + "changing_resolution": true, + "resolution_rate": [ + 0.75 + ], + "changing_resolution_steps": [ + 20 + ] +} diff --git a/configs/changing_resolution/wan_i2v_U.json b/configs/changing_resolution/wan_i2v_U.json new file mode 100644 index 0000000000000000000000000000000000000000..0f8077504b18376115300d505e527e8e6ad2403c --- /dev/null +++ b/configs/changing_resolution/wan_i2v_U.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false, + "changing_resolution": true, + "resolution_rate": [ + 1.0, + 0.75 + ], + "changing_resolution_steps": [ + 5, + 25 + ] +} diff --git a/configs/changing_resolution/wan_t2v.json b/configs/changing_resolution/wan_t2v.json new file mode 100644 index 0000000000000000000000000000000000000000..78782775cedb9e340bf00ba5e17903583635ffa7 --- /dev/null +++ b/configs/changing_resolution/wan_t2v.json @@ -0,0 +1,21 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "changing_resolution": true, + "resolution_rate": [ + 0.75 + ], + "changing_resolution_steps": [ + 25 + ] +} diff --git a/configs/changing_resolution/wan_t2v_U.json b/configs/changing_resolution/wan_t2v_U.json new file mode 100644 index 0000000000000000000000000000000000000000..68531e6639263ba7f32e90c0612b7f7836d7ceab --- /dev/null +++ b/configs/changing_resolution/wan_t2v_U.json @@ -0,0 +1,23 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "changing_resolution": true, + "resolution_rate": [ + 1.0, + 0.75 + ], + "changing_resolution_steps": [ + 10, + 35 + ] +} diff --git a/configs/changing_resolution/wan_t2v_U_teacache.json b/configs/changing_resolution/wan_t2v_U_teacache.json new file mode 100644 index 0000000000000000000000000000000000000000..60ccc67a30176fd66cfbdeae2e568941c35984c3 --- /dev/null +++ b/configs/changing_resolution/wan_t2v_U_teacache.json @@ -0,0 +1,42 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "changing_resolution": true, + "resolution_rate": [ + 1.0, + 0.75 + ], + "changing_resolution_steps": [ + 10, + 35 + ], + "feature_caching": "Tea", + "coefficients": [ + [ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02 + ], + [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01 + ] + ], + "use_ret_steps": false, + "teacache_thresh": 0.1 +} diff --git a/configs/deploy/wan_i2v.json b/configs/deploy/wan_i2v.json new file mode 100644 index 0000000000000000000000000000000000000000..9f1fbde2bdc8507a2e14bb36eb45f62f282a5ce2 --- /dev/null +++ b/configs/deploy/wan_i2v.json @@ -0,0 +1,30 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "sub_servers": { + "dit": [ + "http://localhost:9000" + ], + "prompt_enhancer": [ + "http://localhost:9001" + ], + "text_encoders": [ + "http://localhost:9002" + ], + "image_encoder": [ + "http://localhost:9003" + ], + "vae_model": [ + "http://localhost:9004" + ] + } +} diff --git a/configs/deploy/wan_t2v.json b/configs/deploy/wan_t2v.json new file mode 100644 index 0000000000000000000000000000000000000000..43656d7fa0a2c45125677ca2fb6e85ef37d9087a --- /dev/null +++ b/configs/deploy/wan_t2v.json @@ -0,0 +1,31 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "sub_servers": { + "dit": [ + "http://localhost:9000" + ], + "prompt_enhancer": [ + "http://localhost:9001" + ], + "text_encoders": [ + "http://localhost:9002" + ], + "image_encoder": [ + "http://localhost:9003" + ], + "vae_model": [ + "http://localhost:9004" + ] + } +} diff --git a/configs/dist_infer/wan22_moe_i2v_cfg.json b/configs/dist_infer/wan22_moe_i2v_cfg.json new file mode 100644 index 0000000000000000000000000000000000000000..bafa409ea8077198c458db5a0496e7a726b75074 --- /dev/null +++ b/configs/dist_infer/wan22_moe_i2v_cfg.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary": 0.900, + "use_image_encoder": false, + "parallel": { + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json b/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..91a15a0d6200a00e3f882079a8da2aac009b0fd8 --- /dev/null +++ b/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "boundary": 0.900, + "use_image_encoder": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan22_moe_i2v_ulysses.json b/configs/dist_infer/wan22_moe_i2v_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..eb0759b39a65be5f02f6e94bcd055de5c94ce406 --- /dev/null +++ b/configs/dist_infer/wan22_moe_i2v_ulysses.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary": 0.900, + "use_image_encoder": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/dist_infer/wan22_moe_t2v_cfg.json b/configs/dist_infer/wan22_moe_t2v_cfg.json new file mode 100644 index 0000000000000000000000000000000000000000..f47aad82f81120969cb2b177dddb05280fda7968 --- /dev/null +++ b/configs/dist_infer/wan22_moe_t2v_cfg.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 4.0, + 3.0 + ], + "sample_shift": 12.0, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary": 0.875, + "parallel": { + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan22_moe_t2v_cfg_ulysses.json b/configs/dist_infer/wan22_moe_t2v_cfg_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..645e6d7f5216dc3033461427ad39f6a06f792c98 --- /dev/null +++ b/configs/dist_infer/wan22_moe_t2v_cfg_ulysses.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 4.0, + 3.0 + ], + "sample_shift": 12.0, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary": 0.875, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan22_moe_t2v_ulysses.json b/configs/dist_infer/wan22_moe_t2v_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..03aaf9246536c23330cb1f502c66f815abef8fe3 --- /dev/null +++ b/configs/dist_infer/wan22_moe_t2v_ulysses.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 4.0, + 3.0 + ], + "sample_shift": 12.0, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary": 0.875, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/dist_infer/wan22_ti2v_i2v_cfg.json b/configs/dist_infer/wan22_ti2v_i2v_cfg.json new file mode 100644 index 0000000000000000000000000000000000000000..4754ff30d382e3dd0732bca8e3ba51f9613ac4a7 --- /dev/null +++ b/configs/dist_infer/wan22_ti2v_i2v_cfg.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "fps": 24, + "use_image_encoder": false, + "parallel": { + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan22_ti2v_i2v_cfg_ulysses.json b/configs/dist_infer/wan22_ti2v_i2v_cfg_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..9e340f400f80ad10a579e07c3c294265cff18e94 --- /dev/null +++ b/configs/dist_infer/wan22_ti2v_i2v_cfg_ulysses.json @@ -0,0 +1,27 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "fps": 24, + "use_image_encoder": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan22_ti2v_i2v_ulysses.json b/configs/dist_infer/wan22_ti2v_i2v_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..56d0f65e5aa66d7c7b42b2c7a651a4ab6c1c1930 --- /dev/null +++ b/configs/dist_infer/wan22_ti2v_i2v_ulysses.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "fps": 24, + "use_image_encoder": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/dist_infer/wan22_ti2v_t2v_cfg.json b/configs/dist_infer/wan22_ti2v_t2v_cfg.json new file mode 100644 index 0000000000000000000000000000000000000000..c88cb2a10144a23cda3daebf5bc06d341db5a5e0 --- /dev/null +++ b/configs/dist_infer/wan22_ti2v_t2v_cfg.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "fps": 24, + "parallel": { + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan22_ti2v_t2v_cfg_ulysses.json b/configs/dist_infer/wan22_ti2v_t2v_cfg_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..79f4f03f0d1f16bf0825aabd1975c3105ea19fb6 --- /dev/null +++ b/configs/dist_infer/wan22_ti2v_t2v_cfg_ulysses.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "fps": 24, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan22_ti2v_t2v_ulysses.json b/configs/dist_infer/wan22_ti2v_t2v_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..12b18509e0d98a2bbe9c89b054d96bfa79d66f5e --- /dev/null +++ b/configs/dist_infer/wan22_ti2v_t2v_ulysses.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "fps": 24, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/dist_infer/wan_i2v_dist_cfg_ulysses.json b/configs/dist_infer/wan_i2v_dist_cfg_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..001f1f0c6e813e6c3d3a7fcd1014a973037b91f5 --- /dev/null +++ b/configs/dist_infer/wan_i2v_dist_cfg_ulysses.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan_i2v_dist_ring.json b/configs/dist_infer/wan_i2v_dist_ring.json new file mode 100644 index 0000000000000000000000000000000000000000..fe2608aae5ed52a1873a3df346b00bf59e54e404 --- /dev/null +++ b/configs/dist_infer/wan_i2v_dist_ring.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ring" + } +} diff --git a/configs/dist_infer/wan_i2v_dist_ulysses.json b/configs/dist_infer/wan_i2v_dist_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..a8af87968ceb00ee70a3bd8a2608e52cb88ca738 --- /dev/null +++ b/configs/dist_infer/wan_i2v_dist_ulysses.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/dist_infer/wan_t2v_dist_cfg.json b/configs/dist_infer/wan_t2v_dist_cfg.json new file mode 100644 index 0000000000000000000000000000000000000000..49bc48161ab648cf7bdb9358b10e42b3376aba0e --- /dev/null +++ b/configs/dist_infer/wan_t2v_dist_cfg.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan_t2v_dist_cfg_ring.json b/configs/dist_infer/wan_t2v_dist_cfg_ring.json new file mode 100644 index 0000000000000000000000000000000000000000..50450f43e56cb27770305903996a3ee0674f0096 --- /dev/null +++ b/configs/dist_infer/wan_t2v_dist_cfg_ring.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ring", + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan_t2v_dist_cfg_ulysses.json b/configs/dist_infer/wan_t2v_dist_cfg_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6808dfd7c8c270221a966dea79da3ad54848be --- /dev/null +++ b/configs/dist_infer/wan_t2v_dist_cfg_ulysses.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + } +} diff --git a/configs/dist_infer/wan_t2v_dist_ring.json b/configs/dist_infer/wan_t2v_dist_ring.json new file mode 100644 index 0000000000000000000000000000000000000000..a28549bd6c930885adcc691e63d7024e041d7861 --- /dev/null +++ b/configs/dist_infer/wan_t2v_dist_ring.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ring" + } +} diff --git a/configs/dist_infer/wan_t2v_dist_ulysses.json b/configs/dist_infer/wan_t2v_dist_ulysses.json new file mode 100644 index 0000000000000000000000000000000000000000..c9b9bc2ec71a41ea5f468dc6cf09e072e764e248 --- /dev/null +++ b/configs/dist_infer/wan_t2v_dist_ulysses.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/distill/wan_i2v_distill_4step_cfg.json b/configs/distill/wan_i2v_distill_4step_cfg.json new file mode 100644 index 0000000000000000000000000000000000000000..519f58a4ed1778b4a8b979482a9f7b7b34ea1571 --- /dev/null +++ b/configs/distill/wan_i2v_distill_4step_cfg.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ] +} diff --git a/configs/distill/wan_i2v_distill_4step_cfg_4090.json b/configs/distill/wan_i2v_distill_4step_cfg_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..c6863a558a5d1d9eafc8ce7825f64cfc2ee58be8 --- /dev/null +++ b/configs/distill/wan_i2v_distill_4step_cfg_4090.json @@ -0,0 +1,29 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "clip_cpu_offload": false, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f" +} diff --git a/configs/distill/wan_i2v_distill_4step_cfg_4090_lora.json b/configs/distill/wan_i2v_distill_4step_cfg_4090_lora.json new file mode 100644 index 0000000000000000000000000000000000000000..100eb453da684c3c6c1e8dd751b68671b149af55 --- /dev/null +++ b/configs/distill/wan_i2v_distill_4step_cfg_4090_lora.json @@ -0,0 +1,35 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "clip_cpu_offload": false, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f", + "lora_configs": [ + { + "path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/distill/wan_i2v_distill_4step_cfg_lora.json b/configs/distill/wan_i2v_distill_4step_cfg_lora.json new file mode 100644 index 0000000000000000000000000000000000000000..51797dd123006864bf4dedc81cf1c76b21751eaf --- /dev/null +++ b/configs/distill/wan_i2v_distill_4step_cfg_lora.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "denoising_step_list": [1000, 750, 500, 250], + "lora_configs": [ + { + "path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/distill/wan_t2v_distill_4step_cfg.json b/configs/distill/wan_t2v_distill_4step_cfg.json new file mode 100644 index 0000000000000000000000000000000000000000..89ea0675604af80a6042cd0caa36a744b7d38098 --- /dev/null +++ b/configs/distill/wan_t2v_distill_4step_cfg.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ] +} diff --git a/configs/distill/wan_t2v_distill_4step_cfg_4090.json b/configs/distill/wan_t2v_distill_4step_cfg_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..0255d0923a7f21a832b09d9943044e1c355d8efd --- /dev/null +++ b/configs/distill/wan_t2v_distill_4step_cfg_4090.json @@ -0,0 +1,30 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 6, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "clip_cpu_offload": false, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f" +} diff --git a/configs/distill/wan_t2v_distill_4step_cfg_dynamic.json b/configs/distill/wan_t2v_distill_4step_cfg_dynamic.json new file mode 100644 index 0000000000000000000000000000000000000000..f05e7283308161084ab5280d32186c7ee06d9460 --- /dev/null +++ b/configs/distill/wan_t2v_distill_4step_cfg_dynamic.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 5, + "enable_cfg": false, + "enable_dynamic_cfg": true, + "cfg_scale": 4.0, + "cpu_offload": false, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ] +} diff --git a/configs/distill/wan_t2v_distill_4step_cfg_lora.json b/configs/distill/wan_t2v_distill_4step_cfg_lora.json new file mode 100644 index 0000000000000000000000000000000000000000..0388af4dd0c11cc2e7bd0c3668608e3ac9a2e946 --- /dev/null +++ b/configs/distill/wan_t2v_distill_4step_cfg_lora.json @@ -0,0 +1,21 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "denoising_step_list": [1000, 750, 500, 250], + "lora_configs": [ + { + "path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/distill/wan_t2v_distill_4step_cfg_lora_4090.json b/configs/distill/wan_t2v_distill_4step_cfg_lora_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..5e87e18fdb3303202cb1e943785b15f728f42f88 --- /dev/null +++ b/configs/distill/wan_t2v_distill_4step_cfg_lora_4090.json @@ -0,0 +1,36 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 6, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "clip_cpu_offload": false, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f", + "lora_configs": [ + { + "path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16.json b/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..ce108bbdce5d95a17bc39768648ec2f106b61fd9 --- /dev/null +++ b/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2", + "cpu_offload": true, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false +} diff --git a/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16_dist.json b/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16_dist.json new file mode 100644 index 0000000000000000000000000000000000000000..8049eda8d56f7eaa6fd0c52b5509c14a1fda38c9 --- /dev/null +++ b/configs/hunyuan_video_15/4090/hy15_t2v_480p_bf16_dist.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 7.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2", + "parallel": { + "seq_p_attn_type": "ulysses-4090", + "seq_p_size": 8 + } +} diff --git a/configs/hunyuan_video_15/4090/hy15_t2v_480p_fp8.json b/configs/hunyuan_video_15/4090/hy15_t2v_480p_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..6096bf90a780e3f1eae249d4b5984071c2fbebf6 --- /dev/null +++ b/configs/hunyuan_video_15/4090/hy15_t2v_480p_fp8.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2", + "cpu_offload": true, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false, + "dit_quantized_ckpt": "/path/to/480p_t2v_fp8.safetensors", + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "qwen25vl_quantized_ckpt": "/path/to/qwen25vl_fp8.safetensors", + "qwen25vl_quantized": true, + "qwen25vl_quant_scheme": "fp8-q8f" +} diff --git a/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16.json b/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..2668b2b0b42c937837bbf8a2ce96924e450eae65 --- /dev/null +++ b/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn3", + "cpu_offload": true, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false +} diff --git a/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16_dist.json b/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16_dist.json new file mode 100644 index 0000000000000000000000000000000000000000..d39a77310efddbca9a3361e00cf1856cbe812108 --- /dev/null +++ b/configs/hunyuan_video_15/5090/hy15_t2v_480p_bf16_dist.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn3", + "cpu_offload": true, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false, + "parallel": { + "seq_p_attn_type": "ulysses", + "seq_p_size": 8 + } +} diff --git a/configs/hunyuan_video_15/5090/hy15_t2v_480p_fp8.json b/configs/hunyuan_video_15/5090/hy15_t2v_480p_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..303adcd24821b3fd89bf2abd45df35b7498e5aca --- /dev/null +++ b/configs/hunyuan_video_15/5090/hy15_t2v_480p_fp8.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn3", + "qwen25vl_cpu_offload": true, + "dit_quantized_ckpt": "/path/to/480p_t2v_fp8.safetensors", + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "qwen25vl_quantized_ckpt": "/path/to/qwen25vl_fp8.safetensors", + "qwen25vl_quantized": true, + "qwen25vl_quant_scheme": "fp8-sgl" +} diff --git a/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache.json b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache.json new file mode 100644 index 0000000000000000000000000000000000000000..91fe4b27c668bc1f4512a793bb158394a8ff472d --- /dev/null +++ b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_i2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2", + "feature_caching": "Mag", + "magcache_calibration": false, + "magcache_K": 6, + "magcache_thresh": 0.24, + "magcache_retention_ratio": 0.2, + "magcache_ratios": [[1.0, 1.01562, 1.00781, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.99609, 1.0, 0.99609, 1.0, 0.99609, 1.0, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99219, 0.99219, 0.99219, 0.98828, 0.98828, 0.98828, 0.98828, 0.98438, 0.98438, 0.98047, 0.98047, 0.97656, 0.97266, 0.96484, 0.95703, 0.94922, 0.92969, 0.91016, 0.88672], [1.0, 1.02344, 1.00781, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.99609, 1.0, 0.99609, 1.0, 0.99609, 1.0, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99609, 0.99219, 0.99219, 0.99219, 0.99219, 0.98828, 0.98828, 0.98828, 0.98438, 0.98438, 0.98047, 0.98047, 0.97656, 0.97266, 0.96484, 0.95703, 0.94922, 0.93359, 0.91016, 0.88672]] +} diff --git a/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache_calibration.json b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache_calibration.json new file mode 100644 index 0000000000000000000000000000000000000000..ba180071145dccc2617387780b9d091188e9dd4e --- /dev/null +++ b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_magcache_calibration.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_i2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2", + "feature_caching": "Mag", + "magcache_calibration": true +} diff --git a/configs/hunyuan_video_15/cache/hy_15_i2v_480p_teacache.json b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_teacache.json new file mode 100644 index 0000000000000000000000000000000000000000..6872b0b2048a12b6b975133ee59e7f9cdea284df --- /dev/null +++ b/configs/hunyuan_video_15/cache/hy_15_i2v_480p_teacache.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_i2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "flash_attn3", + "feature_caching": "Tea", + "coefficients": [8.08528429e+03 ,-2.44607178e+03, 2.49489589e+02, -9.10697865e+00, 1.20261379e-01], + "teacache_thresh": 0.15, + "cpu_offload": false, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false +} diff --git a/configs/hunyuan_video_15/cache/hy_15_i2v_720p_teacache.json b/configs/hunyuan_video_15/cache/hy_15_i2v_720p_teacache.json new file mode 100644 index 0000000000000000000000000000000000000000..252a80b9b3fdb57e4db219506631e56c27ffa98f --- /dev/null +++ b/configs/hunyuan_video_15/cache/hy_15_i2v_720p_teacache.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 50, + "transformer_model_name": "720p_i2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 7.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "flash_attn3", + "feature_caching": "Tea", + "coefficients": [3.84300014e+03, -1.39247433e+03, 1.69167679e+02, -7.07679232e+00, 1.02419011e-01], + "teacache_thresh": 0.15, + "cpu_offload": false, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false +} diff --git a/configs/hunyuan_video_15/cache/hy_15_t2v_480p_teacache.json b/configs/hunyuan_video_15/cache/hy_15_t2v_480p_teacache.json new file mode 100644 index 0000000000000000000000000000000000000000..d2925d37deb8b0c068ace1f1b3b3bb8f38760bff --- /dev/null +++ b/configs/hunyuan_video_15/cache/hy_15_t2v_480p_teacache.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "flash_attn3", + "feature_caching": "Tea", + "coefficients": [-2.97190924e+04, 2.22834983e+04, -4.37418360e+03, 3.39340251e+02, -1.01365855e+01, 1.29101768e-01], + "teacache_thresh": 0.15, + "cpu_offload": false, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false +} diff --git a/configs/hunyuan_video_15/cache/hy_15_t2v_720p_teacache.json b/configs/hunyuan_video_15/cache/hy_15_t2v_720p_teacache.json new file mode 100644 index 0000000000000000000000000000000000000000..97b24eac3dd07ae1047b20ea60fcef36fd021a40 --- /dev/null +++ b/configs/hunyuan_video_15/cache/hy_15_t2v_720p_teacache.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 50, + "transformer_model_name": "729p_t2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "flash_attn3", + "feature_caching": "Tea", + "coefficients": [-3.08907507e+04, 1.67786188e+04, -3.19178643e+03, 2.60740519e+02, -8.19205881e+00, 1.07913775e-01], + "teacache_thresh": 0.15, + "cpu_offload": false, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false +} diff --git a/configs/hunyuan_video_15/fp8comm/hy15_i2v_480p_int8_offload_dist_fp8_comm.json b/configs/hunyuan_video_15/fp8comm/hy15_i2v_480p_int8_offload_dist_fp8_comm.json new file mode 100644 index 0000000000000000000000000000000000000000..a9fc93364e7dbaaec0e8acef4804dcf13ed499c9 --- /dev/null +++ b/configs/hunyuan_video_15/fp8comm/hy15_i2v_480p_int8_offload_dist_fp8_comm.json @@ -0,0 +1,23 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_i2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": false, + "attn_type": "sage_attn3", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false, + "dit_quantized_ckpt": "/path/to/quant_model.safetensors", + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "parallel": { + "seq_p_size": 8, + "seq_p_fp8_comm": true, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/hunyuan_video_15/hunyuan_video_i2v_480p.json b/configs/hunyuan_video_15/hunyuan_video_i2v_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..32624592b19679d3dc50e2d8baf236232cbe2f57 --- /dev/null +++ b/configs/hunyuan_video_15/hunyuan_video_i2v_480p.json @@ -0,0 +1,11 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_i2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2" +} diff --git a/configs/hunyuan_video_15/hunyuan_video_i2v_720p.json b/configs/hunyuan_video_15/hunyuan_video_i2v_720p.json new file mode 100644 index 0000000000000000000000000000000000000000..f993a5caeeb347c5e612aac024fc374e9b18c817 --- /dev/null +++ b/configs/hunyuan_video_15/hunyuan_video_i2v_720p.json @@ -0,0 +1,11 @@ +{ + "infer_steps": 50, + "transformer_model_name": "720p_i2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 7.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2" +} diff --git a/configs/hunyuan_video_15/hunyuan_video_i2v_720p_cfg_distilled.json b/configs/hunyuan_video_15/hunyuan_video_i2v_720p_cfg_distilled.json new file mode 100644 index 0000000000000000000000000000000000000000..a467d446b0071724c928c7583444b7da9614cfe0 --- /dev/null +++ b/configs/hunyuan_video_15/hunyuan_video_i2v_720p_cfg_distilled.json @@ -0,0 +1,11 @@ +{ + "infer_steps": 50, + "transformer_model_name": "720p_i2v_distilled", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 9.0, + "sample_guide_scale": 6.0, + "enable_cfg": false, + "attn_type": "sage_attn2" +} diff --git a/configs/hunyuan_video_15/hunyuan_video_t2v_480p.json b/configs/hunyuan_video_15/hunyuan_video_t2v_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..28ca4b9d6cf1da8530a4a17333942342b913ec90 --- /dev/null +++ b/configs/hunyuan_video_15/hunyuan_video_t2v_480p.json @@ -0,0 +1,12 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2" +} diff --git a/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json b/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json new file mode 100644 index 0000000000000000000000000000000000000000..f4d9ceb734956865d5a90a1cb30bdfe18a130f2a --- /dev/null +++ b/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 4, + "transformer_model_name": "480p_t2v", + "fps": 16, + "target_video_length": 81, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": -1.0, + "enable_cfg": false, + "attn_type": "sage_attn2", + "dit_original_ckpt": "hunyuanvideo-1.5/distill_models/480p_t2v/distill_model.safetensors", + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ] +} diff --git a/configs/hunyuan_video_15/hunyuan_video_t2v_720p.json b/configs/hunyuan_video_15/hunyuan_video_t2v_720p.json new file mode 100644 index 0000000000000000000000000000000000000000..f8923c1520129693bbfcd0e92396014b357b2995 --- /dev/null +++ b/configs/hunyuan_video_15/hunyuan_video_t2v_720p.json @@ -0,0 +1,12 @@ +{ + "infer_steps": 50, + "transformer_model_name": "720p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 9.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2" +} diff --git a/configs/hunyuan_video_15/lightae/hy15_t2v_480p_bf16.json b/configs/hunyuan_video_15/lightae/hy15_t2v_480p_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..6fc26bdb0e4ce8009fb72d66e480d720cb00f138 --- /dev/null +++ b/configs/hunyuan_video_15/lightae/hy15_t2v_480p_bf16.json @@ -0,0 +1,14 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 7.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "flash_attn3", + "use_tae": true, + "tae_path": "/path/to/lighttae" +} diff --git a/configs/hunyuan_video_15/offload/hy15_t2v_480p_bf16.json b/configs/hunyuan_video_15/offload/hy15_t2v_480p_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..1055492b1922da21c4efb9c096f5d14e31c9c96e --- /dev/null +++ b/configs/hunyuan_video_15/offload/hy15_t2v_480p_bf16.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 7.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "sage_attn2", + "cpu_offload": true, + "offload_granularity": "block", + "vae_cpu_offload": false, + "byt5_cpu_offload": false, + "qwen25vl_cpu_offload": true, + "siglip_cpu_offload": false +} diff --git a/configs/hunyuan_video_15/quant/hy15_t2v_480p_fp8.json b/configs/hunyuan_video_15/quant/hy15_t2v_480p_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..196a5b7eccfa5536daf34abd0aff5935498c0a9a --- /dev/null +++ b/configs/hunyuan_video_15/quant/hy15_t2v_480p_fp8.json @@ -0,0 +1,15 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_t2v", + "fps": 24, + "target_video_length": 121, + "aspect_ratio": "16:9", + "vae_stride": [4, 16, 16], + "sample_shift": 7.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "flash_attn3", + "dit_quantized_ckpt": "/path/to/quant_model.safetensors", + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl" +} diff --git a/configs/hunyuan_video_15/vsr/hy15_i2v_480p.json b/configs/hunyuan_video_15/vsr/hy15_i2v_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..4f774fff6b1380d224e838b7403ee79e1428492e --- /dev/null +++ b/configs/hunyuan_video_15/vsr/hy15_i2v_480p.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 50, + "transformer_model_name": "480p_i2v", + "fps": 24, + "target_video_length": 121, + "vae_stride": [4, 16, 16], + "sample_shift": 5.0, + "sample_guide_scale": 6.0, + "enable_cfg": true, + "attn_type": "flash_attn3", + "video_super_resolution": { + "sr_version": "720p_sr_distilled", + "flow_shift": 2.0, + "base_resolution": "480p", + "guidance_scale": 1.0, + "num_inference_steps": 6, + "use_meanflow": true + } +} diff --git a/configs/matrix_game2/matrix_game2_gta_drive.json b/configs/matrix_game2/matrix_game2_gta_drive.json new file mode 100644 index 0000000000000000000000000000000000000000..d1952d813a4f56557f55d588a7414f18599595fa --- /dev/null +++ b/configs/matrix_game2/matrix_game2_gta_drive.json @@ -0,0 +1,72 @@ +{ + "infer_steps": 50, + "target_video_length": 150, + "num_output_frames": 150, + "text_len": 512, + "target_height": 352, + "target_width": 640, + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + "seed": 0, + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "sf_config": { + "local_attn_size": 6, + "shift": 5.0, + "num_frame_per_block": 3, + "num_transformer_blocks": 30, + "frame_seq_length": 880, + "num_output_frames": 150, + "num_inference_steps": 1000, + "denoising_step_list": [1000.0000, 908.8427, 713.9794] + }, + "sub_model_folder": "gta_distilled_model", + "sub_model_name": "gta_keyboard2dim.safetensors", + "mode": "gta_drive", + "streaming": false, + "action_config": { + "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + "enable_keyboard": true, + "enable_mouse": true, + "heads_num": 16, + "hidden_size": 128, + "img_hidden_size": 1536, + "keyboard_dim_in": 4, + "keyboard_hidden_dim": 1024, + "mouse_dim_in": 2, + "mouse_hidden_dim": 1024, + "mouse_qk_dim_list": [ + 8, + 28, + 28 + ], + "patch_size": [ + 1, + 2, + 2 + ], + "qk_norm": true, + "qkv_bias": false, + "rope_dim_list": [ + 8, + 28, + 28 + ], + "rope_theta": 256, + "vae_time_compression_ratio": 4, + "windows_size": 3 + }, + "dim": 1536, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_dim": 36, + "inject_sample_info": false, + "model_type": "i2v", + "num_heads": 12, + "num_layers": 30, + "out_dim": 16 +} diff --git a/configs/matrix_game2/matrix_game2_gta_drive_streaming.json b/configs/matrix_game2/matrix_game2_gta_drive_streaming.json new file mode 100644 index 0000000000000000000000000000000000000000..584eab096d7c8c380d927c8518996c48f02c1119 --- /dev/null +++ b/configs/matrix_game2/matrix_game2_gta_drive_streaming.json @@ -0,0 +1,72 @@ +{ + "infer_steps": 50, + "target_video_length": 360, + "num_output_frames": 360, + "text_len": 512, + "target_height": 352, + "target_width": 640, + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + "seed": 0, + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "sf_config": { + "local_attn_size": 6, + "shift": 5.0, + "num_frame_per_block": 3, + "num_transformer_blocks": 30, + "frame_seq_length": 880, + "num_output_frames": 360, + "num_inference_steps": 1000, + "denoising_step_list": [1000.0000, 908.8427, 713.9794] + }, + "sub_model_folder": "gta_distilled_model", + "sub_model_name": "gta_keyboard2dim.safetensors", + "mode": "gta_drive", + "streaming": true, + "action_config": { + "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + "enable_keyboard": true, + "enable_mouse": true, + "heads_num": 16, + "hidden_size": 128, + "img_hidden_size": 1536, + "keyboard_dim_in": 4, + "keyboard_hidden_dim": 1024, + "mouse_dim_in": 2, + "mouse_hidden_dim": 1024, + "mouse_qk_dim_list": [ + 8, + 28, + 28 + ], + "patch_size": [ + 1, + 2, + 2 + ], + "qk_norm": true, + "qkv_bias": false, + "rope_dim_list": [ + 8, + 28, + 28 + ], + "rope_theta": 256, + "vae_time_compression_ratio": 4, + "windows_size": 3 + }, + "dim": 1536, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_dim": 36, + "inject_sample_info": false, + "model_type": "i2v", + "num_heads": 12, + "num_layers": 30, + "out_dim": 16 +} diff --git a/configs/matrix_game2/matrix_game2_templerun.json b/configs/matrix_game2/matrix_game2_templerun.json new file mode 100644 index 0000000000000000000000000000000000000000..63f1aab5e4b5ff43535e6e2cf504beafd550f7a4 --- /dev/null +++ b/configs/matrix_game2/matrix_game2_templerun.json @@ -0,0 +1,65 @@ +{ + "infer_steps": 50, + "target_video_length": 150, + "num_output_frames": 150, + "text_len": 512, + "target_height": 352, + "target_width": 640, + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + "seed": 0, + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "sf_config": { + "local_attn_size": 6, + "shift": 5.0, + "num_frame_per_block": 3, + "num_transformer_blocks": 30, + "frame_seq_length": 880, + "num_output_frames": 150, + "num_inference_steps": 1000, + "denoising_step_list": [1000.0000, 908.8427, 713.9794] + }, + "sub_model_folder": "templerun_distilled_model", + "sub_model_name": "templerun_7dim_onlykey.safetensors", + "mode": "templerun", + "streaming": false, + "action_config": { + "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + "enable_keyboard": true, + "enable_mouse": false, + "heads_num": 16, + "hidden_size": 128, + "img_hidden_size": 1536, + "keyboard_dim_in": 7, + "keyboard_hidden_dim": 1024, + "patch_size": [ + 1, + 2, + 2 + ], + "qk_norm": true, + "qkv_bias": false, + "rope_dim_list": [ + 8, + 28, + 28 + ], + "rope_theta": 256, + "vae_time_compression_ratio": 4, + "windows_size": 3 + }, + "dim": 1536, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_dim": 36, + "inject_sample_info": false, + "model_type": "i2v", + "num_heads": 12, + "num_layers": 30, + "out_dim": 16 +} diff --git a/configs/matrix_game2/matrix_game2_templerun_streaming.json b/configs/matrix_game2/matrix_game2_templerun_streaming.json new file mode 100644 index 0000000000000000000000000000000000000000..90fd2955856a5182c19deac209b9e87e789d8338 --- /dev/null +++ b/configs/matrix_game2/matrix_game2_templerun_streaming.json @@ -0,0 +1,72 @@ +{ + "infer_steps": 50, + "target_video_length": 360, + "num_output_frames": 360, + "text_len": 512, + "target_height": 352, + "target_width": 640, + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + "seed": 0, + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "sf_config": { + "local_attn_size": 6, + "shift": 5.0, + "num_frame_per_block": 3, + "num_transformer_blocks": 30, + "frame_seq_length": 880, + "num_output_frames": 360, + "num_inference_steps": 1000, + "denoising_step_list": [1000.0000, 908.8427, 713.9794] + }, + "sub_model_folder": "templerun_distilled_model", + "sub_model_name": "templerun_7dim_onlykey.safetensors", + "mode": "templerun", + "streaming": true, + "action_config": { + "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + "enable_keyboard": true, + "enable_mouse": true, + "heads_num": 16, + "hidden_size": 128, + "img_hidden_size": 1536, + "keyboard_dim_in": 4, + "keyboard_hidden_dim": 1024, + "mouse_dim_in": 2, + "mouse_hidden_dim": 1024, + "mouse_qk_dim_list": [ + 8, + 28, + 28 + ], + "patch_size": [ + 1, + 2, + 2 + ], + "qk_norm": true, + "qkv_bias": false, + "rope_dim_list": [ + 8, + 28, + 28 + ], + "rope_theta": 256, + "vae_time_compression_ratio": 4, + "windows_size": 3 + }, + "dim": 1536, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_dim": 36, + "inject_sample_info": false, + "model_type": "i2v", + "num_heads": 12, + "num_layers": 30, + "out_dim": 16 +} diff --git a/configs/matrix_game2/matrix_game2_universal.json b/configs/matrix_game2/matrix_game2_universal.json new file mode 100644 index 0000000000000000000000000000000000000000..0b801d8492b19a9d7567534f775bb3804af53a1d --- /dev/null +++ b/configs/matrix_game2/matrix_game2_universal.json @@ -0,0 +1,72 @@ +{ + "infer_steps": 50, + "target_video_length": 150, + "num_output_frames": 150, + "text_len": 512, + "target_height": 352, + "target_width": 640, + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + "seed": 0, + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "sf_config": { + "local_attn_size": 6, + "shift": 5.0, + "num_frame_per_block": 3, + "num_transformer_blocks": 30, + "frame_seq_length": 880, + "num_output_frames": 150, + "num_inference_steps": 1000, + "denoising_step_list": [1000.0000, 908.8427, 713.9794] + }, + "sub_model_folder": "base_distilled_model", + "sub_model_name": "base_distill.safetensors", + "mode": "universal", + "streaming": false, + "action_config": { + "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + "enable_keyboard": true, + "enable_mouse": true, + "heads_num": 16, + "hidden_size": 128, + "img_hidden_size": 1536, + "keyboard_dim_in": 4, + "keyboard_hidden_dim": 1024, + "mouse_dim_in": 2, + "mouse_hidden_dim": 1024, + "mouse_qk_dim_list": [ + 8, + 28, + 28 + ], + "patch_size": [ + 1, + 2, + 2 + ], + "qk_norm": true, + "qkv_bias": false, + "rope_dim_list": [ + 8, + 28, + 28 + ], + "rope_theta": 256, + "vae_time_compression_ratio": 4, + "windows_size": 3 + }, + "dim": 1536, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_dim": 36, + "inject_sample_info": false, + "model_type": "i2v", + "num_heads": 12, + "num_layers": 30, + "out_dim": 16 +} diff --git a/configs/matrix_game2/matrix_game2_universal_streaming.json b/configs/matrix_game2/matrix_game2_universal_streaming.json new file mode 100644 index 0000000000000000000000000000000000000000..f2c4575a78dfe3fed7f02819ad6f70758b637866 --- /dev/null +++ b/configs/matrix_game2/matrix_game2_universal_streaming.json @@ -0,0 +1,72 @@ +{ + "infer_steps": 50, + "target_video_length": 360, + "num_output_frames": 360, + "text_len": 512, + "target_height": 352, + "target_width": 640, + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + "seed": 0, + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "sf_config": { + "local_attn_size": 6, + "shift": 5.0, + "num_frame_per_block": 3, + "num_transformer_blocks": 30, + "frame_seq_length": 880, + "num_output_frames": 360, + "num_inference_steps": 1000, + "denoising_step_list": [1000.0000, 908.8427, 713.9794] + }, + "sub_model_folder": "base_distilled_model", + "sub_model_name": "base_distill.safetensors", + "mode": "universal", + "streaming": true, + "action_config": { + "blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + "enable_keyboard": true, + "enable_mouse": true, + "heads_num": 16, + "hidden_size": 128, + "img_hidden_size": 1536, + "keyboard_dim_in": 4, + "keyboard_hidden_dim": 1024, + "mouse_dim_in": 2, + "mouse_hidden_dim": 1024, + "mouse_qk_dim_list": [ + 8, + 28, + 28 + ], + "patch_size": [ + 1, + 2, + 2 + ], + "qk_norm": true, + "qkv_bias": false, + "rope_dim_list": [ + 8, + 28, + 28 + ], + "rope_theta": 256, + "vae_time_compression_ratio": 4, + "windows_size": 3 + }, + "dim": 1536, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_dim": 36, + "inject_sample_info": false, + "model_type": "i2v", + "num_heads": 12, + "num_layers": 30, + "out_dim": 16 +} diff --git a/configs/model_pipeline.json b/configs/model_pipeline.json new file mode 100644 index 0000000000000000000000000000000000000000..7f2f62b0a4f9f360e03b7ce3969e46ec1d5c65df --- /dev/null +++ b/configs/model_pipeline.json @@ -0,0 +1,179 @@ +{ + "data": + { + "t2v": { + "wan2.1-1.3B": { + "single_stage": { + "pipeline": { + "inputs": [], + "outputs": ["output_video"] + } + }, + "multi_stage": { + "text_encoder": { + "inputs": [], + "outputs": ["text_encoder_output"] + }, + "dit": { + "inputs": ["text_encoder_output"], + "outputs": ["latents"] + }, + "vae_decoder": { + "inputs": ["latents"], + "outputs": ["output_video"] + } + } + }, + "self-forcing-dmd": { + "single_stage": { + "pipeline": { + "inputs": [], + "outputs": ["output_video"] + } + } + } + }, + "i2v": { + "wan2.1-14B-480P": { + "single_stage": { + "pipeline": { + "inputs": ["input_image"], + "outputs": ["output_video"] + } + }, + "multi_stage": { + "text_encoder": { + "inputs": ["input_image"], + "outputs": ["text_encoder_output"] + }, + "image_encoder": { + "inputs": ["input_image"], + "outputs": ["clip_encoder_output"] + }, + "vae_encoder": { + "inputs": ["input_image"], + "outputs": ["vae_encoder_output"] + }, + "dit": { + "inputs": [ + "clip_encoder_output", + "vae_encoder_output", + "text_encoder_output" + ], + "outputs": ["latents"] + }, + "vae_decoder": { + "inputs": ["latents"], + "outputs": ["output_video"] + } + } + }, + "matrix-game2-gta-drive": { + "single_stage": { + "pipeline": { + "inputs": ["input_image"], + "outputs": ["output_video"] + } + } + }, + "matrix-game2-universal": { + "single_stage": { + "pipeline": { + "inputs": ["input_image"], + "outputs": ["output_video"] + } + } + }, + "matrix-game2-templerun": { + "single_stage": { + "pipeline": { + "inputs": ["input_image"], + "outputs": ["output_video"] + } + } + } + }, + "s2v": { + "SekoTalk": { + "single_stage": { + "pipeline": { + "inputs": ["input_image", "input_audio"], + "outputs": ["output_video"] + } + }, + "multi_stage": { + "text_encoder": { + "inputs": ["input_image"], + "outputs": ["text_encoder_output"] + }, + "image_encoder": { + "inputs": ["input_image"], + "outputs": ["clip_encoder_output"] + }, + "vae_encoder": { + "inputs": ["input_image"], + "outputs": ["vae_encoder_output"] + }, + "segment_dit": { + "inputs": [ + "input_audio", + "clip_encoder_output", + "vae_encoder_output", + "text_encoder_output" + ], + "outputs": ["output_video"] + } + } + } + }, + "animate": { + "wan2.2_animate": { + "single_stage": { + "pipeline": { + "inputs": ["input_image","input_video"], + "outputs": ["output_video"] + } + } + } + } + + }, + "meta": { + "special_types": { + "input_image": "IMAGE", + "input_audio": "AUDIO", + "input_video": "VIDEO", + "latents": "TENSOR", + "output_video": "VIDEO" + }, + "model_name_inner_to_outer": { + "seko_talk": "SekoTalk" + }, + "model_name_outer_to_inner": {}, + "monitor": { + "subtask_created_timeout": 1800, + "subtask_pending_timeout": 1800, + "subtask_running_timeouts": { + "t2v-wan2.1-1.3B-multi_stage-dit": 300, + "t2v-wan2.1-1.3B-single_stage-pipeline": 300, + "t2v-self-forcing-dmd-single_stage-pipeline": 300, + "i2v-wan2.1-14B-480P-multi_stage-dit": 600, + "i2v-wan2.1-14B-480P-single_stage-pipeline": 600, + "i2v-SekoTalk-Distill-single_stage-pipeline": 3600, + "i2v-SekoTalk-Distill-multi_stage-segment_dit": 3600 + }, + "worker_avg_window": 20, + "worker_offline_timeout": 5, + "worker_min_capacity": 20, + "worker_min_cnt": 1, + "worker_max_cnt": 10, + "task_timeout": 3600, + "schedule_ratio_high": 0.25, + "schedule_ratio_low": 0.02, + "ping_timeout": 30, + "user_max_active_tasks": 3, + "user_max_daily_tasks": 100, + "user_visit_frequency": 0.05 + } + } +} diff --git a/configs/offload/block/qwen_image_i2i_2509_block.json b/configs/offload/block/qwen_image_i2i_2509_block.json new file mode 100644 index 0000000000000000000000000000000000000000..7859a0a71e918004a639dfcd4efc68eda4a85b67 --- /dev/null +++ b/configs/offload/block/qwen_image_i2i_2509_block.json @@ -0,0 +1,66 @@ +{ + "batchsize": 1, + "num_channels_latents": 16, + "vae_scale_factor": 8, + "infer_steps": 40, + "guidance_embeds": false, + "num_images_per_prompt": 1, + "vae_latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "vae_latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "vae_z_dim": 16, + "feature_caching": "NoCaching", + "transformer_in_channels": 64, + "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 64, + "_auto_resize": true, + "num_layers": 60, + "attention_out_dim": 3072, + "attention_dim_head": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", + "attn_type": "sage_attn2", + "do_true_cfg": true, + "true_cfg_scale": 4.0, + "cpu_offload": true, + "offload_granularity": "block", + "CONDITION_IMAGE_SIZE": 147456, + "USE_IMAGE_ID_IN_PROMPT": true +} diff --git a/configs/offload/block/qwen_image_i2i_block.json b/configs/offload/block/qwen_image_i2i_block.json new file mode 100644 index 0000000000000000000000000000000000000000..7153b4a37070914f5d73c4f3ab4a87cdcd988949 --- /dev/null +++ b/configs/offload/block/qwen_image_i2i_block.json @@ -0,0 +1,66 @@ +{ + "batchsize": 1, + "num_channels_latents": 16, + "vae_scale_factor": 8, + "infer_steps": 50, + "guidance_embeds": false, + "num_images_per_prompt": 1, + "vae_latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "vae_latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "vae_z_dim": 16, + "feature_caching": "NoCaching", + "transformer_in_channels": 64, + "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 64, + "_auto_resize": true, + "num_layers": 60, + "attention_out_dim": 3072, + "attention_dim_head": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", + "attn_type": "flash_attn3", + "do_true_cfg": true, + "true_cfg_scale": 4.0, + "cpu_offload": true, + "offload_granularity": "block", + "CONDITION_IMAGE_SIZE": 1048576, + "USE_IMAGE_ID_IN_PROMPT": false +} diff --git a/configs/offload/block/qwen_image_t2i_block.json b/configs/offload/block/qwen_image_t2i_block.json new file mode 100644 index 0000000000000000000000000000000000000000..5b6313fcad6338c255398ed94932b0de01489362 --- /dev/null +++ b/configs/offload/block/qwen_image_t2i_block.json @@ -0,0 +1,86 @@ +{ + "batchsize": 1, + "_comment": "格式: '宽高比': [width, height]", + "aspect_ratios": { + "1:1": [ + 1328, + 1328 + ], + "16:9": [ + 1664, + 928 + ], + "9:16": [ + 928, + 1664 + ], + "4:3": [ + 1472, + 1140 + ], + "3:4": [ + 142, + 184 + ] + }, + "aspect_ratio": "16:9", + "num_channels_latents": 16, + "vae_scale_factor": 8, + "infer_steps": 50, + "guidance_embeds": false, + "num_images_per_prompt": 1, + "vae_latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "vae_latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "vae_z_dim": 16, + "feature_caching": "NoCaching", + "prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 34, + "_auto_resize": false, + "num_layers": 60, + "attention_out_dim": 3072, + "attention_dim_head": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", + "attn_type": "flash_attn3", + "do_true_cfg": false, + "cpu_offload": true, + "offload_granularity": "block" +} diff --git a/configs/offload/block/wan_i2v_block.json b/configs/offload/block/wan_i2v_block.json new file mode 100644 index 0000000000000000000000000000000000000000..4363093b7e5270d4d7bafbce10fe3b2546900001 --- /dev/null +++ b/configs/offload/block/wan_i2v_block.json @@ -0,0 +1,23 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f", + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "clip_cpu_offload": false +} diff --git a/configs/offload/block/wan_t2v_1_3b.json b/configs/offload/block/wan_t2v_1_3b.json new file mode 100644 index 0000000000000000000000000000000000000000..e03f8a7e776510bac40ae743004eb3b416a19125 --- /dev/null +++ b/configs/offload/block/wan_t2v_1_3b.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "t5_cpu_offload": true, + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl", + "unload_modules": false, + "use_tiling_vae": false +} diff --git a/configs/offload/block/wan_t2v_block.json b/configs/offload/block/wan_t2v_block.json new file mode 100644 index 0000000000000000000000000000000000000000..2f926bce9b84b77afd81ae3df585e3a3563346f1 --- /dev/null +++ b/configs/offload/block/wan_t2v_block.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f", + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "clip_cpu_offload": false +} diff --git a/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json b/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json new file mode 100644 index 0000000000000000000000000000000000000000..fa0348eed97babb7b09e4f52ca1a205d628b32f6 --- /dev/null +++ b/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json @@ -0,0 +1,28 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "phase", + "dit_quantized_ckpt": "/path/to/dit_quant_model", + "dit_quantized": true, + "dit_quant_scheme": "fp8-vllm", + "t5_cpu_offload": true, + "t5_quantized": true, + "t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth", + "t5_quant_scheme": "fp8", + "clip_quantized": true, + "clip_quantized_ckpt": "/path/to/clip-fp8.pth", + "clip_quant_scheme": "fp8", + "use_tiling_vae": true, + "use_tae": true, + "tae_path": "/path/to/taew2_1.pth", + "lazy_load": true +} diff --git a/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json b/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json new file mode 100644 index 0000000000000000000000000000000000000000..0516a9572af23f0dd91013fc9b011b5343e80731 --- /dev/null +++ b/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json @@ -0,0 +1,30 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 1280, + "target_width": 720, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "phase", + "dit_quantized_ckpt": "/path/to/dit_quant_model", + "dit_quantized": true, + "dit_quant_scheme": "fp8-vllm", + "t5_cpu_offload": true, + "t5_quantized": true, + "t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth", + "t5_quant_scheme": "fp8", + "clip_quantized": true, + "clip_quantized_ckpt": "/path/to/clip-fp8.pth", + "clip_quant_scheme": "fp8", + "use_tiling_vae": true, + "use_tae": true, + "tae_path": "/path/to/taew2_1.pth", + "lazy_load": true, + "rotary_chunk": true, + "clean_cuda_cache": true +} diff --git a/configs/offload/phase/wan_i2v_phase.json b/configs/offload/phase/wan_i2v_phase.json new file mode 100644 index 0000000000000000000000000000000000000000..c5bd698bd4345f99ed4d9f721fff98fd394c8311 --- /dev/null +++ b/configs/offload/phase/wan_i2v_phase.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "phase", + "t5_cpu_offload": false, + "clip_cpu_offload": false, + "vae_cpu_offload": false, + "use_tiling_vae": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f" +} diff --git a/configs/offload/phase/wan_t2v_phase.json b/configs/offload/phase/wan_t2v_phase.json new file mode 100644 index 0000000000000000000000000000000000000000..b12036aa844b36b349f17b14ad7d0129faabcc45 --- /dev/null +++ b/configs/offload/phase/wan_t2v_phase.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "phase", + "t5_cpu_offload": false, + "clip_cpu_offload": false, + "vae_cpu_offload": false, + "use_tiling_vae": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f" +} diff --git a/configs/quantization/gguf/wan_i2v_q4_k.json b/configs/quantization/gguf/wan_i2v_q4_k.json new file mode 100644 index 0000000000000000000000000000000000000000..e92f6b216c83dafca6e4e404851265f552d5ae81 --- /dev/null +++ b/configs/quantization/gguf/wan_i2v_q4_k.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "dit_quantized": true, + "dit_quant_scheme": "gguf-Q4_K_S" +} diff --git a/configs/quantization/wan_i2v.json b/configs/quantization/wan_i2v.json new file mode 100644 index 0000000000000000000000000000000000000000..7624495e62fc56b9824d09303a91668c3b959eed --- /dev/null +++ b/configs/quantization/wan_i2v.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "dit_quantized_ckpt": "/path/to/int8/model", + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm" +} diff --git a/configs/quantization/wan_i2v_q8f.json b/configs/quantization/wan_i2v_q8f.json new file mode 100644 index 0000000000000000000000000000000000000000..15c5b1245e88dc225f6c5e3e8f3d484e2ddb26ce --- /dev/null +++ b/configs/quantization/wan_i2v_q8f.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "int8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "int8-q8f" +} diff --git a/configs/quantization/wan_i2v_torchao.json b/configs/quantization/wan_i2v_torchao.json new file mode 100644 index 0000000000000000000000000000000000000000..ca561b87cc079e6c43a7b2d11f38c8c92cc485e6 --- /dev/null +++ b/configs/quantization/wan_i2v_torchao.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-torchao", + "t5_quantized": true, + "t5_quant_scheme": "int8-torchao", + "clip_quantized": true, + "clip_quant_scheme": "int8-torchao" +} diff --git a/configs/qwen_image/qwen_image_i2i.json b/configs/qwen_image/qwen_image_i2i.json new file mode 100644 index 0000000000000000000000000000000000000000..0e0fd3881b53caeab6abc17b903f7241fffbfee2 --- /dev/null +++ b/configs/qwen_image/qwen_image_i2i.json @@ -0,0 +1,64 @@ +{ + "batchsize": 1, + "num_channels_latents": 16, + "vae_scale_factor": 8, + "infer_steps": 50, + "guidance_embeds": false, + "num_images_per_prompt": 1, + "vae_latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "vae_latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "vae_z_dim": 16, + "feature_caching": "NoCaching", + "transformer_in_channels": 64, + "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 64, + "_auto_resize": true, + "num_layers": 60, + "attention_out_dim": 3072, + "attention_dim_head": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", + "attn_type": "flash_attn3", + "do_true_cfg": true, + "true_cfg_scale": 4.0, + "CONDITION_IMAGE_SIZE": 1048576, + "USE_IMAGE_ID_IN_PROMPT": false +} diff --git a/configs/qwen_image/qwen_image_i2i_2509.json b/configs/qwen_image/qwen_image_i2i_2509.json new file mode 100644 index 0000000000000000000000000000000000000000..974988e428650ab9320852bd4edbb9f5551e092b --- /dev/null +++ b/configs/qwen_image/qwen_image_i2i_2509.json @@ -0,0 +1,64 @@ +{ + "batchsize": 1, + "num_channels_latents": 16, + "vae_scale_factor": 8, + "infer_steps": 40, + "guidance_embeds": false, + "num_images_per_prompt": 1, + "vae_latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "vae_latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "vae_z_dim": 16, + "feature_caching": "NoCaching", + "transformer_in_channels": 64, + "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 64, + "_auto_resize": true, + "num_layers": 60, + "attention_out_dim": 3072, + "attention_dim_head": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", + "attn_type": "flash_attn3", + "do_true_cfg": true, + "true_cfg_scale": 4.0, + "CONDITION_IMAGE_SIZE": 147456, + "USE_IMAGE_ID_IN_PROMPT": true +} diff --git a/configs/qwen_image/qwen_image_i2i_2509_quant.json b/configs/qwen_image/qwen_image_i2i_2509_quant.json new file mode 100644 index 0000000000000000000000000000000000000000..c6e73581288181269c5da3a917190d591db8986a --- /dev/null +++ b/configs/qwen_image/qwen_image_i2i_2509_quant.json @@ -0,0 +1,67 @@ +{ + "batchsize": 1, + "num_channels_latents": 16, + "vae_scale_factor": 8, + "infer_steps": 40, + "guidance_embeds": false, + "num_images_per_prompt": 1, + "vae_latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "vae_latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "vae_z_dim": 16, + "feature_caching": "NoCaching", + "transformer_in_channels": 64, + "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 64, + "_auto_resize": true, + "num_layers": 60, + "attention_out_dim": 3072, + "attention_dim_head": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", + "attn_type": "flash_attn3", + "do_true_cfg": true, + "true_cfg_scale": 4.0, + "CONDITION_IMAGE_SIZE": 147456, + "USE_IMAGE_ID_IN_PROMPT": true, + "dit_quantized": true, + "dit_quantized_ckpt": "/path/to/qwen_2509_fp8.safetensors", + "dit_quant_scheme": "fp8-sgl" +} diff --git a/configs/qwen_image/qwen_image_i2i_lora.json b/configs/qwen_image/qwen_image_i2i_lora.json new file mode 100644 index 0000000000000000000000000000000000000000..b0bbc49282aef7b54fe32d09729b495ae697b5aa --- /dev/null +++ b/configs/qwen_image/qwen_image_i2i_lora.json @@ -0,0 +1,70 @@ +{ + "batchsize": 1, + "num_channels_latents": 16, + "vae_scale_factor": 8, + "infer_steps": 8, + "guidance_embeds": false, + "num_images_per_prompt": 1, + "vae_latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "vae_latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "vae_z_dim": 16, + "feature_caching": "NoCaching", + "transformer_in_channels": 64, + "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 64, + "_auto_resize": true, + "num_layers": 60, + "attention_out_dim": 3072, + "attention_dim_head": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", + "attn_type": "flash_attn3", + "do_true_cfg": true, + "true_cfg_scale": 4.0, + "CONDITION_IMAGE_SIZE": 1048576, + "USE_IMAGE_ID_IN_PROMPT": false, + "lora_configs": [ + { + "path": "/path/to/Qwen-Image-Edit-Lightning-4steps-V1.0.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/qwen_image/qwen_image_t2i.json b/configs/qwen_image/qwen_image_t2i.json new file mode 100644 index 0000000000000000000000000000000000000000..5581d50a2d5637396485114385b5dc40e9568461 --- /dev/null +++ b/configs/qwen_image/qwen_image_t2i.json @@ -0,0 +1,85 @@ +{ + "batchsize": 1, + "_comment": "格式: '宽高比': [width, height]", + "aspect_ratios": { + "1:1": [ + 1328, + 1328 + ], + "16:9": [ + 1664, + 928 + ], + "9:16": [ + 928, + 1664 + ], + "4:3": [ + 1472, + 1140 + ], + "3:4": [ + 142, + 184 + ] + }, + "aspect_ratio": "16:9", + "num_channels_latents": 16, + "vae_scale_factor": 8, + "infer_steps": 50, + "guidance_embeds": false, + "num_images_per_prompt": 1, + "vae_latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "vae_latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "vae_z_dim": 16, + "feature_caching": "NoCaching", + "prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 34, + "_auto_resize": false, + "num_layers": 60, + "attention_out_dim": 3072, + "attention_dim_head": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", + "attn_type": "flash_attn3", + "do_true_cfg": true, + "true_cfg_scale": 4.0 +} diff --git a/configs/seko_talk/5090/seko_talk_5090_bf16.json b/configs/seko_talk/5090/seko_talk_5090_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..a6cab78438410e463f554427fb9c6653d19875be --- /dev/null +++ b/configs/seko_talk/5090/seko_talk_5090_bf16.json @@ -0,0 +1,23 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn3", + "cross_attn_1_type": "sage_attn3", + "cross_attn_2_type": "sage_attn3", + "sample_guide_scale": 1, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 1, + "t5_cpu_offload": true, + "clip_cpu_offload": false, + "audio_encoder_cpu_offload": false, + "audio_adapter_cpu_offload": false, + "vae_cpu_offload": false +} diff --git a/configs/seko_talk/5090/seko_talk_5090_int8.json b/configs/seko_talk/5090/seko_talk_5090_int8.json new file mode 100644 index 0000000000000000000000000000000000000000..664983b908669400d34f1ff8bcb14683a6e1d1c9 --- /dev/null +++ b/configs/seko_talk/5090/seko_talk_5090_int8.json @@ -0,0 +1,29 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn3", + "cross_attn_1_type": "sage_attn3", + "cross_attn_2_type": "sage_attn3", + "sample_guide_scale": 1, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 1, + "t5_cpu_offload": false, + "clip_cpu_offload": false, + "audio_encoder_cpu_offload": false, + "audio_adapter_cpu_offload": false, + "vae_cpu_offload": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "adapter_quantized": true, + "adapter_quant_scheme": "int8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "int8-q8f" +} diff --git a/configs/seko_talk/5090/seko_talk_5090_int8_8gpu.json b/configs/seko_talk/5090/seko_talk_5090_int8_8gpu.json new file mode 100644 index 0000000000000000000000000000000000000000..3dfc9d533204841113ddd21db22f9095f34a7ae7 --- /dev/null +++ b/configs/seko_talk/5090/seko_talk_5090_int8_8gpu.json @@ -0,0 +1,34 @@ + +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn3", + "cross_attn_1_type": "sage_attn3", + "cross_attn_2_type": "sage_attn3", + "sample_guide_scale": 1, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 1, + "t5_cpu_offload": false, + "clip_cpu_offload": false, + "audio_encoder_cpu_offload": false, + "audio_adapter_cpu_offload": false, + "vae_cpu_offload": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "adapter_quantized": true, + "adapter_quant_scheme": "int8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "int8-q8f", + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses-4090" + } +} diff --git a/configs/seko_talk/A800/seko_talk_A800_int8.json b/configs/seko_talk/A800/seko_talk_A800_int8.json new file mode 100644 index 0000000000000000000000000000000000000000..a5da24c2a907ea377a3403d4fb14b826814d07cf --- /dev/null +++ b/configs/seko_talk/A800/seko_talk_A800_int8.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm", + "adapter_quantized": true, + "adapter_quant_scheme": "int8-vllm", + "t5_quantized": true, + "t5_quant_scheme": "int8-vllm" +} diff --git a/configs/seko_talk/A800/seko_talk_A800_int8_dist_2gpu.json b/configs/seko_talk/A800/seko_talk_A800_int8_dist_2gpu.json new file mode 100644 index 0000000000000000000000000000000000000000..f99ba12190fcc7dbc58b4dd9d60dacfa7f66e970 --- /dev/null +++ b/configs/seko_talk/A800/seko_talk_A800_int8_dist_2gpu.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm", + "adapter_quantized": true, + "adapter_quant_scheme": "int8-vllm", + "t5_quantized": true, + "t5_quant_scheme": "int8-vllm", + "parallel": { + "seq_p_size": 2, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/A800/seko_talk_A800_int8_dist_4gpu.json b/configs/seko_talk/A800/seko_talk_A800_int8_dist_4gpu.json new file mode 100644 index 0000000000000000000000000000000000000000..c0d0c45df8bf7061170987f1815047c38ec72ec0 --- /dev/null +++ b/configs/seko_talk/A800/seko_talk_A800_int8_dist_4gpu.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm", + "adapter_quantized": true, + "adapter_quant_scheme": "int8-vllm", + "t5_quantized": true, + "t5_quant_scheme": "int8-vllm", + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/A800/seko_talk_A800_int8_dist_8gpu.json b/configs/seko_talk/A800/seko_talk_A800_int8_dist_8gpu.json new file mode 100644 index 0000000000000000000000000000000000000000..278d37664a0e79c0b76492c1e03982cdb291dd68 --- /dev/null +++ b/configs/seko_talk/A800/seko_talk_A800_int8_dist_8gpu.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm", + "adapter_quantized": true, + "adapter_quant_scheme": "int8-vllm", + "t5_quantized": true, + "t5_quant_scheme": "int8-vllm", + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/L40s/1gpu/seko_talk_bf16.json b/configs/seko_talk/L40s/1gpu/seko_talk_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..fd1afa4d23e8d77779d02481c081fb2ab6d16a87 --- /dev/null +++ b/configs/seko_talk/L40s/1gpu/seko_talk_bf16.json @@ -0,0 +1,23 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 0.8, + "t5_cpu_offload": false, + "clip_cpu_offload": false, + "vae_cpu_offload": false, + "audio_encoder_cpu_offload": false, + "audio_adapter_cpu_offload": false +} diff --git a/configs/seko_talk/L40s/1gpu/seko_talk_fp8.json b/configs/seko_talk/L40s/1gpu/seko_talk_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..d88f4542282f4e1e2b3fc2b37c1f37d3cfbf8e60 --- /dev/null +++ b/configs/seko_talk/L40s/1gpu/seko_talk_fp8.json @@ -0,0 +1,27 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8", + "cpu_offload": false, + "t5_cpu_offload": true, + "clip_cpu_offload": true, + "vae_cpu_offload": true, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true +} diff --git a/configs/seko_talk/L40s/2gpu/seko_talk_bf16.json b/configs/seko_talk/L40s/2gpu/seko_talk_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..4cd8a3f099a5a887ad9d500224a22ef24c396740 --- /dev/null +++ b/configs/seko_talk/L40s/2gpu/seko_talk_bf16.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": false, + "t5_cpu_offload": true, + "clip_cpu_offload": true, + "vae_cpu_offload": true, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, + "parallel": { + "seq_p_size": 2, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/L40s/2gpu/seko_talk_fp8.json b/configs/seko_talk/L40s/2gpu/seko_talk_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..b45e87f0d8786fc3797dcf21c4e0fd7562f6fb08 --- /dev/null +++ b/configs/seko_talk/L40s/2gpu/seko_talk_fp8.json @@ -0,0 +1,31 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8", + "cpu_offload": false, + "t5_cpu_offload": true, + "clip_cpu_offload": true, + "vae_cpu_offload": true, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, + "parallel": { + "seq_p_size": 2, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/L40s/4gpu/seko_talk_bf16.json b/configs/seko_talk/L40s/4gpu/seko_talk_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..a3f7b6fb3b38f135dc2f2ead5514f3eebfac2526 --- /dev/null +++ b/configs/seko_talk/L40s/4gpu/seko_talk_bf16.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": false, + "t5_cpu_offload": true, + "clip_cpu_offload": true, + "vae_cpu_offload": true, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/L40s/4gpu/seko_talk_fp8.json b/configs/seko_talk/L40s/4gpu/seko_talk_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..ec682a6e0a48c7be36c4769435d2d3a004cd8f5d --- /dev/null +++ b/configs/seko_talk/L40s/4gpu/seko_talk_fp8.json @@ -0,0 +1,31 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8", + "cpu_offload": false, + "t5_cpu_offload": true, + "clip_cpu_offload": true, + "vae_cpu_offload": true, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/L40s/8gpu/seko_talk_bf16.json b/configs/seko_talk/L40s/8gpu/seko_talk_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..079439fe1b3fee28435fecb837f3638145814422 --- /dev/null +++ b/configs/seko_talk/L40s/8gpu/seko_talk_bf16.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": false, + "t5_cpu_offload": true, + "clip_cpu_offload": true, + "vae_cpu_offload": true, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/L40s/8gpu/seko_talk_fp8.json b/configs/seko_talk/L40s/8gpu/seko_talk_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..d50f7ebb896e29d762830b20e67d0c14a30c318e --- /dev/null +++ b/configs/seko_talk/L40s/8gpu/seko_talk_fp8.json @@ -0,0 +1,31 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8", + "cpu_offload": false, + "t5_cpu_offload": true, + "clip_cpu_offload": true, + "vae_cpu_offload": true, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/mlu/seko_talk_bf16.json b/configs/seko_talk/mlu/seko_talk_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..71741879328b0e9425bc00a9fc3c0c339791dc9e --- /dev/null +++ b/configs/seko_talk/mlu/seko_talk_bf16.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "mlu_sage_attn", + "cross_attn_1_type": "mlu_sage_attn", + "cross_attn_2_type": "mlu_sage_attn", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "rope_type": "torch", + "modulate_type": "torch" +} diff --git a/configs/seko_talk/mlu/seko_talk_int8.json b/configs/seko_talk/mlu/seko_talk_int8.json new file mode 100644 index 0000000000000000000000000000000000000000..673acf8619199f07eb251a5f9a5a08296f07e2d3 --- /dev/null +++ b/configs/seko_talk/mlu/seko_talk_int8.json @@ -0,0 +1,29 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "mlu_sage_attn", + "cross_attn_1_type": "mlu_sage_attn", + "cross_attn_2_type": "mlu_sage_attn", + "seed": 42, + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "clip_quantized": false, + "clip_quant_scheme": "int8-tmo", + "dit_quantized": true, + "dit_quant_scheme": "int8-tmo", + "adapter_quantized": true, + "adapter_quant_scheme": "int8-tmo", + "t5_quantized": true, + "t5_quant_scheme": "int8-tmo", + "modulate_type": "torch", + "rope_type": "torch", + "ln_type": "Default", + "rms_type": "Default" +} diff --git a/configs/seko_talk/mlu/seko_talk_int8_dist.json b/configs/seko_talk/mlu/seko_talk_int8_dist.json new file mode 100644 index 0000000000000000000000000000000000000000..fe5f515260e81d5f1d444be919afcd8f04b1493f --- /dev/null +++ b/configs/seko_talk/mlu/seko_talk_int8_dist.json @@ -0,0 +1,33 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "mlu_sage_attn", + "cross_attn_1_type": "mlu_sage_attn", + "cross_attn_2_type": "mlu_sage_attn", + "seed": 42, + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "clip_quantized": false, + "clip_quant_scheme": "int8-tmo", + "dit_quantized": true, + "dit_quant_scheme": "int8-tmo", + "adapter_quantized": true, + "adapter_quant_scheme": "int8-tmo", + "t5_quantized": true, + "t5_quant_scheme": "int8-tmo", + "modulate_type": "torch", + "rope_type": "torch", + "ln_type": "Default", + "rms_type": "Default", + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/multi_person/01_base.json b/configs/seko_talk/multi_person/01_base.json new file mode 100644 index 0000000000000000000000000000000000000000..417f62ce1caa0175ba79a13f76634a7224b0a621 --- /dev/null +++ b/configs/seko_talk/multi_person/01_base.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false +} diff --git a/configs/seko_talk/multi_person/02_base_fp8.json b/configs/seko_talk/multi_person/02_base_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..232b3987ea72a1979dd1d3af77252006847a5644 --- /dev/null +++ b/configs/seko_talk/multi_person/02_base_fp8.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8", + "t5_quantized": true, + "t5_quant_scheme": "fp8" +} diff --git a/configs/seko_talk/multi_person/03_dist.json b/configs/seko_talk/multi_person/03_dist.json new file mode 100644 index 0000000000000000000000000000000000000000..5cc4cbd6e57a42c61b4683ed63e3d20ebdc237eb --- /dev/null +++ b/configs/seko_talk/multi_person/03_dist.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/multi_person/04_dist_fp8.json b/configs/seko_talk/multi_person/04_dist_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..f91b3fd57af5084d443a207607f59d18044609b3 --- /dev/null +++ b/configs/seko_talk/multi_person/04_dist_fp8.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8", + "t5_quantized": true, + "t5_quant_scheme": "fp8", + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/multi_person/15_base_compile.json b/configs/seko_talk/multi_person/15_base_compile.json new file mode 100644 index 0000000000000000000000000000000000000000..73cb92aacc0ecf1556d67c49a2fc8342777f3b37 --- /dev/null +++ b/configs/seko_talk/multi_person/15_base_compile.json @@ -0,0 +1,27 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "compile": true, + "compile_shapes": [ + [ + 480, + 832 + ], + [ + 720, + 1280 + ] + ] +} diff --git a/configs/seko_talk/seko_talk_01_base.json b/configs/seko_talk/seko_talk_01_base.json new file mode 100644 index 0000000000000000000000000000000000000000..417f62ce1caa0175ba79a13f76634a7224b0a621 --- /dev/null +++ b/configs/seko_talk/seko_talk_01_base.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false +} diff --git a/configs/seko_talk/seko_talk_02_fp8.json b/configs/seko_talk/seko_talk_02_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..15cdbc046e57090cb3265f0b9489806e55438d21 --- /dev/null +++ b/configs/seko_talk/seko_talk_02_fp8.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "clip_quantized": true, + "clip_quant_scheme": "fp8-sgl", + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_03_dist.json b/configs/seko_talk/seko_talk_03_dist.json new file mode 100644 index 0000000000000000000000000000000000000000..f7543370eac394efa72ed618ff65bd4389c62ffe --- /dev/null +++ b/configs/seko_talk/seko_talk_03_dist.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/seko_talk/seko_talk_04_fp8_dist.json b/configs/seko_talk/seko_talk_04_fp8_dist.json new file mode 100644 index 0000000000000000000000000000000000000000..443295169fec8be2aaaa07e89f898092db61bb6d --- /dev/null +++ b/configs/seko_talk/seko_talk_04_fp8_dist.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + }, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_05_offload_fp8_4090.json b/configs/seko_talk/seko_talk_05_offload_fp8_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..96f8629d1c3a2e7614ed54eb51b0a8171784743a --- /dev/null +++ b/configs/seko_talk/seko_talk_05_offload_fp8_4090.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 1, + "t5_cpu_offload": false, + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_cpu_offload": false, + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f", + "audio_encoder_cpu_offload": false, + "audio_adapter_cpu_offload": false, + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-q8f", + "vae_cpu_offload": false, + "use_tiling_vae": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f" +} diff --git a/configs/seko_talk/seko_talk_05_offload_fp8_4090_dist.json b/configs/seko_talk/seko_talk_05_offload_fp8_4090_dist.json new file mode 100644 index 0000000000000000000000000000000000000000..a2d66f64d4c78b4b131960624780c83f8ed4fa22 --- /dev/null +++ b/configs/seko_talk/seko_talk_05_offload_fp8_4090_dist.json @@ -0,0 +1,36 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 1, + "t5_cpu_offload": false, + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_cpu_offload": false, + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f", + "audio_encoder_cpu_offload": false, + "audio_adapter_cpu_offload": false, + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-q8f", + "vae_cpu_offload": false, + "use_tiling_vae": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses-4090" + } +} diff --git a/configs/seko_talk/seko_talk_06_offload_fp8_H100.json b/configs/seko_talk/seko_talk_06_offload_fp8_H100.json new file mode 100644 index 0000000000000000000000000000000000000000..d5a5ae9b3948234a362576a886ead22a2e0a43df --- /dev/null +++ b/configs/seko_talk/seko_talk_06_offload_fp8_H100.json @@ -0,0 +1,30 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": true, + "offload_granularity": "block", + "offload_ratio": 1, + "t5_cpu_offload": true, + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl", + "clip_cpu_offload": false, + "audio_encoder_cpu_offload": false, + "audio_adapter_cpu_offload": false, + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "vae_cpu_offload": false, + "use_tiling_vae": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_07_dist_offload.json b/configs/seko_talk/seko_talk_07_dist_offload.json new file mode 100644 index 0000000000000000000000000000000000000000..b6a64eaa5c67e2b1c83b6cb0287b85496a9816bf --- /dev/null +++ b/configs/seko_talk/seko_talk_07_dist_offload.json @@ -0,0 +1,28 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + }, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": true, + "clip_cpu_offload": false, + "vae_cpu_offload": false, + "offload_ratio": 1, + "use_tiling_vae": true, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": false +} diff --git a/configs/seko_talk/seko_talk_08_5B_base.json b/configs/seko_talk/seko_talk_08_5B_base.json new file mode 100644 index 0000000000000000000000000000000000000000..855670fd5827e2ad0e39c8f557b28e2f444d01df --- /dev/null +++ b/configs/seko_talk/seko_talk_08_5B_base.json @@ -0,0 +1,31 @@ +{ + "infer_steps": 4, + "target_fps": 24, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 121, + "resize_mode": "adaptive", + "text_len": 512, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": false, + "fps": 24, + "use_image_encoder": false, + "use_31_block": false, + "lora_configs": [ + { + "path": "/mnt/aigc/rtxiang/pretrain/qianhai_weights/lora_model.safetensors", + "strength": 0.125 + } + ] +} diff --git a/configs/seko_talk/seko_talk_09_base_fixed_min_area.json b/configs/seko_talk/seko_talk_09_base_fixed_min_area.json new file mode 100644 index 0000000000000000000000000000000000000000..2b04ae5ee3c1c4f6a0b54a3c8f9044e05cd2ea0e --- /dev/null +++ b/configs/seko_talk/seko_talk_09_base_fixed_min_area.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "fixed_min_area", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false +} diff --git a/configs/seko_talk/seko_talk_10_fp8_dist_fixed_min_area.json b/configs/seko_talk/seko_talk_10_fp8_dist_fixed_min_area.json new file mode 100644 index 0000000000000000000000000000000000000000..27f8afe412df886de87aff1dbcd1a9bb465a1037 --- /dev/null +++ b/configs/seko_talk/seko_talk_10_fp8_dist_fixed_min_area.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "fixed_min_area", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + }, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_11_fp8_dist_fixed_shape.json b/configs/seko_talk/seko_talk_11_fp8_dist_fixed_shape.json new file mode 100644 index 0000000000000000000000000000000000000000..0ea9de6f852dd2b09cce33976b77181a9de949c6 --- /dev/null +++ b/configs/seko_talk/seko_talk_11_fp8_dist_fixed_shape.json @@ -0,0 +1,30 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "fixed_shape", + "fixed_shape": [ + 240, + 320 + ], + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + }, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.json b/configs/seko_talk/seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.json new file mode 100644 index 0000000000000000000000000000000000000000..6952e86f1e592926ee0546af99efc82ffad678c0 --- /dev/null +++ b/configs/seko_talk/seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.json @@ -0,0 +1,31 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 17, + "prev_frame_length": 1, + "resize_mode": "fixed_shape", + "fixed_shape": [ + 480, + 480 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + }, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.json b/configs/seko_talk/seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.json new file mode 100644 index 0000000000000000000000000000000000000000..7c9abea232a2894375a0a8186d0881df8864c2ab --- /dev/null +++ b/configs/seko_talk/seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.json @@ -0,0 +1,62 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "bucket_shape": { + "0.667": [ + [ + 480, + 832 + ], + [ + 544, + 960 + ] + ], + "1.500": [ + [ + 832, + 480 + ], + [ + 960, + 544 + ] + ], + "1.000": [ + [ + 480, + 480 + ], + [ + 576, + 576 + ], + [ + 704, + 704 + ] + ] + }, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + }, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.json b/configs/seko_talk/seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.json new file mode 100644 index 0000000000000000000000000000000000000000..9ad4fd7330a0f71d17c9100ca54cb96430fb0cd8 --- /dev/null +++ b/configs/seko_talk/seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.json @@ -0,0 +1,63 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 17, + "prev_frame_length": 1, + "resize_mode": "adaptive", + "bucket_shape": { + "0.667": [ + [ + 480, + 832 + ], + [ + 544, + 960 + ] + ], + "1.500": [ + [ + 832, + 480 + ], + [ + 960, + 544 + ] + ], + "1.000": [ + [ + 480, + 480 + ], + [ + 576, + 576 + ], + [ + 704, + 704 + ] + ] + }, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + }, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_15_base_compile.json b/configs/seko_talk/seko_talk_15_base_compile.json new file mode 100644 index 0000000000000000000000000000000000000000..19547bc498b76474a1904ef6a8c095143153c795 --- /dev/null +++ b/configs/seko_talk/seko_talk_15_base_compile.json @@ -0,0 +1,59 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "compile": true, + "compile_shapes": [ + [ + 480, + 832 + ], + [ + 544, + 960 + ], + [ + 720, + 1280 + ], + [ + 832, + 480 + ], + [ + 960, + 544 + ], + [ + 1280, + 720 + ], + [ + 480, + 480 + ], + [ + 576, + 576 + ], + [ + 704, + 704 + ], + [ + 960, + 960 + ] + ] +} diff --git a/configs/seko_talk/seko_talk_16_fp8_dist_compile.json b/configs/seko_talk/seko_talk_16_fp8_dist_compile.json new file mode 100644 index 0000000000000000000000000000000000000000..f2ef9e412c96358f7343ac3a938fe77a06658cfc --- /dev/null +++ b/configs/seko_talk/seko_talk_16_fp8_dist_compile.json @@ -0,0 +1,71 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + }, + "clip_quantized": true, + "clip_quant_scheme": "fp8-sgl", + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl", + "compile": true, + "compile_shapes": [ + [ + 480, + 832 + ], + [ + 544, + 960 + ], + [ + 720, + 1280 + ], + [ + 832, + 480 + ], + [ + 960, + 544 + ], + [ + 1280, + 720 + ], + [ + 480, + 480 + ], + [ + 576, + 576 + ], + [ + 704, + 704 + ], + [ + 960, + 960 + ] + ] +} diff --git a/configs/seko_talk/seko_talk_17_base_vsr.json b/configs/seko_talk/seko_talk_17_base_vsr.json new file mode 100644 index 0000000000000000000000000000000000000000..7f39cc56842a2d68120419afdb67cef7ad5df105 --- /dev/null +++ b/configs/seko_talk/seko_talk_17_base_vsr.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 2, + "target_fps": 25, + "video_duration": 1, + "audio_sr": 16000, + "target_video_length": 25, + "resize_mode": "fixed_shape", + "fixed_shape": [ + 192, + 320 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "video_super_resolution": { + "scale": 2.0, + "seed": 0, + "model_path": "/base_code/FlashVSR/examples/WanVSR/FlashVSR" + } +} diff --git a/configs/seko_talk/seko_talk_22_nbhd_attn.json b/configs/seko_talk/seko_talk_22_nbhd_attn.json new file mode 100644 index 0000000000000000000000000000000000000000..f1ad262450f66f208c2b019c30ef6eeda79b8b79 --- /dev/null +++ b/configs/seko_talk/seko_talk_22_nbhd_attn.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "nbhd_attn", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false +} diff --git a/configs/seko_talk/seko_talk_23_fp8_dist_nbhd_attn.json b/configs/seko_talk/seko_talk_23_fp8_dist_nbhd_attn.json new file mode 100644 index 0000000000000000000000000000000000000000..fba7ca4b8cf73f75501f2ac0476cc7afd045cdf4 --- /dev/null +++ b/configs/seko_talk/seko_talk_23_fp8_dist_nbhd_attn.json @@ -0,0 +1,28 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "nbhd_attn", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + }, + "clip_quantized": true, + "clip_quant_scheme": "fp8-sgl", + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/seko_talk/seko_talk_24_fp8_dist_compile_nbhd_attn.json b/configs/seko_talk/seko_talk_24_fp8_dist_compile_nbhd_attn.json new file mode 100644 index 0000000000000000000000000000000000000000..cc812be7e4d0a6343edb55d4e594716edaefd41b --- /dev/null +++ b/configs/seko_talk/seko_talk_24_fp8_dist_compile_nbhd_attn.json @@ -0,0 +1,74 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 360, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "nbhd_attn", + "nbhd_attn_setting": { + "coefficient": [1.0, 0.25, 0.056] + }, + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + }, + "clip_quantized": true, + "clip_quant_scheme": "fp8-sgl", + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl", + "compile": true, + "compile_shapes": [ + [ + 480, + 832 + ], + [ + 544, + 960 + ], + [ + 720, + 1280 + ], + [ + 832, + 480 + ], + [ + 960, + 544 + ], + [ + 1280, + 720 + ], + [ + 480, + 480 + ], + [ + 576, + 576 + ], + [ + 704, + 704 + ], + [ + 960, + 960 + ] + ] +} diff --git a/configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json b/configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json new file mode 100644 index 0000000000000000000000000000000000000000..a509d5efbf96c91c46808b069ec7e59435b14493 --- /dev/null +++ b/configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json @@ -0,0 +1,40 @@ +{ + "infer_steps": 2, + "target_fps": 16, + "video_duration": 5, + "audio_sr": 16000, + "target_video_length": 81, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": false, + "offload_granularity": "block", + "offload_ratio": 1, + "t5_cpu_offload": true, + "t5_quantized": true, + "t5_quant_scheme": "int8-q8f", + "clip_cpu_offload": true, + "clip_quantized": false, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, + "adapter_quantized": true, + "adapter_quant_scheme": "int8-q8f", + "vae_cpu_offload": true, + "use_tiling_vae": false, + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "resize_mode": "fixed_shape", + "fixed_shape": [ + 832, + 480 + ], + "parallel": { + "seq_p_size": 8, + "seq_p_fp8_comm": true, + "seq_p_attn_type": "ulysses-4090" + } +} diff --git a/configs/seko_talk/seko_talk_28_f2v.json b/configs/seko_talk/seko_talk_28_f2v.json new file mode 100644 index 0000000000000000000000000000000000000000..cf9bebdfa90eb0efd164da7924f2cd022446b03a --- /dev/null +++ b/configs/seko_talk/seko_talk_28_f2v.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 12, + "audio_sr": 16000, + "target_video_length": 81, + "prev_frame_length": 1, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": false, + "f2v_process": true, + "lora_configs": [ + { + "path": "lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/self_forcing/wan_t2v_sf.json b/configs/self_forcing/wan_t2v_sf.json new file mode 100644 index 0000000000000000000000000000000000000000..df541f92c8fbf80554fef6a7227648d59ded74f4 --- /dev/null +++ b/configs/self_forcing/wan_t2v_sf.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + "seed": 0, + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": false, + "cpu_offload": false, + "sf_config": { + "sf_type": "dmd", + "local_attn_size": -1, + "shift": 5.0, + "num_frame_per_block": 3, + "num_transformer_blocks": 30, + "frame_seq_length": 1560, + "num_output_frames": 21, + "num_inference_steps": 1000, + "denoising_step_list": [1000.0000, 937.5000, 833.3333, 625.0000] + } +} diff --git a/configs/sparse_attn/spas_sage_attn/wan_i2v.json b/configs/sparse_attn/spas_sage_attn/wan_i2v.json new file mode 100644 index 0000000000000000000000000000000000000000..5fffe94bb689118c5123c486d2849559d2174b55 --- /dev/null +++ b/configs/sparse_attn/spas_sage_attn/wan_i2v.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn", + "cross_attn_1_type": "sage_attn", + "cross_attn_2_type": "sage_attn", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/sparse_attn/spas_sage_attn/wan_t2v.json b/configs/sparse_attn/spas_sage_attn/wan_t2v.json new file mode 100644 index 0000000000000000000000000000000000000000..d2d045647a86d2d3afef015936936c3867b66b43 --- /dev/null +++ b/configs/sparse_attn/spas_sage_attn/wan_t2v.json @@ -0,0 +1,14 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "spas_sage_attn", + "cross_attn_1_type": "spas_sage_attn", + "cross_attn_2_type": "spas_sage_attn", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/video_frame_interpolation/wan_t2v.json b/configs/video_frame_interpolation/wan_t2v.json new file mode 100644 index 0000000000000000000000000000000000000000..02cc5498142a3f1be298cf02cba5faf615fe7e1b --- /dev/null +++ b/configs/video_frame_interpolation/wan_t2v.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "video_frame_interpolation": { + "algo": "rife", + "target_fps": 24, + "model_path": "/path to flownet.pkl" + } +} diff --git a/configs/volcengine_voices_list.json b/configs/volcengine_voices_list.json new file mode 100644 index 0000000000000000000000000000000000000000..51316103a23d05e99cfdfa6328104c94dcc9b13b --- /dev/null +++ b/configs/volcengine_voices_list.json @@ -0,0 +1,4727 @@ +{ + "voices": [ + { + "name": "Vivi 2.0", + "voice_type": "zh_female_vv_uranus_bigtts", + "gender": "female", + "version": "2.0", + "resource_id": "seed-tts-2.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "可爱女生", + "voice_type": "ICL_zh_female_keainvsheng_tob", + "gender": "female", + "version": "2.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "大壹", + "voice_type": "zh_male_dayi_saturn_bigtts", + "gender": "male", + "version": "2.0", + "resource_id": "seed-tts-2.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "魅力女友", + "voice_type": "ICL_zh_female_tiaopigongzhu_tob", + "gender": "female", + "version": "2.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "黑猫侦探社咪仔", + "voice_type": "zh_female_mizai_saturn_bigtts", + "gender": "female", + "version": "2.0", + "resource_id": "seed-tts-2.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "爽朗少年", + "voice_type": "ICL_zh_male_shuanglangshaonian_tob", + "gender": "male", + "version": "2.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "鸡汤女", + "voice_type": "zh_female_jitangnv_saturn_bigtts", + "gender": "female", + "version": "2.0", + "resource_id": "seed-tts-2.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "天才同桌", + "voice_type": "ICL_zh_male_tiancaitongzhuo_tob", + "gender": "male", + "version": "2.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "魅力女友", + "voice_type": "zh_female_meilinvyou_saturn_bigtts", + "gender": "female", + "version": "2.0", + "resource_id": "seed-tts-2.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "流畅女声", + "voice_type": "zh_female_santongyongns_saturn_bigtts", + "gender": "female", + "version": "2.0", + "resource_id": "seed-tts-2.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "儒雅逸辰", + "voice_type": "zh_male_ruyayichen_saturn_bigtts", + "gender": "male", + "version": "2.0", + "resource_id": "seed-tts-2.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "知性灿灿", + "voice_type": "saturn_zh_female_cancan_tob", + "gender": "female", + "version": "2.0", + "resource_id": "seed-tts-2.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "温柔女神", + "voice_type": "ICL_zh_female_wenrounvshen_239eff5e8ffa_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "沪普男", + "voice_type": "zh_male_hupunan_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "纯真少女", + "voice_type": "ICL_zh_female_chunzhenshaonv_e588402fb8ad_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷酷哥哥", + "voice_type": "zh_male_lengkugege_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "sad", + "angry", + "fear", + "hate", + "coldness", + "neutral", + "depressed" + ] + }, + { + "name": "理性圆子", + "voice_type": "ICL_zh_female_lixingyuanzi_cs", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Tina老师", + "voice_type": "zh_female_yingyujiaoyu_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "教育场景", + "language": [ + "chinese", + "en_gb" + ], + "emotions": [] + }, + { + "name": "Lauren", + "voice_type": "en_female_lauren_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "内敛才俊", + "voice_type": "ICL_zh_male_neiliancaijun_e991be511569_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "悠悠君子", + "voice_type": "zh_male_M100_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Vivi", + "voice_type": "zh_female_vv_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "粤语小溏", + "voice_type": "zh_female_yueyunv_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "奶气小生", + "voice_type": "ICL_zh_male_xiaonaigou_edf58cf28b8b_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "甜心小美", + "voice_type": "zh_female_tianxinxiaomei_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "sad", + "fear", + "hate", + "neutral" + ] + }, + { + "name": "清甜桃桃", + "voice_type": "ICL_zh_female_qingtiantaotao_cs", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Energetic Male II", + "voice_type": "en_male_campaign_jamal_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "温暖少年", + "voice_type": "ICL_zh_male_yangyang_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "文静毛毛", + "voice_type": "zh_female_maomao_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "亲切女声", + "voice_type": "zh_female_qinqienvsheng_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "鲁班七号", + "voice_type": "zh_male_lubanqihao_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "精灵向导", + "voice_type": "ICL_zh_female_jinglingxiangdao_1beb294a9e3e_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "高冷御姐", + "voice_type": "zh_female_gaolengyujie_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "sad", + "angry", + "surprised", + "fear", + "hate", + "excited", + "coldness", + "neutral" + ] + }, + { + "name": "清晰小雪", + "voice_type": "ICL_zh_female_qingxixiaoxue_cs", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Gotham Hero", + "voice_type": "en_male_chris_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "儒雅公子", + "voice_type": "ICL_zh_male_flc_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "倾心少女", + "voice_type": "ICL_zh_female_qiuling_v1_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "机灵小伙", + "voice_type": "ICL_zh_male_shenmi_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "林潇", + "voice_type": "zh_female_yangmi_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "闷油瓶小哥", + "voice_type": "ICL_zh_male_menyoupingxiaoge_ffed9fc2fee7_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲娇霸总", + "voice_type": "zh_male_aojiaobazong_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "angry", + "hate", + "neutral" + ] + }, + { + "name": "清甜莓莓", + "voice_type": "ICL_zh_female_qingtianmeimei_cs", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Flirty Female", + "voice_type": "en_female_product_darcie_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "悬疑解说", + "voice_type": "zh_male_changtianyi_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "醇厚低音", + "voice_type": "ICL_zh_male_buyan_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "元气甜妹", + "voice_type": "ICL_zh_female_wuxi_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "玲玲姐姐", + "voice_type": "zh_female_linzhiling_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "黯刃秦主", + "voice_type": "ICL_zh_male_anrenqinzhu_cd62e63dcdab_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "广州德哥", + "voice_type": "zh_male_guangzhoudege_emo_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "angry", + "fear", + "neutral" + ] + }, + { + "name": "开朗婷婷", + "voice_type": "ICL_zh_female_kailangtingting_cs", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Peaceful Female", + "voice_type": "en_female_emotional_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "儒雅青年", + "voice_type": "zh_male_ruyaqingnian_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "咆哮小哥", + "voice_type": "ICL_zh_male_BV144_paoxiaoge_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "知心姐姐", + "voice_type": "ICL_zh_female_wenyinvsheng_v1_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "春日部姐姐", + "voice_type": "zh_female_jiyejizi2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "霸道总裁", + "voice_type": "ICL_zh_male_badaozongcai_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "京腔侃爷", + "voice_type": "zh_male_jingqiangkanye_emo_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "angry", + "surprised", + "hate", + "neutral" + ] + }, + { + "name": "清新沐沐", + "voice_type": "ICL_zh_male_qingxinmumu_cs", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Nara", + "voice_type": "en_female_nara_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "霸气青叔", + "voice_type": "zh_male_baqiqingshu_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "和蔼奶奶", + "voice_type": "ICL_zh_female_heainainai_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "阳光阿辰", + "voice_type": "zh_male_qingyiyuxuan_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "唐僧", + "voice_type": "zh_male_tangseng_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "妩媚可人", + "voice_type": "ICL_zh_female_ganli_v1_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "邻居阿姨", + "voice_type": "zh_female_linjuayi_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "angry", + "surprised", + "coldness", + "neutral", + "depressed" + ] + }, + { + "name": "爽朗小阳", + "voice_type": "ICL_zh_male_shuanglangxiaoyang_cs", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Bruce", + "voice_type": "en_male_bruce_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "擎苍", + "voice_type": "zh_male_qingcang_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "邻居阿姨", + "voice_type": "ICL_zh_female_linjuayi_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "快乐小东", + "voice_type": "zh_male_xudong_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "庄周", + "voice_type": "zh_male_zhuangzhou_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "邪魅御姐", + "voice_type": "ICL_zh_female_xiangliangya_v1_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "优柔公子", + "voice_type": "zh_male_yourougongzi_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "angry", + "fear", + "hate", + "excited", + "neutral", + "depressed" + ] + }, + { + "name": "清新波波", + "voice_type": "ICL_zh_male_qingxinbobo_cs", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Michael", + "voice_type": "en_male_michael_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "活力小哥", + "voice_type": "zh_male_yangguangqingnian_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "温柔小雅", + "voice_type": "zh_female_wenrouxiaoya_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷酷哥哥", + "voice_type": "ICL_zh_male_lengkugege_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "猪八戒", + "voice_type": "zh_male_zhubajie_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "嚣张小哥", + "voice_type": "ICL_zh_male_ms_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "儒雅男友", + "voice_type": "zh_male_ruyayichen_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "sad", + "angry", + "fear", + "excited", + "coldness", + "neutral" + ] + }, + { + "name": "温婉珊珊", + "voice_type": "ICL_zh_female_wenwanshanshan_cs", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Cartoon Chef", + "voice_type": "ICL_en_male_cc_sha_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "古风少御", + "voice_type": "zh_female_gufengshaoyu_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "天才童声", + "voice_type": "zh_male_tiancaitongsheng_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "纯澈女生", + "voice_type": "ICL_zh_female_feicui_v1_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "感冒电音姐姐", + "voice_type": "zh_female_ganmaodianyin_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "油腻大叔", + "voice_type": "ICL_zh_male_you_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "俊朗男友", + "voice_type": "zh_male_junlangnanyou_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "sad", + "angry", + "surprised", + "fear", + "neutral" + ] + }, + { + "name": "甜美小雨", + "voice_type": "ICL_zh_female_tianmeixiaoyu_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Lucas", + "voice_type": "zh_male_M100_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "温柔淑女", + "voice_type": "zh_female_wenroushunv_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "猴哥", + "voice_type": "zh_male_sunwukong_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "初恋女友", + "voice_type": "ICL_zh_female_yuxin_v1_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "直率英子", + "voice_type": "zh_female_naying_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "孤傲公子", + "voice_type": "ICL_zh_male_guaogongzi_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "北京小爷", + "voice_type": "zh_male_beijingxiaoye_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "angry", + "surprised", + "fear", + "excited", + "coldness", + "neutral" + ] + }, + { + "name": "热情艾娜", + "voice_type": "ICL_zh_female_reqingaina_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Sophie", + "voice_type": "zh_female_sophie_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "反卷青年", + "voice_type": "zh_male_fanjuanqingnian_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "audiobook", + "scene": "有声阅读", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "熊二", + "voice_type": "zh_male_xionger_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "贴心闺蜜", + "voice_type": "ICL_zh_female_xnx_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "女雷神", + "voice_type": "zh_female_leidian_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "胡子叔叔", + "voice_type": "ICL_zh_male_huzi_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "柔美女友", + "voice_type": "zh_female_roumeinvyou_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "sad", + "angry", + "surprised", + "fear", + "hate", + "excited", + "coldness", + "neutral" + ] + }, + { + "name": "甜美小橘", + "voice_type": "ICL_zh_female_tianmeixiaoju_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Daisy", + "voice_type": "en_female_dacey_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "佩奇猪", + "voice_type": "zh_female_peiqi_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "温柔白月光", + "voice_type": "ICL_zh_female_yry_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "豫州子轩", + "voice_type": "zh_male_yuzhouzixuan_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "性感魅惑", + "voice_type": "ICL_zh_female_luoqing_v1_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "阳光青年", + "voice_type": "zh_male_yangguangqingnian_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "happy", + "sad", + "angry", + "fear", + "excited", + "coldness", + "neutral" + ] + }, + { + "name": "沉稳明仔", + "voice_type": "ICL_zh_male_chenwenmingzai_cs_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Owen", + "voice_type": "en_male_charlie_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "武则天", + "voice_type": "zh_female_wuzetian_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "开朗学长", + "voice_type": "en_male_jason_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "呆萌川妹", + "voice_type": "zh_female_daimengchuanmei_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病弱公子", + "voice_type": "ICL_zh_male_bingruogongzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "魅力女友", + "voice_type": "zh_female_meilinvyou_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese" + ], + "emotions": [ + "sad", + "fear", + "neutral" + ] + }, + { + "name": "亲切小卓", + "voice_type": "ICL_zh_male_qinqiexiaozhuo_cs_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Luna", + "voice_type": "en_female_sarah_new_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "顾姐", + "voice_type": "zh_female_gujie_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "魅力苏菲", + "voice_type": "zh_female_sophie_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "广西远舟", + "voice_type": "zh_male_guangxiyuanzhou_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "邪魅女王", + "voice_type": "ICL_zh_female_bingjiao3_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "爽快思思", + "voice_type": "zh_female_shuangkuaisisi_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "chinese", + "en_gb" + ], + "emotions": [ + "happy", + "sad", + "angry", + "surprised", + "excited", + "coldness", + "neutral" + ] + }, + { + "name": "灵动欣欣", + "voice_type": "ICL_zh_female_lingdongxinxin_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Michael", + "voice_type": "ICL_en_male_michael_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "樱桃丸子", + "voice_type": "zh_female_yingtaowanzi_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "贴心妹妹", + "voice_type": "ICL_zh_female_yilin_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "双节棍小哥", + "voice_type": "zh_male_zhoujielun_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲慢青年", + "voice_type": "ICL_zh_male_aomanqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Candice", + "voice_type": "en_female_candice_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "en_us" + ], + "emotions": [ + "angry", + "neutral", + "ASMR", + "happy", + "chat", + "chat", + "warm", + "affectionate" + ] + }, + { + "name": "乖巧可儿", + "voice_type": "ICL_zh_female_guaiqiaokeer_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Charlie", + "voice_type": "ICL_en_female_cc_cm_v1_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "广告解说", + "voice_type": "zh_male_chunhui_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "甜美桃子", + "voice_type": "zh_female_tianmeitaozi_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "湾湾小何", + "voice_type": "zh_female_wanwanxiaohe_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "醋精男生", + "voice_type": "ICL_zh_male_cujingnansheng_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Serena", + "voice_type": "en_female_skye_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "en_us" + ], + "emotions": [ + "sad", + "angry", + "neutral", + "ASMR", + "happy", + "chat", + "chat", + "warm", + "affectionate" + ] + }, + { + "name": "暖心茜茜", + "voice_type": "ICL_zh_female_nuanxinqianqian_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Big Boogie", + "voice_type": "ICL_en_male_oogie2_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "少儿故事", + "voice_type": "zh_female_shaoergushi_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "清新女声", + "voice_type": "zh_female_qingxinnvsheng_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "湾区大叔", + "voice_type": "zh_female_wanqudashu_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "撒娇男友", + "voice_type": "ICL_zh_male_sajiaonanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Glen", + "voice_type": "en_male_glen_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "en_us" + ], + "emotions": [ + "sad", + "angry", + "neutral", + "ASMR", + "happy", + "chat", + "chat", + "warm", + "affectionate" + ] + }, + { + "name": "软萌团子", + "voice_type": "ICL_zh_female_ruanmengtuanzi_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Frosty Man", + "voice_type": "ICL_en_male_frosty1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "四郎", + "voice_type": "zh_male_silang_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "知性女声", + "voice_type": "zh_female_zhixingnvsheng_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "广州德哥", + "voice_type": "zh_male_guozhoudege_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "温柔男友", + "voice_type": "ICL_zh_male_wenrounanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Sylus", + "voice_type": "en_male_sylus_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "en_us" + ], + "emotions": [ + "sad", + "angry", + "neutral", + "ASMR", + "happy", + "chat", + "chat", + "warm", + "affectionate", + "authoritative" + ] + }, + { + "name": "阳光洋洋", + "voice_type": "ICL_zh_male_yangguangyangyang_cs_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "The Grinch", + "voice_type": "ICL_en_male_grinch2_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "俏皮女声", + "voice_type": "zh_female_qiaopinvsheng_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "清爽男大", + "voice_type": "zh_male_qingshuangnanda_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "浩宇小哥", + "voice_type": "zh_male_haoyuxiaoge_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "温顺少年", + "voice_type": "ICL_zh_male_wenshunshaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Corey", + "voice_type": "en_male_corey_emo_v2_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "en_gb" + ], + "emotions": [ + "sad", + "angry", + "neutral", + "ASMR", + "happy", + "chat", + "chat", + "warm", + "affectionate", + "authoritative" + ] + }, + { + "name": "软萌糖糖", + "voice_type": "ICL_zh_female_ruanmengtangtang_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Zayne", + "voice_type": "ICL_en_male_zayne_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "懒音绵宝", + "voice_type": "zh_male_lanxiaoyang_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "邻家女孩", + "voice_type": "zh_female_linjianvhai_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "北京小爷", + "voice_type": "zh_male_beijingxiaoye_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "粘人男友", + "voice_type": "ICL_zh_male_naigounanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Nadia", + "voice_type": "en_female_nadia_tips_emo_v2_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_emotion", + "scene": "多情感", + "language": [ + "en_gb" + ], + "emotions": [ + "sad", + "angry", + "neutral", + "ASMR", + "happy", + "chat", + "chat", + "warm", + "affectionate" + ] + }, + { + "name": "秀丽倩倩", + "voice_type": "ICL_zh_female_xiuliqianqian_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Jigsaw", + "voice_type": "ICL_en_male_cc_jigsaw_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "亮嗓萌仔", + "voice_type": "zh_male_dongmanhaimian_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "渊博小叔", + "voice_type": "zh_male_yuanboxiaoshu_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "京腔侃爷/Harmony", + "voice_type": "zh_male_jingqiangkanye_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "撒娇男生", + "voice_type": "ICL_zh_male_sajiaonansheng_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "开心小鸿", + "voice_type": "ICL_zh_female_kaixinxiaohong_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Chucky", + "voice_type": "ICL_en_male_cc_chucky_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "磁性解说男声/Morgan", + "voice_type": "zh_male_jieshuonansheng_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "阳光青年", + "voice_type": "zh_male_yangguangqingnian_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "妹坨洁儿", + "voice_type": "zh_female_meituojieer_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "accent", + "scene": "趣味口音", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "活泼男友", + "voice_type": "ICL_zh_male_huoponanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "轻盈朵朵", + "voice_type": "ICL_zh_female_qingyingduoduo_cs_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Clown Man", + "voice_type": "ICL_en_male_cc_penny_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "鸡汤妹妹/Hope", + "voice_type": "zh_female_jitangmeimei_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "甜美小源", + "voice_type": "zh_female_tianmeixiaoyuan_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "甜系男友", + "voice_type": "ICL_zh_male_tianxinanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "暖阳女声", + "voice_type": "zh_female_kefunvsheng_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "customer_service", + "scene": "客服场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Kevin McCallister", + "voice_type": "ICL_en_male_kevin2_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "贴心女声/Candy", + "voice_type": "zh_female_tiexinnvsheng_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "清澈梓梓", + "voice_type": "zh_female_qingchezizi_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "活力青年", + "voice_type": "ICL_zh_male_huoliqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Xavier", + "voice_type": "ICL_en_male_xavier1_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "萌丫头/Cutey", + "voice_type": "zh_female_mengyatou_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "video_dubbing", + "scene": "视频配音", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "解说小明", + "voice_type": "zh_male_jieshuoxiaoming_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "开朗青年", + "voice_type": "ICL_zh_male_kailangqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Noah", + "voice_type": "ICL_en_male_cc_dracula_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "开朗姐姐", + "voice_type": "zh_female_kailangjiejie_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷漠兄长", + "voice_type": "ICL_zh_male_lengmoxiongzhang_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Adam", + "voice_type": "en_male_adam_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "邻家男孩", + "voice_type": "zh_male_linjiananhai_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "翩翩公子", + "voice_type": "ICL_zh_male_pianpiangongzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Amanda", + "voice_type": "en_female_amanda_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "甜美悦悦", + "voice_type": "zh_female_tianmeiyueyue_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "懵懂青年", + "voice_type": "ICL_zh_male_mengdongqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Jackson", + "voice_type": "en_male_jackson_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_us" + ], + "emotions": [] + }, + { + "name": "心灵鸡汤", + "voice_type": "zh_female_xinlingjitang_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷脸兄长", + "voice_type": "ICL_zh_male_lenglianxiongzhang_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Delicate Girl", + "voice_type": "en_female_daisy_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "知性温婉", + "voice_type": "ICL_zh_female_zhixingwenwan_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病娇少年", + "voice_type": "ICL_zh_male_bingjiaoshaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Dave", + "voice_type": "en_male_dave_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "暖心体贴", + "voice_type": "ICL_zh_male_nuanxintitie_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病娇男友", + "voice_type": "ICL_zh_male_bingjiaonanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Hades", + "voice_type": "en_male_hades_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "开朗轻快", + "voice_type": "ICL_zh_male_kailangqingkuai_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病弱少年", + "voice_type": "ICL_zh_male_bingruoshaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Onez", + "voice_type": "en_female_onez_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "活泼爽朗", + "voice_type": "ICL_zh_male_huoposhuanglang_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "意气少年", + "voice_type": "ICL_zh_male_yiqishaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Emily", + "voice_type": "en_female_emily_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "率真小伙", + "voice_type": "ICL_zh_male_shuaizhenxiaohuo_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "干净少年", + "voice_type": "ICL_zh_male_ganjingshaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Daniel", + "voice_type": "zh_male_xudong_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "温柔小哥", + "voice_type": "zh_male_wenrouxiaoge_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷漠男友", + "voice_type": "ICL_zh_male_lengmonanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Alastor", + "voice_type": "ICL_en_male_cc_alastor_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "灿灿/Shiny", + "voice_type": "zh_female_cancan_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "精英青年", + "voice_type": "ICL_zh_male_jingyingqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Smith", + "voice_type": "en_male_smith_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "爽快思思/Skye", + "voice_type": "zh_female_shuangkuaisisi_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "热血少年", + "voice_type": "ICL_zh_male_rexueshaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Anna", + "voice_type": "en_female_anna_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_gb" + ], + "emotions": [] + }, + { + "name": "温暖阿虎/Alvin", + "voice_type": "zh_male_wennuanahu_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "清爽少年", + "voice_type": "ICL_zh_male_qingshuangshaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Ethan", + "voice_type": "ICL_en_male_aussie_v1_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_au" + ], + "emotions": [] + }, + { + "name": "少年梓辛/Brayan", + "voice_type": "zh_male_shaonianzixin_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese", + "en_us" + ], + "emotions": [] + }, + { + "name": "中二青年", + "voice_type": "ICL_zh_male_zhongerqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Sarah", + "voice_type": "en_female_sarah_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_au" + ], + "emotions": [] + }, + { + "name": "温柔文雅", + "voice_type": "ICL_zh_female_wenrouwenya_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "general", + "scene": "通用场景", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "凌云青年", + "voice_type": "ICL_zh_male_lingyunqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Dryw", + "voice_type": "en_male_dryw_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "en_au" + ], + "emotions": [] + }, + { + "name": "自负青年", + "voice_type": "ICL_zh_male_zifuqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Diana", + "voice_type": "multi_female_maomao_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "es" + ], + "emotions": [] + }, + { + "name": "不羁青年", + "voice_type": "ICL_zh_male_bujiqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Lucía", + "voice_type": "multi_male_M100_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "es" + ], + "emotions": [] + }, + { + "name": "儒雅君子", + "voice_type": "ICL_zh_male_ruyajunzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Sofía", + "voice_type": "multi_female_sophie_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "es" + ], + "emotions": [] + }, + { + "name": "低音沉郁", + "voice_type": "ICL_zh_male_diyinchenyu_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "Daníel", + "voice_type": "multi_male_xudong_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "es" + ], + "emotions": [] + }, + { + "name": "冷脸学霸", + "voice_type": "ICL_zh_male_lenglianxueba_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "ひかる(光)", + "voice_type": "multi_zh_male_youyoujunzi_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "ja" + ], + "emotions": [] + }, + { + "name": "儒雅总裁", + "voice_type": "ICL_zh_male_ruyazongcai_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "さとみ(智美)", + "voice_type": "multi_female_sophie_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "ja" + ], + "emotions": [] + }, + { + "name": "深沉总裁", + "voice_type": "ICL_zh_male_shenchenzongcai_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "まさお(正男)", + "voice_type": "multi_male_xudong_conversation_wvae_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "ja" + ], + "emotions": [] + }, + { + "name": "小侯爷", + "voice_type": "ICL_zh_male_xiaohouye_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "つき(月)", + "voice_type": "multi_female_maomao_conversation_wvae_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "ja" + ], + "emotions": [] + }, + { + "name": "孤高公子", + "voice_type": "ICL_zh_male_gugaogongzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "あけみ(朱美)", + "voice_type": "multi_female_gaolengyujie_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "ja" + ], + "emotions": [] + }, + { + "name": "仗剑君子", + "voice_type": "ICL_zh_male_zhangjianjunzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "かずね(和音)/Javier or Álvaro", + "voice_type": "multi_male_jingqiangkanye_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "ja", + "es" + ], + "emotions": [] + }, + { + "name": "温润学者", + "voice_type": "ICL_zh_male_wenrunxuezhe_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "はるこ(晴子)/Esmeralda", + "voice_type": "multi_female_shuangkuaisisi_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "ja", + "es" + ], + "emotions": [] + }, + { + "name": "亲切青年", + "voice_type": "ICL_zh_male_qinqieqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "ひろし(広志)/Roberto", + "voice_type": "multi_male_wanqudashu_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "multi_language", + "scene": "多语种", + "language": [ + "ja", + "es" + ], + "emotions": [] + }, + { + "name": "温柔学长", + "voice_type": "ICL_zh_male_wenrouxuezhang_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "高冷总裁", + "voice_type": "ICL_zh_male_gaolengzongcai_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷峻高智", + "voice_type": "ICL_zh_male_lengjungaozhi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "孱弱少爷", + "voice_type": "ICL_zh_male_chanruoshaoye_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "自信青年", + "voice_type": "ICL_zh_male_zixinqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "青涩青年", + "voice_type": "ICL_zh_male_qingseqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "学霸同桌", + "voice_type": "ICL_zh_male_xuebatongzhuo_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷傲总裁", + "voice_type": "ICL_zh_male_lengaozongcai_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "元气少年", + "voice_type": "ICL_zh_male_yuanqishaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "洒脱青年", + "voice_type": "ICL_zh_male_satuoqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "直率青年", + "voice_type": "ICL_zh_male_zhishuaiqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "斯文青年", + "voice_type": "ICL_zh_male_siwenqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "俊逸公子", + "voice_type": "ICL_zh_male_junyigongzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "仗剑侠客", + "voice_type": "ICL_zh_male_zhangjianxiake_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "机甲智能", + "voice_type": "ICL_zh_male_jijiaozhineng_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "奶气萌娃", + "voice_type": "zh_male_naiqimengwa_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "婆婆", + "voice_type": "zh_female_popo_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "高冷御姐", + "voice_type": "zh_female_gaolengyujie_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲娇霸总", + "voice_type": "zh_male_aojiaobazong_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "魅力女友", + "voice_type": "zh_female_meilinvyou_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "深夜播客", + "voice_type": "zh_male_shenyeboke_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "柔美女友", + "voice_type": "zh_female_sajiaonvyou_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "撒娇学妹", + "voice_type": "zh_female_yuanqinvyou_moon_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病弱少女", + "voice_type": "ICL_zh_female_bingruoshaonv_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "活泼女孩", + "voice_type": "ICL_zh_female_huoponvhai_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "东方浩然", + "voice_type": "zh_male_dongfanghaoran_moon_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "绿茶小哥", + "voice_type": "ICL_zh_male_lvchaxiaoge_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "娇弱萝莉", + "voice_type": "ICL_zh_female_jiaoruoluoli_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷淡疏离", + "voice_type": "ICL_zh_male_lengdanshuli_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "憨厚敦实", + "voice_type": "ICL_zh_male_hanhoudunshi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "活泼刁蛮", + "voice_type": "ICL_zh_female_huopodiaoman_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "固执病娇", + "voice_type": "ICL_zh_male_guzhibingjiao_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "撒娇粘人", + "voice_type": "ICL_zh_male_sajiaonianren_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲慢娇声", + "voice_type": "ICL_zh_female_aomanjiaosheng_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "潇洒随性", + "voice_type": "ICL_zh_male_xiaosasuixing_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "诡异神秘", + "voice_type": "ICL_zh_male_guiyishenmi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "儒雅才俊", + "voice_type": "ICL_zh_male_ruyacaijun_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "正直青年", + "voice_type": "ICL_zh_male_zhengzhiqingnian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "娇憨女王", + "voice_type": "ICL_zh_female_jiaohannvwang_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病娇萌妹", + "voice_type": "ICL_zh_female_bingjiaomengmei_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "青涩小生", + "voice_type": "ICL_zh_male_qingsenaigou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "纯真学弟", + "voice_type": "ICL_zh_male_chunzhenxuedi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "优柔帮主", + "voice_type": "ICL_zh_male_youroubangzhu_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "优柔公子", + "voice_type": "ICL_zh_male_yourougongzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "贴心男友", + "voice_type": "ICL_zh_male_tiexinnanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "少年将军", + "voice_type": "ICL_zh_male_shaonianjiangjun_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病娇哥哥", + "voice_type": "ICL_zh_male_bingjiaogege_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "学霸男同桌", + "voice_type": "ICL_zh_male_xuebanantongzhuo_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "幽默叔叔", + "voice_type": "ICL_zh_male_youmoshushu_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "假小子", + "voice_type": "ICL_zh_female_jiaxiaozi_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "温柔男同桌", + "voice_type": "ICL_zh_male_wenrounantongzhuo_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "幽默大爷", + "voice_type": "ICL_zh_male_youmodaye_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "枕边低语", + "voice_type": "ICL_zh_male_asmryexiu_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "神秘法师", + "voice_type": "ICL_zh_male_shenmifashi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "娇喘女声", + "voice_type": "zh_female_jiaochuan_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "开朗弟弟", + "voice_type": "zh_male_livelybro_mars_bigtts", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "谄媚女声", + "voice_type": "zh_female_flattery_mars_bigtts", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "冷峻上司", + "voice_type": "ICL_zh_male_lengjunshangsi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "醋精男友", + "voice_type": "ICL_zh_male_cujingnanyou_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "风发少年", + "voice_type": "ICL_zh_male_fengfashaonian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "磁性男嗓", + "voice_type": "ICL_zh_male_cixingnansang_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "成熟总裁", + "voice_type": "ICL_zh_male_chengshuzongcai_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲娇精英", + "voice_type": "ICL_zh_male_aojiaojingying_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲娇公子", + "voice_type": "ICL_zh_male_aojiaogongzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "霸道少爷", + "voice_type": "ICL_zh_male_badaoshaoye_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "腹黑公子", + "voice_type": "ICL_zh_male_fuheigongzi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "暖心学姐", + "voice_type": "ICL_zh_female_nuanxinxuejie_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "成熟姐姐", + "voice_type": "ICL_zh_female_chengshujiejie_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病娇姐姐", + "voice_type": "ICL_zh_female_bingjiaojiejie_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "妩媚御姐", + "voice_type": "ICL_zh_female_wumeiyujie_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲娇女友", + "voice_type": "ICL_zh_female_aojiaonvyou_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "贴心女友", + "voice_type": "ICL_zh_female_tiexinnvyou_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "性感御姐", + "voice_type": "ICL_zh_female_xingganyujie_tob", + "gender": "female", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病娇弟弟", + "voice_type": "ICL_zh_male_bingjiaodidi_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲慢少爷", + "voice_type": "ICL_zh_male_aomanshaoye_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "傲气凌人", + "voice_type": "ICL_zh_male_aiqilingren_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + }, + { + "name": "病娇白莲", + "voice_type": "ICL_zh_male_bingjiaobailian_tob", + "gender": "male", + "version": "1.0", + "resource_id": "seed-tts-1.0", + "voice_category": "roleplay", + "scene": "角色扮演", + "language": [ + "chinese" + ], + "emotions": [] + } + ], + "languages": [ + { + "name": "chinese", + "zh": "中文", + "en": "Chinese" + }, + { + "name": "en_au", + "zh": "澳洲英语", + "en": "Australian English" + }, + { + "name": "en_gb", + "zh": "英式英语", + "en": "British English" + }, + { + "name": "en_us", + "zh": "美式英语", + "en": "American English" + }, + { + "name": "es", + "zh": "西语", + "en": "Spanish" + }, + { + "name": "ja", + "zh": "日语", + "en": "Japanese" + } + ], + "emotions": [ + { + "name": "ASMR", + "zh": "低语", + "en": "ASMR" + }, + { + "name": "affectionate", + "zh": "深情", + "en": "affectionate" + }, + { + "name": "angry", + "zh": "生气", + "en": "angry" + }, + { + "name": "authoritative", + "zh": "权威", + "en": "authoritative" + }, + { + "name": "chat", + "zh": "对话", + "en": "chat" + }, + { + "name": "coldness", + "zh": "冷漠", + "en": "coldness" + }, + { + "name": "depressed", + "zh": "沮丧", + "en": "depressed" + }, + { + "name": "excited", + "zh": "激动", + "en": "excited" + }, + { + "name": "fear", + "zh": "恐惧", + "en": "fear" + }, + { + "name": "happy", + "zh": "开心", + "en": "happy" + }, + { + "name": "hate", + "zh": "厌恶", + "en": "hate" + }, + { + "name": "neutral", + "zh": "中性", + "en": "neutral" + }, + { + "name": "sad", + "zh": "悲伤", + "en": "sad" + }, + { + "name": "surprised", + "zh": "惊讶", + "en": "surprised" + }, + { + "name": "warm", + "zh": "温暖", + "en": "warm" + } + ] +} diff --git a/configs/wan/wan_flf2v.json b/configs/wan/wan_flf2v.json new file mode 100644 index 0000000000000000000000000000000000000000..998f38f05a73f6ce0bd74116764c6ea317f5391d --- /dev/null +++ b/configs/wan/wan_flf2v.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 16, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/wan/wan_i2v.json b/configs/wan/wan_i2v.json new file mode 100644 index 0000000000000000000000000000000000000000..6c2107083c6b94f2dcd36e9eec30664292e3ba9a --- /dev/null +++ b/configs/wan/wan_i2v.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 3, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/wan/wan_t2v.json b/configs/wan/wan_t2v.json new file mode 100644 index 0000000000000000000000000000000000000000..2e2825047a1d4ad72ae661befd729f26162dbabe --- /dev/null +++ b/configs/wan/wan_t2v.json @@ -0,0 +1,14 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/wan/wan_t2v_enhancer.json b/configs/wan/wan_t2v_enhancer.json new file mode 100644 index 0000000000000000000000000000000000000000..b27afdfbf5f645eabc9c9e179358b02c5cacf8e9 --- /dev/null +++ b/configs/wan/wan_t2v_enhancer.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "sub_servers": { + "prompt_enhancer": [ + "http://localhost:9001" + ] + } +} diff --git a/configs/wan/wan_vace.json b/configs/wan/wan_vace.json new file mode 100644 index 0000000000000000000000000000000000000000..23fbff233a7243bb032e46d3013c6dce5be5b0b8 --- /dev/null +++ b/configs/wan/wan_vace.json @@ -0,0 +1,13 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5, + "sample_shift": 16, + "enable_cfg": true, + "cpu_offload": false +} diff --git a/configs/wan22/wan_animate.json b/configs/wan22/wan_animate.json new file mode 100644 index 0000000000000000000000000000000000000000..c4c3d59b32eb406c2c30fab6474fd23d74363b49 --- /dev/null +++ b/configs/wan22/wan_animate.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 20, + "target_video_length": 77, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "adapter_attn_type": "flash_attn3", + "sample_shift": 5.0, + "sample_guide_scale": 5.0, + "enable_cfg": false, + "cpu_offload": false, + "refert_num": 1, + "replace_flag": false, + "fps": 30 +} diff --git a/configs/wan22/wan_animate_4090.json b/configs/wan22/wan_animate_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..e0f89657ddea8ca989d2e33d51076cc6e2e3af95 --- /dev/null +++ b/configs/wan22/wan_animate_4090.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 20, + "target_video_length": 77, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "adapter_attn_type": "sage_attn2", + "sample_shift": 5.0, + "sample_guide_scale": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "phase", + "refert_num": 1, + "replace_flag": false, + "fps": 30, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8", + "clip_quantized": true, + "clip_quant_scheme": "fp8" +} diff --git a/configs/wan22/wan_animate_lora.json b/configs/wan22/wan_animate_lora.json new file mode 100644 index 0000000000000000000000000000000000000000..0e1c8059cf0b716a7778d9665d8409574331a603 --- /dev/null +++ b/configs/wan22/wan_animate_lora.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 4, + "target_video_length": 77, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "adapter_attn_type": "flash_attn3", + "sample_shift": 5.0, + "sample_guide_scale": 1.0, + "enable_cfg": false, + "cpu_offload": false, + "refert_num": 1, + "replace_flag": false, + "fps": 30, + "lora_configs": [ + { + "path": "lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/wan22/wan_animate_replace.json b/configs/wan22/wan_animate_replace.json new file mode 100644 index 0000000000000000000000000000000000000000..c4805a471386291791fdca873c0fc34d9df257d4 --- /dev/null +++ b/configs/wan22/wan_animate_replace.json @@ -0,0 +1,18 @@ +{ + "infer_steps": 20, + "target_video_length": 77, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "adapter_attn_type": "flash_attn3", + "sample_shift": 5.0, + "sample_guide_scale": 5.0, + "enable_cfg": false, + "cpu_offload": false, + "refert_num": 1, + "fps": 30, + "replace_flag": true +} diff --git a/configs/wan22/wan_animate_replace_4090.json b/configs/wan22/wan_animate_replace_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..fab2908b784cfc657d35853ea16e65a7706727cd --- /dev/null +++ b/configs/wan22/wan_animate_replace_4090.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 20, + "target_video_length": 77, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "adapter_attn_type": "sage_attn2", + "sample_shift": 5.0, + "sample_guide_scale": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "phase", + "refert_num": 1, + "fps": 30, + "replace_flag": true, + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f", + "clip_quantized": true, + "clip_quant_scheme": "fp8-q8f" +} diff --git a/configs/wan22/wan_distill_moe_flf2v.json b/configs/wan22/wan_distill_moe_flf2v.json new file mode 100644 index 0000000000000000000000000000000000000000..d3f90018fa5a2ada4174f18972be495ed4b9d39f --- /dev/null +++ b/configs/wan22/wan_distill_moe_flf2v.json @@ -0,0 +1,28 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 16, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ] +} diff --git a/configs/wan22/wan_distill_moe_flf2v_fp8.json b/configs/wan22/wan_distill_moe_flf2v_fp8.json new file mode 100644 index 0000000000000000000000000000000000000000..2e982f38b0bb5299a4c84b5b708be625f79fae26 --- /dev/null +++ b/configs/wan22/wan_distill_moe_flf2v_fp8.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 16, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8" +} diff --git a/configs/wan22/wan_distill_moe_flf2v_int8.json b/configs/wan22/wan_distill_moe_flf2v_int8.json new file mode 100644 index 0000000000000000000000000000000000000000..c3489b04395f9884582810be50778d3fa55f5d80 --- /dev/null +++ b/configs/wan22/wan_distill_moe_flf2v_int8.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 16, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm", + "t5_quantized": true, + "t5_quant_scheme": "int8" +} diff --git a/configs/wan22/wan_distill_moe_flf2v_with_lora.json b/configs/wan22/wan_distill_moe_flf2v_with_lora.json new file mode 100644 index 0000000000000000000000000000000000000000..445ca8fe2dd7c356080a5dd08414faa6e5a096da --- /dev/null +++ b/configs/wan22/wan_distill_moe_flf2v_with_lora.json @@ -0,0 +1,40 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 16, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "lora_configs": [ + { + "name": "low_noise_model", + "path": "/path/to/low_noise_lora", + "strength": 1.0 + }, + { + "name": "high_noise_model", + "path": "/path/to/high_noise_lora", + "strength": 1.0 + } + ] +} diff --git a/configs/wan22/wan_moe_flf2v.json b/configs/wan22/wan_moe_flf2v.json new file mode 100644 index 0000000000000000000000000000000000000000..73879fc935900abc1a81b3dc840ca057a122a21e --- /dev/null +++ b/configs/wan22/wan_moe_flf2v.json @@ -0,0 +1,21 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_shift": 16, + "enable_cfg": true, + "cpu_offload": true, + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary": 0.900, + "use_image_encoder": false +} diff --git a/configs/wan22/wan_moe_i2v.json b/configs/wan22/wan_moe_i2v.json new file mode 100644 index 0000000000000000000000000000000000000000..eb29bfe8cefa414cedd0edd352028fd628f96dbb --- /dev/null +++ b/configs/wan22/wan_moe_i2v.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "boundary": 0.900, + "use_image_encoder": false +} diff --git a/configs/wan22/wan_moe_i2v_4090.json b/configs/wan22/wan_moe_i2v_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..d37b9289257dece6a4f5cbcc7a2d49d44b4d1f49 --- /dev/null +++ b/configs/wan22/wan_moe_i2v_4090.json @@ -0,0 +1,27 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "phase", + "boundary": 0.900, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ] +} diff --git a/configs/wan22/wan_moe_i2v_audio.json b/configs/wan22/wan_moe_i2v_audio.json new file mode 100644 index 0000000000000000000000000000000000000000..a0deb7ce7e08975da20a3e5e5e70d6eb21be8dec --- /dev/null +++ b/configs/wan22/wan_moe_i2v_audio.json @@ -0,0 +1,38 @@ +{ + "infer_steps": 6, + "target_fps": 16, + "video_duration": 16, + "audio_sr": 16000, + "text_len": 512, + "target_video_length": 81, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 1.0, + 1.0 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary": 0.900, + "use_image_encoder": false, + "use_31_block": false, + "lora_configs": [ + { + "name": "high_noise_model", + "path": "/mnt/Text2Video/wuzhuguanyu/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors", + "strength": 1.0 + }, + { + "name": "low_noise_model", + "path": "/mnt/Text2Video/wuzhuguanyu/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/wan22/wan_moe_i2v_distill.json b/configs/wan22/wan_moe_i2v_distill.json new file mode 100644 index 0000000000000000000000000000000000000000..92f3360d75a3801a359e1b403deba77a57d8c166 --- /dev/null +++ b/configs/wan22/wan_moe_i2v_distill.json @@ -0,0 +1,28 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ] +} diff --git a/configs/wan22/wan_moe_i2v_distill_4090.json b/configs/wan22/wan_moe_i2v_distill_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..279a0f25e88954426b732d70388d216b0c01a643 --- /dev/null +++ b/configs/wan22/wan_moe_i2v_distill_4090.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "fp8-q8f" +} diff --git a/configs/wan22/wan_moe_i2v_distill_5090.json b/configs/wan22/wan_moe_i2v_distill_5090.json new file mode 100644 index 0000000000000000000000000000000000000000..04494c2234874188f86c3aac7e2d92d9f311ad68 --- /dev/null +++ b/configs/wan22/wan_moe_i2v_distill_5090.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn3", + "cross_attn_1_type": "sage_attn3", + "cross_attn_2_type": "sage_attn3", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "t5_quantized": true, + "t5_quant_scheme": "int8-q8f" +} diff --git a/configs/wan22/wan_moe_i2v_distill_quant.json b/configs/wan22/wan_moe_i2v_distill_quant.json new file mode 100644 index 0000000000000000000000000000000000000000..ce4e1c4aeb8524304b2e87c95909d713116cab84 --- /dev/null +++ b/configs/wan22/wan_moe_i2v_distill_quant.json @@ -0,0 +1,32 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl" +} diff --git a/configs/wan22/wan_moe_i2v_distill_with_lora.json b/configs/wan22/wan_moe_i2v_distill_with_lora.json new file mode 100644 index 0000000000000000000000000000000000000000..f28ddc45126899e4ad8f5365546bca0bf3fb745d --- /dev/null +++ b/configs/wan22/wan_moe_i2v_distill_with_lora.json @@ -0,0 +1,40 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "lora_configs": [ + { + "name": "high_noise_model", + "path": "lightx2v/Wan2.2-Distill-Loras/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", + "strength": 1.0 + }, + { + "name": "low_noise_model", + "path": "lightx2v/Wan2.2-Distill-Loras/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/wan22/wan_moe_t2v.json b/configs/wan22/wan_moe_t2v.json new file mode 100644 index 0000000000000000000000000000000000000000..d8be0d1078ba45f9b915b3e44899ac04fe05c38c --- /dev/null +++ b/configs/wan22/wan_moe_t2v.json @@ -0,0 +1,21 @@ +{ + "infer_steps": 40, + "target_video_length": 81, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 4.0, + 3.0 + ], + "sample_shift": 12.0, + "enable_cfg": true, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary": 0.875 +} diff --git a/configs/wan22/wan_moe_t2v_distill.json b/configs/wan22/wan_moe_t2v_distill.json new file mode 100644 index 0000000000000000000000000000000000000000..eabdf11a0c2232a0bc43e1d7bb1bf76e4ae44455 --- /dev/null +++ b/configs/wan22/wan_moe_t2v_distill.json @@ -0,0 +1,34 @@ +{ + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": [ + 4.0, + 3.0 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "lora_configs": [ + { + "name": "low_noise_model", + "path": "Wan2.1-T2V-14B/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors", + "strength": 1.0 + } + ] +} diff --git a/configs/wan22/wan_ti2v_i2v.json b/configs/wan22/wan_ti2v_i2v.json new file mode 100644 index 0000000000000000000000000000000000000000..44e9c4e334de58491bc6d4b6ce9e40acdc3628a6 --- /dev/null +++ b/configs/wan22/wan_ti2v_i2v.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 24, + "use_image_encoder": false +} diff --git a/configs/wan22/wan_ti2v_i2v_4090.json b/configs/wan22/wan_ti2v_i2v_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..742f8ab16d3de6d5bb29abdd116bf84604e5c3a7 --- /dev/null +++ b/configs/wan22/wan_ti2v_i2v_4090.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "fps": 24, + "use_image_encoder": false, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "vae_offload_cache": true +} diff --git a/configs/wan22/wan_ti2v_t2v.json b/configs/wan22/wan_ti2v_t2v.json new file mode 100644 index 0000000000000000000000000000000000000000..a1541e808a1d54508efc83757613619cf16dc6d7 --- /dev/null +++ b/configs/wan22/wan_ti2v_t2v.json @@ -0,0 +1,24 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "cpu_offload": false, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 24 +} diff --git a/configs/wan22/wan_ti2v_t2v_4090.json b/configs/wan22/wan_ti2v_t2v_4090.json new file mode 100644 index 0000000000000000000000000000000000000000..c1e5c49aced9fc2d3426aa6f5cc21a0defdc1f32 --- /dev/null +++ b/configs/wan22/wan_ti2v_t2v_4090.json @@ -0,0 +1,25 @@ +{ + "infer_steps": 50, + "target_video_length": 121, + "text_len": 512, + "target_height": 704, + "target_width": 1280, + "num_channels_latents": 48, + "vae_stride": [ + 4, + 16, + 16 + ], + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 5.0, + "sample_shift": 5.0, + "enable_cfg": true, + "fps": 24, + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "vae_offload_cache": true +} diff --git a/dockerfiles/Dockerfile b/dockerfiles/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..ca883dcf14053ba7daeca5f6f65ff83a2c925607 --- /dev/null +++ b/dockerfiles/Dockerfile @@ -0,0 +1,105 @@ +FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel AS base + +WORKDIR /app + +ENV DEBIAN_FRONTEND=noninteractive +ENV LANG=C.UTF-8 +ENV LC_ALL=C.UTF-8 +ENV LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH + +RUN apt-get update && apt-get install -y vim tmux zip unzip bzip2 wget git git-lfs build-essential libibverbs-dev ca-certificates \ + curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev libssl-dev flex bison libgtk-3-dev libpango1.0-dev \ + libsoup2.4-dev libnice-dev libopus-dev libvpx-dev libx264-dev libsrtp2-dev libglib2.0-dev libdrm-dev libjpeg-dev libpng-dev \ + && apt-get clean && rm -rf /var/lib/apt/lists/* && git lfs install + +RUN conda install conda-forge::ffmpeg=8.0.0 -y && conda clean -all -y + +RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv meson ruff pre-commit fastapi uvicorn requests -U + +RUN git clone https://github.com/vllm-project/vllm.git && cd vllm \ + && python use_existing_torch.py && pip install --no-cache-dir -r requirements/build.txt \ + && pip install --no-cache-dir --no-build-isolation -v -e . + +RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel \ + && make build && make clean + +RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \ + imageio-ffmpeg einops loguru qtorch ftfy av decord matplotlib debugpy + +RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive + +RUN cd flash-attention && python setup.py install && rm -rf build + +RUN cd flash-attention/hopper && python setup.py install && rm -rf build + +RUN git clone https://github.com/ModelTC/SageAttention.git --depth 1 + +RUN cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0,12.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install --no-cache-dir -v -e . + +RUN git clone https://github.com/ModelTC/SageAttention-1104.git --depth 1 + +RUN cd SageAttention-1104/sageattention3_blackwell && python setup.py install && rm -rf build + +RUN git clone https://github.com/SandAI-org/MagiAttention.git --recursive + +RUN cd MagiAttention && TORCH_CUDA_ARCH_LIST="9.0" pip install --no-cache-dir --no-build-isolation -v -e . + +RUN git clone https://github.com/ModelTC/FlashVSR.git --depth 1 + +RUN cd FlashVSR && pip install --no-cache-dir -v -e . + +COPY lightx2v_kernel /app/lightx2v_kernel + +RUN git clone https://github.com/NVIDIA/cutlass.git --depth 1 && cd /app/lightx2v_kernel && MAX_JOBS=32 && CMAKE_BUILD_PARALLEL_LEVEL=4 \ + uv build --wheel \ + -Cbuild-dir=build . \ + -Ccmake.define.CUTLASS_PATH=/app/cutlass \ + --verbose \ + --color=always \ + --no-build-isolation \ + && pip install dist/*whl --force-reinstall --no-deps \ + && rm -rf /app/lightx2v_kernel && rm -rf /app/cutlass + +# cloud deploy +RUN pip install --no-cache-dir aio-pika asyncpg>=0.27.0 aioboto3>=12.0.0 PyJWT alibabacloud_dypnsapi20170525==1.2.2 redis==6.4.0 tos + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable +ENV PATH=/root/.cargo/bin:$PATH + +RUN cd /opt \ + && wget https://mirrors.tuna.tsinghua.edu.cn/gnu/libiconv/libiconv-1.15.tar.gz \ + && tar zxvf libiconv-1.15.tar.gz \ + && cd libiconv-1.15 \ + && ./configure \ + && make \ + && make install \ + && rm -rf /opt/libiconv-1.15 + +RUN cd /opt \ + && git clone https://github.com/GStreamer/gstreamer.git -b 1.27.2 --depth 1 \ + && cd gstreamer \ + && meson setup builddir \ + && meson compile -C builddir \ + && meson install -C builddir \ + && ldconfig \ + && rm -rf /opt/gstreamer + +RUN cd /opt \ + && git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.27.2 --depth 1 \ + && cd gst-plugins-rs \ + && cargo build --package gst-plugin-webrtchttp --release \ + && install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/ \ + && rm -rf /opt/gst-plugins-rs + +RUN ldconfig + + +# q8f for base docker +RUN git clone https://github.com/KONAKONA666/q8_kernels.git --depth 1 +RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build + +# q8f for 5090 docker +# RUN git clone https://github.com/ModelTC/LTX-Video-Q8-Kernels.git --depth 1 +# RUN cd LTX-Video-Q8-Kernels && git submodule init && git submodule update && python setup.py install && rm -rf build + +WORKDIR /workspace diff --git a/dockerfiles/Dockerfile_5090 b/dockerfiles/Dockerfile_5090 new file mode 100644 index 0000000000000000000000000000000000000000..c8bce20aa53f5a010e98c1dbdbc8f7351382b164 --- /dev/null +++ b/dockerfiles/Dockerfile_5090 @@ -0,0 +1,105 @@ +FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel AS base + +WORKDIR /app + +ENV DEBIAN_FRONTEND=noninteractive +ENV LANG=C.UTF-8 +ENV LC_ALL=C.UTF-8 +ENV LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH + +RUN apt-get update && apt-get install -y vim tmux zip unzip bzip2 wget git git-lfs build-essential libibverbs-dev ca-certificates \ + curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev libssl-dev flex bison libgtk-3-dev libpango1.0-dev \ + libsoup2.4-dev libnice-dev libopus-dev libvpx-dev libx264-dev libsrtp2-dev libglib2.0-dev libdrm-dev libjpeg-dev libpng-dev \ + && apt-get clean && rm -rf /var/lib/apt/lists/* && git lfs install + +RUN conda install conda-forge::ffmpeg=8.0.0 -y && conda clean -all -y + +RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv meson ruff pre-commit fastapi uvicorn requests -U + +RUN git clone https://github.com/vllm-project/vllm.git && cd vllm \ + && python use_existing_torch.py && pip install --no-cache-dir -r requirements/build.txt \ + && pip install --no-cache-dir --no-build-isolation -v -e . + +RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel \ + && make build && make clean + +RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \ + imageio-ffmpeg einops loguru qtorch ftfy av decord matplotlib debugpy + +RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive + +RUN cd flash-attention && python setup.py install && rm -rf build + +RUN cd flash-attention/hopper && python setup.py install && rm -rf build + +RUN git clone https://github.com/ModelTC/SageAttention.git --depth 1 + +RUN cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0,12.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install --no-cache-dir -v -e . + +RUN git clone https://github.com/ModelTC/SageAttention-1104.git --depth 1 + +RUN cd SageAttention-1104/sageattention3_blackwell && python setup.py install && rm -rf build + +RUN git clone https://github.com/SandAI-org/MagiAttention.git --recursive + +RUN cd MagiAttention && TORCH_CUDA_ARCH_LIST="9.0" pip install --no-cache-dir --no-build-isolation -v -e . + +RUN git clone https://github.com/ModelTC/FlashVSR.git --depth 1 + +RUN cd FlashVSR && pip install --no-cache-dir -v -e . + +COPY lightx2v_kernel /app/lightx2v_kernel + +RUN git clone https://github.com/NVIDIA/cutlass.git --depth 1 && cd /app/lightx2v_kernel && MAX_JOBS=32 && CMAKE_BUILD_PARALLEL_LEVEL=4 \ + uv build --wheel \ + -Cbuild-dir=build . \ + -Ccmake.define.CUTLASS_PATH=/app/cutlass \ + --verbose \ + --color=always \ + --no-build-isolation \ + && pip install dist/*whl --force-reinstall --no-deps \ + && rm -rf /app/lightx2v_kernel && rm -rf /app/cutlass + +# cloud deploy +RUN pip install --no-cache-dir aio-pika asyncpg>=0.27.0 aioboto3>=12.0.0 PyJWT alibabacloud_dypnsapi20170525==1.2.2 redis==6.4.0 tos + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable +ENV PATH=/root/.cargo/bin:$PATH + +RUN cd /opt \ + && wget https://mirrors.tuna.tsinghua.edu.cn/gnu/libiconv/libiconv-1.15.tar.gz \ + && tar zxvf libiconv-1.15.tar.gz \ + && cd libiconv-1.15 \ + && ./configure \ + && make \ + && make install \ + && rm -rf /opt/libiconv-1.15 + +RUN cd /opt \ + && git clone https://github.com/GStreamer/gstreamer.git -b 1.27.2 --depth 1 \ + && cd gstreamer \ + && meson setup builddir \ + && meson compile -C builddir \ + && meson install -C builddir \ + && ldconfig \ + && rm -rf /opt/gstreamer + +RUN cd /opt \ + && git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.27.2 --depth 1 \ + && cd gst-plugins-rs \ + && cargo build --package gst-plugin-webrtchttp --release \ + && install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/ \ + && rm -rf /opt/gst-plugins-rs + +RUN ldconfig + + +# q8f for base docker +# RUN git clone https://github.com/KONAKONA666/q8_kernels.git --depth 1 +# RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build + +# q8f for 5090 docker +RUN git clone https://github.com/ModelTC/LTX-Video-Q8-Kernels.git --depth 1 +RUN cd LTX-Video-Q8-Kernels && git submodule init && git submodule update && python setup.py install && rm -rf build + +WORKDIR /workspace diff --git a/dockerfiles/Dockerfile_cambricon_mlu590 b/dockerfiles/Dockerfile_cambricon_mlu590 new file mode 100644 index 0000000000000000000000000000000000000000..3cfa0278ea2199a52d39a2c79b3a7dd412f00281 --- /dev/null +++ b/dockerfiles/Dockerfile_cambricon_mlu590 @@ -0,0 +1,31 @@ +FROM cambricon-base/pytorch:v25.10.0-torch2.8.0-torchmlu1.29.1-ubuntu22.04-py310 AS base + +WORKDIR /workspace/LightX2V + +# Set envs +ENV PYTHONPATH=/workspace/LightX2V +ENV LD_LIBRARY_PATH=/usr/local/neuware/lib64:${LD_LIBRARY_PATH} + +# Install deps +RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg && \ + pip install --no-cache-dir \ + ftfy \ + imageio \ + imageio-ffmpeg \ + loguru \ + aiohttp \ + gguf \ + diffusers \ + peft==0.17.0 \ + transformers==4.57.1 && + +# Copy files +COPY app app +COPY assets assets +COPY configs configs +COPY lightx2v lightx2v +COPY lightx2v_kernel lightx2v_kernel +COPY lightx2v_platform lightx2v_platform +COPY scripts scripts +COPY test_cases test_cases +COPY tools tools diff --git a/dockerfiles/Dockerfile_cu124 b/dockerfiles/Dockerfile_cu124 new file mode 100644 index 0000000000000000000000000000000000000000..ada01a3b841603e1663229882635bb716a0996bc --- /dev/null +++ b/dockerfiles/Dockerfile_cu124 @@ -0,0 +1,76 @@ +FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel AS base + +WORKDIR /app + +ENV DEBIAN_FRONTEND=noninteractive +ENV LANG=C.UTF-8 +ENV LC_ALL=C.UTF-8 +ENV LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH + +RUN apt-get update && apt-get install -y vim tmux zip unzip wget git git-lfs build-essential libibverbs-dev ca-certificates \ + curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev libssl-dev flex bison libgtk-3-dev libpango1.0-dev \ + libsoup2.4-dev libnice-dev libopus-dev libvpx-dev libx264-dev libsrtp2-dev libglib2.0-dev libdrm-dev\ + && apt-get clean && rm -rf /var/lib/apt/lists/* && git lfs install + +RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv meson ruff pre-commit fastapi uvicorn requests -U + +RUN git clone https://github.com/vllm-project/vllm.git -b v0.10.0 && cd vllm \ + && python use_existing_torch.py && pip install -r requirements/build.txt \ + && pip install --no-cache-dir --no-build-isolation -v -e . + +RUN git clone https://github.com/sgl-project/sglang.git -b v0.4.10 && cd sglang/sgl-kernel \ + && make build && make clean + +RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \ + imageio-ffmpeg einops loguru qtorch ftfy av decord + +RUN conda install conda-forge::ffmpeg=8.0.0 -y && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg && conda clean -all -y + +RUN git clone https://github.com/Dao-AILab/flash-attention.git -b v2.8.3 --recursive + +RUN cd flash-attention && python setup.py install && rm -rf build + +RUN cd flash-attention/hopper && python setup.py install && rm -rf build + +RUN git clone https://github.com/ModelTC/SageAttention.git + +RUN cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install --no-cache-dir -v -e . + +RUN git clone https://github.com/KONAKONA666/q8_kernels.git + +RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build + +# cloud deploy +RUN pip install --no-cache-dir aio-pika asyncpg>=0.27.0 aioboto3>=12.0.0 PyJWT alibabacloud_dypnsapi20170525==1.2.2 redis==6.4.0 tos + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable +ENV PATH=/root/.cargo/bin:$PATH + +RUN cd /opt \ + && wget https://mirrors.tuna.tsinghua.edu.cn/gnu//libiconv/libiconv-1.15.tar.gz \ + && tar zxvf libiconv-1.15.tar.gz \ + && cd libiconv-1.15 \ + && ./configure \ + && make \ + && make install \ + && rm -rf /opt/libiconv-1.15 + +RUN cd /opt \ + && git clone https://github.com/GStreamer/gstreamer.git -b 1.24.12 --depth 1 \ + && cd gstreamer \ + && meson setup builddir \ + && meson compile -C builddir \ + && meson install -C builddir \ + && ldconfig \ + && rm -rf /opt/gstreamer + +RUN cd /opt \ + && git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.24.12 --depth 1 \ + && cd gst-plugins-rs \ + && cargo build --package gst-plugin-webrtchttp --release \ + && install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/ \ + && rm -rf /opt/gst-plugins-rs + +RUN ldconfig + +WORKDIR /workspace diff --git a/dockerfiles/Dockerfile_deploy b/dockerfiles/Dockerfile_deploy new file mode 100644 index 0000000000000000000000000000000000000000..ee4b1fa068483f8bcafc636b572b557e36a23250 --- /dev/null +++ b/dockerfiles/Dockerfile_deploy @@ -0,0 +1,32 @@ +FROM node:alpine3.21 AS frontend_builder +COPY lightx2v /opt/lightx2v + +RUN cd /opt/lightx2v/deploy/server/frontend \ + && npm install \ + && npm run build + +FROM lightx2v/lightx2v:25111101-cu128 AS base + +RUN mkdir /workspace/LightX2V +WORKDIR /workspace/LightX2V +ENV PYTHONPATH=/workspace/LightX2V + +# for multi-person & animate +RUN pip install ultralytics moviepy pydub pyannote.audio onnxruntime decord peft onnxruntime pandas matplotlib loguru sentencepiece + +RUN export COMMIT=0e78a118995e66bb27d78518c4bd9a3e95b4e266 \ + && export TORCH_CUDA_ARCH_LIST="9.0" \ + && git clone --depth 1 https://github.com/facebookresearch/sam2.git \ + && cd sam2 \ + && git fetch --depth 1 origin $COMMIT \ + && git checkout $COMMIT \ + && python setup.py install + +COPY tools tools +COPY assets assets +COPY configs configs +COPY lightx2v lightx2v +COPY lightx2v_kernel lightx2v_kernel +COPY lightx2v_platform lightx2v_platform + +COPY --from=frontend_builder /opt/lightx2v/deploy/server/frontend/dist lightx2v/deploy/server/frontend/dist diff --git a/docs/EN/.readthedocs.yaml b/docs/EN/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dda16a6b056a96ef6afe797884912090a612f567 --- /dev/null +++ b/docs/EN/.readthedocs.yaml @@ -0,0 +1,17 @@ +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.10" + +formats: + - epub + +sphinx: + configuration: docs/EN/source/conf.py + +python: + install: + - requirements: requirements-docs.txt diff --git a/docs/EN/Makefile b/docs/EN/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293 --- /dev/null +++ b/docs/EN/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/EN/make.bat b/docs/EN/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a --- /dev/null +++ b/docs/EN/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/EN/source/conf.py b/docs/EN/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..25f127e5b6e95951e941d510fce9cf6788bbaab9 --- /dev/null +++ b/docs/EN/source/conf.py @@ -0,0 +1,128 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +import logging +import os +import sys +from typing import List + +import sphinxcontrib.redoc +from sphinx.ext import autodoc + +logger = logging.getLogger(__name__) +sys.path.append(os.path.abspath("../..")) + +# -- Project information ----------------------------------------------------- + +project = "Lightx2v" +copyright = "2025, Lightx2v Team" +author = "the Lightx2v Team" + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.intersphinx", + "sphinx_copybutton", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.mathjax", + "myst_parser", + "sphinxarg.ext", + "sphinxcontrib.redoc", + "sphinxcontrib.openapi", +] + +myst_enable_extensions = [ + "dollarmath", + "amsmath", +] + +html_static_path = ["_static"] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns: List[str] = ["**/*.template.rst"] + +# Exclude the prompt "$" when copying code +copybutton_prompt_text = r"\$ " +copybutton_prompt_is_regexp = True + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_title = project +html_theme = "sphinx_book_theme" +# html_theme = 'sphinx_rtd_theme' +html_logo = "../../../assets/img_lightx2v.png" +html_theme_options = { + "path_to_docs": "docs/EN/source", + "repository_url": "https://github.com/ModelTC/lightx2v", + "use_repository_button": True, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ['_static'] + + +# Generate additional rst documentation here. +def setup(app): + # from docs.source.generate_examples import generate_examples + # generate_examples() + pass + + +# Mock out external dependencies here. +autodoc_mock_imports = [ + "cpuinfo", + "torch", + "transformers", + "psutil", + "prometheus_client", + "sentencepiece", + "lightllmnumpy", + "tqdm", + "tensorizer", +] + +for mock_target in autodoc_mock_imports: + if mock_target in sys.modules: + logger.info( + "Potentially problematic mock target (%s) found; autodoc_mock_imports cannot mock modules that have already been loaded into sys.modules when the sphinx build starts.", + mock_target, + ) + + +class MockedClassDocumenter(autodoc.ClassDocumenter): + """Remove note about base class when a class is derived from object.""" + + def add_line(self, line: str, source: str, *lineno: int) -> None: + if line == " Bases: :py:class:`object`": + return + super().add_line(line, source, *lineno) + + +autodoc.ClassDocumenter = MockedClassDocumenter + +navigation_with_keys = False diff --git a/docs/EN/source/deploy_guides/deploy_comfyui.md b/docs/EN/source/deploy_guides/deploy_comfyui.md new file mode 100644 index 0000000000000000000000000000000000000000..afde7bd5307eda6874432a204da335a4fa733909 --- /dev/null +++ b/docs/EN/source/deploy_guides/deploy_comfyui.md @@ -0,0 +1,25 @@ +# ComfyUI Deployment + +## ComfyUI-Lightx2vWrapper + +The official ComfyUI integration nodes for LightX2V are now available in a dedicated repository, providing a complete modular configuration system and optimization features. + +### Project Repository + +- GitHub: [https://github.com/ModelTC/ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper) + +### Key Features + +- Modular Configuration System: Separate nodes for each aspect of video generation +- Support for both Text-to-Video (T2V) and Image-to-Video (I2V) generation modes +- Advanced Optimizations: + - TeaCache acceleration (up to 3x speedup) + - Quantization support (int8, fp8) + - Memory optimization with CPU offloading + - Lightweight VAE options +- LoRA Support: Chain multiple LoRA models for customization +- Multiple Model Support: wan2.1, hunyuan architectures + +### Installation and Usage + +Please visit the GitHub repository above for detailed installation instructions, usage tutorials, and example workflows. diff --git a/docs/EN/source/deploy_guides/deploy_gradio.md b/docs/EN/source/deploy_guides/deploy_gradio.md new file mode 100644 index 0000000000000000000000000000000000000000..5be510feee143be743e73149024b456c40d91c96 --- /dev/null +++ b/docs/EN/source/deploy_guides/deploy_gradio.md @@ -0,0 +1,240 @@ +# Gradio Deployment Guide + +## 📖 Overview + +Lightx2v is a lightweight video inference and generation engine that provides a web interface based on Gradio, supporting both Image-to-Video and Text-to-Video generation modes. + +For Windows systems, we provide a convenient one-click deployment solution with automatic environment configuration and intelligent parameter optimization. Please refer to the [One-Click Gradio Startup (Recommended)](./deploy_local_windows.md/#one-click-gradio-startup-recommended) section for detailed instructions. + +![Gradio English Interface](../../../../assets/figs/portabl_windows/pic_gradio_en.png) + +## 📁 File Structure + +``` +LightX2V/app/ +├── gradio_demo.py # English interface demo +├── gradio_demo_zh.py # Chinese interface demo +├── run_gradio.sh # Startup script +├── README.md # Documentation +├── outputs/ # Generated video save directory +└── inference_logs.log # Inference logs +``` + +This project contains two main demo files: +- `gradio_demo.py` - English interface version +- `gradio_demo_zh.py` - Chinese interface version + +## 🚀 Quick Start + +### Environment Requirements + +Follow the [Quick Start Guide](../getting_started/quickstart.md) to install the environment + +#### Recommended Optimization Library Configuration + +- ✅ [Flash attention](https://github.com/Dao-AILab/flash-attention) +- ✅ [Sage attention](https://github.com/thu-ml/SageAttention) +- ✅ [vllm-kernel](https://github.com/vllm-project/vllm) +- ✅ [sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) +- ✅ [q8-kernel](https://github.com/KONAKONA666/q8_kernels) (only supports ADA architecture GPUs) + +Install according to the project homepage tutorials for each operator as needed. + +### 📥 Model Download + +Refer to the [Model Structure Documentation](../getting_started/model_structure.md) to download complete models (including quantized and non-quantized versions) or download only quantized/non-quantized versions. + +#### wan2.1 Model Directory Structure + +``` +models/ +├── wan2.1_i2v_720p_lightx2v_4step.safetensors # Original precision +├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization +├── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 quantization +├── wan2.1_i2v_720p_int8_lightx2v_4step_split # INT8 quantization block storage directory +├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split # FP8 quantization block storage directory +├── Other weights (e.g., t2v) +├── t5/clip/xlm-roberta-large/google # text and image encoder +├── vae/lightvae/lighttae # vae +└── config.json # Model configuration file +``` + +#### wan2.2 Model Directory Structure + +``` +models/ +├── wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors # high noise original precision +├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step_1030.safetensors # high noise FP8 quantization +├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030.safetensors # high noise INT8 quantization +├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030_split # high noise INT8 quantization block storage directory +├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # low noise original precision +├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # low noise FP8 quantization +├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # low noise INT8 quantization +├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step_split # low noise INT8 quantization block storage directory +├── t5/clip/xlm-roberta-large/google # text and image encoder +├── vae/lightvae/lighttae # vae +└── config.json # Model configuration file +``` + +**📝 Download Instructions**: + +- Model weights can be downloaded from HuggingFace: + - [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) + - [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) +- Text and Image Encoders can be downloaded from [Encoders](https://huggingface.co/lightx2v/Encoders) +- VAE can be downloaded from [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) +- For `xxx_split` directories (e.g., `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split`), which store multiple safetensors by block, suitable for devices with insufficient memory. For example, devices with 16GB or less memory should download according to their own situation. + +### Startup Methods + +#### Method 1: Using Startup Script (Recommended) + +**Linux Environment:** +```bash +# 1. Edit the startup script to configure relevant paths +cd app/ +vim run_gradio.sh + +# Configuration items that need to be modified: +# - lightx2v_path: Lightx2v project root directory path +# - model_path: Model root directory path (contains all model files) + +# 💾 Important note: Recommend pointing model paths to SSD storage locations +# Example: /mnt/ssd/models/ or /data/ssd/models/ + +# 2. Run the startup script +bash run_gradio.sh + +# 3. Or start with parameters +bash run_gradio.sh --lang en --port 8032 +bash run_gradio.sh --lang zh --port 7862 +``` + +**Windows Environment:** +```cmd +# 1. Edit the startup script to configure relevant paths +cd app\ +notepad run_gradio_win.bat + +# Configuration items that need to be modified: +# - lightx2v_path: Lightx2v project root directory path +# - model_path: Model root directory path (contains all model files) + +# 💾 Important note: Recommend pointing model paths to SSD storage locations +# Example: D:\models\ or E:\models\ + +# 2. Run the startup script +run_gradio_win.bat + +# 3. Or start with parameters +run_gradio_win.bat --lang en --port 8032 +run_gradio_win.bat --lang zh --port 7862 +``` + +#### Method 2: Direct Command Line Startup + +```bash +pip install -v git+https://github.com/ModelTC/LightX2V.git +``` + +**Linux Environment:** + +**English Interface Version:** +```bash +python gradio_demo.py \ + --model_path /path/to/models \ + --server_name 0.0.0.0 \ + --server_port 7862 +``` + +**Chinese Interface Version:** +```bash +python gradio_demo_zh.py \ + --model_path /path/to/models \ + --server_name 0.0.0.0 \ + --server_port 7862 +``` + +**Windows Environment:** + +**English Interface Version:** +```cmd +python gradio_demo.py ^ + --model_path D:\models ^ + --server_name 127.0.0.1 ^ + --server_port 7862 +``` + +**Chinese Interface Version:** +```cmd +python gradio_demo_zh.py ^ + --model_path D:\models ^ + --server_name 127.0.0.1 ^ + --server_port 7862 +``` + +**💡 Tip**: Model type (wan2.1/wan2.2), task type (i2v/t2v), and specific model file selection are all configured in the Web interface. + +## 📋 Command Line Parameters + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `--model_path` | str | ✅ | - | Model root directory path (directory containing all model files) | +| `--server_port` | int | ❌ | 7862 | Server port | +| `--server_name` | str | ❌ | 0.0.0.0 | Server IP address | +| `--output_dir` | str | ❌ | ./outputs | Output video save directory | + +**💡 Note**: Model type (wan2.1/wan2.2), task type (i2v/t2v), and specific model file selection are all configured in the Web interface. + +## 🎯 Features + +### Model Configuration + +- **Model Type**: Supports wan2.1 and wan2.2 model architectures +- **Task Type**: Supports Image-to-Video (i2v) and Text-to-Video (t2v) generation modes +- **Model Selection**: Frontend automatically identifies and filters available model files, supports automatic quantization precision detection +- **Encoder Configuration**: Supports selection of T5 text encoder, CLIP image encoder, and VAE decoder +- **Operator Selection**: Supports multiple attention operators and quantization matrix multiplication operators, system automatically sorts by installation status + +### Input Parameters + +- **Prompt**: Describe the expected video content +- **Negative Prompt**: Specify elements you don't want to appear +- **Input Image**: Upload input image required in i2v mode +- **Resolution**: Supports multiple preset resolutions (480p/540p/720p) +- **Random Seed**: Controls the randomness of generation results +- **Inference Steps**: Affects the balance between generation quality and speed (defaults to 4 steps for distilled models) + +### Video Parameters + +- **FPS**: Frames per second +- **Total Frames**: Video length +- **CFG Scale Factor**: Controls prompt influence strength (1-10, defaults to 1 for distilled models) +- **Distribution Shift**: Controls generation style deviation degree (0-10) + +## 🔧 Auto-Configuration Feature + +The system automatically configures optimal inference options based on your hardware configuration (GPU VRAM and CPU memory) without manual adjustment. The best configuration is automatically applied on startup, including: + +- **GPU Memory Optimization**: Automatically enables CPU offloading, VAE tiling inference, etc. based on VRAM size +- **CPU Memory Optimization**: Automatically enables lazy loading, module unloading, etc. based on system memory +- **Operator Selection**: Automatically selects the best installed operators (sorted by priority) +- **Quantization Configuration**: Automatically detects and applies quantization precision based on model file names + + +### Log Viewing + +```bash +# View inference logs +tail -f inference_logs.log + +# View GPU usage +nvidia-smi + +# View system resources +htop +``` + +Welcome to submit Issues and Pull Requests to improve this project! + +**Note**: Please comply with relevant laws and regulations when using videos generated by this tool, and do not use them for illegal purposes. diff --git a/docs/EN/source/deploy_guides/deploy_local_windows.md b/docs/EN/source/deploy_guides/deploy_local_windows.md new file mode 100644 index 0000000000000000000000000000000000000000..e675f9c5af1bae7d2220747a85219a8ccbfab291 --- /dev/null +++ b/docs/EN/source/deploy_guides/deploy_local_windows.md @@ -0,0 +1,127 @@ +# Windows Local Deployment Guide + +## 📖 Overview + +This document provides detailed instructions for deploying LightX2V locally on Windows environments, including batch file inference, Gradio Web interface inference, and other usage methods. + +## 🚀 Quick Start + +### Environment Requirements + +#### Hardware Requirements +- **GPU**: NVIDIA GPU, recommended 8GB+ VRAM +- **Memory**: Recommended 16GB+ RAM +- **Storage**: Strongly recommended to use SSD solid-state drives, mechanical hard drives will cause slow model loading + + +## 🎯 Usage Methods + +### Method 1: Using Batch File Inference + +Refer to [Quick Start Guide](../getting_started/quickstart.md) to install environment, and use [batch files](https://github.com/ModelTC/LightX2V/tree/main/scripts/win) to run. + +### Method 2: Using Gradio Web Interface Inference + +#### Manual Gradio Configuration + +Refer to [Quick Start Guide](../getting_started/quickstart.md) to install environment, refer to [Gradio Deployment Guide](./deploy_gradio.md) + +#### One-Click Gradio Startup (Recommended) + +**📦 Download Software Package** +- [Quark Cloud](https://pan.quark.cn/s/8af1162d7a15) + +**📁 Directory Structure** +After extraction, ensure the directory structure is as follows: + +``` +├── env/ # LightX2V environment directory +├── LightX2V/ # LightX2V project directory +├── start_lightx2v.bat # One-click startup script +├── lightx2v_config.txt # Configuration file +├── LightX2V使用说明.txt # LightX2V usage instructions +├── outputs/ # Generated video save directory +└── models/ # Model storage directory +``` + +**📥 Model Download**: + +Refer to [Model Structure Documentation](../getting_started/model_structure.md) or [Gradio Deployment Guide](./deploy_gradio.md) to download complete models (including quantized and non-quantized versions) or download only quantized/non-quantized versions. + + +**📋 Configuration Parameters** + +Edit the `lightx2v_config.txt` file and modify the following parameters as needed: + +```ini + +# Interface language (zh: Chinese, en: English) +lang=en + +# Server port +port=8032 + +# GPU device ID (0, 1, 2...) +gpu=0 + +# Model path +model_path=models/ +``` + +**🚀 Start Service** + +Double-click to run the `start_lightx2v.bat` file, the script will: +1. Automatically read configuration file +2. Verify model paths and file integrity +3. Start Gradio Web interface +4. Automatically open browser to access service + + +![Gradio English Interface](../../../../assets/figs/portabl_windows/pic_gradio_en.png) + +**⚠️ Important Notes**: +- **Display Issues**: If the webpage opens blank or displays abnormally, please run `pip install --upgrade gradio` to upgrade the Gradio version. + +### Method 3: Using ComfyUI Inference + +This guide will instruct you on how to download and use the portable version of the Lightx2v-ComfyUI environment, so you can avoid manual environment configuration steps. This is suitable for users who want to quickly start experiencing accelerated video generation with Lightx2v on Windows systems. + +#### Download the Windows Portable Environment: + +- [Baidu Cloud Download](https://pan.baidu.com/s/1FVlicTXjmXJA1tAVvNCrBw?pwd=wfid), extraction code: wfid + +The portable environment already packages all Python runtime dependencies, including the code and dependencies for ComfyUI and LightX2V. After downloading, simply extract to use. + +After extraction, the directory structure is as follows: + +```shell +lightx2v_env +├──📂 ComfyUI # ComfyUI code +├──📂 portable_python312_embed # Standalone Python environment +└── run_nvidia_gpu.bat # Windows startup script (double-click to start) +``` + +#### Start ComfyUI + +Directly double-click the run_nvidia_gpu.bat file. The system will open a Command Prompt window and run the program. The first startup may take a while, please be patient. After startup is complete, the browser will automatically open and display the ComfyUI frontend interface. + +![i2v example workflow](../../../../assets/figs/portabl_windows/pic1.png) + +The plugin used by LightX2V-ComfyUI is [ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper). Example workflows can be obtained from this project. + +#### Tested Graphics Cards (offload mode) + +- Tested model: `Wan2.1-I2V-14B-480P` + +| GPU Model | Task Type | VRAM Capacity | Actual Max VRAM Usage | Actual Max RAM Usage | +|:-----------|:------------|:--------------|:---------------------|:---------------------| +| 3090Ti | I2V | 24G | 6.1G | 7.1G | +| 3080Ti | I2V | 12G | 6.1G | 7.1G | +| 3060Ti | I2V | 8G | 6.1G | 7.1G | +| 4070Ti Super | I2V | 16G | 6.1G | 7.1G | +| 4070 | I2V | 12G | 6.1G | 7.1G | +| 4060 | I2V | 8G | 6.1G | 7.1G | + +#### Environment Packaging and Usage Reference +- [ComfyUI](https://github.com/comfyanonymous/ComfyUI) +- [Portable-Windows-ComfyUI-Docs](https://docs.comfy.org/zh-CN/installation/comfyui_portable_windows#portable-%E5%8F%8A%E8%87%AA%E9%83%A8%E7%BD%B2) diff --git a/docs/EN/source/deploy_guides/deploy_service.md b/docs/EN/source/deploy_guides/deploy_service.md new file mode 100644 index 0000000000000000000000000000000000000000..b34b349622d77ba874b394404794976652547a01 --- /dev/null +++ b/docs/EN/source/deploy_guides/deploy_service.md @@ -0,0 +1,88 @@ +# Service Deployment + +lightx2v provides asynchronous service functionality. The code entry point is [here](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py) + +### Start the Service + +```shell +# Modify the paths in the script +bash scripts/start_server.sh +``` + +The `--port 8000` option means the service will bind to port `8000` on the local machine. You can change this as needed. + +### Client Sends Request + +```shell +python scripts/post.py +``` + +The service endpoint is: `/v1/tasks/` + +The `message` parameter in `scripts/post.py` is as follows: + +```python +message = { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "", +} +``` + +1. `prompt`, `negative_prompt`, and `image_path` are basic inputs for video generation. `image_path` can be an empty string, indicating no image input is needed. + + +### Client Checks Server Status + +```shell +python scripts/check_status.py +``` + +The service endpoints include: + +1. `/v1/service/status` is used to check the status of the service. It returns whether the service is `busy` or `idle`. The service only accepts new requests when it is `idle`. + +2. `/v1/tasks/` is used to get all tasks received and completed by the server. + +3. `/v1/tasks/{task_id}/status` is used to get the status of a specified `task_id`. It returns whether the task is `processing` or `completed`. + +### Client Stops the Current Task on the Server at Any Time + +```shell +python scripts/stop_running_task.py +``` + +The service endpoint is: `/v1/tasks/running` + +After terminating the task, the server will not exit but will return to waiting for new requests. + +### Starting Multiple Services on a Single Node + +On a single node, you can start multiple services using `scripts/start_server.sh` (Note that the port numbers under the same IP must be different for each service), or you can start multiple services at once using `scripts/start_multi_servers.sh`: + +```shell +num_gpus=8 bash scripts/start_multi_servers.sh +``` + +Where `num_gpus` indicates the number of services to start; the services will run on consecutive ports starting from `--start_port`. + +### Scheduling Between Multiple Services + +```shell +python scripts/post_multi_servers.py +``` + +`post_multi_servers.py` will schedule multiple client requests based on the idle status of the services. + +### API Endpoints Summary + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/v1/tasks/` | POST | Create video generation task | +| `/v1/tasks/form` | POST | Create video generation task via form | +| `/v1/tasks/` | GET | Get all task list | +| `/v1/tasks/{task_id}/status` | GET | Get status of specified task | +| `/v1/tasks/{task_id}/result` | GET | Get result video file of specified task | +| `/v1/tasks/running` | DELETE | Stop currently running task | +| `/v1/files/download/{file_path}` | GET | Download file | +| `/v1/service/status` | GET | Get service status | diff --git a/docs/EN/source/deploy_guides/for_low_latency.md b/docs/EN/source/deploy_guides/for_low_latency.md new file mode 100644 index 0000000000000000000000000000000000000000..5e8df8a467452ff0223c9f62295e8c74b73b0002 --- /dev/null +++ b/docs/EN/source/deploy_guides/for_low_latency.md @@ -0,0 +1,41 @@ +# Deployment for Low Latency Scenarios + +In low latency scenarios, we pursue faster speed, ignoring issues such as video memory and RAM overhead. We provide two solutions: + +## 💡 Solution 1: Inference with Step Distillation Model + +This solution can refer to the [Step Distillation Documentation](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/step_distill.html) + +🧠 **Step Distillation** is a very direct acceleration inference solution for video generation models. By distilling from 50 steps to 4 steps, the time consumption will be reduced to 4/50 of the original. At the same time, under this solution, it can still be combined with the following solutions: +1. [Efficient Attention Mechanism Solution](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/attention.html) +2. [Model Quantization](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/quantization.html) + +## 💡 Solution 2: Inference with Non-Step Distillation Model + +Step distillation requires relatively large training resources, and the model after step distillation may have degraded video dynamic range. + +For the original model without step distillation, we can use the following solutions or a combination of multiple solutions for acceleration: + +1. [Parallel Inference](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/parallel.html) for multi-GPU parallel acceleration. +2. [Feature Caching](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/cache.html) to reduce the actual inference steps. +3. [Efficient Attention Mechanism Solution](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/attention.html) to accelerate Attention inference. +4. [Model Quantization](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/quantization.html) to accelerate Linear layer inference. +5. [Variable Resolution Inference](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/changing_resolution.html) to reduce the resolution of intermediate inference steps. + +## 💡 Using Tiny VAE + +In some cases, the VAE component can be time-consuming. You can use a lightweight VAE for acceleration, which can also reduce some GPU memory usage. + +```python +{ + "use_tae": true, + "tae_path": "/path to taew2_1.pth" +} +``` +The taew2_1.pth weights can be downloaded from [here](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth) + +## ⚠️ Note + +Some acceleration solutions currently cannot be used together, and we are working to resolve this issue. + +If you have any questions, feel free to report bugs or request features in [🐛 GitHub Issues](https://github.com/ModelTC/lightx2v/issues) diff --git a/docs/EN/source/deploy_guides/for_low_resource.md b/docs/EN/source/deploy_guides/for_low_resource.md new file mode 100644 index 0000000000000000000000000000000000000000..ad11c7489ff6f5748dcfd2b482bfd93600d9437f --- /dev/null +++ b/docs/EN/source/deploy_guides/for_low_resource.md @@ -0,0 +1,219 @@ +# Lightx2v Low-Resource Deployment Guide + +## 📋 Overview + +This guide is specifically designed for hardware resource-constrained environments, particularly configurations with **8GB VRAM + 16/32GB RAM**, providing detailed instructions on how to successfully run Lightx2v 14B models for 480p and 720p video generation. + +Lightx2v is a powerful video generation model, but it requires careful optimization to run smoothly in resource-constrained environments. This guide provides a complete solution from hardware selection to software configuration, ensuring you can achieve the best video generation experience under limited hardware conditions. + +## 🎯 Target Hardware Configuration + +### Recommended Hardware Specifications + +**GPU Requirements**: +- **VRAM**: 8GB (RTX 3060/3070/4060/4060Ti, etc.) +- **Architecture**: NVIDIA graphics cards with CUDA support + +**System Memory**: +- **Minimum**: 16GB DDR4 +- **Recommended**: 32GB DDR4/DDR5 +- **Memory Speed**: 3200MHz or higher recommended + +**Storage Requirements**: +- **Type**: NVMe SSD strongly recommended +- **Capacity**: At least 50GB available space +- **Speed**: Read speed of 3000MB/s or higher recommended + +**CPU Requirements**: +- **Cores**: 8 cores or more recommended +- **Frequency**: 3.0GHz or higher recommended +- **Architecture**: Support for AVX2 instruction set + +## ⚙️ Core Optimization Strategies + +### 1. Environment Optimization + +Before running Lightx2v, it's recommended to set the following environment variables to optimize performance: + +```bash +# CUDA memory allocation optimization +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Enable CUDA Graph mode to improve inference performance +export ENABLE_GRAPH_MODE=true + +# Use BF16 precision for inference to reduce VRAM usage (default FP32 precision) +export DTYPE=BF16 +``` + +**Optimization Details**: +- `expandable_segments:True`: Allows dynamic expansion of CUDA memory segments, reducing memory fragmentation +- `ENABLE_GRAPH_MODE=true`: Enables CUDA Graph to reduce kernel launch overhead +- `DTYPE=BF16`: Uses BF16 precision to reduce VRAM usage while maintaining quality + +### 2. Quantization Strategy + +Quantization is a key optimization technique in low-resource environments, reducing memory usage by lowering model precision. + +#### Quantization Scheme Comparison + +**FP8 Quantization** (Recommended for RTX 40 series): +```python +# Suitable for GPUs supporting FP8, providing better precision +dit_quant_scheme = "fp8" # DIT model quantization +t5_quant_scheme = "fp8" # T5 text encoder quantization +clip_quant_scheme = "fp8" # CLIP visual encoder quantization +``` + +**INT8 Quantization** (Universal solution): +```python +# Suitable for all GPUs, minimal memory usage +dit_quant_scheme = "int8" # 8-bit integer quantization +t5_quant_scheme = "int8" # Text encoder quantization +clip_quant_scheme = "int8" # Visual encoder quantization +``` + +### 3. Efficient Operator Selection Guide + +Choosing the right operators can significantly improve inference speed and reduce memory usage. + +#### Attention Operator Selection + +**Recommended Priority**: +1. **[Sage Attention](https://github.com/thu-ml/SageAttention)** (Highest priority) + +2. **[Flash Attention](https://github.com/Dao-AILab/flash-attention)** (Universal solution) + +#### Matrix Multiplication Operator Selection + +**ADA Architecture GPUs** (RTX 40 series): + +Recommended priority: +1. **[q8-kernel](https://github.com/KONAKONA666/q8_kernels)** (Highest performance, ADA architecture only) +2. **[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)** (Balanced solution) +3. **[vllm-kernel](https://github.com/vllm-project/vllm)** (Universal solution) + +**Other Architecture GPUs**: +1. **[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)** (Recommended) +2. **[vllm-kernel](https://github.com/vllm-project/vllm)** (Alternative) + +### 4. Parameter Offloading Strategy + +Parameter offloading technology allows models to dynamically schedule parameters between CPU and disk, breaking through VRAM limitations. + +#### Three-Level Offloading Architecture + +```python +# Disk-CPU-GPU three-level offloading configuration +cpu_offload=True # Enable CPU offloading +t5_cpu_offload=True # Enable T5 encoder CPU offloading +offload_granularity=phase # DIT model fine-grained offloading +t5_offload_granularity=block # T5 encoder fine-grained offloading +lazy_load = True # Enable lazy loading mechanism +num_disk_workers = 2 # Disk I/O worker threads +``` + +#### Offloading Strategy Details + +**Lazy Loading Mechanism**: +- Model parameters are loaded from disk to CPU on demand +- Reduces runtime memory usage +- Supports large models running with limited memory + +**Disk Storage Optimization**: +- Use high-speed SSD to store model parameters +- Store model files grouped by blocks +- Refer to conversion script [documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme.md), specify `--save_by_block` parameter during conversion + +### 5. VRAM Optimization Techniques + +VRAM optimization strategies for 720p video generation. + +#### CUDA Memory Management + +```python +# CUDA memory cleanup configuration +clean_cuda_cache = True # Timely cleanup of GPU cache +rotary_chunk = True # Rotary position encoding chunked computation +rotary_chunk_size = 100 # Chunk size, adjustable based on VRAM +``` + +#### Chunked Computation Strategy + +**Rotary Position Encoding Chunking**: +- Process long sequences in small chunks +- Reduce peak VRAM usage +- Maintain computational precision + +### 6. VAE Optimization + +VAE (Variational Autoencoder) is a key component in video generation, and optimizing VAE can significantly improve performance. + +#### VAE Chunked Inference + +```python +# VAE optimization configuration +use_tiling_vae = True # Enable VAE chunked inference +``` + +#### Lightweight VAE + +```python +# VAE optimization configuration +use_tae = True # Use lightweight VAE +tae_path = "/path to taew2_1.pth" +``` +You can download taew2_1.pth [here](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth) + +**VAE Optimization Effects**: +- Standard VAE: Baseline performance, 100% quality retention +- Standard VAE chunked: Reduces VRAM usage, increases inference time, 100% quality retention +- Lightweight VAE: Extremely low VRAM usage, video quality loss + +### 7. Model Selection Strategy + +Choosing the right model version is crucial for low-resource environments. + +#### Recommended Model Comparison + +**Distilled Models** (Strongly recommended): +- ✅ **[Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v)** + +- ✅ **[Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v)** + +#### Performance Optimization Suggestions + +When using the above distilled models, you can further optimize performance: +- Disable CFG: `"enable_cfg": false` +- Reduce inference steps: `infer_step: 4` +- Reference configuration files: [config](https://github.com/ModelTC/LightX2V/tree/main/configs/distill) + +## 🚀 Complete Configuration Examples + +### Pre-configured Templates + +- **[14B Model 480p Video Generation Configuration](https://github.com/ModelTC/lightx2v/tree/main/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json)** + +- **[14B Model 720p Video Generation Configuration](https://github.com/ModelTC/lightx2v/tree/main/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json)** + +- **[1.3B Model 720p Video Generation Configuration](https://github.com/ModelTC/LightX2V/tree/main/configs/offload/block/wan_t2v_1_3b.json)** + - The inference bottleneck for 1.3B models is the T5 encoder, so the configuration file specifically optimizes for T5 + +**[Launch Script](https://github.com/ModelTC/LightX2V/tree/main/scripts/wan/run_wan_i2v_lazy_load.sh)** + +## 📚 Reference Resources + +- [Parameter Offloading Mechanism Documentation](../method_tutorials/offload.md) - In-depth understanding of offloading technology principles +- [Quantization Technology Guide](../method_tutorials/quantization.md) - Detailed explanation of quantization technology +- [Gradio Deployment Guide](deploy_gradio.md) - Detailed Gradio deployment instructions + +## ⚠️ Important Notes + +1. **Hardware Requirements**: Ensure your hardware meets minimum configuration requirements +2. **Driver Version**: Recommend using the latest NVIDIA drivers (535+) +3. **CUDA Version**: Ensure CUDA version is compatible with PyTorch (recommend CUDA 11.8+) +4. **Storage Space**: Reserve sufficient disk space for model caching (at least 50GB) +5. **Network Environment**: Stable network connection required for initial model download +6. **Environment Variables**: Be sure to set the recommended environment variables to optimize performance + +**Technical Support**: If you encounter issues, please submit an Issue to the project repository. diff --git a/docs/EN/source/deploy_guides/lora_deploy.md b/docs/EN/source/deploy_guides/lora_deploy.md new file mode 100644 index 0000000000000000000000000000000000000000..769d1d57b4f3c620f36e077fe35dbb4145196cc2 --- /dev/null +++ b/docs/EN/source/deploy_guides/lora_deploy.md @@ -0,0 +1,214 @@ +# LoRA Model Deployment and Related Tools + +LoRA (Low-Rank Adaptation) is an efficient model fine-tuning technique that significantly reduces the number of trainable parameters through low-rank matrix decomposition. LightX2V fully supports LoRA technology, including LoRA inference, LoRA extraction, and LoRA merging functions. + +## 🎯 LoRA Technical Features + +- **Efficient Fine-tuning**: Dramatically reduces training parameters through low-rank adaptation +- **Flexible Deployment**: Supports dynamic loading and removal of LoRA weights +- **Multiple Formats**: Supports various LoRA weight formats and naming conventions +- **Comprehensive Tools**: Provides complete LoRA extraction and merging toolchain + +## 📜 LoRA Inference Deployment + +### Configuration File Method + +Specify LoRA path in configuration file: + +```json +{ + "lora_configs": [ + { + "path": "/path/to/your/lora.safetensors", + "strength": 1.0 + } + ] +} +``` + +**Configuration Parameter Description:** + +- `lora_path`: LoRA weight file path list, supports loading multiple LoRAs simultaneously +- `strength_model`: LoRA strength coefficient (alpha), controls LoRA's influence on the original model + +### Command Line Method + +Specify LoRA path directly in command line (supports loading single LoRA only): + +```bash +python -m lightx2v.infer \ + --model_cls wan2.1 \ + --task t2v \ + --model_path /path/to/model \ + --config_json /path/to/config.json \ + --lora_path /path/to/your/lora.safetensors \ + --lora_strength 0.8 \ + --prompt "Your prompt here" +``` + +### Multiple LoRAs Configuration + +To use multiple LoRAs with different strengths, specify them in the config JSON file: + +```json +{ + "lora_configs": [ + { + "path": "/path/to/first_lora.safetensors", + "strength": 0.8 + }, + { + "path": "/path/to/second_lora.safetensors", + "strength": 0.5 + } + ] +} +``` + +### Supported LoRA Formats + +LightX2V supports multiple LoRA weight naming conventions: + +| Format Type | Weight Naming | Description | +|-------------|---------------|-------------| +| **Standard LoRA** | `lora_A.weight`, `lora_B.weight` | Standard LoRA matrix decomposition format | +| **Down/Up Format** | `lora_down.weight`, `lora_up.weight` | Another common naming convention | +| **Diff Format** | `diff` | `weight` difference values | +| **Bias Diff** | `diff_b` | `bias` weight difference values | +| **Modulation Diff** | `diff_m` | `modulation` weight difference values | + +### Inference Script Examples + +**Step Distillation LoRA Inference:** + +```bash +# T2V LoRA Inference +bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh + +# I2V LoRA Inference +bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh +``` + +**Audio-Driven LoRA Inference:** + +```bash +bash scripts/wan/run_wan_i2v_audio.sh +``` + +### Using LoRA in API Service + +Specify through [config file](wan_t2v_distill_4step_cfg_lora.json), modify the startup command in [scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh): + +```bash +python -m lightx2v.api_server \ + --model_cls wan2.1_distill \ + --task t2v \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora.json \ + --port 8000 \ + --nproc_per_node 1 +``` + +## 🔧 LoRA Extraction Tool + +Use `tools/extract/lora_extractor.py` to extract LoRA weights from the difference between two models. + +### Basic Usage + +```bash +python tools/extract/lora_extractor.py \ + --source-model /path/to/base/model \ + --target-model /path/to/finetuned/model \ + --output /path/to/extracted/lora.safetensors \ + --rank 32 +``` + +### Parameter Description + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `--source-model` | str | ✅ | - | Base model path | +| `--target-model` | str | ✅ | - | Fine-tuned model path | +| `--output` | str | ✅ | - | Output LoRA file path | +| `--source-type` | str | ❌ | `safetensors` | Base model format (`safetensors`/`pytorch`) | +| `--target-type` | str | ❌ | `safetensors` | Fine-tuned model format (`safetensors`/`pytorch`) | +| `--output-format` | str | ❌ | `safetensors` | Output format (`safetensors`/`pytorch`) | +| `--rank` | int | ❌ | `32` | LoRA rank value | +| `--output-dtype` | str | ❌ | `bf16` | Output data type | +| `--diff-only` | bool | ❌ | `False` | Save weight differences only, without LoRA decomposition | + +### Advanced Usage Examples + +**Extract High-Rank LoRA:** + +```bash +python tools/extract/lora_extractor.py \ + --source-model /path/to/base/model \ + --target-model /path/to/finetuned/model \ + --output /path/to/high_rank_lora.safetensors \ + --rank 64 \ + --output-dtype fp16 +``` + +**Save Weight Differences Only:** + +```bash +python tools/extract/lora_extractor.py \ + --source-model /path/to/base/model \ + --target-model /path/to/finetuned/model \ + --output /path/to/weight_diff.safetensors \ + --diff-only +``` + +## 🔀 LoRA Merging Tool + +Use `tools/extract/lora_merger.py` to merge LoRA weights into the base model for subsequent quantization and other operations. + +### Basic Usage + +```bash +python tools/extract/lora_merger.py \ + --source-model /path/to/base/model \ + --lora-model /path/to/lora.safetensors \ + --output /path/to/merged/model.safetensors \ + --alpha 1.0 +``` + +### Parameter Description + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `--source-model` | str | ✅ | - | Base model path | +| `--lora-model` | str | ✅ | - | LoRA weights path | +| `--output` | str | ✅ | - | Output merged model path | +| `--source-type` | str | ❌ | `safetensors` | Base model format | +| `--lora-type` | str | ❌ | `safetensors` | LoRA weights format | +| `--output-format` | str | ❌ | `safetensors` | Output format | +| `--alpha` | float | ❌ | `1.0` | LoRA merge strength | +| `--output-dtype` | str | ❌ | `bf16` | Output data type | + +### Advanced Usage Examples + +**Partial Strength Merging:** + +```bash +python tools/extract/lora_merger.py \ + --source-model /path/to/base/model \ + --lora-model /path/to/lora.safetensors \ + --output /path/to/merged_model.safetensors \ + --alpha 0.7 \ + --output-dtype fp32 +``` + +**Multi-Format Support:** + +```bash +python tools/extract/lora_merger.py \ + --source-model /path/to/base/model.pt \ + --source-type pytorch \ + --lora-model /path/to/lora.safetensors \ + --lora-type safetensors \ + --output /path/to/merged_model.safetensors \ + --output-format safetensors \ + --alpha 1.0 +``` diff --git a/docs/EN/source/getting_started/benchmark.md b/docs/EN/source/getting_started/benchmark.md new file mode 100644 index 0000000000000000000000000000000000000000..ecc8dc95e801f4e6a6f77e5888e014c0e0e4b846 --- /dev/null +++ b/docs/EN/source/getting_started/benchmark.md @@ -0,0 +1,3 @@ +# Benchmark + +For a better display of video playback effects and detailed performance comparisons, you can get better presentation and corresponding documentation content on this [🔗 page](https://github.com/ModelTC/LightX2V/blob/main/docs/EN/source/getting_started/benchmark_source.md). diff --git a/docs/EN/source/getting_started/benchmark_source.md b/docs/EN/source/getting_started/benchmark_source.md new file mode 100644 index 0000000000000000000000000000000000000000..35bafdfef46fc56020302b701724199ea5c16117 --- /dev/null +++ b/docs/EN/source/getting_started/benchmark_source.md @@ -0,0 +1,149 @@ +# 🚀 Benchmark + +> This document showcases the performance test results of LightX2V across different hardware environments, including detailed comparison data for H200 and RTX 4090 platforms. + +--- + +## 🖥️ H200 Environment (~140GB VRAM) + +### 📋 Software Environment Configuration + +| Component | Version | +|:----------|:--------| +| **Python** | 3.11 | +| **PyTorch** | 2.7.1+cu128 | +| **SageAttention** | 2.2.0 | +| **vLLM** | 0.9.2 | +| **sgl-kernel** | 0.1.8 | + +--- + +### 🎬 480P 5s Video Test + +**Test Configuration:** +- **Model**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) +- **Parameters**: `infer_steps=40`, `seed=42`, `enable_cfg=True` + +#### 📊 Performance Comparison Table + +| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect | +|:-------------|:-----------------:|:--------------:|:-------:|:------------:| +| **Wan2.1 Official** | 366 | 71 | 1.0x | | +| **FastVideo** | 292 | 26 | **1.25x** | | +| **LightX2V_1** | 250 | 53 | **1.46x** | | +| **LightX2V_2** | 216 | 50 | **1.70x** | | +| **LightX2V_3** | 191 | 35 | **1.92x** | | +| **LightX2V_3-Distill** | 14 | 35 | **🏆 20.85x** | | +| **LightX2V_4** | 107 | 35 | **3.41x** | | + +--- + +### 🎬 720P 5s Video Test + +**Test Configuration:** +- **Model**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) +- **Parameters**: `infer_steps=40`, `seed=1234`, `enable_cfg=True` + +#### 📊 Performance Comparison Table + +| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect | +|:-------------|:-----------------:|:--------------:|:-------:|:------------:| +| **Wan2.1 Official** | 974 | 81 | 1.0x | | +| **FastVideo** | 914 | 40 | **1.07x** | | +| **LightX2V_1** | 807 | 65 | **1.21x** | | +| **LightX2V_2** | 751 | 57 | **1.30x** | | +| **LightX2V_3** | 671 | 43 | **1.45x** | | +| **LightX2V_3-Distill** | 44 | 43 | **🏆 22.14x** | | +| **LightX2V_4** | 344 | 46 | **2.83x** | | + +--- + +## 🖥️ RTX 4090 Environment (~24GB VRAM) + +### 📋 Software Environment Configuration + +| Component | Version | +|:----------|:--------| +| **Python** | 3.9.16 | +| **PyTorch** | 2.5.1+cu124 | +| **SageAttention** | 2.1.0 | +| **vLLM** | 0.6.6 | +| **sgl-kernel** | 0.0.5 | +| **q8-kernels** | 0.0.0 | + +--- + +### 🎬 480P 5s Video Test + +**Test Configuration:** +- **Model**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) +- **Parameters**: `infer_steps=40`, `seed=42`, `enable_cfg=True` + +#### 📊 Performance Comparison Table + +| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect | +|:-------------|:-----------------:|:--------------:|:-------:|:------------:| +| **Wan2GP(profile=3)** | 779 | 20 | **1.0x** | | +| **LightX2V_5** | 738 | 16 | **1.05x** | | +| **LightX2V_5-Distill** | 68 | 16 | **11.45x** | | +| **LightX2V_6** | 630 | 12 | **1.24x** | | +| **LightX2V_6-Distill** | 63 | 12 | **🏆 12.36x** | + +--- + +### 🎬 720P 5s Video Test + +**Test Configuration:** +- **Model**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) +- **Parameters**: `infer_steps=40`, `seed=1234`, `enable_cfg=True` + +#### 📊 Performance Comparison Table + +| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect | +|:-------------|:-----------------:|:--------------:|:-------:|:------------:| +| **Wan2GP(profile=3)** | -- | OOM | -- | | +| **LightX2V_5** | 2473 | 23 | -- | | +| **LightX2V_5-Distill** | 183 | 23 | -- | | +| **LightX2V_6** | 2169 | 18 | -- | | +| **LightX2V_6-Distill** | 171 | 18 | -- | | + +--- + +## 📖 Configuration Descriptions + +### 🖥️ H200 Environment Configuration Descriptions + +| Configuration | Technical Features | +|:--------------|:------------------| +| **Wan2.1 Official** | Based on [Wan2.1 official repository](https://github.com/Wan-Video/Wan2.1) original implementation | +| **FastVideo** | Based on [FastVideo official repository](https://github.com/hao-ai-lab/FastVideo), using SageAttention2 backend optimization | +| **LightX2V_1** | Uses SageAttention2 to replace native attention mechanism, adopts DIT BF16+FP32 (partial sensitive layers) mixed precision computation, improving computational efficiency while maintaining precision | +| **LightX2V_2** | Unified BF16 precision computation, further reducing memory usage and computational overhead while maintaining generation quality | +| **LightX2V_3** | Introduces FP8 quantization technology to significantly reduce computational precision requirements, combined with Tiling VAE technology to optimize memory usage | +| **LightX2V_3-Distill** | Based on LightX2V_3 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality | +| **LightX2V_4** | Based on LightX2V_3 with TeaCache(teacache_thresh=0.2) caching reuse technology, achieving acceleration through intelligent redundant computation skipping | + +### 🖥️ RTX 4090 Environment Configuration Descriptions + +| Configuration | Technical Features | +|:--------------|:------------------| +| **Wan2GP(profile=3)** | Implementation based on [Wan2GP repository](https://github.com/deepbeepmeep/Wan2GP), using MMGP optimization technology. Profile=3 configuration is suitable for RTX 3090/4090 environments with at least 32GB RAM and 24GB VRAM, adapting to limited memory resources by sacrificing VRAM. Uses quantized models: [480P model](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors) and [720P model](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors) | +| **LightX2V_5** | Uses SageAttention2 to replace native attention mechanism, adopts DIT FP8+FP32 (partial sensitive layers) mixed precision computation, enables CPU offload technology, executes partial sensitive layers with FP32 precision, asynchronously offloads DIT inference process data to CPU, saves VRAM, with block-level offload granularity | +| **LightX2V_5-Distill** | Based on LightX2V_5 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality | +| **LightX2V_6** | Based on LightX2V_3 with CPU offload technology enabled, executes partial sensitive layers with FP32 precision, asynchronously offloads DIT inference process data to CPU, saves VRAM, with block-level offload granularity | +| **LightX2V_6-Distill** | Based on LightX2V_6 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality | + +--- + +## 📁 Configuration Files Reference + +Benchmark-related configuration files and execution scripts are available at: + +| Type | Link | Description | +|:-----|:-----|:------------| +| **Configuration Files** | [configs/bench](https://github.com/ModelTC/LightX2V/tree/main/configs/bench) | Contains JSON files with various optimization configurations | +| **Execution Scripts** | [scripts/bench](https://github.com/ModelTC/LightX2V/tree/main/scripts/bench) | Contains benchmark execution scripts | + +--- + +> 💡 **Tip**: It is recommended to choose the appropriate optimization solution based on your hardware configuration to achieve the best performance. diff --git a/docs/EN/source/getting_started/model_structure.md b/docs/EN/source/getting_started/model_structure.md new file mode 100644 index 0000000000000000000000000000000000000000..6edfc288bfc916392d91ee2205d3b4a3a71736b9 --- /dev/null +++ b/docs/EN/source/getting_started/model_structure.md @@ -0,0 +1,573 @@ +# Model Format and Loading Guide + +## 📖 Overview + +LightX2V is a flexible video generation inference framework that supports multiple model sources and formats, providing users with rich options: + +- ✅ **Wan Official Models**: Directly compatible with officially released complete models from Wan2.1 and Wan2.2 +- ✅ **Single-File Models**: Supports single-file format models released by LightX2V (including quantized versions) +- ✅ **LoRA Models**: Supports loading distilled LoRAs released by LightX2V + +This document provides detailed instructions on how to use various model formats, configuration parameters, and best practices. + +--- + +## 🗂️ Format 1: Wan Official Models + +### Model Repositories +- [Wan2.1 Collection](https://huggingface.co/collections/Wan-AI/wan21-68ac4ba85372ae5a8e282a1b) +- [Wan2.2 Collection](https://huggingface.co/collections/Wan-AI/wan22-68ac4ae80a8b477e79636fc8) + +### Model Features +- **Official Guarantee**: Complete models officially released by Wan-AI with highest quality +- **Complete Components**: Includes all necessary components (DIT, T5, CLIP, VAE) +- **Original Precision**: Uses BF16/FP32 precision with no quantization loss +- **Strong Compatibility**: Fully compatible with Wan official toolchain + +### Wan2.1 Official Models + +#### Directory Structure + +Using [Wan2.1-I2V-14B-720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) as an example: + +``` +Wan2.1-I2V-14B-720P/ +├── diffusion_pytorch_model-00001-of-00007.safetensors # DIT model shard 1 +├── diffusion_pytorch_model-00002-of-00007.safetensors # DIT model shard 2 +├── diffusion_pytorch_model-00003-of-00007.safetensors # DIT model shard 3 +├── diffusion_pytorch_model-00004-of-00007.safetensors # DIT model shard 4 +├── diffusion_pytorch_model-00005-of-00007.safetensors # DIT model shard 5 +├── diffusion_pytorch_model-00006-of-00007.safetensors # DIT model shard 6 +├── diffusion_pytorch_model-00007-of-00007.safetensors # DIT model shard 7 +├── diffusion_pytorch_model.safetensors.index.json # Shard index file +├── models_t5_umt5-xxl-enc-bf16.pth # T5 text encoder +├── models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth # CLIP encoder +├── Wan2.1_VAE.pth # VAE encoder/decoder +├── config.json # Model configuration +├── xlm-roberta-large/ # CLIP tokenizer +├── google/ # T5 tokenizer +├── assets/ +└── examples/ +``` + +#### Usage + +```bash +# Download model +huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P \ + --local-dir ./models/Wan2.1-I2V-14B-720P + +# Configure launch script +model_path=./models/Wan2.1-I2V-14B-720P +lightx2v_path=/path/to/LightX2V + +# Run inference +cd LightX2V/scripts +bash wan/run_wan_i2v.sh +``` + +### Wan2.2 Official Models + +#### Directory Structure + +Using [Wan2.2-I2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) as an example: + +``` +Wan2.2-I2V-A14B/ +├── high_noise_model/ # High-noise model directory +│ ├── diffusion_pytorch_model-00001-of-00009.safetensors +│ ├── diffusion_pytorch_model-00002-of-00009.safetensors +│ ├── ... +│ ├── diffusion_pytorch_model-00009-of-00009.safetensors +│ └── diffusion_pytorch_model.safetensors.index.json +├── low_noise_model/ # Low-noise model directory +│ ├── diffusion_pytorch_model-00001-of-00009.safetensors +│ ├── diffusion_pytorch_model-00002-of-00009.safetensors +│ ├── ... +│ ├── diffusion_pytorch_model-00009-of-00009.safetensors +│ └── diffusion_pytorch_model.safetensors.index.json +├── models_t5_umt5-xxl-enc-bf16.pth # T5 text encoder +├── Wan2.1_VAE.pth # VAE encoder/decoder +├── configuration.json # Model configuration +├── google/ # T5 tokenizer +├── assets/ # Example assets (optional) +└── examples/ # Example files (optional) +``` + +#### Usage + +```bash +# Download model +huggingface-cli download Wan-AI/Wan2.2-I2V-A14B \ + --local-dir ./models/Wan2.2-I2V-A14B + +# Configure launch script +model_path=./models/Wan2.2-I2V-A14B +lightx2v_path=/path/to/LightX2V + +# Run inference +cd LightX2V/scripts +bash wan22/run_wan22_moe_i2v.sh +``` + +### Available Model List + +#### Wan2.1 Official Model List + +| Model Name | Download Link | +|---------|----------| +| Wan2.1-I2V-14B-720P | [Link](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) | +| Wan2.1-I2V-14B-480P | [Link](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) | +| Wan2.1-T2V-14B | [Link](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) | +| Wan2.1-T2V-1.3B | [Link](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) | +| Wan2.1-FLF2V-14B-720P | [Link](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) | +| Wan2.1-VACE-14B | [Link](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) | +| Wan2.1-VACE-1.3B | [Link](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) | + +#### Wan2.2 Official Model List + +| Model Name | Download Link | +|---------|----------| +| Wan2.2-I2V-A14B | [Link](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) | +| Wan2.2-T2V-A14B | [Link](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B) | +| Wan2.2-TI2V-5B | [Link](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B) | +| Wan2.2-Animate-14B | [Link](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B) | + +### Usage Tips + +> 💡 **Quantized Model Usage**: To use quantized models, refer to the [Model Conversion Script](https://github.com/ModelTC/LightX2V/blob/main/tools/convert/readme_zh.md) for conversion, or directly use pre-converted quantized models in Format 2 below +> +> 💡 **Memory Optimization**: For devices with RTX 4090 24GB or smaller memory, it's recommended to combine quantization techniques with CPU offload features: +> - Quantization Configuration: Refer to [Quantization Documentation](../method_tutorials/quantization.md) +> - CPU Offload: Refer to [Parameter Offload Documentation](../method_tutorials/offload.md) +> - Wan2.1 Configuration: Refer to [offload config files](https://github.com/ModelTC/LightX2V/tree/main/configs/offload) +> - Wan2.2 Configuration: Refer to [wan22 config files](https://github.com/ModelTC/LightX2V/tree/main/configs/wan22) with `4090` suffix + +--- + +## 🗂️ Format 2: LightX2V Single-File Models (Recommended) + +### Model Repositories +- [Wan2.1-LightX2V](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- [Wan2.2-LightX2V](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) + +### Model Features +- **Single-File Management**: Single safetensors file, easy to manage and deploy +- **Multi-Precision Support**: Provides original precision, FP8, INT8, and other precision versions +- **Distillation Acceleration**: Supports 4-step fast inference +- **Tool Compatibility**: Compatible with ComfyUI and other tools + +**Examples**: +- `wan2.1_i2v_720p_lightx2v_4step.safetensors` - 720P I2V original precision +- `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors` - 720P I2V FP8 quantization +- `wan2.1_i2v_480p_int8_lightx2v_4step.safetensors` - 480P I2V INT8 quantization +- ... + +### Wan2.1 Single-File Models + +#### Scenario A: Download Single Model File + +**Step 1: Select and Download Model** + +```bash +# Create model directory +mkdir -p ./models/wan2.1_i2v_720p + +# Download 720P I2V FP8 quantized model +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models/wan2.1_i2v_720p \ + --include "wan2.1_i2v_720p_lightx2v_4step.safetensors" +``` + +**Step 2: Manually Organize Other Components** + +Directory structure as follows: +``` +wan2.1_i2v_720p/ +├── wan2.1_i2v_720p_lightx2v_4step.safetensors # Original precision +└── t5/clip/vae/config.json/xlm-roberta-large/google and other components # Need manual organization +``` + +**Step 3: Configure Launch Script** + +```bash +# Set in launch script (point to directory containing model file) +model_path=./models/wan2.1_i2v_720p +lightx2v_path=/path/to/LightX2V + +# Run script +cd LightX2V/scripts +bash wan/run_wan_i2v_distill_4step_cfg.sh +``` + +> 💡 **Tip**: When there's only one model file in the directory, LightX2V will automatically load it. + +#### Scenario B: Download Multiple Model Files + +When you download multiple models with different precisions to the same directory, you need to explicitly specify which model to use in the configuration file. + +**Step 1: Download Multiple Models** + +```bash +# Create model directory +mkdir -p ./models/wan2.1_i2v_720p_multi + +# Download original precision model +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models/wan2.1_i2v_720p_multi \ + --include "wan2.1_i2v_720p_lightx2v_4step.safetensors" + +# Download FP8 quantized model +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models/wan2.1_i2v_720p_multi \ + --include "wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors" + +# Download INT8 quantized model +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models/wan2.1_i2v_720p_multi \ + --include "wan2.1_i2v_720p_int8_lightx2v_4step.safetensors" +``` + +**Step 2: Manually Organize Other Components** + +Directory structure as follows: + +``` +wan2.1_i2v_720p_multi/ +├── wan2.1_i2v_720p_lightx2v_4step.safetensors # Original precision +├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization +└── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 quantization +└── t5/clip/vae/config.json/xlm-roberta-large/google and other components # Need manual organization +``` + +**Step 3: Specify Model in Configuration File** + +Edit configuration file (e.g., `configs/distill/wan_i2v_distill_4step_cfg.json`): + +```json +{ + // Use original precision model + "dit_original_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_lightx2v_4step.safetensors", + + // Or use FP8 quantized model + // "dit_quantized_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors", + // "dit_quantized": true, + // "dit_quant_scheme": "fp8-vllm", + + // Or use INT8 quantized model + // "dit_quantized_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_int8_lightx2v_4step.safetensors", + // "dit_quantized": true, + // "dit_quant_scheme": "int8-vllm", + + // Other configurations... +} +``` +### Usage Tips + +> 💡 **Configuration Parameter Description**: +> - **dit_original_ckpt**: Used to specify the path to original precision models (BF16/FP32/FP16) +> - **dit_quantized_ckpt**: Used to specify the path to quantized models (FP8/INT8), must be used with `dit_quantized` and `dit_quant_scheme` parameters + +**Step 4: Start Inference** + +```bash +cd LightX2V/scripts +bash wan/run_wan_i2v_distill_4step_cfg.sh +``` + +> 💡 **Tip**: Other components (T5, CLIP, VAE, tokenizer, etc.) need to be manually organized into the model directory + +### Wan2.2 Single-File Models + +#### Directory Structure Requirements + +When using Wan2.2 single-file models, you need to manually create a specific directory structure: + +``` +wan2.2_models/ +├── high_noise_model/ # High-noise model directory (required) +│ └── wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors +├── low_noise_model/ # Low-noise model directory (required) +│ └── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors +└── t5/clip/vae/config.json/... # Other components (manually organized) +``` + +#### Scenario A: Only One Model File Per Directory + +```bash +# Create required subdirectories +mkdir -p ./models/wan2.2_models/high_noise_model +mkdir -p ./models/wan2.2_models/low_noise_model + +# Download high-noise model to corresponding directory +huggingface-cli download lightx2v/Wan2.2-Distill-Models \ + --local-dir ./models/wan2.2_models/high_noise_model \ + --include "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" + +# Download low-noise model to corresponding directory +huggingface-cli download lightx2v/Wan2.2-Distill-Models \ + --local-dir ./models/wan2.2_models/low_noise_model \ + --include "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" + +# Configure launch script (point to parent directory) +model_path=./models/wan2.2_models +lightx2v_path=/path/to/LightX2V + +# Run script +cd LightX2V/scripts +bash wan22/run_wan22_moe_i2v_distill.sh +``` + +> 💡 **Tip**: When there's only one model file in each subdirectory, LightX2V will automatically load it. + +#### Scenario B: Multiple Model Files Per Directory + +When you place multiple models with different precisions in both `high_noise_model/` and `low_noise_model/` directories, you need to explicitly specify them in the configuration file. + +```bash +# Create directories +mkdir -p ./models/wan2.2_models_multi/high_noise_model +mkdir -p ./models/wan2.2_models_multi/low_noise_model + +# Download multiple versions of high-noise model +huggingface-cli download lightx2v/Wan2.2-Distill-Models \ + --local-dir ./models/wan2.2_models_multi/high_noise_model \ + --include "wan2.2_i2v_A14b_high_noise_*.safetensors" + +# Download multiple versions of low-noise model +huggingface-cli download lightx2v/Wan2.2-Distill-Models \ + --local-dir ./models/wan2.2_models_multi/low_noise_model \ + --include "wan2.2_i2v_A14b_low_noise_*.safetensors" +``` + +**Directory Structure**: + +``` +wan2.2_models_multi/ +├── high_noise_model/ +│ ├── wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors # Original precision +│ ├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization +│ └── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors # INT8 quantization +└── low_noise_model/ +│ ├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # Original precision +│ ├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization +│ └── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # INT8 quantization +└── t5/vae/config.json/xlm-roberta-large/google and other components # Need manual organization +``` + +**Configuration File Settings**: + +```json +{ + // Use original precision model + "high_noise_original_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors", + + // Or use FP8 quantized model + // "high_noise_quantized_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step.safetensors", + // "low_noise_quantized_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors", + // "dit_quantized": true, + // "dit_quant_scheme": "fp8-vllm" + + // Or use INT8 quantized model + // "high_noise_quantized_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + // "low_noise_quantized_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + // "dit_quantized": true, + // "dit_quant_scheme": "int8-vllm" +} +``` + +### Usage Tips + +> 💡 **Configuration Parameter Description**: +> - **high_noise_original_ckpt** / **low_noise_original_ckpt**: Used to specify the path to original precision models (BF16/FP32/FP16) +> - **high_noise_quantized_ckpt** / **low_noise_quantized_ckpt**: Used to specify the path to quantized models (FP8/INT8), must be used with `dit_quantized` and `dit_quant_scheme` parameters + + +### Available Model List + +#### Wan2.1 Single-File Model List + +**Image-to-Video Models (I2V)** + +| Filename | Precision | Description | +|--------|------|------| +| `wan2.1_i2v_480p_lightx2v_4step.safetensors` | BF16 | 4-step model original precision | +| `wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4-step model FP8 quantization | +| `wan2.1_i2v_480p_int8_lightx2v_4step.safetensors` | INT8 | 4-step model INT8 quantization | +| `wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4-step model ComfyUI format | +| `wan2.1_i2v_720p_lightx2v_4step.safetensors` | BF16 | 4-step model original precision | +| `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4-step model FP8 quantization | +| `wan2.1_i2v_720p_int8_lightx2v_4step.safetensors` | INT8 | 4-step model INT8 quantization | +| `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4-step model ComfyUI format | + +**Text-to-Video Models (T2V)** + +| Filename | Precision | Description | +|--------|------|------| +| `wan2.1_t2v_14b_lightx2v_4step.safetensors` | BF16 | 4-step model original precision | +| `wan2.1_t2v_14b_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4-step model FP8 quantization | +| `wan2.1_t2v_14b_int8_lightx2v_4step.safetensors` | INT8 | 4-step model INT8 quantization | +| `wan2.1_t2v_14b_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4-step model ComfyUI format | + +#### Wan2.2 Single-File Model List + +**Image-to-Video Models (I2V) - A14B Series** + +| Filename | Precision | Description | +|--------|------|------| +| `wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors` | BF16 | High-noise model - 4-step original precision | +| `wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | High-noise model - 4-step FP8 quantization | +| `wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors` | INT8 | High-noise model - 4-step INT8 quantization | +| `wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors` | BF16 | Low-noise model - 4-step original precision | +| `wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | Low-noise model - 4-step FP8 quantization | +| `wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors` | INT8 | Low-noise model - 4-step INT8 quantization | + +> 💡 **Usage Tips**: +> - Wan2.2 models use a dual-noise architecture, requiring both high-noise and low-noise models to be downloaded +> - Refer to the "Wan2.2 Single-File Models" section above for detailed directory organization + +--- + +## 🗂️ Format 3: LightX2V LoRA Models + +LoRA (Low-Rank Adaptation) models provide a lightweight model fine-tuning solution that enables customization for specific effects without modifying the base model. + +### Model Repositories + +- **Wan2.1 LoRA Models**: [lightx2v/Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras) +- **Wan2.2 LoRA Models**: [lightx2v/Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras) + +### Usage Methods + +#### Method 1: Offline Merging + +Merge LoRA weights offline into the base model to generate a new complete model file. + +**Steps**: + +Refer to the [Model Conversion Documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) for offline merging. + +**Advantages**: +- ✅ No need to load LoRA during inference +- ✅ Better performance + +**Disadvantages**: +- ❌ Requires additional storage space +- ❌ Switching different LoRAs requires re-merging + +#### Method 2: Online Loading + +Dynamically load LoRA weights during inference without modifying the base model. + +**LoRA Application Principle**: + +```python +# LoRA weight application formula +# lora_scale = (alpha / rank) +# W' = W + lora_scale * B @ A +# Where: B = up_proj (out_features, rank) +# A = down_proj (rank, in_features) + +if weights_dict["alpha"] is not None: + lora_scale = weights_dict["alpha"] / lora_down.shape[0] +elif alpha is not None: + lora_scale = alpha / lora_down.shape[0] +else: + lora_scale = 1.0 +``` + +**Configuration Method**: + +**Wan2.1 LoRA Configuration**: + +```json +{ + "lora_configs": [ + { + "path": "wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0, + "alpha": null + } + ] +} +``` + +**Wan2.2 LoRA Configuration**: + +Since Wan2.2 uses a dual-model architecture (high-noise/low-noise), LoRA needs to be configured separately for both models: + +```json +{ + "lora_configs": [ + { + "name": "low_noise_model", + "path": "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0, + "alpha": null + }, + { + "name": "high_noise_model", + "path": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0, + "alpha": null + } + ] +} +``` + +**Parameter Description**: + +| Parameter | Description | Default | +|------|------|--------| +| `path` | LoRA model file path | Required | +| `strength` | LoRA strength coefficient, range [0.0, 1.0] | 1.0 | +| `alpha` | LoRA scaling factor, uses model's built-in value when `null` | null | +| `name` | (Wan2.2 only) Specifies which model to apply to | Required | + +**Advantages**: +- ✅ Flexible switching between different LoRAs +- ✅ Saves storage space +- ✅ Can dynamically adjust LoRA strength + +**Disadvantages**: +- ❌ Additional loading time during inference +- ❌ Slightly increases memory usage + +--- + +## 📚 Related Resources + +### Official Repositories +- [LightX2V GitHub](https://github.com/ModelTC/LightX2V) +- [LightX2V Single-File Model Repository](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- [Wan-AI Official Model Repository](https://huggingface.co/Wan-AI) + +### Model Download Links + +**Wan2.1 Series** +- [Wan2.1 Collection](https://huggingface.co/collections/Wan-AI/wan21-68ac4ba85372ae5a8e282a1b) + +**Wan2.2 Series** +- [Wan2.2 Collection](https://huggingface.co/collections/Wan-AI/wan22-68ac4ae80a8b477e79636fc8) + +**LightX2V Single-File Models** +- [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) + +### Documentation Links +- [Quantization Documentation](../method_tutorials/quantization.md) +- [Parameter Offload Documentation](../method_tutorials/offload.md) +- [Configuration File Examples](https://github.com/ModelTC/LightX2V/tree/main/configs) + +--- + +Through this document, you should be able to: + +✅ Understand all model formats supported by LightX2V +✅ Select appropriate models and precisions based on your needs +✅ Correctly download and organize model files +✅ Configure launch parameters and successfully run inference +✅ Resolve common model loading issues + +If you have other questions, feel free to ask in [GitHub Issues](https://github.com/ModelTC/LightX2V/issues). diff --git a/docs/EN/source/getting_started/quickstart.md b/docs/EN/source/getting_started/quickstart.md new file mode 100644 index 0000000000000000000000000000000000000000..799959e4ae2cd6dcbf0f8ad262c276a87370806b --- /dev/null +++ b/docs/EN/source/getting_started/quickstart.md @@ -0,0 +1,349 @@ +# LightX2V Quick Start Guide + +Welcome to LightX2V! This guide will help you quickly set up the environment and start using LightX2V for video generation. + +## 📋 Table of Contents + +- [System Requirements](#system-requirements) +- [Linux Environment Setup](#linux-environment-setup) + - [Docker Environment (Recommended)](#docker-environment-recommended) + - [Conda Environment Setup](#conda-environment-setup) +- [Windows Environment Setup](#windows-environment-setup) +- [Inference Usage](#inference-usage) + +## 🚀 System Requirements + +- **Operating System**: Linux (Ubuntu 18.04+) or Windows 10/11 +- **Python**: 3.10 or higher +- **GPU**: NVIDIA GPU with CUDA support, at least 8GB VRAM +- **Memory**: 16GB or more recommended +- **Storage**: At least 50GB available space + +## 🐧 Linux Environment Setup + +### 🐳 Docker Environment (Recommended) + +We strongly recommend using the Docker environment, which is the simplest and fastest installation method. + +#### 1. Pull Image + +Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags), select a tag with the latest date, such as `25111101-cu128`: + +```bash +docker pull lightx2v/lightx2v:25111101-cu128 +``` + +We recommend using the `cuda128` environment for faster inference speed. If you need to use the `cuda124` environment, you can use image versions with the `-cu124` suffix: + +```bash +docker pull lightx2v/lightx2v:25101501-cu124 +``` + +#### 2. Run Container + +```bash +docker run --gpus all -itd --ipc=host --name [container_name] -v [mount_settings] --entrypoint /bin/bash [image_id] +``` + +#### 3. China Mirror Source (Optional) + +For mainland China, if the network is unstable when pulling images, you can pull from Alibaba Cloud: + +```bash +# cuda128 +docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25111101-cu128 + +# cuda124 +docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25101501-cu124 +``` + +### 🐍 Conda Environment Setup + +If you prefer to set up the environment yourself using Conda, please follow these steps: + +#### Step 1: Clone Repository + +```bash +# Download project code +git clone https://github.com/ModelTC/LightX2V.git +cd LightX2V +``` + +#### Step 2: Create Conda Virtual Environment + +```bash +# Create and activate conda environment +conda create -n lightx2v python=3.11 -y +conda activate lightx2v +``` + +#### Step 3: Install Dependencies + +```bash +pip install -v -e . +``` + +#### Step 4: Install Attention Operators + +**Option A: Flash Attention 2** +```bash +git clone https://github.com/Dao-AILab/flash-attention.git --recursive +cd flash-attention && python setup.py install +``` + +**Option B: Flash Attention 3 (for Hopper architecture GPUs)** +```bash +cd flash-attention/hopper && python setup.py install +``` + +**Option C: SageAttention 2 (Recommended)** +```bash +git clone https://github.com/thu-ml/SageAttention.git +cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0,12.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install -v -e . +``` + +#### Step 4: Install Quantization Operators (Optional) + +Quantization operators are used to support model quantization, which can significantly reduce memory usage and accelerate inference. Choose the appropriate quantization operator based on your needs: + +**Option A: VLLM Kernels (Recommended)** +Suitable for various quantization schemes, supports FP8 and other quantization formats. + +```bash +pip install vllm +``` + +Or install from source for the latest features: + +```bash +git clone https://github.com/vllm-project/vllm.git +cd vllm +uv pip install -e . +``` + +**Option B: SGL Kernels** +Suitable for SGL quantization scheme, requires torch == 2.8.0. + +```bash +pip install sgl-kernel --upgrade +``` + +**Option C: Q8 Kernels** +Suitable for Ada architecture GPUs (such as RTX 4090, L40S, etc.). + +```bash +git clone https://github.com/KONAKONA666/q8_kernels.git +cd q8_kernels && git submodule init && git submodule update +python setup.py install +``` + +> 💡 **Note**: +> - You can skip this step if you don't need quantization functionality +> - Quantized models can be downloaded from [LightX2V HuggingFace](https://huggingface.co/lightx2v) +> - For more quantization information, please refer to the [Quantization Documentation](method_tutorials/quantization.html) + +#### Step 5: Verify Installation + +```python +import lightx2v +print(f"LightX2V Version: {lightx2v.__version__}") +``` + +## 🪟 Windows Environment Setup + +Windows systems only support Conda environment setup. Please follow these steps: + +### 🐍 Conda Environment Setup + +#### Step 1: Check CUDA Version + +First, confirm your GPU driver and CUDA version: + +```cmd +nvidia-smi +``` + +Record the **CUDA Version** information in the output, which needs to be consistent in subsequent installations. + +#### Step 2: Create Python Environment + +```cmd +# Create new environment (Python 3.12 recommended) +conda create -n lightx2v python=3.12 -y + +# Activate environment +conda activate lightx2v +``` + +> 💡 **Note**: Python 3.10 or higher is recommended for best compatibility. + +#### Step 3: Install PyTorch Framework + +**Method 1: Download Official Wheel Package (Recommended)** + +1. Visit the [PyTorch Official Download Page](https://download.pytorch.org/whl/torch/) +2. Select the corresponding version wheel package, paying attention to matching: + - **Python Version**: Consistent with your environment + - **CUDA Version**: Matches your GPU driver + - **Platform**: Select Windows version + +**Example (Python 3.12 + PyTorch 2.6 + CUDA 12.4):** + +```cmd +# Download and install PyTorch +pip install torch-2.6.0+cu124-cp312-cp312-win_amd64.whl + +# Install supporting packages +pip install torchvision==0.21.0 torchaudio==2.6.0 +``` + +**Method 2: Direct Installation via pip** + +```cmd +# CUDA 12.4 version example +pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124 +``` + +#### Step 4: Install Windows Version vLLM + +Download the corresponding wheel package from [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases). + +**Version Matching Requirements:** +- Python version matching +- PyTorch version matching +- CUDA version matching + +```cmd +# Install vLLM (please adjust according to actual filename) +pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl +``` + +#### Step 5: Install Attention Mechanism Operators + +**Option A: Flash Attention 2** + +```cmd +pip install flash-attn==2.7.2.post1 +``` + +**Option B: SageAttention 2 (Strongly Recommended)** + +**Download Sources:** +- [Windows Special Version 1](https://github.com/woct0rdho/SageAttention/releases) +- [Windows Special Version 2](https://github.com/sdbds/SageAttention-for-windows/releases) + +```cmd +# Install SageAttention (please adjust according to actual filename) +pip install sageattention-2.1.1+cu126torch2.6.0-cp312-cp312-win_amd64.whl +``` + +> ⚠️ **Note**: SageAttention's CUDA version doesn't need to be strictly aligned, but Python and PyTorch versions must match. + +#### Step 6: Clone Repository + +```cmd +# Clone project code +git clone https://github.com/ModelTC/LightX2V.git +cd LightX2V + +# Install Windows-specific dependencies +pip install -r requirements_win.txt +pip install -v -e . +``` + +#### Step 7: Install Quantization Operators (Optional) + +Quantization operators are used to support model quantization, which can significantly reduce memory usage and accelerate inference. + +**Install VLLM (Recommended):** + +Download the corresponding wheel package from [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases) and install it. + +```cmd +# Install vLLM (please adjust according to actual filename) +pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl +``` + +> 💡 **Note**: +> - You can skip this step if you don't need quantization functionality +> - Quantized models can be downloaded from [LightX2V HuggingFace](https://huggingface.co/lightx2v) +> - For more quantization information, please refer to the [Quantization Documentation](method_tutorials/quantization.html) + +## 🎯 Inference Usage + +### 📥 Model Preparation + +Before starting inference, you need to download the model files in advance. We recommend: + +- **Download Source**: Download models from [LightX2V Official Hugging Face](https://huggingface.co/lightx2v/) or other open-source model repositories +- **Storage Location**: It's recommended to store models on SSD disks for better read performance +- **Available Models**: Including Wan2.1-I2V, Wan2.1-T2V, and other models supporting different resolutions and functionalities + +### 📁 Configuration Files and Scripts + +The configuration files used for inference are available [here](https://github.com/ModelTC/LightX2V/tree/main/configs), and scripts are available [here](https://github.com/ModelTC/LightX2V/tree/main/scripts). + +You need to configure the downloaded model path in the run script. In addition to the input arguments in the script, there are also some necessary parameters in the configuration file specified by `--config_json`. You can modify them as needed. + +### 🚀 Start Inference + +#### Linux Environment + +```bash +# Run after modifying the path in the script +bash scripts/wan/run_wan_t2v.sh +``` + +#### Windows Environment + +```cmd +# Use Windows batch script +scripts\win\run_wan_t2v.bat +``` + +#### Python Script Launch + +```python +from lightx2v import LightX2VPipeline + +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.1-T2V-14B", + model_cls="wan2.1", + task="t2v", +) + +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=50, + height=480, # 720 + width=832, # 1280 + num_frames=81, + guidance_scale=5.0, + sample_shift=5.0, +) + +seed = 42 +prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +save_result_path="/path/to/save_results/output.mp4" + +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) +``` + +> 💡 **More Examples**: For more usage examples including quantization, offloading, caching, and other advanced configurations, please refer to the [examples directory](https://github.com/ModelTC/LightX2V/tree/main/examples). + +## 📞 Get Help + +If you encounter problems during installation or usage, please: + +1. Search for related issues in [GitHub Issues](https://github.com/ModelTC/LightX2V/issues) +2. Submit a new Issue describing your problem + +--- + +🎉 **Congratulations!** You have successfully set up the LightX2V environment and can now start enjoying video generation! diff --git a/docs/EN/source/index.rst b/docs/EN/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..3e244e62ab26025fc1d5ae328eaf29205a5cf254 --- /dev/null +++ b/docs/EN/source/index.rst @@ -0,0 +1,67 @@ +Welcome to Lightx2v! +================== + +.. figure:: ../../../assets/img_lightx2v.png + :width: 80% + :align: center + :alt: Lightx2v + :class: no-scaled-link + +.. raw:: html + +
+ + License + Ask DeepWiki + Doc + Doc + Docker + +
+ +
+ LightX2V: Light Video Generation Inference Framework +
+ +LightX2V is a lightweight video generation inference framework designed to provide an inference tool that leverages multiple advanced video generation inference techniques. As a unified inference platform, this framework supports various generation tasks such as text-to-video (T2V) and image-to-video (I2V) across different models. X2V means transforming different input modalities (such as text or images) to video output. + +GitHub: https://github.com/ModelTC/lightx2v + +HuggingFace: https://huggingface.co/lightx2v + +Documentation +------------- + +.. toctree:: + :maxdepth: 1 + :caption: Quick Start + + Quick Start + Model Structure + Benchmark + +.. toctree:: + :maxdepth: 1 + :caption: Method Tutorials + + Model Quantization + Feature Caching + Attention Module + Offload + Parallel Inference + Changing Resolution Inference + Step Distill + Autoregressive Distill + Video Frame Interpolation + +.. toctree:: + :maxdepth: 1 + :caption: Deployment Guides + + Low Latency Deployment + Low Resource Deployment + Lora Deployment + Service Deployment + Gradio Deployment + ComfyUI Deployment + Local Windows Deployment diff --git a/docs/EN/source/method_tutorials/attention.md b/docs/EN/source/method_tutorials/attention.md new file mode 100644 index 0000000000000000000000000000000000000000..1396140c74bd7af94b7dc6ccb6f5bb867ed05794 --- /dev/null +++ b/docs/EN/source/method_tutorials/attention.md @@ -0,0 +1,35 @@ +# Attention Mechanisms + +## Attention Mechanisms Supported by LightX2V + +| Name | Type Name | GitHub Link | +|--------------------|------------------|-------------| +| Flash Attention 2 | `flash_attn2` | [flash-attention v2](https://github.com/Dao-AILab/flash-attention) | +| Flash Attention 3 | `flash_attn3` | [flash-attention v3](https://github.com/Dao-AILab/flash-attention) | +| Sage Attention 2 | `sage_attn2` | [SageAttention](https://github.com/thu-ml/SageAttention) | +| Radial Attention | `radial_attn` | [Radial Attention](https://github.com/mit-han-lab/radial-attention) | +| Sparge Attention | `sparge_ckpt` | [Sparge Attention](https://github.com/thu-ml/SpargeAttn) | + +--- + +## Configuration Examples + +The configuration files for attention mechanisms are located [here](https://github.com/ModelTC/lightx2v/tree/main/configs/attentions) + +By specifying --config_json to a specific config file, you can test different attention mechanisms. + +For example, for radial_attn, the configuration is as follows: + +```json +{ + "self_attn_1_type": "radial_attn", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3" +} +``` + +To switch to other types, simply replace the corresponding values with the type names from the table above. + +Tips: radial_attn can only be used in self attention due to the limitations of its sparse algorithm principle. + +For further customization of attention mechanism behavior, please refer to the official documentation or implementation code of each attention library. diff --git a/docs/EN/source/method_tutorials/autoregressive_distill.md b/docs/EN/source/method_tutorials/autoregressive_distill.md new file mode 100644 index 0000000000000000000000000000000000000000..32b9ed50ad551f69e283ff062a825906e6b8a8fe --- /dev/null +++ b/docs/EN/source/method_tutorials/autoregressive_distill.md @@ -0,0 +1,53 @@ +# Autoregressive Distillation + +Autoregressive distillation is a technical exploration in LightX2V. By training distilled models, it reduces inference steps from the original 40-50 steps to **8 steps**, achieving inference acceleration while enabling infinite-length video generation through KV Cache technology. + +> ⚠️ Warning: Currently, autoregressive distillation has mediocre effects and the acceleration improvement has not met expectations, but it can serve as a long-term research project. Currently, LightX2V only supports autoregressive models for T2V. + +## 🔍 Technical Principle + +Autoregressive distillation is implemented through [CausVid](https://github.com/tianweiy/CausVid) technology. CausVid performs step distillation and CFG distillation on 1.3B autoregressive models. LightX2V extends it with a series of enhancements: + +1. **Larger Models**: Supports autoregressive distillation training for 14B models; +2. **More Complete Data Processing Pipeline**: Generates a training dataset of 50,000 prompt-video pairs; + +For detailed implementation, refer to [CausVid-Plus](https://github.com/GoatWu/CausVid-Plus). + +## 🛠️ Configuration Files + +### Configuration File + +Configuration options are provided in the [configs/causvid/](https://github.com/ModelTC/lightx2v/tree/main/configs/causvid) directory: + +| Configuration File | Model Address | +|-------------------|---------------| +| [wan_t2v_causvid.json](https://github.com/ModelTC/lightx2v/blob/main/configs/causvid/wan_t2v_causvid.json) | https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid | + +### Key Configuration Parameters + +```json +{ + "enable_cfg": false, // Disable CFG for speed improvement + "num_fragments": 3, // Number of video segments generated at once, 5s each + "num_frames": 21, // Frames per video segment, modify with caution! + "num_frame_per_block": 3, // Frames per autoregressive block, modify with caution! + "num_blocks": 7, // Autoregressive blocks per video segment, modify with caution! + "frame_seq_length": 1560, // Encoding length per frame, modify with caution! + "denoising_step_list": [ // Denoising timestep list + 999, 934, 862, 756, 603, 410, 250, 140, 74 + ] +} +``` + +## 📜 Usage + +### Model Preparation + +Place the downloaded model (`causal_model.pt` or `causal_model.safetensors`) in the `causvid_models/` folder under the Wan model root directory: +- For T2V: `Wan2.1-T2V-14B/causvid_models/` + +### Inference Script + +```bash +bash scripts/wan/run_wan_t2v_causvid.sh +``` diff --git a/docs/EN/source/method_tutorials/cache.md b/docs/EN/source/method_tutorials/cache.md new file mode 100644 index 0000000000000000000000000000000000000000..953c36af8a7ec604ecf98f12fb401ecb1803307e --- /dev/null +++ b/docs/EN/source/method_tutorials/cache.md @@ -0,0 +1,3 @@ +# Feature Cache + +To demonstrate some video playback effects, you can get better display effects and corresponding documentation content on this [🔗 page](https://github.com/ModelTC/LightX2V/blob/main/docs/EN/source/method_tutorials/cache_source.md). diff --git a/docs/EN/source/method_tutorials/cache_source.md b/docs/EN/source/method_tutorials/cache_source.md new file mode 100644 index 0000000000000000000000000000000000000000..3fd14c010867058f64644e43a07a625a438840b2 --- /dev/null +++ b/docs/EN/source/method_tutorials/cache_source.md @@ -0,0 +1,139 @@ +# Feature Caching + +## Cache Acceleration Algorithm +- In the inference process of diffusion models, cache reuse is an important acceleration algorithm. +- The core idea is to skip redundant computations at certain time steps by reusing historical cache results to improve inference efficiency. +- The key to the algorithm is how to decide which time steps to perform cache reuse, usually based on dynamic judgment of model state changes or error thresholds. +- During inference, key content such as intermediate features, residuals, and attention outputs need to be cached. When entering reusable time steps, the cached content is directly utilized, and the current output is reconstructed through approximation methods like Taylor expansion, thereby reducing repeated calculations and achieving efficient inference. + +### TeaCache +The core idea of `TeaCache` is to accumulate the **relative L1** distance between adjacent time step inputs. When the accumulated distance reaches a set threshold, it determines that the current time step should not use cache reuse; conversely, when the accumulated distance does not reach the set threshold, cache reuse is used to accelerate the inference process. +- Specifically, the algorithm calculates the relative L1 distance between the current input and the previous step input at each inference step and accumulates it. +- When the accumulated distance does not exceed the threshold, it indicates that the model state change is not obvious, so the most recently cached content is directly reused, skipping some redundant calculations. This can significantly reduce the number of forward computations of the model and improve inference speed. + +In practical effects, TeaCache achieves significant acceleration while ensuring generation quality. On a single H200 card, the time consumption and video comparison before and after acceleration are as follows: + + + + + + + + + + +
+ Before acceleration: 58s + + After acceleration: 17.9s +
+ + + +
+ + +- Acceleration ratio: **3.24** +- Config: [wan_t2v_1_3b_tea_480p.json](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json) +- Reference paper: [https://arxiv.org/abs/2411.19108](https://arxiv.org/abs/2411.19108) + +### TaylorSeer Cache +The core of `TaylorSeer Cache` lies in using Taylor's formula to recalculate cached content as residual compensation for cache reuse time steps. +- The specific approach is to not only simply reuse historical cache at cache reuse time steps, but also approximately reconstruct the current output through Taylor expansion. This can further improve output accuracy while reducing computational load. +- Taylor expansion can effectively capture minor changes in model state, allowing errors caused by cache reuse to be compensated, thereby ensuring generation quality while accelerating. + +`TaylorSeer Cache` is suitable for scenarios with high output accuracy requirements and can further improve model inference performance based on cache reuse. + + + + + + + + + + +
+ Before acceleration: 57.7s + + After acceleration: 41.3s +
+ + + +
+ + +- Acceleration ratio: **1.39** +- Config: [wan_t2v_taylorseer](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/taylorseer/wan_t2v_taylorseer.json) +- Reference paper: [https://arxiv.org/abs/2503.06923](https://arxiv.org/abs/2503.06923) + +### AdaCache +The core idea of `AdaCache` is to dynamically adjust the step size of cache reuse based on partial cached content in specified block chunks. +- The algorithm analyzes feature differences between two adjacent time steps within specific blocks and adaptively determines the next cache reuse time step interval based on the difference magnitude. +- When model state changes are small, the step size automatically increases, reducing cache update frequency; when state changes are large, the step size decreases to ensure output quality. + +This allows flexible adjustment of caching strategies based on dynamic changes in the actual inference process, achieving more efficient acceleration and better generation results. AdaCache is suitable for application scenarios that have high requirements for both inference speed and generation quality. + + + + + + + + + + +
+ Before acceleration: 227s + + After acceleration: 83s +
+ + + +
+ + +- Acceleration ratio: **2.73** +- Config: [wan_i2v_ada](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/adacache/wan_i2v_ada.json) +- Reference paper: [https://arxiv.org/abs/2411.02397](https://arxiv.org/abs/2411.02397) + +### CustomCache +`CustomCache` combines the advantages of `TeaCache` and `TaylorSeer Cache`. +- It combines the real-time and reasonable cache decision-making of `TeaCache`, determining when to perform cache reuse through dynamic thresholds. +- At the same time, it utilizes `TaylorSeer`'s Taylor expansion method to make use of cached content. + +This not only efficiently determines the timing of cache reuse but also maximizes the utilization of cached content, improving output accuracy and generation quality. Actual testing shows that `CustomCache` produces video quality superior to using `TeaCache`, `TaylorSeer Cache`, or `AdaCache` alone across multiple content generation tasks, making it one of the currently optimal comprehensive cache acceleration algorithms. + + + + + + + + + + +
+ Before acceleration: 57.9s + + After acceleration: 16.6s +
+ + + +
+ + +- Acceleration ratio: **3.49** +- Config: [wan_t2v_custom_1_3b](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/custom/wan_t2v_custom_1_3b.json) + + +## Usage + +The config files for feature caching are located [here](https://github.com/ModelTC/lightx2v/tree/main/configs/caching) + +By specifying --config_json to the specific config file, you can test different cache algorithms. + +[Here](https://github.com/ModelTC/lightx2v/tree/main/scripts/cache) are some running scripts for use. diff --git a/docs/EN/source/method_tutorials/changing_resolution.md b/docs/EN/source/method_tutorials/changing_resolution.md new file mode 100644 index 0000000000000000000000000000000000000000..9c936949efc29464dfa313454844618b34d80ef3 --- /dev/null +++ b/docs/EN/source/method_tutorials/changing_resolution.md @@ -0,0 +1,66 @@ +# Variable Resolution Inference + +## Overview + +Variable resolution inference is a technical strategy for optimizing the denoising process. It improves computational efficiency while maintaining generation quality by using different resolutions at different stages of the denoising process. The core idea of this method is to use lower resolution for coarse denoising in the early stages and switch to normal resolution for fine processing in the later stages. + +## Technical Principles + +### Multi-stage Denoising Strategy + +Variable resolution inference is based on the following observations: + +- **Early-stage denoising**: Mainly handles coarse noise and overall structure, requiring less detailed information +- **Late-stage denoising**: Focuses on detail optimization and high-frequency information recovery, requiring complete resolution information + +### Resolution Switching Mechanism + +1. **Low-resolution stage** (early stage) + - Downsample the input to a lower resolution (e.g., 0.75x of original size) + - Execute initial denoising steps + - Quickly remove most noise and establish basic structure + +2. **Normal resolution stage** (late stage) + - Upsample the denoising result from the first step back to original resolution + - Continue executing remaining denoising steps + - Restore detailed information and complete fine processing + +### U-shaped Resolution Strategy + +If resolution is reduced at the very beginning of the denoising steps, it may cause significant differences between the final generated video and the video generated through normal inference. Therefore, a U-shaped resolution strategy can be adopted, where the original resolution is maintained for the first few steps, then resolution is reduced for inference. + +## Usage + +The config files for variable resolution inference are located [here](https://github.com/ModelTC/LightX2V/tree/main/configs/changing_resolution) + +You can test variable resolution inference by specifying --config_json to the specific config file. + +You can refer to the scripts [here](https://github.com/ModelTC/LightX2V/blob/main/scripts/changing_resolution) to run. + +### Example 1: +``` +{ + "infer_steps": 50, + "changing_resolution": true, + "resolution_rate": [0.75], + "changing_resolution_steps": [25] +} +``` + +This means a total of 50 steps, with resolution at 0.75x original resolution from step 1 to 25, and original resolution from step 26 to the final step. + +### Example 2: +``` +{ + "infer_steps": 50, + "changing_resolution": true, + "resolution_rate": [1.0, 0.75], + "changing_resolution_steps": [10, 35] +} +``` + +This means a total of 50 steps, with original resolution from step 1 to 10, 0.75x original resolution from step 11 to 35, and original resolution from step 36 to the final step. + +Generally, if `changing_resolution_steps` is [A, B, C], the denoising starts at step 1, and the total number of steps is X, then the inference process will be divided into four segments. + +Specifically, these segments are (0, A], (A, B], (B, C], and (C, X], where each segment is a left-open, right-closed interval. diff --git a/docs/EN/source/method_tutorials/offload.md b/docs/EN/source/method_tutorials/offload.md new file mode 100644 index 0000000000000000000000000000000000000000..0fff2dc569d02bd6c669ab52112d4c70d5f9c03d --- /dev/null +++ b/docs/EN/source/method_tutorials/offload.md @@ -0,0 +1,177 @@ +# Parameter Offload + +## 📖 Overview + +LightX2V implements an advanced parameter offload mechanism specifically designed for large model inference under limited hardware resources. The system provides an excellent speed-memory balance by intelligently managing model weights across different memory hierarchies. + +**Core Features:** +- **Block/Phase-level Offload**: Efficiently manages model weights in block/phase units for optimal memory usage + - **Block**: The basic computational unit of Transformer models, containing complete Transformer layers (self-attention, cross-attention, feedforward networks, etc.), serving as a larger memory management unit + - **Phase**: Finer-grained computational stages within blocks, containing individual computational components (such as self-attention, cross-attention, feedforward networks, etc.), providing more precise memory control +- **Multi-tier Storage Support**: GPU → CPU → Disk hierarchy with intelligent caching +- **Asynchronous Operations**: Overlaps computation and data transfer using CUDA streams +- **Disk/NVMe Serialization**: Supports secondary storage when memory is insufficient + +## 🎯 Offload Strategies + +### Strategy 1: GPU-CPU Block/Phase Offload + +**Use Case**: Insufficient GPU memory but sufficient system memory + +**How It Works**: Manages model weights in block or phase units between GPU and CPU memory, utilizing CUDA streams to overlap computation and data transfer. Blocks contain complete Transformer layers, while Phases are individual computational components within blocks. + +
+GPU-CPU block/phase offload workflow +
+ +
+Swap operation +
+ +
+Swap concept +
+ + +**Block vs Phase Explanation**: +- **Block Granularity**: Larger memory management unit containing complete Transformer layers (self-attention, cross-attention, feedforward networks, etc.), suitable for sufficient memory scenarios with reduced management overhead +- **Phase Granularity**: Finer-grained memory management containing individual computational components (such as self-attention, cross-attention, feedforward networks, etc.), suitable for memory-constrained scenarios with more flexible memory control + +**Key Features:** +- **Asynchronous Transfer**: Uses three CUDA streams with different priorities for parallel computation and transfer + - Compute stream (priority=-1): High priority, handles current computation + - GPU load stream (priority=0): Medium priority, handles CPU to GPU prefetching + - CPU load stream (priority=0): Medium priority, handles GPU to CPU offloading +- **Prefetch Mechanism**: Preloads the next block/phase to GPU in advance +- **Intelligent Caching**: Maintains weight cache in CPU memory +- **Stream Synchronization**: Ensures correctness of data transfer and computation +- **Swap Operation**: Rotates block/phase positions after computation for continuous execution + + + + +### Strategy 2: Disk-CPU-GPU Block/Phase Offload (Lazy Loading) + +**Use Case**: Both GPU memory and system memory are insufficient + +**How It Works**: Builds upon Strategy 1 by introducing disk storage, implementing a three-tier storage hierarchy (Disk → CPU → GPU). CPU continues to serve as a cache pool with configurable size, suitable for devices with limited CPU memory. + + +
+Disk-CPU-GPU block/phase offload workflow +
+ + +
+Working steps +
+ +**Key Features:** +- **Lazy Loading**: Model weights are loaded from disk on-demand, avoiding loading the entire model at once +- **Intelligent Caching**: CPU memory buffer uses FIFO strategy with configurable size +- **Multi-threaded Prefetch**: Uses multiple disk worker threads for parallel loading +- **Asynchronous Transfer**: Uses CUDA streams to overlap computation and data transfer +- **Swap Rotation**: Achieves continuous computation through position rotation, avoiding repeated loading/offloading + +**Working Steps**: +- **Disk Storage**: Model weights are stored on SSD/NVMe by block, one .safetensors file per block +- **Task Scheduling**: When a block/phase is needed, priority task queue assigns disk worker threads +- **Asynchronous Loading**: Multiple disk threads load weight files from disk to CPU memory buffer in parallel +- **Intelligent Caching**: CPU memory buffer manages cache using FIFO strategy with configurable size +- **Cache Hit**: If weights are already in cache, transfer directly to GPU without disk read +- **Prefetch Transfer**: Weights in cache are asynchronously transferred to GPU memory (using GPU load stream) +- **Compute Execution**: Weights on GPU perform computation (using compute stream) while background continues prefetching next block/phase +- **Swap Rotation**: After computation completes, rotate block/phase positions for continuous computation +- **Memory Management**: When CPU cache is full, automatically evict the least recently used weight block/phase + + + +## ⚙️ Configuration Parameters + +### GPU-CPU Offload Configuration + +```python +config = { + "cpu_offload": True, + "offload_ratio": 1.0, # Offload ratio (0.0-1.0) + "offload_granularity": "block", # Offload granularity: "block" or "phase" + "lazy_load": False, # Disable lazy loading +} +``` + +### Disk-CPU-GPU Offload Configuration + +```python +config = { + "cpu_offload": True, + "lazy_load": True, # Enable lazy loading + "offload_ratio": 1.0, # Offload ratio + "offload_granularity": "phase", # Recommended to use phase granularity + "num_disk_workers": 2, # Number of disk worker threads + "offload_to_disk": True, # Enable disk offload +} +``` + +**Intelligent Cache Key Parameters:** +- `max_memory`: Controls CPU cache size, affects cache hit rate and memory usage +- `num_disk_workers`: Controls number of disk loading threads, affects prefetch speed +- `offload_granularity`: Controls cache granularity (block or phase), affects cache efficiency + - `"block"`: Cache management in complete Transformer layer units + - `"phase"`: Cache management in individual computational component units + +**Offload Configuration for Non-DIT Model Components (T5, CLIP, VAE):** + +The offload behavior of these components follows these rules: +- **Default Behavior**: If not specified separately, T5, CLIP, VAE will follow the `cpu_offload` setting +- **Independent Configuration**: Can set offload strategy separately for each component for fine-grained control + +**Configuration Example**: +```json +{ + "cpu_offload": true, // DIT model offload switch + "t5_cpu_offload": false, // T5 encoder independent setting + "clip_cpu_offload": false, // CLIP encoder independent setting + "vae_cpu_offload": false // VAE encoder independent setting +} +``` + +For memory-constrained devices, a progressive offload strategy is recommended: + +1. **Step 1**: Only enable `cpu_offload`, disable `t5_cpu_offload`, `clip_cpu_offload`, `vae_cpu_offload` +2. **Step 2**: If memory is still insufficient, gradually enable CPU offload for T5, CLIP, VAE +3. **Step 3**: If memory is still not enough, consider using quantization + CPU offload or enable `lazy_load` + +**Practical Experience**: +- **RTX 4090 24GB + 14B Model**: Usually only need to enable `cpu_offload`, manually set other component offload to `false`, and use FP8 quantized version +- **Smaller Memory GPUs**: Need to combine quantization, CPU offload, and lazy loading +- **Quantization Schemes**: Refer to [Quantization Documentation](../method_tutorials/quantization.md) to select appropriate quantization strategy + + +**Configuration File Reference**: +- **Wan2.1 Series Models**: Refer to [offload config files](https://github.com/ModelTC/lightx2v/tree/main/configs/offload) +- **Wan2.2 Series Models**: Refer to [wan22 config files](https://github.com/ModelTC/lightx2v/tree/main/configs/wan22) with `4090` suffix + +## 🎯 Usage Recommendations +- 🔄 GPU-CPU Block/Phase Offload: Suitable for insufficient GPU memory (RTX 3090/4090 24G) but sufficient system memory (>64/128G) + +- 💾 Disk-CPU-GPU Block/Phase Offload: Suitable for both insufficient GPU memory (RTX 3060/4090 8G) and system memory (16/32G) + +- 🚫 No Offload: Suitable for high-end hardware configurations pursuing best performance + + +## 🔍 Troubleshooting + +### Common Issues and Solutions + +1. **Disk I/O Bottleneck** + - Solution: Use NVMe SSD, increase num_disk_workers + + +2. **Memory Buffer Overflow** + - Solution: Increase max_memory or reduce num_disk_workers + +3. **Loading Timeout** + - Solution: Check disk performance, optimize file system + + +**Note**: This offload mechanism is specifically designed for LightX2V, fully utilizing the asynchronous computing capabilities of modern hardware, significantly lowering the hardware threshold for large model inference. diff --git a/docs/EN/source/method_tutorials/parallel.md b/docs/EN/source/method_tutorials/parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..324ea7b74c28117cfebc5d6fa66baefad31bc32f --- /dev/null +++ b/docs/EN/source/method_tutorials/parallel.md @@ -0,0 +1,53 @@ +# Parallel Inference + +LightX2V supports distributed parallel inference, enabling the utilization of multiple GPUs for inference. The DiT component supports two parallel attention mechanisms: **Ulysses** and **Ring**, while also supporting **Cfg parallel inference**. Parallel inference significantly reduces inference time and alleviates memory overhead on each GPU. + +## DiT Parallel Configuration + +### 1. Ulysses Parallel + +**Configuration method:** +```json + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +``` + +### 2. Ring Parallel + +**Configuration method:** +```json + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ring" + } +``` + +## Cfg Parallel Configuration + +**Configuration method:** +```json + "parallel": { + "cfg_p_size": 2 + } +``` + +## Hybrid Parallel Configuration + +**Configuration method:** +```json + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + } +``` + +## Usage + +Parallel inference configuration files are available [here](https://github.com/ModelTC/lightx2v/tree/main/configs/dist_infer) + +By specifying --config_json to a specific config file, you can test parallel inference. + +[Here](https://github.com/ModelTC/lightx2v/tree/main/scripts/dist_infer) are some run scripts for your use. diff --git a/docs/EN/source/method_tutorials/quantization.md b/docs/EN/source/method_tutorials/quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..da355a92a08c9280cfb3fd0447d851b8a658c0f7 --- /dev/null +++ b/docs/EN/source/method_tutorials/quantization.md @@ -0,0 +1,158 @@ +# Model Quantization Techniques + +## 📖 Overview + +LightX2V supports quantized inference for DIT, T5, and CLIP models, reducing memory usage and improving inference speed by lowering model precision. + +--- + +## 🔧 Quantization Modes + +| Quantization Mode | Weight Quantization | Activation Quantization | Compute Kernel | Supported Hardware | +|--------------|----------|----------|----------|----------| +| `fp8-vllm` | FP8 channel symmetric | FP8 channel dynamic symmetric | [VLLM](https://github.com/vllm-project/vllm) | H100/H200/H800, RTX 40 series, etc. | +| `int8-vllm` | INT8 channel symmetric | INT8 channel dynamic symmetric | [VLLM](https://github.com/vllm-project/vllm) | A100/A800, RTX 30/40 series, etc. | +| `fp8-sgl` | FP8 channel symmetric | FP8 channel dynamic symmetric | [SGL](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) | H100/H200/H800, RTX 40 series, etc. | +| `int8-sgl` | INT8 channel symmetric | INT8 channel dynamic symmetric | [SGL](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) | A100/A800, RTX 30/40 series, etc. | +| `fp8-q8f` | FP8 channel symmetric | FP8 channel dynamic symmetric | [Q8-Kernels](https://github.com/KONAKONA666/q8_kernels) | RTX 40 series, L40S, etc. | +| `int8-q8f` | INT8 channel symmetric | INT8 channel dynamic symmetric | [Q8-Kernels](https://github.com/KONAKONA666/q8_kernels) | RTX 40 series, L40S, etc. | +| `int8-torchao` | INT8 channel symmetric | INT8 channel dynamic symmetric | [TorchAO](https://github.com/pytorch/ao) | A100/A800, RTX 30/40 series, etc. | +| `int4-g128-marlin` | INT4 group symmetric | FP16 | [Marlin](https://github.com/IST-DASLab/marlin) | H200/H800/A100/A800, RTX 30/40 series, etc. | +| `fp8-b128-deepgemm` | FP8 block symmetric | FP8 group symmetric | [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) | H100/H200/H800, RTX 40 series, etc.| + +--- + +## 🔧 Obtaining Quantized Models + +### Method 1: Download Pre-Quantized Models + +Download pre-quantized models from LightX2V model repositories: + +**DIT Models** + +Download pre-quantized DIT models from [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models): + +```bash +# Download DIT FP8 quantized model +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models \ + --include "wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors" +``` + +**Encoder Models** + +Download pre-quantized T5 and CLIP models from [Encoders-LightX2V](https://huggingface.co/lightx2v/Encoders-Lightx2v): + +```bash +# Download T5 FP8 quantized model +huggingface-cli download lightx2v/Encoders-Lightx2v \ + --local-dir ./models \ + --include "models_t5_umt5-xxl-enc-fp8.pth" + +# Download CLIP FP8 quantized model +huggingface-cli download lightx2v/Encoders-Lightx2v \ + --local-dir ./models \ + --include "models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8.pth" +``` + +### Method 2: Self-Quantize Models + +For detailed quantization tool usage, refer to: [Model Conversion Documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) + +--- + +## 🚀 Using Quantized Models + +### DIT Model Quantization + +#### Supported Quantization Modes + +DIT quantization modes (`dit_quant_scheme`) support: `fp8-vllm`, `int8-vllm`, `fp8-sgl`, `int8-sgl`, `fp8-q8f`, `int8-q8f`, `int8-torchao`, `int4-g128-marlin`, `fp8-b128-deepgemm` + +#### Configuration Example + +```json +{ + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "dit_quantized_ckpt": "/path/to/dit_quantized_model" // Optional +} +``` + +> 💡 **Tip**: When there's only one DIT model in the script's `model_path`, `dit_quantized_ckpt` doesn't need to be specified separately. + +### T5 Model Quantization + +#### Supported Quantization Modes + +T5 quantization modes (`t5_quant_scheme`) support: `int8-vllm`, `fp8-sgl`, `int8-q8f`, `fp8-q8f`, `int8-torchao` + +#### Configuration Example + +```json +{ + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl", + "t5_quantized_ckpt": "/path/to/t5_quantized_model" // Optional +} +``` + +> 💡 **Tip**: When a T5 quantized model exists in the script's specified `model_path` (such as `models_t5_umt5-xxl-enc-fp8.pth` or `models_t5_umt5-xxl-enc-int8.pth`), `t5_quantized_ckpt` doesn't need to be specified separately. + +### CLIP Model Quantization + +#### Supported Quantization Modes + +CLIP quantization modes (`clip_quant_scheme`) support: `int8-vllm`, `fp8-sgl`, `int8-q8f`, `fp8-q8f`, `int8-torchao` + +#### Configuration Example + +```json +{ + "clip_quantized": true, + "clip_quant_scheme": "fp8-sgl", + "clip_quantized_ckpt": "/path/to/clip_quantized_model" // Optional +} +``` + +> 💡 **Tip**: When a CLIP quantized model exists in the script's specified `model_path` (such as `models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8.pth` or `models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8.pth`), `clip_quantized_ckpt` doesn't need to be specified separately. + +### Performance Optimization Strategy + +If memory is insufficient, you can combine parameter offloading to further reduce memory usage. Refer to [Parameter Offload Documentation](../method_tutorials/offload.md): + +> - **Wan2.1 Configuration**: Refer to [offload config files](https://github.com/ModelTC/LightX2V/tree/main/configs/offload) +> - **Wan2.2 Configuration**: Refer to [wan22 config files](https://github.com/ModelTC/LightX2V/tree/main/configs/wan22) with `4090` suffix + +--- + +## 📚 Related Resources + +### Configuration File Examples +- [INT8 Quantization Config](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v.json) +- [Q8F Quantization Config](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v_q8f.json) +- [TorchAO Quantization Config](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v_torchao.json) + +### Run Scripts +- [Quantization Inference Scripts](https://github.com/ModelTC/LightX2V/tree/main/scripts/quantization) + +### Tool Documentation +- [Quantization Tool Documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) +- [LightCompress Quantization Documentation](https://github.com/ModelTC/llmc/blob/main/docs/zh_cn/source/backend/lightx2v.md) + +### Model Repositories +- [Wan2.1-LightX2V Quantized Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- [Wan2.2-LightX2V Quantized Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) +- [Encoders Quantized Models](https://huggingface.co/lightx2v/Encoders-Lightx2v) + +--- + +Through this document, you should be able to: + +✅ Understand quantization schemes supported by LightX2V +✅ Select appropriate quantization strategies based on hardware +✅ Correctly configure quantization parameters +✅ Obtain and use quantized models +✅ Optimize inference performance and memory usage + +If you have other questions, feel free to ask in [GitHub Issues](https://github.com/ModelTC/LightX2V/issues). diff --git a/docs/EN/source/method_tutorials/step_distill.md b/docs/EN/source/method_tutorials/step_distill.md new file mode 100644 index 0000000000000000000000000000000000000000..023a30b0d21a518e2c0f3ef8a4cbc91c7249b8d5 --- /dev/null +++ b/docs/EN/source/method_tutorials/step_distill.md @@ -0,0 +1,183 @@ +# Step Distillation + +Step distillation is an important optimization technique in LightX2V. By training distilled models, it significantly reduces inference steps from the original 40-50 steps to **4 steps**, dramatically improving inference speed while maintaining video quality. LightX2V implements step distillation along with CFG distillation to further enhance inference speed. + +## 🔍 Technical Principle + +### DMD Distillation + +The core technology of step distillation is [DMD Distillation](https://arxiv.org/abs/2311.18828). The DMD distillation framework is shown in the following diagram: + +
+DMD Distillation Framework +
+ +The core idea of DMD distillation is to minimize the KL divergence between the output distributions of the distilled model and the original model: + +$$ +\begin{aligned} +D_{KL}\left(p_{\text{fake}} \; \| \; p_{\text{real}} \right) &= \mathbb{E}{x\sim p\text{fake}}\left(\log\left(\frac{p_\text{fake}(x)}{p_\text{real}(x)}\right)\right)\\ +&= \mathbb{E}{\substack{ +z \sim \mathcal{N}(0; \mathbf{I}) \\ +x = G_\theta(z) +}}-\big(\log~p_\text{real}(x) - \log~p_\text{fake}(x)\big). +\end{aligned} +$$ + +Since directly computing the probability density is nearly impossible, DMD distillation instead computes the gradient of this KL divergence: + +$$ +\begin{aligned} +\nabla_\theta D_{KL} +&= \mathbb{E}{\substack{ +z \sim \mathcal{N}(0; \mathbf{I}) \\ +x = G_\theta(z) +} } \Big[- +\big( +s_\text{real}(x) - s_\text{fake}(x)\big) +\hspace{.5mm} \frac{dG}{d\theta} +\Big], +\end{aligned} +$$ + +where $s_\text{real}(x) =\nabla_{x} \text{log}~p_\text{real}(x)$ and $s_\text{fake}(x) =\nabla_{x} \text{log}~p_\text{fake}(x)$ are score functions. Score functions can be computed by the model. Therefore, DMD distillation maintains three models in total: + +- `real_score`, computes the score of the real distribution; since the real distribution is fixed, DMD distillation uses the original model with fixed weights as its score function; +- `fake_score`, computes the score of the fake distribution; since the fake distribution is constantly updated, DMD distillation initializes it with the original model and fine-tunes it to learn the output distribution of the generator; +- `generator`, the student model, guided by computing the gradient of the KL divergence between `real_score` and `fake_score`. + +> References: +> 1. [DMD (One-step Diffusion with Distribution Matching Distillation)](https://arxiv.org/abs/2311.18828) +> 2. [DMD2 (Improved Distribution Matching Distillation for Fast Image Synthesis)](https://arxiv.org/abs/2405.14867) + +### Self-Forcing + +DMD distillation technology is designed for image generation. The step distillation in LightX2V is implemented based on [Self-Forcing](https://github.com/guandeh17/Self-Forcing) technology. The overall implementation of Self-Forcing is similar to DMD, but following DMD2, it removes the regression loss and uses ODE initialization instead. Additionally, Self-Forcing adds an important optimization for video generation tasks: + +Current DMD distillation-based methods struggle to generate videos in one step. Self-Forcing selects one timestep for optimization each time, with the generator computing gradients only at this step. This approach significantly improves Self-Forcing's training speed and enhances the denoising quality at intermediate timesteps, also improving its effectiveness. + +> References: +> 1. [Self-Forcing (Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion)](https://arxiv.org/abs/2506.08009) + +### LightX2V + +Self-Forcing performs step distillation and CFG distillation on 1.3B autoregressive models. LightX2V extends it with a series of enhancements: + +1. **Larger Models**: Supports step distillation training for 14B models; +2. **More Model Types**: Supports standard bidirectional models and I2V model step distillation training; +3. **Better Results**: LightX2V uses high-quality prompts from approximately 50,000 data entries for training; + +For detailed implementation, refer to [Self-Forcing-Plus](https://github.com/GoatWu/Self-Forcing-Plus). + +## 🎯 Technical Features + +- **Inference Acceleration**: Reduces inference steps from 40-50 to 4 steps without CFG, achieving approximately **20-24x** speedup +- **Quality Preservation**: Maintains original video generation quality through distillation techniques +- **Strong Compatibility**: Supports both T2V and I2V tasks +- **Flexible Usage**: Supports loading complete step distillation models or loading step distillation LoRA on top of native models; compatible with int8/fp8 model quantization + +## 🛠️ Configuration Files + +### Basic Configuration Files + +Multiple configuration options are provided in the [configs/distill/](https://github.com/ModelTC/lightx2v/tree/main/configs/distill) directory: + +| Configuration File | Purpose | Model Address | +|-------------------|---------|---------------| +| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | Load T2V 4-step distillation complete model | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) | +| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | Load I2V 4-step distillation complete model | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) | +| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | Load Wan-T2V model and step distillation LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) | +| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | Load Wan-I2V model and step distillation LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) | + +### Key Configuration Parameters + +- Since DMD distillation only trains a few fixed timesteps, we recommend using `LCM Scheduler` for inference. In [WanStepDistillScheduler](https://github.com/ModelTC/LightX2V/blob/main/lightx2v/models/schedulers/wan/step_distill/scheduler.py), `LCM Scheduler` is already fixed in use, requiring no user configuration. +- `infer_steps`, `denoising_step_list` and `sample_shift` are set to parameters matching those during training, and are generally not recommended for user modification. +- `enable_cfg` must be set to `false` (equivalent to setting `sample_guide_scale = 1`), otherwise the video may become completely blurred. +- `lora_configs` supports merging multiple LoRAs with different strengths. When `lora_configs` is not empty, the original `Wan2.1` model is loaded by default. Therefore, when using `lora_config` and wanting to use step distillation, please set the path and strength of the step distillation LoRA. + +```json +{ + "infer_steps": 4, // Inference steps + "denoising_step_list": [1000, 750, 500, 250], // Denoising timestep list + "sample_shift": 5, // Scheduler timestep shift + "enable_cfg": false, // Disable CFG for speed improvement + "lora_configs": [ // LoRA weights path (optional) + { + "path": "path/to/distill_lora.safetensors", + "strength": 1.0 + } + ] +} +``` + +## 📜 Usage + +### Model Preparation + +**Complete Model:** +Place the downloaded model (`distill_model.pt` or `distill_model.safetensors`) in the `distill_models/` folder under the Wan model root directory: + +- For T2V: `Wan2.1-T2V-14B/distill_models/` +- For I2V-480P: `Wan2.1-I2V-14B-480P/distill_models/` + +**LoRA:** + +1. Place the downloaded LoRA in any location +2. Modify the `lora_path` parameter in the configuration file to the LoRA storage path + +### Inference Scripts + +**T2V Complete Model:** + +```bash +bash scripts/wan/run_wan_t2v_distill_4step_cfg.sh +``` + +**I2V Complete Model:** + +```bash +bash scripts/wan/run_wan_i2v_distill_4step_cfg.sh +``` + +### Step Distillation LoRA Inference Scripts + +**T2V LoRA:** + +```bash +bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh +``` + +**I2V LoRA:** + +```bash +bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh +``` + +## 🔧 Service Deployment + +### Start Distillation Model Service + +Modify the startup command in [scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh): + +```bash +python -m lightx2v.api_server \ + --model_cls wan2.1_distill \ + --task t2v \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg.json \ + --port 8000 \ + --nproc_per_node 1 +``` + +Run the service startup script: + +```bash +scripts/server/start_server.sh +``` + +For more details, see [Service Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_service.html). + +### Usage in Gradio Interface + +See [Gradio Documentation](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_gradio.html) diff --git a/docs/EN/source/method_tutorials/video_frame_interpolation.md b/docs/EN/source/method_tutorials/video_frame_interpolation.md new file mode 100644 index 0000000000000000000000000000000000000000..f262f52a1c8762a4c349507fad0a379a1dcc64cb --- /dev/null +++ b/docs/EN/source/method_tutorials/video_frame_interpolation.md @@ -0,0 +1,246 @@ +# Video Frame Interpolation (VFI) + +> **Important Note**: Video frame interpolation is enabled through configuration files, not command-line parameters. Please add a `video_frame_interpolation` configuration block to your JSON config file to enable this feature. + +## Overview + +Video Frame Interpolation (VFI) is a technique that generates intermediate frames between existing frames to increase the frame rate and create smoother video playback. LightX2V integrates the RIFE (Real-Time Intermediate Flow Estimation) model to provide high-quality frame interpolation capabilities. + +## What is RIFE? + +RIFE is a state-of-the-art video frame interpolation method that uses optical flow estimation to generate intermediate frames. It can effectively: + +- Increase video frame rate (e.g., from 16 FPS to 32 FPS) +- Create smooth motion transitions +- Maintain high visual quality with minimal artifacts +- Process videos in real-time + +## Installation and Setup + +### Download RIFE Model + +First, download the RIFE model weights using the provided script: + +```bash +python tools/download_rife.py +``` + +For example, to download to the location: +```bash +python tools/download_rife.py /path/to/rife/train_log +``` + +This script will: +- Download RIFEv4.26 model from HuggingFace +- Extract and place the model files in the correct directory +- Clean up temporary files + +## Usage + +### Configuration File Setup + +Video frame interpolation is enabled through configuration files. Add a `video_frame_interpolation` configuration block to your JSON config file: + +```json +{ + "infer_steps": 50, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "fps": 16, + "video_frame_interpolation": { + "algo": "rife", + "target_fps": 32, + "model_path": "/path/to/rife/train_log" + } +} +``` + +### Command Line Interface + +Run inference using a configuration file that includes VFI settings: + +```bash +python lightx2v/infer.py \ + --model_cls wan2.1 \ + --task t2v \ + --model_path /path/to/model \ + --config_json ./configs/video_frame_interpolation/wan_t2v.json \ + --prompt "A beautiful sunset over the ocean" \ + --save_result_path ./output.mp4 +``` + +### Configuration Parameters + +In the `video_frame_interpolation` configuration block: + +- `algo`: Frame interpolation algorithm, currently supports "rife" +- `target_fps`: Target frame rate for the output video +- `model_path`: RIFE model path, typically "/path/to/rife/train_log" + +Other related configurations: +- `fps`: Source video frame rate (default 16) + +### Configuration Priority + +The system automatically handles video frame rate configuration with the following priority: +1. `video_frame_interpolation.target_fps` - If video frame interpolation is enabled, this frame rate is used as the output frame rate +2. `fps` (default 16) - If video frame interpolation is not enabled, this frame rate is used; it's always used as the source frame rate + + +## How It Works + +### Frame Interpolation Process + +1. **Source Video Generation**: The base model generates video frames at the source FPS +2. **Frame Analysis**: RIFE analyzes adjacent frames to estimate optical flow +3. **Intermediate Frame Generation**: New frames are generated between existing frames +4. **Temporal Smoothing**: The interpolated frames create smooth motion transitions + +### Technical Details + +- **Input Format**: ComfyUI Image tensors [N, H, W, C] in range [0, 1] +- **Output Format**: Interpolated ComfyUI Image tensors [M, H, W, C] in range [0, 1] +- **Processing**: Automatic padding and resolution handling +- **Memory Optimization**: Efficient GPU memory management + +## Example Configurations + +### Basic Frame Rate Doubling + +Create configuration file `wan_t2v_vfi_32fps.json`: + +```json +{ + "infer_steps": 50, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "seed": 42, + "sample_guide_scale": 6, + "enable_cfg": true, + "fps": 16, + "video_frame_interpolation": { + "algo": "rife", + "target_fps": 32, + "model_path": "/path/to/rife/train_log" + } +} +``` + +Run command: +```bash +python lightx2v/infer.py \ + --model_cls wan2.1 \ + --task t2v \ + --model_path ./models/wan2.1 \ + --config_json ./wan_t2v_vfi_32fps.json \ + --prompt "A cat playing in the garden" \ + --save_result_path ./output_32fps.mp4 +``` + +### Higher Frame Rate Enhancement + +Create configuration file `wan_i2v_vfi_60fps.json`: + +```json +{ + "infer_steps": 30, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "seed": 42, + "sample_guide_scale": 6, + "enable_cfg": true, + "fps": 16, + "video_frame_interpolation": { + "algo": "rife", + "target_fps": 60, + "model_path": "/path/to/rife/train_log" + } +} +``` + +Run command: +```bash +python lightx2v/infer.py \ + --model_cls wan2.1 \ + --task i2v \ + --model_path ./models/wan2.1 \ + --config_json ./wan_i2v_vfi_60fps.json \ + --image_path ./input.jpg \ + --prompt "Smooth camera movement" \ + --save_result_path ./output_60fps.mp4 +``` + +## Performance Considerations + +### Memory Usage + +- RIFE processing requires additional GPU memory +- Memory usage scales with video resolution and length +- Consider using lower resolutions for longer videos + +### Processing Time + +- Frame interpolation adds processing overhead +- Higher target frame rates require more computation +- Processing time is roughly proportional to the number of interpolated frames + +### Quality vs Speed Trade-offs + +- Higher interpolation ratios may introduce artifacts +- Optimal range: 2x to 4x frame rate increase +- For extreme interpolation (>4x), consider multiple passes + +## Best Practices + +### Optimal Use Cases + +- **Motion-heavy videos**: Benefit most from frame interpolation +- **Camera movements**: Smoother panning and zooming +- **Action sequences**: Reduced motion blur perception +- **Slow-motion effects**: Create fluid slow-motion videos + +### Recommended Settings + +- **Source FPS**: 16-24 FPS (generated by base model) +- **Target FPS**: 32-60 FPS (2x to 4x increase) +- **Resolution**: Up to 720p for best performance + +### Troubleshooting + +#### Common Issues + +1. **Out of Memory**: Reduce video resolution or target FPS +2. **Artifacts in output**: Lower the interpolation ratio +3. **Slow processing**: Check GPU memory and consider using CPU offloading + +#### Solutions + +Solve issues by modifying the configuration file: + +```json +{ + // For memory issues, use lower resolution + "target_height": 480, + "target_width": 832, + + // For quality issues, use moderate interpolation + "video_frame_interpolation": { + "target_fps": 24 // instead of 60 + }, + + // For performance issues, enable offloading + "cpu_offload": true +} +``` + +## Technical Implementation + +The RIFE integration in LightX2V includes: + +- **RIFEWrapper**: ComfyUI-compatible wrapper for RIFE model +- **Automatic Model Loading**: Seamless integration with the inference pipeline +- **Memory Optimization**: Efficient tensor management and GPU memory usage +- **Quality Preservation**: Maintains original video quality while adding frames diff --git a/docs/PAPERS_ZH_CN/.readthedocs.yaml b/docs/PAPERS_ZH_CN/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..784bfbe697f4f39630a2beb948b20f8dc7c22a66 --- /dev/null +++ b/docs/PAPERS_ZH_CN/.readthedocs.yaml @@ -0,0 +1,17 @@ +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.10" + +formats: + - epub + +sphinx: + configuration: docs/PAPERS_ZH_CN/source/conf.py + +python: + install: + - requirements: requirements-docs.txt diff --git a/docs/PAPERS_ZH_CN/Makefile b/docs/PAPERS_ZH_CN/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293 --- /dev/null +++ b/docs/PAPERS_ZH_CN/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/PAPERS_ZH_CN/make.bat b/docs/PAPERS_ZH_CN/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a --- /dev/null +++ b/docs/PAPERS_ZH_CN/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/PAPERS_ZH_CN/source/conf.py b/docs/PAPERS_ZH_CN/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d499df2da21e3609d39882eda567372fdd544f --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/conf.py @@ -0,0 +1,122 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +import logging +import os +import sys +from typing import List + +import sphinxcontrib.redoc +from sphinx.ext import autodoc + +logger = logging.getLogger(__name__) +sys.path.append(os.path.abspath("../..")) + +# -- Project information ----------------------------------------------------- + +project = "Lightx2v" +copyright = "2025, Lightx2v Team" +author = "the Lightx2v Team" + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.intersphinx", + "sphinx_copybutton", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "myst_parser", + "sphinxarg.ext", + "sphinxcontrib.redoc", + "sphinxcontrib.openapi", +] + +html_static_path = ["_static"] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns: List[str] = ["**/*.template.rst"] + +# Exclude the prompt "$" when copying code +copybutton_prompt_text = r"\$ " +copybutton_prompt_is_regexp = True + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_title = project +html_theme = "sphinx_book_theme" +# html_theme = 'sphinx_rtd_theme' +html_logo = "../../../assets/img_lightx2v.png" +html_theme_options = { + "path_to_docs": "docs/ZH_CN/source", + "repository_url": "https://github.com/ModelTC/lightx2v", + "use_repository_button": True, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ['_static'] + + +# Generate additional rst documentation here. +def setup(app): + # from docs.source.generate_examples import generate_examples + # generate_examples() + pass + + +# Mock out external dependencies here. +autodoc_mock_imports = [ + "cpuinfo", + "torch", + "transformers", + "psutil", + "prometheus_client", + "sentencepiece", + "lightllmnumpy", + "tqdm", + "tensorizer", +] + +for mock_target in autodoc_mock_imports: + if mock_target in sys.modules: + logger.info( + "Potentially problematic mock target (%s) found; autodoc_mock_imports cannot mock modules that have already been loaded into sys.modules when the sphinx build starts.", + mock_target, + ) + + +class MockedClassDocumenter(autodoc.ClassDocumenter): + """Remove note about base class when a class is derived from object.""" + + def add_line(self, line: str, source: str, *lineno: int) -> None: + if line == " Bases: :py:class:`object`": + return + super().add_line(line, source, *lineno) + + +autodoc.ClassDocumenter = MockedClassDocumenter + +navigation_with_keys = False diff --git a/docs/PAPERS_ZH_CN/source/index.rst b/docs/PAPERS_ZH_CN/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..cb4414a9a3c80e31572fe99b052aa9e065122274 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/index.rst @@ -0,0 +1,53 @@ +欢迎了解 Lightx2v 论文收藏集! +================== + +.. figure:: ../../../assets/img_lightx2v.png + :width: 80% + :align: center + :alt: Lightx2v + :class: no-scaled-link + +.. raw:: html + +
+ + License + Ask DeepWiki + Doc + Doc + Papers + Docker + +
+ +
+ LightX2V: 一个轻量级的视频生成推理框架 +
+ + +LightX2V 是一个轻量级的视频生成推理框架。这里是我们维护的一个视频生成推理加速相关的论文收藏集,帮助你快速了解视频生成推理加速相关的经典方法和最新进展。 + +GitHub: https://github.com/ModelTC/lightx2v + +HuggingFace: https://huggingface.co/lightx2v + +论文收藏集 +------------- + +.. toctree:: + :maxdepth: 1 + :caption: 论文分类 + + 图像视频生成基础 + 开源模型 + 模型量化 + 特征缓存 + 注意力机制 + 参数卸载 + 并行推理 + 变分辨率推理 + 步数蒸馏 + 自回归模型 + vae加速 + prompt增强 + 强化学习 diff --git a/docs/PAPERS_ZH_CN/source/papers/RL.md b/docs/PAPERS_ZH_CN/source/papers/RL.md new file mode 100644 index 0000000000000000000000000000000000000000..0081f92328a9fe548654098b76adc5815aba6171 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/RL.md @@ -0,0 +1,3 @@ +# 强化学习 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/attention.md b/docs/PAPERS_ZH_CN/source/papers/attention.md new file mode 100644 index 0000000000000000000000000000000000000000..a17d295b51fb08e97f2a41c7c79bd5a1831cddfc --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/attention.md @@ -0,0 +1,113 @@ +# 注意力机制 + +### Sparse VideoGen: Accelerating Video Diffusion Transformers with Spatial-Temporal Sparsity + +[paper](https://arxiv.org/abs/2502.01776) | [code](https://github.com/svg-project/Sparse-VideoGen) + +### Sparse VideoGen2: Accelerate Video Generation with Sparse Attention via Semantic-Aware Permutation + +[paper](https://arxiv.org/abs/2505.18875) + +### Training-free and Adaptive Sparse Attention for Efficient Long Video Generation + +[paper](https://arxiv.org/abs/2502.21079) + +### DSV: Exploiting Dynamic Sparsity to Accelerate Large-Scale Video DiT Training + +[paper](https://arxiv.org/abs/2502.07590) + +### MMInference: Accelerating Pre-filling for Long-Context VLMs via Modality-Aware Permutation Sparse Attention + +[paper](https://github.com/microsoft/MInference) + +### FPSAttention: Training-Aware FP8 and Sparsity Co-Design for Fast Video Diffusion + +[paper](https://arxiv.org/abs/2506.04648) + +### VORTA: Efficient Video Diffusion via Routing Sparse Attention + +[paper](https://arxiv.org/abs/2505.18809) + +### Training-Free Efficient Video Generation via Dynamic Token Carving + +[paper](https://arxiv.org/abs/2505.16864) + +### RainFusion: Adaptive Video Generation Acceleration via Multi-Dimensional Visual Redundancy + +[paper](https://arxiv.org/abs/2505.21036) + +### Radial Attention: O(nlogn) Sparse Attention with Energy Decay for Long Video Generation + +[paper](https://arxiv.org/abs/2506.19852) + +### VMoBA: Mixture-of-Block Attention for Video Diffusion Models + +[paper](https://arxiv.org/abs/2506.23858) + +### SpargeAttention: Accurate and Training-free Sparse Attention Accelerating Any Model Inference + +[paper](https://arxiv.org/abs/2502.18137) | [code](https://github.com/thu-ml/SpargeAttn) + +### Fast Video Generation with Sliding Tile Attention + +[paper](https://arxiv.org/abs/2502.04507) | [code](https://github.com/hao-ai-lab/FastVideo) + +### PAROAttention: Pattern-Aware ReOrdering for Efficient Sparse and Quantized Attention in Visual Generation Models + +[paper](https://arxiv.org/abs/2506.16054) + +### Generalized Neighborhood Attention: Multi-dimensional Sparse Attention at the Speed of Light + +[paper](https://arxiv.org/abs/2504.16922) + +### Astraea: A GPU-Oriented Token-wise Acceleration Framework for Video Diffusion Transformers + +[paper](https://arxiv.org/abs/2506.05096) + +### ∇NABLA: Neighborhood Adaptive Block-Level Attention + +[paper](https://arxiv.org/abs/2507.13546v1) [code](https://github.com/gen-ai-team/Wan2.1-NABLA) + +### Compact Attention: Exploiting Structured Spatio-Temporal Sparsity for Fast Video Generation + +[paper](https://arxiv.org/abs/2508.12969) + +### A Survey of Efficient Attention Methods: Hardware-efficient, Sparse, Compact, and Linear Attention + +[paper](https://attention-survey.github.io/files/Attention_Survey.pdf) + +### Bidirectional Sparse Attention for Faster Video Diffusion Training + +[paper](https://arxiv.org/abs/2509.01085) + +### Mixture of Contexts for Long Video Generation + +[paper](https://arxiv.org/abs/2508.21058) + +### LoViC: Efficient Long Video Generation with Context Compression + +[paper](https://arxiv.org/abs/2507.12952) + +### MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training + +[paper](https://sandai-org.github.io/MagiAttention/blog/) [code](https://github.com/SandAI-org/MagiAttention) + +### DraftAttention: Fast Video Diffusion via Low-Resolution Attention Guidance + +[paper](https://arxiv.org/abs/2505.14708) [code](https://github.com/shawnricecake/draft-attention) + +### XAttention: Block Sparse Attention with Antidiagonal Scoring + +[paper](https://arxiv.org/abs/2503.16428) [code](https://github.com/mit-han-lab/x-attention) + +### VSA: Faster Video Diffusion with Trainable Sparse Attention + +[paper](https://arxiv.org/abs/2505.13389) [code](https://github.com/hao-ai-lab/FastVideo) + +### QuantSparse: Comprehensively Compressing Video Diffusion Transformer with Model Quantization and Attention Sparsification + +[paper](https://arxiv.org/abs/2509.23681) + +### SLA: Beyond Sparsity in Diffusion Transformers via Fine-Tunable Sparse-Linear Attention + +[paper](https://arxiv.org/abs/2509.24006) diff --git a/docs/PAPERS_ZH_CN/source/papers/autoregressive.md b/docs/PAPERS_ZH_CN/source/papers/autoregressive.md new file mode 100644 index 0000000000000000000000000000000000000000..27c2d1dad269960325ee64c86df155622f2410c8 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/autoregressive.md @@ -0,0 +1,3 @@ +# 自回归模型 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/cache.md b/docs/PAPERS_ZH_CN/source/papers/cache.md new file mode 100644 index 0000000000000000000000000000000000000000..661ef10213cbf5af62d0f3443a8ab774285f4be4 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/cache.md @@ -0,0 +1,3 @@ +# 特征缓存 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/changing_resolution.md b/docs/PAPERS_ZH_CN/source/papers/changing_resolution.md new file mode 100644 index 0000000000000000000000000000000000000000..175eea9e79dc2936404ad276ffe96f86ae8b7a35 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/changing_resolution.md @@ -0,0 +1,3 @@ +# 变分辨率推理 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/generation_basics.md b/docs/PAPERS_ZH_CN/source/papers/generation_basics.md new file mode 100644 index 0000000000000000000000000000000000000000..6d5b97f72edf792e6f670998849da9a2a3e7d96d --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/generation_basics.md @@ -0,0 +1,3 @@ +# 图像视频生成基础 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/models.md b/docs/PAPERS_ZH_CN/source/papers/models.md new file mode 100644 index 0000000000000000000000000000000000000000..de9e00c69d9b1d02e35766addcc9df106604b842 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/models.md @@ -0,0 +1,276 @@ + +
+ +# Open-Source Models + +📢: Collections of Awesome Open-Source Model Resources. + +
+ + +## 📚 *Contents* + +- Open-Source Models + - [Foundation Models](#foundation-models) + - [World Models](#world-models) + + +### Foundation Models: + +- **Stable Video Diffusion: Scaling Latent Video Diffusion Models to Large Datasets**, Technical Report 2023. + + *Andreas Blattmann, Tim Dockhorn, Sumith Kulal, Daniel Mendelevitch, Maciej Kilian, et al.* + + [[Paper](https://arxiv.org/abs/2311.15127)] [[Code](https://github.com/Stability-AI/generative-models)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/UNet-brown) + +
BibTex + + ```text + @article{blattmann2023stable, + title={Stable video diffusion: Scaling latent video diffusion models to large datasets}, + author={Blattmann, Andreas and Dockhorn, Tim and Kulal, Sumith and Mendelevitch, Daniel and Kilian, Maciej and Lorenz, Dominik and Levi, Yam and English, Zion and Voleti, Vikram and Letts, Adam and others}, + journal={arXiv preprint arXiv:2311.15127}, + year={2023} + } + ``` +
+ +- **Wan: Open and Advanced Large-Scale Video Generative Models**, Technical Report 2025. + + *Team Wan, Ang Wang, Baole Ai, Bin Wen, Chaojie Mao, Chen-Wei Xie, et al.* + + [[Paper](https://arxiv.org/abs/2503.20314)] [[Code](https://github.com/Wan-Video/Wan2.1)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/DIT-brown) + +
BibTex + + ```text + @article{wan2025wan, + title={Wan: Open and advanced large-scale video generative models}, + author={Wan, Team and Wang, Ang and Ai, Baole and Wen, Bin and Mao, Chaojie and Xie, Chen-Wei and Chen, Di and Yu, Feiwu and Zhao, Haiming and Yang, Jianxiao and others}, + journal={arXiv preprint arXiv:2503.20314}, + year={2025} + } + ``` +
+ +- **HunyuanVideo: A Systematic Framework For Large Video Generation Model**, Technical Report 2024. + + *Weijie Kong, Qi Tian, Zijian Zhang, Rox Min, Zuozhuo Dai, Jin Zhou, et al.* + + [[Paper](https://arxiv.org/abs/2412.03603)] [[Code](https://github.com/Tencent-Hunyuan/HunyuanVideo)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/DIT-brown) + +
BibTex + + ```text + @article{kong2024hunyuanvideo, + title={Hunyuanvideo: A systematic framework for large video generative models}, + author={Kong, Weijie and Tian, Qi and Zhang, Zijian and Min, Rox and Dai, Zuozhuo and Zhou, Jin and Xiong, Jiangfeng and Li, Xin and Wu, Bo and Zhang, Jianwei and others}, + journal={arXiv preprint arXiv:2412.03603}, + year={2024} + } + ``` +
+ +- **CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer**, ICLR 2025. + + *Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, et al.* + + [[Paper](https://arxiv.org/abs/2408.06072)] [[Code](https://github.com/zai-org/CogVideo)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/DIT-brown) + +
BibTex + + ```text + @article{yang2024cogvideox, + title={Cogvideox: Text-to-video diffusion models with an expert transformer}, + author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others}, + journal={arXiv preprint arXiv:2408.06072}, + year={2024} + } + ``` +
+ + +- **SkyReels V2: Infinite-Length Film Generative Model**, Technical Report 2025. + + *Guibin Chen, Dixuan Lin, Jiangping Yang, Chunze Lin, Junchen Zhu, Mingyuan Fan, et al.* + + [[Paper](https://arxiv.org/abs/2504.13074)] [[Code](https://github.com/SkyworkAI/SkyReels-V2)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/DIT-brown) + +
BibTex + + ```text + @misc{chen2025skyreelsv2infinitelengthfilmgenerative, + title={SkyReels-V2: Infinite-length Film Generative Model}, + author={Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou}, + year={2025}, + eprint={2504.13074}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2504.13074}, + } + ``` +
+ + +- **Open-Sora: Democratizing Efficient Video Production for All**, Technical Report 2025. + + *Xiangyu Peng, Zangwei Zheng, Chenhui Shen, Tom Young, Xinying Guo, et al.* + + [[Paper](https://arxiv.org/abs/2503.09642v2)] [[Code](https://github.com/hpcaitech/Open-Sora)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/DIT-brown) ![](https://img.shields.io/badge/V2V-orange) + +
BibTex + + ```text + @article{peng2025open, + title={Open-sora 2.0: Training a commercial-level video generation model in $200 k}, + author={Peng, Xiangyu and Zheng, Zangwei and Shen, Chenhui and Young, Tom and Guo, Xinying and Wang, Binluo and Xu, Hang and Liu, Hongxin and Jiang, Mingyan and Li, Wenjun and others}, + journal={arXiv preprint arXiv:2503.09642}, + year={2025} + } + ``` +
+ +- **Pyramidal Flow Matching for Efficient Video Generative Modeling**, Technical Report 2024. + + *Yang Jin, Zhicheng Sun, Ningyuan Li, Kun Xu, Kun Xu, et al.* + + [[Paper](https://arxiv.org/abs/2410.05954)] [[Code](https://github.com/jy0205/Pyramid-Flow)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/AR-brown) +
BibTex + + ```text + @article{jin2024pyramidal, + title={Pyramidal flow matching for efficient video generative modeling}, + author={Jin, Yang and Sun, Zhicheng and Li, Ningyuan and Xu, Kun and Jiang, Hao and Zhuang, Nan and Huang, Quzhe and Song, Yang and Mu, Yadong and Lin, Zhouchen}, + journal={arXiv preprint arXiv:2410.05954}, + year={2024} + } + ``` +
+ +- **MAGI-1: Autoregressive Video Generation at Scale**, Technical Report 2025. + + *Sand.ai, Hansi Teng, Hongyu Jia, Lei Sun, Lingzhi Li, Maolin Li, Mingqiu Tang, et al.* + + [[Paper](https://arxiv.org/pdf/2505.13211)] [[Code](https://github.com/SandAI-org/Magi-1)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/AR-brown) ![](https://img.shields.io/badge/V2V-orange) +
BibTex + + ```text + @article{teng2025magi, + title={MAGI-1: Autoregressive Video Generation at Scale}, + author={Teng, Hansi and Jia, Hongyu and Sun, Lei and Li, Lingzhi and Li, Maolin and Tang, Mingqiu and Han, Shuai and Zhang, Tianning and Zhang, WQ and Luo, Weifeng and others}, + journal={arXiv preprint arXiv:2505.13211}, + year={2025} + } + ``` +
+ +- **From Slow Bidirectional to Fast Autoregressive Video Diffusion Models**, CVPR 2025. + + *Tianwei Yin, Qiang Zhang, Richard Zhang, William T. Freeman, Fredo Durand, et al.* + + [[Paper](http://arxiv.org/abs/2412.07772)] [[Code](https://github.com/tianweiy/CausVid)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/AR-brown) +
BibTex + + ```text + @inproceedings{yin2025slow, + title={From slow bidirectional to fast autoregressive video diffusion models}, + author={Yin, Tianwei and Zhang, Qiang and Zhang, Richard and Freeman, William T and Durand, Fredo and Shechtman, Eli and Huang, Xun}, + booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, + pages={22963--22974}, + year={2025} + } + ``` +
+ +- **Packing Input Frame Context in Next-Frame Prediction Models for Video Generation**, arxiv 2025. + + *Lvmin Zhang, Maneesh Agrawala.* + + [[Paper](https://arxiv.org/abs/2504.12626)] [[Code](https://github.com/lllyasviel/FramePack)] ![](https://img.shields.io/badge/T2V-blue) ![](https://img.shields.io/badge/I2V-green) ![](https://img.shields.io/badge/AR-brown) +
BibTex + + ```text + @article{zhang2025packing, + title={Packing input frame context in next-frame prediction models for video generation}, + author={Zhang, Lvmin and Agrawala, Maneesh}, + journal={arXiv preprint arXiv:2504.12626}, + year={2025} + } + ``` +
+ +### World Models: + +- **Matrix-Game 2.0: An Open-Source, Real-Time, and Streaming Interactive World Model**, Technical Report 2025. + + *Xianglong He, Chunli Peng, Zexiang Liu, Boyang Wang, Yifan Zhang, et al.* + + [[Paper](https://arxiv.org/abs/2508.13009)] [[Code](https://matrix-game-v2.github.io/)] ![](https://img.shields.io/badge/keyboard-blue) ![](https://img.shields.io/badge/mouse-green) ![](https://img.shields.io/badge/DIT-brown) +
BibTex + + ```text + @article{he2025matrix, + title={Matrix-Game 2.0: An Open-Source, Real-Time, and Streaming Interactive World Model}, + author={He, Xianglong and Peng, Chunli and Liu, Zexiang and Wang, Boyang and Zhang, Yifan and Cui, Qi and Kang, Fei and Jiang, Biao and An, Mengyin and Ren, Yangyang and others}, + journal={arXiv preprint arXiv:2508.13009}, + year={2025} + } + ``` +
+ +- **HunyuanWorld 1.0: Generating Immersive, Explorable, and Interactive 3D Worlds from Words or Pixels**, Technical Report 2025. + + *HunyuanWorld Team, Zhenwei Wang, Yuhao Liu, Junta Wu, Zixiao Gu, Haoyuan Wang, et al.* + + [[Paper](https://arxiv.org/abs/2507.21809)] [[Code](https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0)] ![](https://img.shields.io/badge/image-blue) ![](https://img.shields.io/badge/text-green) ![](https://img.shields.io/badge/DIT-brown) +
BibTex + + ```text + @article{team2025hunyuanworld, + title={HunyuanWorld 1.0: Generating Immersive, Explorable, and Interactive 3D Worlds from Words or Pixels}, + author={Team, HunyuanWorld and Wang, Zhenwei and Liu, Yuhao and Wu, Junta and Gu, Zixiao and Wang, Haoyuan and Zuo, Xuhui and Huang, Tianyu and Li, Wenhuan and Zhang, Sheng and others}, + journal={arXiv preprint arXiv:2507.21809}, + year={2025} + } + ``` +
+ +- **Cosmos-Drive-Dreams: Scalable Synthetic Driving Data Generation with World Foundation Models**, Technical Report 2025. + + *Xuanchi Ren, Yifan Lu, Tianshi Cao, Ruiyuan Gao, Shengyu Huang, Amirmojtaba Sabour, et al.* + + [[Paper](https://arxiv.org/abs/2506.09042)] [[Code](https://research.nvidia.com/labs/toronto-ai/cosmos_drive_dreams)] ![](https://img.shields.io/badge/drive-blue) ![](https://img.shields.io/badge/DIT-brown) +
BibTex + + ```text + @article{ren2025cosmos, + title={Cosmos-Drive-Dreams: Scalable Synthetic Driving Data Generation with World Foundation Models}, + author={Ren, Xuanchi and Lu, Yifan and Cao, Tianshi and Gao, Ruiyuan and Huang, Shengyu and Sabour, Amirmojtaba and Shen, Tianchang and Pfaff, Tobias and Wu, Jay Zhangjie and Chen, Runjian and others}, + journal={arXiv preprint arXiv:2506.09042}, + year={2025} + } + ``` +
+ +- **Genie 3: A new frontier for world models**, Blog 2025. + + *Google DeepMind* + + [[Blog](https://deepmind.google/discover/blog/genie-3-a-new-frontier-for-world-models/)] ![](https://img.shields.io/badge/event-blue) ![](https://img.shields.io/badge/DIT-brown) + +- **GAIA-2: A Controllable Multi-View Generative World Model for Autonomous Driving.**, Technical Report 2025. + + *Lloyd Russell, Anthony Hu, Lorenzo Bertoni, George Fedoseev, Jamie Shotton, et al.* + + [[Paper](https://arxiv.org/abs/2503.20523)] [[Code](https://github.com/Tencent-Hunyuan/HunyuanWorld-1.0)] ![](https://img.shields.io/badge/drive-blue) ![](https://img.shields.io/badge/transformer-brown) +
BibTex + + ```text + @article{russell2025gaia, + title={Gaia-2: A controllable multi-view generative world model for autonomous driving}, + author={Russell, Lloyd and Hu, Anthony and Bertoni, Lorenzo and Fedoseev, George and Shotton, Jamie and Arani, Elahe and Corrado, Gianluca}, + journal={arXiv preprint arXiv:2503.20523}, + year={2025} + } + ``` +
diff --git a/docs/PAPERS_ZH_CN/source/papers/offload.md b/docs/PAPERS_ZH_CN/source/papers/offload.md new file mode 100644 index 0000000000000000000000000000000000000000..302f0b6869df1d6263f2748a7dfd11bd7dae4f1e --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/offload.md @@ -0,0 +1,3 @@ +# 参数卸载 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/parallel.md b/docs/PAPERS_ZH_CN/source/papers/parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..b68de1c0ebdc89c76e3155a1c22daed2ee3a7a8d --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/parallel.md @@ -0,0 +1,3 @@ +# 并行推理 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/prompt_enhance.md b/docs/PAPERS_ZH_CN/source/papers/prompt_enhance.md new file mode 100644 index 0000000000000000000000000000000000000000..d51d77bf8a8338c6eb04de88a09eb7cf65d01681 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/prompt_enhance.md @@ -0,0 +1,3 @@ +# prompt增强 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/quantization.md b/docs/PAPERS_ZH_CN/source/papers/quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..d29d8f6af2e8e44f9a179d7a33150286ec500504 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/quantization.md @@ -0,0 +1,3 @@ +# 模型量化 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/step_distill.md b/docs/PAPERS_ZH_CN/source/papers/step_distill.md new file mode 100644 index 0000000000000000000000000000000000000000..e5e772f9a881bdce0e42f0a2862741da9c630830 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/step_distill.md @@ -0,0 +1,3 @@ +# 步数蒸馏 + +xxx diff --git a/docs/PAPERS_ZH_CN/source/papers/vae.md b/docs/PAPERS_ZH_CN/source/papers/vae.md new file mode 100644 index 0000000000000000000000000000000000000000..914d54db1b4b307acceba09a66d1eb5bcd9bbee3 --- /dev/null +++ b/docs/PAPERS_ZH_CN/source/papers/vae.md @@ -0,0 +1,3 @@ +# vae加速 + +xxx diff --git a/docs/ZH_CN/.readthedocs.yaml b/docs/ZH_CN/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3677c6092fd6559cf83d4663cfcf32df98d891a --- /dev/null +++ b/docs/ZH_CN/.readthedocs.yaml @@ -0,0 +1,17 @@ +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.10" + +formats: + - epub + +sphinx: + configuration: docs/ZH_CN/source/conf.py + +python: + install: + - requirements: requirements-docs.txt diff --git a/docs/ZH_CN/Makefile b/docs/ZH_CN/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293 --- /dev/null +++ b/docs/ZH_CN/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/ZH_CN/make.bat b/docs/ZH_CN/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a --- /dev/null +++ b/docs/ZH_CN/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/ZH_CN/source/conf.py b/docs/ZH_CN/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..659c3130dff364844704bfed3350f0ff797fd9f0 --- /dev/null +++ b/docs/ZH_CN/source/conf.py @@ -0,0 +1,128 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +import logging +import os +import sys +from typing import List + +import sphinxcontrib.redoc +from sphinx.ext import autodoc + +logger = logging.getLogger(__name__) +sys.path.append(os.path.abspath("../..")) + +# -- Project information ----------------------------------------------------- + +project = "Lightx2v" +copyright = "2025, Lightx2v Team" +author = "the Lightx2v Team" + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.intersphinx", + "sphinx_copybutton", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.mathjax", + "myst_parser", + "sphinxarg.ext", + "sphinxcontrib.redoc", + "sphinxcontrib.openapi", +] + +myst_enable_extensions = [ + "dollarmath", + "amsmath", +] + +html_static_path = ["_static"] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns: List[str] = ["**/*.template.rst"] + +# Exclude the prompt "$" when copying code +copybutton_prompt_text = r"\$ " +copybutton_prompt_is_regexp = True + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_title = project +html_theme = "sphinx_book_theme" +# html_theme = 'sphinx_rtd_theme' +html_logo = "../../../assets/img_lightx2v.png" +html_theme_options = { + "path_to_docs": "docs/ZH_CN/source", + "repository_url": "https://github.com/ModelTC/lightx2v", + "use_repository_button": True, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ['_static'] + + +# Generate additional rst documentation here. +def setup(app): + # from docs.source.generate_examples import generate_examples + # generate_examples() + pass + + +# Mock out external dependencies here. +autodoc_mock_imports = [ + "cpuinfo", + "torch", + "transformers", + "psutil", + "prometheus_client", + "sentencepiece", + "lightllmnumpy", + "tqdm", + "tensorizer", +] + +for mock_target in autodoc_mock_imports: + if mock_target in sys.modules: + logger.info( + "Potentially problematic mock target (%s) found; autodoc_mock_imports cannot mock modules that have already been loaded into sys.modules when the sphinx build starts.", + mock_target, + ) + + +class MockedClassDocumenter(autodoc.ClassDocumenter): + """Remove note about base class when a class is derived from object.""" + + def add_line(self, line: str, source: str, *lineno: int) -> None: + if line == " Bases: :py:class:`object`": + return + super().add_line(line, source, *lineno) + + +autodoc.ClassDocumenter = MockedClassDocumenter + +navigation_with_keys = False diff --git a/docs/ZH_CN/source/deploy_guides/deploy_comfyui.md b/docs/ZH_CN/source/deploy_guides/deploy_comfyui.md new file mode 100644 index 0000000000000000000000000000000000000000..9355731c754462092d18b36cc46dca1c0761f8c4 --- /dev/null +++ b/docs/ZH_CN/source/deploy_guides/deploy_comfyui.md @@ -0,0 +1,25 @@ +# ComfyUI 部署 + +## ComfyUI-Lightx2vWrapper + +LightX2V 的官方 ComfyUI 集成节点已经发布在独立仓库中,提供了完整的模块化配置系统和优化功能。 + +### 项目地址 + +- GitHub: [https://github.com/ModelTC/ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper) + +### 主要特性 + +- 模块化配置系统:为视频生成的各个方面提供独立节点 +- 支持文生视频(T2V)和图生视频(I2V)两种生成模式 +- 高级优化功能: + - TeaCache 加速(最高 3 倍加速) + - 量化支持(int8、fp8) + - CPU 卸载内存优化 + - 轻量级 VAE 选项 +- LoRA 支持:可链式组合多个 LoRA 模型 +- 多模型支持:wan2.1、hunyuan 等架构 + +### 安装和使用 + +请访问上述 GitHub 仓库查看详细的安装说明、使用教程和示例工作流。 diff --git a/docs/ZH_CN/source/deploy_guides/deploy_gradio.md b/docs/ZH_CN/source/deploy_guides/deploy_gradio.md new file mode 100644 index 0000000000000000000000000000000000000000..a2f16b57c8144ace18a01bd7001dddc2a1dc69a5 --- /dev/null +++ b/docs/ZH_CN/source/deploy_guides/deploy_gradio.md @@ -0,0 +1,241 @@ +# Gradio 部署指南 + +## 📖 概述 + +Lightx2v 是一个轻量级的视频推理和生成引擎,提供基于 Gradio 的 Web 界面,支持图像到视频(Image-to-Video)和文本到视频(Text-to-Video)两种生成模式。 + +对于Windows系统,我们提供了便捷的一键部署方式,支持自动环境配置和智能参数优化。详细操作请参考[一键启动Gradio](./deploy_local_windows.md/#一键启动gradio推荐)章节。 + +![Gradio中文界面](../../../../assets/figs/portabl_windows/pic_gradio_zh.png) + +## 📁 文件结构 + +``` +LightX2V/app/ +├── gradio_demo.py # 英文界面演示 +├── gradio_demo_zh.py # 中文界面演示 +├── run_gradio.sh # 启动脚本 +├── README.md # 说明文档 +├── outputs/ # 生成视频保存目录 +└── inference_logs.log # 推理日志 +``` + +本项目包含两个主要演示文件: +- `gradio_demo.py` - 英文界面版本 +- `gradio_demo_zh.py` - 中文界面版本 + +## 🚀 快速开始 + +### 环境要求 + +按照[快速开始文档](../getting_started/quickstart.md)安装环境 + +#### 推荐优化库配置 + +- ✅ [Flash attention](https://github.com/Dao-AILab/flash-attention) +- ✅ [Sage attention](https://github.com/thu-ml/SageAttention) +- ✅ [vllm-kernel](https://github.com/vllm-project/vllm) +- ✅ [sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) +- ✅ [q8-kernel](https://github.com/KONAKONA666/q8_kernels) (仅支持ADA架构的GPU) + +可根据需要,按照各算子的项目主页教程进行安装。 + +### 📥 模型下载 + +可参考[模型结构文档](../getting_started/model_structure.md)下载完整模型(包含量化和非量化版本)或仅下载量化/非量化版本。 + +#### wan2.1 模型目录结构 + +``` +models/ +├── wan2.1_i2v_720p_lightx2v_4step.safetensors # 原始精度 +├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化 +├── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 量化 +├── wan2.1_i2v_720p_int8_lightx2v_4step_split # INT8 量化分block存储目录 +├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split # FP8 量化分block存储目录 +├── 其他权重(例如t2v) +├── t5/clip/xlm-roberta-large/google # text和image encoder +├── vae/lightvae/lighttae # vae +└── config.json # 模型配置文件 +``` + +#### wan2.2 模型目录结构 + +``` +models/ +├── wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors # high noise 原始精度 +├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step_1030.safetensors # high noise FP8 量化 +├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030.safetensors # high noise INT8 量化 +├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030_split # high noise INT8 量化分block存储目录 +├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # low noise 原始精度 +├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # low noise FP8 量化 +├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # low noise INT8 量化 +├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step_split # low noise INT8 量化分block存储目录 +├── t5/clip/xlm-roberta-large/google # text和image encoder +├── vae/lightvae/lighttae # vae +└── config.json # 模型配置文件 +``` + +**📝 下载说明**: + +- 模型权重可从 HuggingFace 下载: + - [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) + - [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) +- Text 和 Image Encoder 可从 [Encoders](https://huggingface.co/lightx2v/Encoders) 下载 +- VAE 可从 [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) 下载 +- 对于 `xxx_split` 目录(例如 `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split`),即按照 block 存储的多个 safetensors,适用于内存不足的设备。例如内存 16GB 以内,请根据自身情况下载 + + +### 启动方式 + +#### 方式一:使用启动脚本(推荐) + +**Linux 环境:** +```bash +# 1. 编辑启动脚本,配置相关路径 +cd app/ +vim run_gradio.sh + +# 需要修改的配置项: +# - lightx2v_path: Lightx2v项目根目录路径 +# - model_path: 模型根目录路径(包含所有模型文件) + +# 💾 重要提示:建议将模型路径指向SSD存储位置 +# 例如:/mnt/ssd/models/ 或 /data/ssd/models/ + +# 2. 运行启动脚本 +bash run_gradio.sh + +# 3. 或使用参数启动 +bash run_gradio.sh --lang zh --port 8032 +bash run_gradio.sh --lang en --port 7862 +``` + +**Windows 环境:** +```cmd +# 1. 编辑启动脚本,配置相关路径 +cd app\ +notepad run_gradio_win.bat + +# 需要修改的配置项: +# - lightx2v_path: Lightx2v项目根目录路径 +# - model_path: 模型根目录路径(包含所有模型文件) + +# 💾 重要提示:建议将模型路径指向SSD存储位置 +# 例如:D:\models\ 或 E:\models\ + +# 2. 运行启动脚本 +run_gradio_win.bat + +# 3. 或使用参数启动 +run_gradio_win.bat --lang zh --port 8032 +run_gradio_win.bat --lang en --port 7862 +``` + +#### 方式二:直接命令行启动 + +```bash +pip install -v git+https://github.com/ModelTC/LightX2V.git +``` + +**Linux 环境:** + +**中文界面版本:** +```bash +python gradio_demo_zh.py \ + --model_path /path/to/models \ + --server_name 0.0.0.0 \ + --server_port 7862 +``` + +**英文界面版本:** +```bash +python gradio_demo.py \ + --model_path /path/to/models \ + --server_name 0.0.0.0 \ + --server_port 7862 +``` + +**Windows 环境:** + +**中文界面版本:** +```cmd +python gradio_demo_zh.py ^ + --model_path D:\models ^ + --server_name 127.0.0.1 ^ + --server_port 7862 +``` + +**英文界面版本:** +```cmd +python gradio_demo.py ^ + --model_path D:\models ^ + --server_name 127.0.0.1 ^ + --server_port 7862 +``` + +**💡 提示**:模型类型(wan2.1/wan2.2)、任务类型(i2v/t2v)以及具体的模型文件选择均在 Web 界面中进行配置。 + +## 📋 命令行参数 + +| 参数 | 类型 | 必需 | 默认值 | 说明 | +|------|------|------|--------|------| +| `--model_path` | str | ✅ | - | 模型根目录路径(包含所有模型文件的目录) | +| `--server_port` | int | ❌ | 7862 | 服务器端口 | +| `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 | +| `--output_dir` | str | ❌ | ./outputs | 输出视频保存目录 | + +**💡 说明**:模型类型(wan2.1/wan2.2)、任务类型(i2v/t2v)以及具体的模型文件选择均在 Web 界面中进行配置。 + +## 🎯 功能特性 + +### 模型配置 + +- **模型类型**: 支持 wan2.1 和 wan2.2 两种模型架构 +- **任务类型**: 支持图像到视频(i2v)和文本到视频(t2v)两种生成模式 +- **模型选择**: 前端自动识别并筛选可用的模型文件,支持自动检测量化精度 +- **编码器配置**: 支持选择 T5 文本编码器、CLIP 图像编码器和 VAE 解码器 +- **算子选择**: 支持多种注意力算子和量化矩阵乘法算子,系统会根据安装状态自动排序 + +### 输入参数 + +- **提示词 (Prompt)**: 描述期望的视频内容 +- **负向提示词 (Negative Prompt)**: 指定不希望出现的元素 +- **输入图像**: i2v 模式下需要上传输入图像 +- **分辨率**: 支持多种预设分辨率(480p/540p/720p) +- **随机种子**: 控制生成结果的随机性 +- **推理步数**: 影响生成质量和速度的平衡(蒸馏模型默认为 4 步) + +### 视频参数 + +- **FPS**: 每秒帧数 +- **总帧数**: 视频长度 +- **CFG缩放因子**: 控制提示词影响强度(1-10,蒸馏模型默认为 1) +- **分布偏移**: 控制生成风格偏离程度(0-10) + +## 🔧 自动配置功能 + +系统会根据您的硬件配置(GPU 显存和 CPU 内存)自动配置最优推理选项,无需手动调整。启动时会自动应用最佳配置,包括: + +- **GPU 内存优化**: 根据显存大小自动启用 CPU 卸载、VAE 分块推理等 +- **CPU 内存优化**: 根据系统内存自动启用延迟加载、模块卸载等 +- **算子选择**: 自动选择已安装的最优算子(按优先级排序) +- **量化配置**: 根据模型文件名自动检测并应用量化精度 + + +### 日志查看 + +```bash +# 查看推理日志 +tail -f inference_logs.log + +# 查看GPU使用情况 +nvidia-smi + +# 查看系统资源 +htop +``` + +欢迎提交Issue和Pull Request来改进这个项目! + +**注意**: 使用本工具生成的视频内容请遵守相关法律法规,不得用于非法用途。 diff --git a/docs/ZH_CN/source/deploy_guides/deploy_local_windows.md b/docs/ZH_CN/source/deploy_guides/deploy_local_windows.md new file mode 100644 index 0000000000000000000000000000000000000000..e2baabbe703a7df6fbe68e7f51c140248b7be527 --- /dev/null +++ b/docs/ZH_CN/source/deploy_guides/deploy_local_windows.md @@ -0,0 +1,129 @@ +# Windows 本地部署指南 + +## 📖 概述 + +本文档将详细指导您在Windows环境下完成LightX2V的本地部署配置,包括批处理文件推理、Gradio Web界面推理等多种使用方式。 + +## 🚀 快速开始 + +### 环境要求 + +#### 硬件要求 +- **GPU**: NVIDIA GPU,建议 8GB+ VRAM +- **内存**: 建议 16GB+ RAM +- **存储**: 强烈建议使用 SSD 固态硬盘,机械硬盘会导致模型加载缓慢 + +## 🎯 使用方式 + +### 方式一:使用批处理文件推理 + +参考[快速开始文档](../getting_started/quickstart.md)安装环境,并使用[批处理文件](https://github.com/ModelTC/LightX2V/tree/main/scripts/win)运行。 + +### 方式二:使用Gradio Web界面推理 + +#### 手动配置Gradio + +参考[快速开始文档](../getting_started/quickstart.md)安装环境,参考[Gradio部署指南](./deploy_gradio.md) + +#### 一键启动Gradio(推荐) + +**📦 下载软件包** +- [夸克网盘](https://pan.quark.cn/s/8af1162d7a15) + +**📁 目录结构** +解压后,确保目录结构如下: + +``` +├── env/ # LightX2V 环境目录 +├── LightX2V/ # LightX2V 项目目录 +├── start_lightx2v.bat # 一键启动脚本 +├── lightx2v_config.txt # 配置文件 +├── LightX2V使用说明.txt # LightX2V使用说明 +├── outputs/ # 生成的视频保存目录 +└── models/ # 模型存放目录 +``` + +**📥 下载模型**: + +可参考[模型结构文档](../getting_started/model_structure.md)或者[gradio部署文档](./deploy_gradio.md)下载完整模型(包含量化和非量化版本)或仅下载量化/非量化版本。 + + +**📋 配置参数** + +编辑 `lightx2v_config.txt` 文件,根据需要修改以下参数: + +```ini + +# 界面语言 (zh: 中文, en: 英文) +lang=zh + +# 服务器端口 +port=8032 + +# GPU设备ID (0, 1, 2...) +gpu=0 + +# 模型路径 +model_path=models/ +``` + +**🚀 启动服务** + +双击运行 `start_lightx2v.bat` 文件,脚本将: +1. 自动读取配置文件 +2. 验证模型路径和文件完整性 +3. 启动 Gradio Web 界面 +4. 自动打开浏览器访问服务 + + +![Gradio中文界面](../../../../assets/figs/portabl_windows/pic_gradio_zh.png) + +**⚠️ 重要提示**: +- **页面显示问题**: 如果网页打开空白或显示异常,请运行 `pip install --upgrade gradio` 升级Gradio版本。 + + +### 方式三:使用ComfyUI推理 + +此说明将指导您如何下载与使用便携版的Lightx2v-ComfyUI环境,如此可以免去手动配置环境的步骤,适用于想要在Windows系统下快速开始体验使用Lightx2v加速视频生成的用户。 + +#### 下载Windows便携环境: + +- [百度网盘下载](https://pan.baidu.com/s/1FVlicTXjmXJA1tAVvNCrBw?pwd=wfid),提取码:wfid + +便携环境中已经打包了所有Python运行相关的依赖,也包括ComfyUI和LightX2V的代码及其相关依赖,下载后解压即可使用。 + +解压后对应的文件目录说明如下: + +```shell +lightx2v_env +├──📂 ComfyUI # ComfyUI代码 +├──📂 portable_python312_embed # 独立的Python环境 +└── run_nvidia_gpu.bat # Windows启动脚本(双击启动) +``` + +#### 启动ComfyUI + +直接双击run_nvidia_gpu.bat文件,系统会打开一个Command Prompt窗口并运行程序,一般第一次启动时间会比较久,请耐心等待,启动完成后会自动打开浏览器并出现ComfyUI的前端界面。 + +![i2v示例工作流](../../../../assets/figs/portabl_windows/pic1.png) + +LightX2V-ComfyUI的插件使用的是,[ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper),示例工作流可以从此项目中获取。 + +#### 已测试显卡(offload模式) + +- 测试模型`Wan2.1-I2V-14B-480P` + +| 显卡型号 | 任务类型 | 显存容量 | 实际最大显存占用 | 实际最大内存占用 | +|:----------|:-----------|:-----------|:-------- |:---------- | +| 3090Ti | I2V | 24G | 6.1G | 7.1G | +| 3080Ti | I2V | 12G | 6.1G | 7.1G | +| 3060Ti | I2V | 8G | 6.1G | 7.1G | +| 4070Ti Super | I2V | 16G | 6.1G | 7.1G | +| 4070 | I2V | 12G | 6.1G | 7.1G | +| 4060 | I2V | 8G | 6.1G | 7.1G | + + + +#### 环境打包和使用参考 +- [ComfyUI](https://github.com/comfyanonymous/ComfyUI) +- [Portable-Windows-ComfyUI-Docs](https://docs.comfy.org/zh-CN/installation/comfyui_portable_windows#portable-%E5%8F%8A%E8%87%AA%E9%83%A8%E7%BD%B2) diff --git a/docs/ZH_CN/source/deploy_guides/deploy_service.md b/docs/ZH_CN/source/deploy_guides/deploy_service.md new file mode 100644 index 0000000000000000000000000000000000000000..b0111b32e8186661f862ecd339f0a2e59649b058 --- /dev/null +++ b/docs/ZH_CN/source/deploy_guides/deploy_service.md @@ -0,0 +1,88 @@ +# 服务化部署 + +lightx2v 提供异步服务功能。代码入口点在 [这里](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py) + +### 启动服务 + +```shell +# 修改脚本中的路径 +bash scripts/start_server.sh +``` + +`--port 8000` 选项表示服务将绑定到本地机器的 `8000` 端口。您可以根据需要更改此端口。 + +### 客户端发送请求 + +```shell +python scripts/post.py +``` + +服务端点:`/v1/tasks/` + +`scripts/post.py` 中的 `message` 参数如下: + +```python +message = { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "" +} +``` + +1. `prompt`、`negative_prompt` 和 `image_path` 是视频生成的基本输入。`image_path` 可以是空字符串,表示不需要图像输入。 + + +### 客户端检查服务器状态 + +```shell +python scripts/check_status.py +``` + +服务端点包括: + +1. `/v1/service/status` 用于检查服务状态。返回服务是 `busy` 还是 `idle`。服务只有在 `idle` 时才接受新请求。 + +2. `/v1/tasks/` 用于获取服务器接收和完成的所有任务。 + +3. `/v1/tasks/{task_id}/status` 用于获取指定 `task_id` 的任务状态。返回任务是 `processing` 还是 `completed`。 + +### 客户端随时停止服务器上的当前任务 + +```shell +python scripts/stop_running_task.py +``` + +服务端点:`/v1/tasks/running` + +终止任务后,服务器不会退出,而是返回等待新请求的状态。 + +### 在单个节点上启动多个服务 + +在单个节点上,您可以使用 `scripts/start_server.sh` 启动多个服务(注意同一 IP 下的端口号必须不同),或者可以使用 `scripts/start_multi_servers.sh` 同时启动多个服务: + +```shell +num_gpus=8 bash scripts/start_multi_servers.sh +``` + +其中 `num_gpus` 表示要启动的服务数量;服务将从 `--start_port` 开始在连续端口上运行。 + +### 多个服务之间的调度 + +```shell +python scripts/post_multi_servers.py +``` + +`post_multi_servers.py` 将根据服务的空闲状态调度多个客户端请求。 + +### API 端点总结 + +| 端点 | 方法 | 描述 | +|------|------|------| +| `/v1/tasks/` | POST | 创建视频生成任务 | +| `/v1/tasks/form` | POST | 通过表单创建视频生成任务 | +| `/v1/tasks/` | GET | 获取所有任务列表 | +| `/v1/tasks/{task_id}/status` | GET | 获取指定任务状态 | +| `/v1/tasks/{task_id}/result` | GET | 获取指定任务的结果视频文件 | +| `/v1/tasks/running` | DELETE | 停止当前运行的任务 | +| `/v1/files/download/{file_path}` | GET | 下载文件 | +| `/v1/service/status` | GET | 获取服务状态 | diff --git a/docs/ZH_CN/source/deploy_guides/for_low_latency.md b/docs/ZH_CN/source/deploy_guides/for_low_latency.md new file mode 100644 index 0000000000000000000000000000000000000000..b101f6494a579773cf80ba0e886e42f802357646 --- /dev/null +++ b/docs/ZH_CN/source/deploy_guides/for_low_latency.md @@ -0,0 +1,42 @@ +# 低延迟场景部署 + +在低延迟的场景,我们会追求更快的速度,忽略显存和内存开销等问题。我们提供两套方案: + +## 💡 方案一:步数蒸馏模型的推理 + +该方案可以参考[步数蒸馏文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/step_distill.html) + +🧠 **步数蒸馏**是非常直接的视频生成模型的加速推理方案。从50步蒸馏到4步,耗时将缩短到原来的4/50。同时,该方案下,仍然可以和以下方案结合使用: +1. [高效注意力机制方案](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html) +2. [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) + +## 💡 方案二:非步数蒸馏模型的推理 + +步数蒸馏需要比较大的训练资源,以及步数蒸馏后的模型,可能会出现视频动态范围变差的问题。 + +对于非步数蒸馏的原始模型,我们可以使用以下方案或者多种方案结合的方式进行加速: + +1. [并行推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/parallel.html) 进行多卡并行加速。 +2. [特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html) 降低实际推理步数。 +3. [高效注意力机制方案](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html) 加速 Attention 的推理。 +4. [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) 加速 Linear 层的推理。 +5. [变分辨率推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/changing_resolution.html) 降低中间推理步的分辨率。 + +## 💡 使用Tiny VAE + +在某些情况下,VAE部分耗时会比较大,可以使用轻量级VAE进行加速,同时也可以降低一部分显存。 + +```python +{ + "use_tae": true, + "tae_path": "/path to taew2_1.pth" +} +``` +taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载 + + +## ⚠️ 注意 + +有一部分的加速方案之间目前无法结合使用,我们目前正在致力于解决这一问题。 + +如有问题,欢迎在 [🐛 GitHub Issues](https://github.com/ModelTC/lightx2v/issues) 中进行错误报告或者功能请求 diff --git a/docs/ZH_CN/source/deploy_guides/for_low_resource.md b/docs/ZH_CN/source/deploy_guides/for_low_resource.md new file mode 100644 index 0000000000000000000000000000000000000000..c5e50dea8edddd5242bfc0c08bfbc1109e867438 --- /dev/null +++ b/docs/ZH_CN/source/deploy_guides/for_low_resource.md @@ -0,0 +1,223 @@ +# Lightx2v 低资源部署指南 + +## 📋 概述 + +本指南专门针对硬件资源受限的环境,特别是**8GB显存 + 16/32GB内存**的配置,详细说明如何成功运行Lightx2v 14B模型进行480p和720p视频生成。 + +Lightx2v是一个强大的视频生成模型,但在资源受限的环境下需要精心优化才能流畅运行。本指南将为您提供从硬件选择到软件配置的完整解决方案,确保您能够在有限的硬件条件下获得最佳的视频生成体验。 + +## 🎯 目标硬件配置详解 + +### 推荐硬件规格 + +**GPU要求**: +- **显存**: 8GB (RTX 3060/3070/4060/4060Ti 等) +- **架构**: 支持CUDA的NVIDIA显卡 + +**系统内存**: +- **最低要求**: 16GB DDR4 +- **推荐配置**: 32GB DDR4/DDR5 +- **内存速度**: 建议3200MHz及以上 + +**存储要求**: +- **类型**: 强烈推荐NVMe SSD +- **容量**: 至少50GB可用空间 +- **速度**: 读取速度建议3000MB/s以上 + +**CPU要求**: +- **核心数**: 建议8核心及以上 +- **频率**: 建议3.0GHz及以上 +- **架构**: 支持AVX2指令集 + +## ⚙️ 核心优化策略详解 + +### 1. 环境优化 + +在运行Lightx2v之前,建议设置以下环境变量以优化性能: + +```bash +# CUDA内存分配优化 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# 启用CUDA Graph模式,提升推理性能 +export ENABLE_GRAPH_MODE=true + +# 使用BF16精度推理,减少显存占用(默认FP32精度) +export DTYPE=BF16 +``` + +**优化说明**: +- `expandable_segments:True`: 允许CUDA内存段动态扩展,减少内存碎片 +- `ENABLE_GRAPH_MODE=true`: 启用CUDA Graph,减少内核启动开销 +- `DTYPE=BF16`: 使用BF16精度,在保持质量的同时减少显存占用 + +### 2. 量化策略 + +量化是低资源环境下的关键优化技术,通过降低模型精度来减少内存占用。 + +#### 量化方案对比 + +**FP8量化** (推荐用于RTX 40系列): +```python +# 适用于支持FP8的GPU,提供更好的精度 +dit_quant_scheme = "fp8" # DIT模型量化 +t5_quant_scheme = "fp8" # T5文本编码器量化 +clip_quant_scheme = "fp8" # CLIP视觉编码器量化 +``` + +**INT8量化** (通用方案): +```python +# 适用于所有GPU,内存占用最小 +dit_quant_scheme = "int8" # 8位整数量化 +t5_quant_scheme = "int8" # 文本编码器量化 +clip_quant_scheme = "int8" # 视觉编码器量化 +``` +### 3. 高效算子选择指南 + +选择合适的算子可以显著提升推理速度和减少内存占用。 + +#### 注意力算子选择 + +**推荐优先级**: +1. **[Sage Attention](https://github.com/thu-ml/SageAttention)** (最高优先级) + +2. **[Flash Attention](https://github.com/Dao-AILab/flash-attention)** (通用方案) + + +#### 矩阵乘算子选择 + +**ADA架构显卡** (RTX 40系列): + +推荐优先级: +1. **[q8-kernel](https://github.com/KONAKONA666/q8_kernels)** (最高性能,仅支持ADA架构) +2. **[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)** (平衡方案) +3. **[vllm-kernel](https://github.com/vllm-project/vllm)** (通用方案) + +**其他架构显卡**: +1. **[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)** (推荐) +2. **[vllm-kernel](https://github.com/vllm-project/vllm)** (备选) + +### 4. 参数卸载策略详解 + +参数卸载技术允许模型在CPU和磁盘之间动态调度参数,突破显存限制。 + +#### 三级卸载架构 + +```python +# 磁盘-CPU-GPU三级卸载配置 +cpu_offload=True # 启用CPU卸载 +t5_cpu_offload=True # 启用T5编码器CPU卸载 +offload_granularity=phase # DIT模型细粒度卸载 +t5_offload_granularity=block # T5编码器细粒度卸载 +lazy_load = True # 启用延迟加载机制 +num_disk_workers = 2 # 磁盘I/O工作线程数 +``` + +#### 卸载策略详解 + +**延迟加载机制**: +- 模型参数按需从磁盘加载到CPU +- 减少运行时内存占用 +- 支持大模型在有限内存下运行 + +**磁盘存储优化**: +- 使用高速SSD存储模型参数 +- 按照block分组存储模型文件 +- 参考转换脚本[文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md),转换时指定`--save_by_block`参数 + +### 5. 显存优化技术详解 + +针对720p视频生成的显存优化策略。 + +#### CUDA内存管理 + +```python +# CUDA内存清理配置 +clean_cuda_cache = True # 及时清理GPU缓存 +rotary_chunk = True # 旋转位置编码分块计算 +rotary_chunk_size = 100 # 分块大小,可根据显存调整 +``` + +#### 分块计算策略 + +**旋转位置编码分块**: +- 将长序列分成小块处理 +- 减少峰值显存占用 +- 保持计算精度 + +### 6. VAE优化详解 + +VAE (变分自编码器) 是视频生成的关键组件,优化VAE可以显著提升性能。 + +#### VAE分块推理 + +```python +# VAE优化配置 +use_tiling_vae = True # 启用VAE分块推理 +``` + +#### 轻量级VAE + +```python +# VAE优化配置 +use_tae = True +tae_path = "/path to taew2_1.pth" +``` +taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载 + +**VAE优化效果**: +- 标准VAE: 基准性能,100%质量保持 +- 标准VAE分块: 降低显存,增加推理时间,100%质量保持 +- 轻量VAE: 极低显存,视频质量有损 + + +### 7. 模型选择策略 + +选择合适的模型版本对低资源环境至关重要。 + +#### 推荐模型对比 + +**蒸馏模型** (强烈推荐): +- ✅ **[Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v)** + +- ✅ **[Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v)** + + +#### 性能优化建议 + +使用上述蒸馏模型时,可以进一步优化性能: +- 关闭CFG: `"enable_cfg": false` +- 减少推理步数: `infer_step: 4` +- 参考配置文件: [config](https://github.com/ModelTC/LightX2V/tree/main/configs/distill) + +## 🚀 完整配置示例 + +### 预配置模板 + +- **[14B模型480p视频生成配置](https://github.com/ModelTC/lightx2v/tree/main/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json)** + +- **[14B模型720p视频生成配置](https://github.com/ModelTC/lightx2v/tree/main/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json)** + +- **[1.3B模型720p视频生成配置](https://github.com/ModelTC/LightX2V/tree/main/configs/offload/block/wan_t2v_1_3b.json)** + - 1.3B模型推理瓶颈是T5 encoder,配置文件专门针对T5进行优化 + +**[启动脚本](https://github.com/ModelTC/LightX2V/tree/main/scripts/wan/run_wan_i2v_lazy_load.sh)** + + +## 📚 参考资源 + +- [参数卸载机制文档](../method_tutorials/offload.md) - 深入了解卸载技术原理 +- [量化技术指南](../method_tutorials/quantization.md) - 量化技术详细说明 +- [Gradio部署指南](deploy_gradio.md) - Gradio部署详细说明 + +## ⚠️ 重要注意事项 + +1. **硬件要求**: 确保您的硬件满足最低配置要求 +2. **驱动版本**: 建议使用最新的NVIDIA驱动 (535+) +3. **CUDA版本**: 确保CUDA版本与PyTorch兼容 (建议CUDA 11.8+) +4. **存储空间**: 预留足够的磁盘空间用于模型缓存 (至少50GB) +5. **网络环境**: 首次下载模型需要稳定的网络连接 +6. **环境变量**: 务必设置推荐的环境变量以优化性能 + + +**技术支持**: 如遇到问题,请提交Issue到项目仓库。 diff --git a/docs/ZH_CN/source/deploy_guides/lora_deploy.md b/docs/ZH_CN/source/deploy_guides/lora_deploy.md new file mode 100644 index 0000000000000000000000000000000000000000..c75b71dc88c0d2a7f042ad2d06d77de10295f455 --- /dev/null +++ b/docs/ZH_CN/source/deploy_guides/lora_deploy.md @@ -0,0 +1,213 @@ +# LoRA 模型部署与相关工具 + +LoRA (Low-Rank Adaptation) 是一种高效的模型微调技术,通过低秩矩阵分解显著减少可训练参数数量。LightX2V 全面支持 LoRA 技术,包括 LoRA 推理、LoRA 提取和 LoRA 合并等功能。 + +## 🎯 LoRA 技术特性 + +- **灵活部署**:支持动态加载和移除 LoRA 权重 +- **多种格式**:支持多种 LoRA 权重格式和命名约定 +- **工具完善**:提供完整的 LoRA 提取、合并工具链 + +## 📜 LoRA 推理部署 + +### 配置文件方式 + +在配置文件中指定 LoRA 路径: + +```json +{ + "lora_configs": [ + { + "path": "/path/to/your/lora.safetensors", + "strength": 1.0 + } + ] +} +``` + +**配置参数说明:** + +- `lora_path`: LoRA 权重文件路径列表,支持多个 LoRA 同时加载 +- `strength_model`: LoRA 强度系数 (alpha),控制 LoRA 对原模型的影响程度 + +### 命令行方式 + +直接在命令行中指定 LoRA 路径(仅支持加载单个 LoRA): + +```bash +python -m lightx2v.infer \ + --model_cls wan2.1 \ + --task t2v \ + --model_path /path/to/model \ + --config_json /path/to/config.json \ + --lora_path /path/to/your/lora.safetensors \ + --lora_strength 0.8 \ + --prompt "Your prompt here" +``` + +### 多LoRA配置 + +要使用多个具有不同强度的LoRA,请在配置JSON文件中指定: + +```json +{ + "lora_configs": [ + { + "path": "/path/to/first_lora.safetensors", + "strength": 0.8 + }, + { + "path": "/path/to/second_lora.safetensors", + "strength": 0.5 + } + ] +} +``` + +### 支持的 LoRA 格式 + +LightX2V 支持多种 LoRA 权重命名约定: + +| 格式类型 | 权重命名 | 说明 | +|----------|----------|------| +| **标准 LoRA** | `lora_A.weight`, `lora_B.weight` | 标准的 LoRA 矩阵分解格式 | +| **Down/Up 格式** | `lora_down.weight`, `lora_up.weight` | 另一种常见的命名约定 | +| **差值格式** | `diff` | `weight` 权重差值 | +| **偏置差值** | `diff_b` | `bias` 权重差值 | +| **调制差值** | `diff_m` | `modulation` 权重差值 | + +### 推理脚本示例 + +**步数蒸馏 LoRA 推理:** + +```bash +# T2V LoRA 推理 +bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh + +# I2V LoRA 推理 +bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh +``` + +**音频驱动 LoRA 推理:** + +```bash +bash scripts/wan/run_wan_i2v_audio.sh +``` + +### API 服务中使用 LoRA + +在 API 服务中通过 [config 文件](wan_t2v_distill_4step_cfg_lora.json) 指定,对 [scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh) 中的启动命令进行修改: + +```bash +python -m lightx2v.api_server \ + --model_cls wan2.1_distill \ + --task t2v \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora.json \ + --port 8000 \ + --nproc_per_node 1 +``` + +## 🔧 LoRA 提取工具 + +使用 `tools/extract/lora_extractor.py` 从两个模型的差异中提取 LoRA 权重。 + +### 基本用法 + +```bash +python tools/extract/lora_extractor.py \ + --source-model /path/to/base/model \ + --target-model /path/to/finetuned/model \ + --output /path/to/extracted/lora.safetensors \ + --rank 32 +``` + +### 参数说明 + +| 参数 | 类型 | 必需 | 默认值 | 说明 | +|------|------|------|--------|------| +| `--source-model` | str | ✅ | - | 基础模型路径 | +| `--target-model` | str | ✅ | - | 微调后模型路径 | +| `--output` | str | ✅ | - | 输出 LoRA 文件路径 | +| `--source-type` | str | ❌ | `safetensors` | 基础模型格式 (`safetensors`/`pytorch`) | +| `--target-type` | str | ❌ | `safetensors` | 微调模型格式 (`safetensors`/`pytorch`) | +| `--output-format` | str | ❌ | `safetensors` | 输出格式 (`safetensors`/`pytorch`) | +| `--rank` | int | ❌ | `32` | LoRA 秩值 | +| `--output-dtype` | str | ❌ | `bf16` | 输出数据类型 | +| `--diff-only` | bool | ❌ | `False` | 仅保存权重差值,不进行 LoRA 分解 | + +### 高级用法示例 + +**提取高秩 LoRA:** + +```bash +python tools/extract/lora_extractor.py \ + --source-model /path/to/base/model \ + --target-model /path/to/finetuned/model \ + --output /path/to/high_rank_lora.safetensors \ + --rank 64 \ + --output-dtype fp16 +``` + +**仅保存权重差值:** + +```bash +python tools/extract/lora_extractor.py \ + --source-model /path/to/base/model \ + --target-model /path/to/finetuned/model \ + --output /path/to/weight_diff.safetensors \ + --diff-only +``` + +## 🔀 LoRA 合并工具 + +使用 `tools/extract/lora_merger.py` 将 LoRA 权重合并到基础模型中,以进行后续量化等操作。 + +### 基本用法 + +```bash +python tools/extract/lora_merger.py \ + --source-model /path/to/base/model \ + --lora-model /path/to/lora.safetensors \ + --output /path/to/merged/model.safetensors \ + --alpha 1.0 +``` + +### 参数说明 + +| 参数 | 类型 | 必需 | 默认值 | 说明 | +|------|------|------|--------|------| +| `--source-model` | str | ✅ | 无 | 基础模型路径 | +| `--lora-model` | str | ✅ | 无 | LoRA 权重路径 | +| `--output` | str | ✅ | 无 | 输出合并模型路径 | +| `--source-type` | str | ❌ | `safetensors` | 基础模型格式 | +| `--lora-type` | str | ❌ | `safetensors` | LoRA 权重格式 | +| `--output-format` | str | ❌ | `safetensors` | 输出格式 | +| `--alpha` | float | ❌ | `1.0` | LoRA 合并强度 | +| `--output-dtype` | str | ❌ | `bf16` | 输出数据类型 | + +### 高级用法示例 + +**部分强度合并:** + +```bash +python tools/extract/lora_merger.py \ + --source-model /path/to/base/model \ + --lora-model /path/to/lora.safetensors \ + --output /path/to/merged_model.safetensors \ + --alpha 0.7 \ + --output-dtype fp32 +``` + +**多格式支持:** + +```bash +python tools/extract/lora_merger.py \ + --source-model /path/to/base/model.pt \ + --source-type pytorch \ + --lora-model /path/to/lora.safetensors \ + --lora-type safetensors \ + --output /path/to/merged_model.safetensors \ + --output-format safetensors \ + --alpha 1.0 +``` diff --git a/docs/ZH_CN/source/getting_started/benchmark.md b/docs/ZH_CN/source/getting_started/benchmark.md new file mode 100644 index 0000000000000000000000000000000000000000..6adc116c38878349b9175f2e23c39cc67606c5a8 --- /dev/null +++ b/docs/ZH_CN/source/getting_started/benchmark.md @@ -0,0 +1,3 @@ +# 基准测试 + +由于要展示一些视频的播放效果和详细的性能对比,您可以在这个[🔗 页面](https://github.com/ModelTC/LightX2V/blob/main/docs/ZH_CN/source/getting_started/benchmark_source.md)获得更好的展示效果以及相对应的文档内容。 diff --git a/docs/ZH_CN/source/getting_started/benchmark_source.md b/docs/ZH_CN/source/getting_started/benchmark_source.md new file mode 100644 index 0000000000000000000000000000000000000000..c399d1837cee6a435b042f30b2af5b870cb53503 --- /dev/null +++ b/docs/ZH_CN/source/getting_started/benchmark_source.md @@ -0,0 +1,149 @@ +# 🚀 基准测试 + +> 本文档展示了LightX2V在不同硬件环境下的性能测试结果,包括H200和RTX 4090平台的详细对比数据。 + +--- + +## 🖥️ H200 环境 (~140GB显存) + +### 📋 软件环境配置 + +| 组件 | 版本 | +|:-----|:-----| +| **Python** | 3.11 | +| **PyTorch** | 2.7.1+cu128 | +| **SageAttention** | 2.2.0 | +| **vLLM** | 0.9.2 | +| **sgl-kernel** | 0.1.8 | + +--- + +### 🎬 480P 5s视频测试 + +**测试配置:** +- **模型**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) +- **参数**: `infer_steps=40`, `seed=42`, `enable_cfg=True` + +#### 📊 性能对比表 + +| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 | +|:-----|:----------:|:---------------:|:------:|:--------:| +| **Wan2.1 Official** | 366 | 71 | 1.0x | | +| **FastVideo** | 292 | 26 | **1.25x** | | +| **LightX2V_1** | 250 | 53 | **1.46x** | | +| **LightX2V_2** | 216 | 50 | **1.70x** | | +| **LightX2V_3** | 191 | 35 | **1.92x** | | +| **LightX2V_3-Distill** | 14 | 35 | **🏆 20.85x** | | +| **LightX2V_4** | 107 | 35 | **3.41x** | | + +--- + +### 🎬 720P 5s视频测试 + +**测试配置:** +- **模型**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) +- **参数**: `infer_steps=40`, `seed=1234`, `enable_cfg=True` + +#### 📊 性能对比表 + +| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 | +|:-----|:----------:|:---------------:|:------:|:--------:| +| **Wan2.1 Official** | 974 | 81 | 1.0x | | +| **FastVideo** | 914 | 40 | **1.07x** | | +| **LightX2V_1** | 807 | 65 | **1.21x** | | +| **LightX2V_2** | 751 | 57 | **1.30x** | | +| **LightX2V_3** | 671 | 43 | **1.45x** | | +| **LightX2V_3-Distill** | 44 | 43 | **🏆 22.14x** | | +| **LightX2V_4** | 344 | 46 | **2.83x** | | + +--- + +## 🖥️ RTX 4090 环境 (~24GB显存) + +### 📋 软件环境配置 + +| 组件 | 版本 | +|:-----|:-----| +| **Python** | 3.9.16 | +| **PyTorch** | 2.5.1+cu124 | +| **SageAttention** | 2.1.0 | +| **vLLM** | 0.6.6 | +| **sgl-kernel** | 0.0.5 | +| **q8-kernels** | 0.0.0 | + +--- + +### 🎬 480P 5s视频测试 + +**测试配置:** +- **模型**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) +- **参数**: `infer_steps=40`, `seed=42`, `enable_cfg=True` + +#### 📊 性能对比表 + +| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 | +|:-----|:----------:|:---------------:|:------:|:--------:| +| **Wan2GP(profile=3)** | 779 | 20 | **1.0x** | | +| **LightX2V_5** | 738 | 16 | **1.05x** | | +| **LightX2V_5-Distill** | 68 | 16 | **11.45x** | | +| **LightX2V_6** | 630 | 12 | **1.24x** | | +| **LightX2V_6-Distill** | 63 | 12 | **🏆 12.36x** | | + +--- + +### 🎬 720P 5s视频测试 + +**测试配置:** +- **模型**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) +- **参数**: `infer_steps=40`, `seed=1234`, `enable_cfg=True` + +#### 📊 性能对比表 + +| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 | +|:-----|:----------:|:---------------:|:------:|:--------:| +| **Wan2GP(profile=3)** | -- | OOM | -- | | +| **LightX2V_5** | 2473 | 23 | -- | | +| **LightX2V_5-Distill** | 183 | 23 | -- | | +| **LightX2V_6** | 2169 | 18 | -- | | +| **LightX2V_6-Distill** | 171 | 18 | -- | | + +--- + +## 📖 配置说明 + +### 🖥️ H200 环境配置说明 + +| 配置 | 技术特点 | +|:-----|:---------| +| **Wan2.1 Official** | 基于[Wan2.1官方仓库](https://github.com/Wan-Video/Wan2.1)的原始实现 | +| **FastVideo** | 基于[FastVideo官方仓库](https://github.com/hao-ai-lab/FastVideo),使用SageAttention2后端优化 | +| **LightX2V_1** | 使用SageAttention2替换原生注意力机制,采用DIT BF16+FP32(部分敏感层)混合精度计算,在保持精度的同时提升计算效率 | +| **LightX2V_2** | 统一使用BF16精度计算,进一步减少显存占用和计算开销,同时保持生成质量 | +| **LightX2V_3** | 引入FP8量化技术显著减少计算精度要求,结合Tiling VAE技术优化显存使用 | +| **LightX2V_3-Distill** | 在LightX2V_3基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 | +| **LightX2V_4** | 在LightX2V_3基础上加入TeaCache(teacache_thresh=0.2)缓存复用技术,通过智能跳过冗余计算实现加速 | + +### 🖥️ RTX 4090 环境配置说明 + +| 配置 | 技术特点 | +|:-----|:---------| +| **Wan2GP(profile=3)** | 基于[Wan2GP仓库](https://github.com/deepbeepmeep/Wan2GP)实现,使用MMGP优化技术。profile=3配置适用于至少32GB内存和24GB显存的RTX 3090/4090环境,通过牺牲显存来适应有限的内存资源。使用量化模型:[480P模型](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors)和[720P模型](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors) | +| **LightX2V_5** | 使用SageAttention2替换原生注意力机制,采用DIT FP8+FP32(部分敏感层)混合精度计算,启用CPU offload技术,将部分敏感层执行FP32精度计算,将DIT推理过程中异步数据卸载到CPU上,节省显存,offload粒度为block级别 | +| **LightX2V_5-Distill** | 在LightX2V_5基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 | +| **LightX2V_6** | 在LightX2V_3基础上启用CPU offload技术,将部分敏感层执行FP32精度计算,将DIT推理过程中异步数据卸载到CPU上,节省显存,offload粒度为block级别 | +| **LightX2V_6-Distill** | 在LightX2V_6基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 | + +--- + +## 📁 配置文件参考 + +基准测试相关的配置文件和运行脚本可在以下位置获取: + +| 类型 | 链接 | 说明 | +|:-----|:-----|:-----| +| **配置文件** | [configs/bench](https://github.com/ModelTC/LightX2V/tree/main/configs/bench) | 包含各种优化配置的JSON文件 | +| **运行脚本** | [scripts/bench](https://github.com/ModelTC/LightX2V/tree/main/scripts/bench) | 包含基准测试的执行脚本 | + +--- + +> 💡 **提示**: 建议根据您的硬件配置选择合适的优化方案,以获得最佳的性能表现。 diff --git a/docs/ZH_CN/source/getting_started/model_structure.md b/docs/ZH_CN/source/getting_started/model_structure.md new file mode 100644 index 0000000000000000000000000000000000000000..8ceb48a107ef321c12829b487b2c5b1e9dcd1154 --- /dev/null +++ b/docs/ZH_CN/source/getting_started/model_structure.md @@ -0,0 +1,571 @@ +# 模型格式与加载指南 + +## 📖 概述 + +LightX2V 是一个灵活的视频生成推理框架,支持多种模型来源和格式,为用户提供丰富的选择: + +- ✅ **Wan 官方模型**:直接兼容 Wan2.1 和 Wan2.2 官方发布的完整模型 +- ✅ **单文件模型**:支持 LightX2V 发布的单文件格式模型(包含量化版本) +- ✅ **LoRA 模型**:支持加载 LightX2V 发布的蒸馏 LoRA + +本文档将详细介绍各种模型格式的使用方法、配置参数和最佳实践。 + +--- + +## 🗂️ 格式一:Wan 官方模型 + +### 模型仓库 +- [Wan2.1 Collection](https://huggingface.co/collections/Wan-AI/wan21-68ac4ba85372ae5a8e282a1b) +- [Wan2.2 Collection](https://huggingface.co/collections/Wan-AI/wan22-68ac4ae80a8b477e79636fc8) + +### 模型特点 +- **官方保证**:Wan-AI 官方发布的完整模型,质量最高 +- **完整组件**:包含所有必需的组件(DIT、T5、CLIP、VAE) +- **原始精度**:使用 BF16/FP32 精度,无量化损失 +- **兼容性强**:与 Wan 官方工具链完全兼容 + +### Wan2.1 官方模型 + +#### 目录结构 + +以 [Wan2.1-I2V-14B-720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 为例: + +``` +Wan2.1-I2V-14B-720P/ +├── diffusion_pytorch_model-00001-of-00007.safetensors # DIT 模型分片 1 +├── diffusion_pytorch_model-00002-of-00007.safetensors # DIT 模型分片 2 +├── diffusion_pytorch_model-00003-of-00007.safetensors # DIT 模型分片 3 +├── diffusion_pytorch_model-00004-of-00007.safetensors # DIT 模型分片 4 +├── diffusion_pytorch_model-00005-of-00007.safetensors # DIT 模型分片 5 +├── diffusion_pytorch_model-00006-of-00007.safetensors # DIT 模型分片 6 +├── diffusion_pytorch_model-00007-of-00007.safetensors # DIT 模型分片 7 +├── diffusion_pytorch_model.safetensors.index.json # 分片索引文件 +├── models_t5_umt5-xxl-enc-bf16.pth # T5 文本编码器 +├── models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth # CLIP 编码器 +├── Wan2.1_VAE.pth # VAE 编解码器 +├── config.json # 模型配置 +├── xlm-roberta-large/ # CLIP tokenizer +├── google/ # T5 tokenizer +├── assets/ +└── examples/ +``` + +#### 使用方法 + +```bash +# 下载模型 +huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P \ + --local-dir ./models/Wan2.1-I2V-14B-720P + +# 配置启动脚本 +model_path=./models/Wan2.1-I2V-14B-720P +lightx2v_path=/path/to/LightX2V + +# 运行推理 +cd LightX2V/scripts +bash wan/run_wan_i2v.sh +``` + +### Wan2.2 官方模型 + +#### 目录结构 + +以 [Wan2.2-I2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) 为例: + +``` +Wan2.2-I2V-A14B/ +├── high_noise_model/ # 高噪声模型目录 +│ ├── diffusion_pytorch_model-00001-of-00009.safetensors +│ ├── diffusion_pytorch_model-00002-of-00009.safetensors +│ ├── ... +│ ├── diffusion_pytorch_model-00009-of-00009.safetensors +│ └── diffusion_pytorch_model.safetensors.index.json +├── low_noise_model/ # 低噪声模型目录 +│ ├── diffusion_pytorch_model-00001-of-00009.safetensors +│ ├── diffusion_pytorch_model-00002-of-00009.safetensors +│ ├── ... +│ ├── diffusion_pytorch_model-00009-of-00009.safetensors +│ └── diffusion_pytorch_model.safetensors.index.json +├── models_t5_umt5-xxl-enc-bf16.pth # T5 文本编码器 +├── Wan2.1_VAE.pth # VAE 编解码器 +├── configuration.json # 模型配置 +├── google/ # T5 tokenizer +├── assets/ # 示例资源(可选) +└── examples/ # 示例文件(可选) +``` + +#### 使用方法 + +```bash +# 下载模型 +huggingface-cli download Wan-AI/Wan2.2-I2V-A14B \ + --local-dir ./models/Wan2.2-I2V-A14B + +# 配置启动脚本 +model_path=./models/Wan2.2-I2V-A14B +lightx2v_path=/path/to/LightX2V + +# 运行推理 +cd LightX2V/scripts +bash wan22/run_wan22_moe_i2v.sh +``` + +### 可用模型列表 + +#### Wan2.1 官方模型列表 + +| 模型名称 | 下载链接 | +|---------|----------| +| Wan2.1-I2V-14B-720P | [链接](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) | +| Wan2.1-I2V-14B-480P | [链接](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) | +| Wan2.1-T2V-14B | [链接](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) | +| Wan2.1-T2V-1.3B | [链接](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) | +| Wan2.1-FLF2V-14B-720P | [链接](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) | +| Wan2.1-VACE-14B | [链接](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) | +| Wan2.1-VACE-1.3B | [链接](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) | + +#### Wan2.2 官方模型列表 + +| 模型名称 | 下载链接 | +|---------|----------| +| Wan2.2-I2V-A14B | [链接](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) | +| Wan2.2-T2V-A14B | [链接](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B) | +| Wan2.2-TI2V-5B | [链接](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B) | +| Wan2.2-Animate-14B | [链接](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B) | + +### 使用提示 + +> 💡 **量化模型使用**:如需使用量化模型,可参考[模型转换脚本](https://github.com/ModelTC/LightX2V/blob/main/tools/convert/readme_zh.md)进行转换,或直接使用下方格式二中的预转换量化模型 +> +> 💡 **显存优化**:对于 RTX 4090 24GB 或更小显存的设备,建议结合量化技术和 CPU 卸载功能: +> - 量化配置:参考[量化技术文档](../method_tutorials/quantization.md) +> - CPU 卸载:参考[参数卸载文档](../method_tutorials/offload.md) +> - Wan2.1 配置:参考 [offload 配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/offload) +> - Wan2.2 配置:参考 [wan22 配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/wan22) 中以 `4090` 结尾的配置 + +--- + +## 🗂️ 格式二:LightX2V 单文件模型(推荐) + +### 模型仓库 +- [Wan2.1-LightX2V](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- [Wan2.2-LightX2V](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) + +### 模型特点 +- **单文件管理**:单个 safetensors 文件,易于管理和部署 +- **多精度支持**:提供原始精度、FP8、INT8 等多种精度版本 +- **蒸馏加速**:支持 4-step 快速推理 +- **工具兼容**:兼容 ComfyUI 等其他工具 + +**示例**: +- `wan2.1_i2v_720p_lightx2v_4step.safetensors` - 720P 图生视频原始精度 +- `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors` - 720P 图生视频 FP8 量化 +- `wan2.1_i2v_480p_int8_lightx2v_4step.safetensors` - 480P 图生视频 INT8 量化 +- ... + +### Wan2.1 单文件模型 + +#### 场景 A:下载单个模型文件 + +**步骤 1:选择并下载模型** + +```bash +# 创建模型目录 +mkdir -p ./models/wan2.1_i2v_720p + +# 下载 720P 图生视频 FP8 量化模型 +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models/wan2.1_i2v_720p \ + --include "wan2.1_i2v_720p_lightx2v_4step.safetensors" +``` + +**步骤 2:手动组织其他模块** + +目录结构如下 +``` +wan2.1_i2v_720p/ +├── wan2.1_i2v_720p_lightx2v_4step.safetensors # 原始精度 +└── t5/clip/vae/config.json/xlm-roberta-large/google等其他组件 # 需要手动组织 +``` + +**步骤 3:配置启动脚本** + +```bash +# 在启动脚本中设置(指向包含模型文件的目录) +model_path=./models/wan2.1_i2v_720p +lightx2v_path=/path/to/LightX2V + +# 运行脚本 +cd LightX2V/scripts +bash wan/run_wan_i2v_distill_4step_cfg.sh +``` + +> 💡 **提示**:当目录下只有一个模型文件时,LightX2V 会自动加载该文件。 + +#### 场景 B:下载多个模型文件 + +当您下载了多个不同精度的模型到同一目录时,需要在配置文件中明确指定使用哪个模型。 + +**步骤 1:下载多个模型** + +```bash +# 创建模型目录 +mkdir -p ./models/wan2.1_i2v_720p_multi + +# 下载原始精度模型 +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models/wan2.1_i2v_720p_multi \ + --include "wan2.1_i2v_720p_lightx2v_4step.safetensors" + +# 下载 FP8 量化模型 +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models/wan2.1_i2v_720p_multi \ + --include "wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors" + +# 下载 INT8 量化模型 +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models/wan2.1_i2v_720p_multi \ + --include "wan2.1_i2v_720p_int8_lightx2v_4step.safetensors" +``` + +**步骤 2:手动组织其他模块** + +目录结构如下: + +``` +wan2.1_i2v_720p_multi/ +├── wan2.1_i2v_720p_lightx2v_4step.safetensors # 原始精度 +├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化 +└── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 量化 +└── t5/clip/vae/config.json/xlm-roberta-large/google等其他组件 # 需要手动组织 +``` + +**步骤 3:在配置文件中指定模型** + +编辑配置文件(如 `configs/distill/wan_i2v_distill_4step_cfg.json`): + +```json +{ + // 使用原始精度模型 + "dit_original_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_lightx2v_4step.safetensors", + + // 或使用 FP8 量化模型 + // "dit_quantized_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors", + // "dit_quantized": true, + // "dit_quant_scheme": "fp8-vllm", + + // 或使用 INT8 量化模型 + // "dit_quantized_ckpt": "./models/wan2.1_i2v_720p_multi/wan2.1_i2v_720p_int8_lightx2v_4step.safetensors", + // "dit_quantized": true, + // "dit_quant_scheme": "int8-vllm", + + // 其他配置... +} +``` +### 使用提示 + +> 💡 **配置参数说明**: +> - **dit_original_ckpt**:用于指定原始精度模型(BF16/FP32/FP16)的路径 +> - **dit_quantized_ckpt**:用于指定量化模型(FP8/INT8)的路径,需配合 `dit_quantized` 和 `dit_quant_scheme` 参数使用 + +**步骤 4:启动推理** + +```bash +cd LightX2V/scripts +bash wan/run_wan_i2v_distill_4step_cfg.sh +``` + +### Wan2.2 单文件模型 + +#### 目录结构要求 + +使用 Wan2.2 单文件模型时,需要手动创建特定的目录结构: + +``` +wan2.2_models/ +├── high_noise_model/ # 高噪声模型目录(必须) +│ └── wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors # 高噪声模型文件 +└── low_noise_model/ # 低噪声模型目录(必须) +│ └── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # 低噪声模型文件 +└── t5/vae/config.json/xlm-roberta-large/google等其他组件 # 需要手动组织 +``` + +#### 场景 A:每个目录下只有一个模型文件 + +```bash +# 创建必需的子目录 +mkdir -p ./models/wan2.2_models/high_noise_model +mkdir -p ./models/wan2.2_models/low_noise_model + +# 下载高噪声模型到对应目录 +huggingface-cli download lightx2v/Wan2.2-Distill-Models \ + --local-dir ./models/wan2.2_models/high_noise_model \ + --include "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" + +# 下载低噪声模型到对应目录 +huggingface-cli download lightx2v/Wan2.2-Distill-Models \ + --local-dir ./models/wan2.2_models/low_noise_model \ + --include "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" + +# 配置启动脚本(指向父目录) +model_path=./models/wan2.2_models +lightx2v_path=/path/to/LightX2V + +# 运行脚本 +cd LightX2V/scripts +bash wan22/run_wan22_moe_i2v_distill.sh +``` + +> 💡 **提示**:当每个子目录下只有一个模型文件时,LightX2V 会自动加载。 + +#### 场景 B:每个目录下有多个模型文件 + +当您在 `high_noise_model/` 和 `low_noise_model/` 目录下分别放置了多个不同精度的模型时,需要在配置文件中明确指定。 + +```bash +# 创建目录 +mkdir -p ./models/wan2.2_models_multi/high_noise_model +mkdir -p ./models/wan2.2_models_multi/low_noise_model + +# 下载高噪声模型的多个版本 +huggingface-cli download lightx2v/Wan2.2-Distill-Models \ + --local-dir ./models/wan2.2_models_multi/high_noise_model \ + --include "wan2.2_i2v_A14b_high_noise_*.safetensors" + +# 下载低噪声模型的多个版本 +huggingface-cli download lightx2v/Wan2.2-Distill-Models \ + --local-dir ./models/wan2.2_models_multi/low_noise_model \ + --include "wan2.2_i2v_A14b_low_noise_*.safetensors" +``` + +**目录结构**: + +``` +wan2.2_models_multi/ +├── high_noise_model/ +│ ├── wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors # 原始精度 +│ ├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化 +│ └── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors # INT8 量化 +└── low_noise_model/ +│ ├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # 原始精度 +│ ├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化 +│ └── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # INT8 量化 +└── t5/vae/config.json/xlm-roberta-large/google等其他组件 # 需要手动组织 +``` + +**配置文件设置**: + +```json +{ + // 使用原始精度模型 + "high_noise_original_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors", + + // 或使用 FP8 量化模型 + // "high_noise_quantized_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step.safetensors", + // "low_noise_quantized_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors", + // "dit_quantized": true, + // "dit_quant_scheme": "fp8-vllm" + + // 或使用 INT8 量化模型 + // "high_noise_quantized_ckpt": "./models/wan2.2_models_multi/high_noise_model/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + // "low_noise_quantized_ckpt": "./models/wan2.2_models_multi/low_noise_model/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + // "dit_quantized": true, + // "dit_quant_scheme": "int8-vllm" +} +``` + +### 使用提示 + +> 💡 **配置参数说明**: +> - **high_noise_original_ckpt** / **low_noise_original_ckpt**:用于指定原始精度模型(BF16/FP32/FP16)的路径 +> - **high_noise_quantized_ckpt** / **low_noise_quantized_ckpt**:用于指定量化模型(FP8/INT8)的路径,需配合 `dit_quantized` 和 `dit_quant_scheme` 参数使用 + + +### 可用模型列表 + +#### Wan2.1 单文件模型列表 + +**图生视频模型(I2V)** + +| 文件名 | 精度 | 说明 | +|--------|------|------| +| `wan2.1_i2v_480p_lightx2v_4step.safetensors` | BF16 | 4步模型原始精度 | +| `wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4步模型FP8 量化 | +| `wan2.1_i2v_480p_int8_lightx2v_4step.safetensors` | INT8 | 4步模型INT8 量化 | +| `wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4步模型ComfyUI 格式 | +| `wan2.1_i2v_720p_lightx2v_4step.safetensors` | BF16 | 4步模型原始精度 | +| `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4步模型FP8 量化 | +| `wan2.1_i2v_720p_int8_lightx2v_4step.safetensors` | INT8 | 4步模型INT8 量化 | +| `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4步模型ComfyUI 格式 | + +**文生视频模型(T2V)** + +| 文件名 | 精度 | 说明 | +|--------|------|------| +| `wan2.1_t2v_14b_lightx2v_4step.safetensors` | BF16 | 4步模型原始精度 | +| `wan2.1_t2v_14b_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 4步模型FP8 量化 | +| `wan2.1_t2v_14b_int8_lightx2v_4step.safetensors` | INT8 | 4步模型INT8 量化 | +| `wan2.1_t2v_14b_scaled_fp8_e4m3_lightx2v_4step_comfyui.safetensors` | FP8 | 4步模型ComfyUI 格式 | + +#### Wan2.2 单文件模型列表 + +**图生视频模型(I2V)- A14B 系列** + +| 文件名 | 精度 | 说明 | +|--------|------|------| +| `wan2.2_i2v_A14b_high_noise_lightx2v_4step.safetensors` | BF16 | 高噪声模型-4步原始精度 | +| `wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 高噪声模型-4步FP8量化 | +| `wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors` | INT8 | 高噪声模型-4步INT8量化 | +| `wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors` | BF16 | 低噪声模型-4步原始精度 | +| `wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors` | FP8 | 低噪声模型-4步FP8量化 | +| `wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors` | INT8 | 低噪声模型-4步INT8量化 | + +> 💡 **使用提示**: +> - Wan2.2 模型采用双噪声架构,需要同时下载高噪声(high_noise)和低噪声(low_noise)模型 +> - 详细的目录组织方式请参考上方"Wan2.2 单文件模型"部分 + +--- + +## 🗂️ 格式三:LightX2V LoRA 模型 + +LoRA(Low-Rank Adaptation)模型提供了一种轻量级的模型微调方案,可以在不修改基础模型的情况下实现特定效果的定制化。 + +### 模型仓库 + +- **Wan2.1 LoRA 模型**:[lightx2v/Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras) +- **Wan2.2 LoRA 模型**:[lightx2v/Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras) + +### 使用方式 + +#### 方式一:离线合并 + +将 LoRA 权重离线合并到基础模型中,生成新的完整模型文件。 + +**操作步骤**: + +参考 [模型转换文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) 进行离线合并。 + +**优点**: +- ✅ 推理时无需额外加载 LoRA +- ✅ 性能更优 + +**缺点**: +- ❌ 需要额外存储空间 +- ❌ 切换不同 LoRA 需要重新合并 + +#### 方式二:在线加载 + +在推理时动态加载 LoRA 权重,无需修改基础模型。 + +**LoRA 应用原理**: + +```python +# LoRA 权重应用公式 +# lora_scale = (alpha / rank) +# W' = W + lora_scale * B @ A +# 其中:B = up_proj (out_features, rank) +# A = down_proj (rank, in_features) + +if weights_dict["alpha"] is not None: + lora_scale = weights_dict["alpha"] / lora_down.shape[0] +elif alpha is not None: + lora_scale = alpha / lora_down.shape[0] +else: + lora_scale = 1.0 +``` + +**配置方法**: + +**Wan2.1 LoRA 配置**: + +```json +{ + "lora_configs": [ + { + "path": "wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0, + "alpha": null + } + ] +} +``` + +**Wan2.2 LoRA 配置**: + +由于 Wan2.2 采用双模型架构(高噪声/低噪声),需要分别为两个模型配置 LoRA: + +```json +{ + "lora_configs": [ + { + "name": "low_noise_model", + "path": "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0, + "alpha": null + }, + { + "name": "high_noise_model", + "path": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step.safetensors", + "strength": 1.0, + "alpha": null + } + ] +} +``` + +**参数说明**: + +| 参数 | 说明 | 默认值 | +|------|------|--------| +| `path` | LoRA 模型文件路径 | 必填 | +| `strength` | LoRA 强度系数,范围 [0.0, 1.0] | 1.0 | +| `alpha` | LoRA 缩放因子,`null` 时使用模型内置值 | null | +| `name` | (仅 Wan2.2)指定应用到哪个模型 | 必填 | + +**优点**: +- ✅ 灵活切换不同 LoRA +- ✅ 节省存储空间 +- ✅ 可动态调整 LoRA 强度 + +**缺点**: +- ❌ 推理时需额外加载时间 +- ❌ 略微增加显存占用 + +--- + +## 📚 相关资源 + +### 官方仓库 +- [LightX2V GitHub](https://github.com/ModelTC/LightX2V) +- [LightX2V 单文件模型仓库](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- [Wan-AI 官方模型仓库](https://huggingface.co/Wan-AI) + +### 模型下载链接 + +**Wan2.1 系列** +- [Wan2.1 Collection](https://huggingface.co/collections/Wan-AI/wan21-68ac4ba85372ae5a8e282a1b) + +**Wan2.2 系列** +- [Wan2.2 Collection](https://huggingface.co/collections/Wan-AI/wan22-68ac4ae80a8b477e79636fc8) + +**LightX2V 单文件模型** +- [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) + +### 文档链接 +- [量化技术文档](../method_tutorials/quantization.md) +- [参数卸载文档](../method_tutorials/offload.md) +- [配置文件示例](https://github.com/ModelTC/LightX2V/tree/main/configs) + +--- + +通过本文档,您应该能够: + +✅ 理解 LightX2V 支持的所有模型格式 +✅ 根据需求选择合适的模型和精度 +✅ 正确下载和组织模型文件 +✅ 配置启动参数并成功运行推理 +✅ 解决常见的模型加载问题 + +如有其他问题,欢迎在 [GitHub Issues](https://github.com/ModelTC/LightX2V/issues) 中提问。 diff --git a/docs/ZH_CN/source/getting_started/quickstart.md b/docs/ZH_CN/source/getting_started/quickstart.md new file mode 100644 index 0000000000000000000000000000000000000000..a62a1939cbc4a1f65169d4a2e11b49b2f70295d0 --- /dev/null +++ b/docs/ZH_CN/source/getting_started/quickstart.md @@ -0,0 +1,352 @@ +# LightX2V 快速入门指南 + +欢迎使用 LightX2V!本指南将帮助您快速搭建环境并开始使用 LightX2V 进行视频生成。 + +## 📋 目录 + +- [系统要求](#系统要求) +- [Linux 系统环境搭建](#linux-系统环境搭建) + - [Docker 环境(推荐)](#docker-环境推荐) + - [Conda 环境搭建](#conda-环境搭建) +- [Windows 系统环境搭建](#windows-系统环境搭建) +- [推理使用](#推理使用) + +## 🚀 系统要求 + +- **操作系统**: Linux (Ubuntu 18.04+) 或 Windows 10/11 +- **Python**: 3.10 或更高版本 +- **GPU**: NVIDIA GPU,支持 CUDA,至少 8GB 显存 +- **内存**: 建议 16GB 或更多 +- **存储**: 至少 50GB 可用空间 + +## 🐧 Linux 系统环境搭建 + +### 🐳 Docker 环境(推荐) + +我们强烈推荐使用 Docker 环境,这是最简单快捷的安装方式。 + +#### 1. 拉取镜像 + +访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25111101-cu128`: + +```bash +docker pull lightx2v/lightx2v:25111101-cu128 +``` + +我们推荐使用`cuda128`环境,以获得更快的推理速度,若需要使用`cuda124`环境,可以使用带`-cu124`后缀的镜像版本: + +```bash +docker pull lightx2v/lightx2v:25101501-cu124 +``` + +#### 2. 运行容器 + +```bash +docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --entrypoint /bin/bash [镜像id] +``` + +#### 3. 中国镜像源(可选) + +对于中国大陆地区,如果拉取镜像时网络不稳定,可以从阿里云上拉取: + +```bash +# cuda128 +docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25111101-cu128 + +# cuda124 +docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25101501-cu124 +``` + +### 🐍 Conda 环境搭建 + +如果您希望使用 Conda 自行搭建环境,请按照以下步骤操作: + +#### 步骤 1: 克隆项目 + +```bash +# 下载项目代码 +git clone https://github.com/ModelTC/LightX2V.git +cd LightX2V +``` + +#### 步骤 2: 创建 conda 虚拟环境 + +```bash +# 创建并激活 conda 环境 +conda create -n lightx2v python=3.11 -y +conda activate lightx2v +``` + +#### 步骤 3: 安装依赖及代码 + +```bash +pip install -v -e . +``` + +#### 步骤 4: 安装注意力机制算子 + +**选项 A: Flash Attention 2** +```bash +git clone https://github.com/Dao-AILab/flash-attention.git --recursive +cd flash-attention && python setup.py install +``` + +**选项 B: Flash Attention 3(用于 Hopper 架构显卡)** +```bash +cd flash-attention/hopper && python setup.py install +``` + +**选项 C: SageAttention 2(推荐)** +```bash +git clone https://github.com/thu-ml/SageAttention.git +cd SageAttention && CUDA_ARCHITECTURES="8.0,8.6,8.9,9.0,12.0" EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install -v -e . +``` + +#### 步骤 4: 安装量化算子(可选) + +量化算子用于支持模型量化功能,可以显著降低显存占用并加速推理。根据您的需求选择合适的量化算子: + +**选项 A: VLLM Kernels(推荐)** +适用于多种量化方案,支持 FP8 等量化格式。 + +```bash +pip install vllm +``` + +或者从源码安装以获得最新功能: + +```bash +git clone https://github.com/vllm-project/vllm.git +cd vllm +uv pip install -e . +``` + +**选项 B: SGL Kernels** +适用于 SGL 量化方案,需要 torch == 2.8.0。 + +```bash +pip install sgl-kernel --upgrade +``` + +**选项 C: Q8 Kernels** +适用于 Ada 架构显卡(如 RTX 4090、L40S 等)。 + +```bash +git clone https://github.com/KONAKONA666/q8_kernels.git +cd q8_kernels && git submodule init && git submodule update +python setup.py install +``` + +> 💡 **提示**: +> - 如果不需要使用量化功能,可以跳过此步骤 +> - 量化模型可以从 [LightX2V HuggingFace](https://huggingface.co/lightx2v) 下载 +> - 更多量化相关信息请参考 [量化文档](method_tutorials/quantization.html) + +#### 步骤 5: 验证安装 +```python +import lightx2v +print(f"LightX2V 版本: {lightx2v.__version__}") +``` + +## 🪟 Windows 系统环境搭建 + +Windows 系统仅支持 Conda 环境搭建方式,请按照以下步骤操作: + +### 🐍 Conda 环境搭建 + +#### 步骤 1: 检查 CUDA 版本 + +首先确认您的 GPU 驱动和 CUDA 版本: + +```cmd +nvidia-smi +``` + +记录输出中的 **CUDA Version** 信息,后续安装时需要保持版本一致。 + +#### 步骤 2: 创建 Python 环境 + +```cmd +# 创建新环境(推荐 Python 3.12) +conda create -n lightx2v python=3.12 -y + +# 激活环境 +conda activate lightx2v +``` + +> 💡 **提示**: 建议使用 Python 3.10 或更高版本以获得最佳兼容性。 + +#### 步骤 3: 安装 PyTorch 框架 + +**方法一:下载官方 wheel 包(推荐)** + +1. 访问 [PyTorch 官方下载页面](https://download.pytorch.org/whl/torch/) +2. 选择对应版本的 wheel 包,注意匹配: + - **Python 版本**: 与您的环境一致 + - **CUDA 版本**: 与您的 GPU 驱动匹配 + - **平台**: 选择 Windows 版本 + +**示例(Python 3.12 + PyTorch 2.6 + CUDA 12.4):** + +```cmd +# 下载并安装 PyTorch +pip install torch-2.6.0+cu124-cp312-cp312-win_amd64.whl + +# 安装配套包 +pip install torchvision==0.21.0 torchaudio==2.6.0 +``` + +**方法二:使用 pip 直接安装** + +```cmd +# CUDA 12.4 版本示例 +pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124 +``` + +#### 步骤 4: 安装 Windows 版 vLLM + +从 [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases) 下载对应的 wheel 包。 + +**版本匹配要求:** +- Python 版本匹配 +- PyTorch 版本匹配 +- CUDA 版本匹配 + +```cmd +# 安装 vLLM(请根据实际文件名调整) +pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl +``` + +#### 步骤 5: 安装注意力机制算子 + +**选项 A: Flash Attention 2** + +```cmd +pip install flash-attn==2.7.2.post1 +``` + +**选项 B: SageAttention 2(强烈推荐)** + +**下载源:** +- [Windows 专用版本 1](https://github.com/woct0rdho/SageAttention/releases) +- [Windows 专用版本 2](https://github.com/sdbds/SageAttention-for-windows/releases) + +```cmd +# 安装 SageAttention(请根据实际文件名调整) +pip install sageattention-2.1.1+cu126torch2.6.0-cp312-cp312-win_amd64.whl +``` + +> ⚠️ **注意**: SageAttention 的 CUDA 版本可以不严格对齐,但 Python 和 PyTorch 版本必须匹配。 + +#### 步骤 6: 克隆项目 + +```cmd +# 克隆项目代码 +git clone https://github.com/ModelTC/LightX2V.git +cd LightX2V + +# 安装 Windows 专用依赖 +pip install -r requirements_win.txt +pip install -v -e . +``` + +#### 步骤 7: 安装量化算子(可选) + +量化算子用于支持模型量化功能,可以显著降低显存占用并加速推理。 + +**安装 VLLM(推荐):** + +从 [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases) 下载对应的 wheel 包并安装。 + +```cmd +# 安装 vLLM(请根据实际文件名调整) +pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl +``` + +> 💡 **提示**: +> - 如果不需要使用量化功能,可以跳过此步骤 +> - 量化模型可以从 [LightX2V HuggingFace](https://huggingface.co/lightx2v) 下载 +> - 更多量化相关信息请参考 [量化文档](method_tutorials/quantization.html) + +#### 步骤 8: 验证安装 +```python +import lightx2v +print(f"LightX2V 版本: {lightx2v.__version__}") +``` + +## 🎯 推理使用 + +### 📥 模型准备 + +在开始推理之前,您需要提前下载好模型文件。我们推荐: + +- **下载源**: 从 [LightX2V 官方 Hugging Face](https://huggingface.co/lightx2v/)或者其他开源模型库下载模型 +- **存储位置**: 建议将模型存储在 SSD 磁盘上以获得更好的读取性能 +- **可用模型**: 包括 Wan2.1-I2V、Wan2.1-T2V 等多种模型,支持不同分辨率和功能 + +### 📁 配置文件与脚本 + +推理会用到的配置文件都在[这里](https://github.com/ModelTC/LightX2V/tree/main/configs),脚本都在[这里](https://github.com/ModelTC/LightX2V/tree/main/scripts)。 + +需要将下载的模型路径配置到运行脚本中。除了脚本中的输入参数,`--config_json` 指向的配置文件中也会包含一些必要参数,您可以根据需要自行修改。 + +### 🚀 开始推理 + +#### Linux 环境 + +```bash +# 修改脚本中的路径后运行 +bash scripts/wan/run_wan_t2v.sh +``` + +#### Windows 环境 + +```cmd +# 使用 Windows 批处理脚本 +scripts\win\run_wan_t2v.bat +``` +#### Python脚本启动 + +```python +from lightx2v import LightX2VPipeline + +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.1-T2V-14B", + model_cls="wan2.1", + task="t2v", +) + +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=50, + height=480, # 720 + width=832, # 1280 + num_frames=81, + guidance_scale=5.0, + sample_shift=5.0, +) + +seed = 42 +prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +save_result_path="/path/to/save_results/output.mp4" + +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) +``` + + +## 📞 获取帮助 + +如果您在安装或使用过程中遇到问题,请: + +1. 在 [GitHub Issues](https://github.com/ModelTC/LightX2V/issues) 中搜索相关问题 +2. 提交新的 Issue 描述您的问题 + +--- + +🎉 **恭喜!** 现在您已经成功搭建了 LightX2V 环境,可以开始享受视频生成的乐趣了! diff --git a/docs/ZH_CN/source/index.rst b/docs/ZH_CN/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..3ea4ebd747337c67518da537f3833879d0e4187a --- /dev/null +++ b/docs/ZH_CN/source/index.rst @@ -0,0 +1,68 @@ +欢迎了解 Lightx2v! +================== + +.. figure:: ../../../assets/img_lightx2v.png + :width: 80% + :align: center + :alt: Lightx2v + :class: no-scaled-link + +.. raw:: html + +
+ + License + Ask DeepWiki + Doc + Doc + Docker + +
+ +
+ LightX2V: 一个轻量级的视频生成推理框架 +
+ + +LightX2V 是一个轻量级的视频生成推理框架,集成多种先进的视频生成推理技术,统一支持 文本生成视频 (T2V)、图像生成视频 (I2V) 等多种生成任务及模型。X2V 表示将不同的输入模态(X,如文本或图像)转换(to)为视频输出(V)。 + +GitHub: https://github.com/ModelTC/lightx2v + +HuggingFace: https://huggingface.co/lightx2v + +文档列表 +------------- + +.. toctree:: + :maxdepth: 1 + :caption: 快速入门 + + 快速入门 + 模型结构 + 基准测试 + +.. toctree:: + :maxdepth: 1 + :caption: 方法教程 + + 模型量化 + 特征缓存 + 注意力机制 + 参数卸载 + 并行推理 + 变分辨率推理 + 步数蒸馏 + 自回归蒸馏 + 视频帧插值 + +.. toctree:: + :maxdepth: 1 + :caption: 部署指南 + + 低延迟场景部署 + 低资源场景部署 + Lora模型部署 + 服务化部署 + Gradio部署 + ComfyUI部署 + 本地windows电脑部署 diff --git a/docs/ZH_CN/source/method_tutorials/attention.md b/docs/ZH_CN/source/method_tutorials/attention.md new file mode 100644 index 0000000000000000000000000000000000000000..06f9570579cdfb4042e3cfdcf6f63329bf4a9554 --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/attention.md @@ -0,0 +1,35 @@ +# 注意力机制 + +## LightX2V支持的注意力机制 + +| 名称 | 类型名称 | GitHub 链接 | +|--------------------|------------------|-------------| +| Flash Attention 2 | `flash_attn2` | [flash-attention v2](https://github.com/Dao-AILab/flash-attention) | +| Flash Attention 3 | `flash_attn3` | [flash-attention v3](https://github.com/Dao-AILab/flash-attention) | +| Sage Attention 2 | `sage_attn2` | [SageAttention](https://github.com/thu-ml/SageAttention) | +| Radial Attention | `radial_attn` | [Radial Attention](https://github.com/mit-han-lab/radial-attention) | +| Sparge Attention | `sparge_ckpt` | [Sparge Attention](https://github.com/thu-ml/SpargeAttn) | + +--- + +## 配置示例 + +注意力机制的config文件在[这里](https://github.com/ModelTC/lightx2v/tree/main/configs/attentions) + +通过指定--config_json到具体的config文件,即可以测试不同的注意力机制 + +比如对于radial_attn,配置如下: + +```json +{ + "self_attn_1_type": "radial_attn", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3" +} +``` + +如需更换为其他类型,只需将对应值替换为上述表格中的类型名称即可。 + +tips: radial_attn因为稀疏算法原理的限制只能用在self attention + +如需进一步定制注意力机制的行为,请参考各注意力库的官方文档或实现代码。 diff --git a/docs/ZH_CN/source/method_tutorials/autoregressive_distill.md b/docs/ZH_CN/source/method_tutorials/autoregressive_distill.md new file mode 100644 index 0000000000000000000000000000000000000000..19845ce07f44e7b6ea3d15e7255637cc4d13bea4 --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/autoregressive_distill.md @@ -0,0 +1,53 @@ +# 自回归蒸馏 + +自回归蒸馏是 LightX2V 中的一个技术探索,通过训练蒸馏模型将推理步数从原始的 40-50 步减少到 **8 步**,在实现推理加速的同时能够通过 KV Cache 技术生成无限长视频。 + +> ⚠️ 警告:目前自回归蒸馏的效果一般,加速效果也没有达到预期,但是可以作为一个长期的研究项目。目前 LightX2V 仅支持 T2V 的自回归模型。 + +## 🔍 技术原理 + +自回归蒸馏通过 [CausVid](https://github.com/tianweiy/CausVid) 技术实现。CausVid 针对 1.3B 的自回归模型进行步数蒸馏、CFG蒸馏。LightX2V 在其基础上,进行了一系列扩展: + +1. **更大的模型**:支持 14B 模型的自回归蒸馏训练; +2. **更完整的数据处理流程**:生成 50,000 个提示词-视频对的训练数据集; + +具体实现可参考 [CausVid-Plus](https://github.com/GoatWu/CausVid-Plus)。 + +## 🛠️ 配置文件说明 + +### 配置文件 + +在 [configs/causvid/](https://github.com/ModelTC/lightx2v/tree/main/configs/causvid) 目录下提供了配置选项: + +| 配置文件 | 模型地址 | +|----------|------------| +| [wan_t2v_causvid.json](https://github.com/ModelTC/lightx2v/blob/main/configs/causvid/wan_t2v_causvid.json) | https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid | + +### 关键配置参数 + +```json +{ + "enable_cfg": false, // 关闭CFG以提升速度 + "num_fragments": 3, // 一次生成视频的段数,每段5s + "num_frames": 21, // 每段视频的帧数,谨慎修改! + "num_frame_per_block": 3, // 每个自回归块的帧数,谨慎修改! + "num_blocks": 7, // 每段视频的自回归块数,谨慎修改! + "frame_seq_length": 1560, // 每帧的编码长度,谨慎修改! + "denoising_step_list": [ // 去噪时间步列表 + 999, 934, 862, 756, 603, 410, 250, 140, 74 + ] +} +``` + +## 📜 使用方法 + +### 模型准备 + +将下载好的模型(`causal_model.pt` 或者 `causal_model.safetensors`)放到 Wan 模型根目录的 `causvid_models/` 文件夹下即可 +- 对于 T2V:`Wan2.1-T2V-14B/causvid_models/` + +### 推理脚本 + +```bash +bash scripts/wan/run_wan_t2v_causvid.sh +``` diff --git a/docs/ZH_CN/source/method_tutorials/cache.md b/docs/ZH_CN/source/method_tutorials/cache.md new file mode 100644 index 0000000000000000000000000000000000000000..a867119c81d3f779a66105feed3be7df695a2ea5 --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/cache.md @@ -0,0 +1,3 @@ +# 特征缓存 + +由于要展示一些视频的播放效果,你可以在这个[🔗 页面](https://github.com/ModelTC/LightX2V/blob/main/docs/ZH_CN/source/method_tutorials/cache_source.md)获得更好的展示效果以及相对应的文档内容。 diff --git a/docs/ZH_CN/source/method_tutorials/cache_source.md b/docs/ZH_CN/source/method_tutorials/cache_source.md new file mode 100644 index 0000000000000000000000000000000000000000..37624af05f9d2ad63a65d071891a62a9911b28d8 --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/cache_source.md @@ -0,0 +1,139 @@ +# 特征缓存 + +## 缓存加速算法 +- 在扩散模型的推理过程中,缓存复用是一种重要的加速算法。 +- 其核心思想是在部分时间步跳过冗余计算,通过复用历史缓存结果提升推理效率。 +- 算法的关键在于如何决策在哪些时间步进行缓存复用,通常基于模型状态变化或误差阈值动态判断。 +- 在推理过程中,需要缓存如中间特征、残差、注意力输出等关键内容。当进入可复用时间步时,直接利用已缓存的内容,通过泰勒展开等近似方法重构当前输出,从而减少重复计算,实现高效推理。 + +### TeaCache +`TeaCache`的核心思想是通过对相邻时间步输入的**相对L1**距离进行累加,当累计距离达到设定阈值时,判定当前时间步不使用缓存复用;相反,当累计距离未达到设定阈值时则使用缓存复用加速推理过程。 +- 具体来说,算法在每一步推理时计算当前输入与上一步输入的相对L1距离,并将其累加。 +- 当累计距离未超过阈值,说明模型状态变化不明显,则直接复用最近一次缓存的内容,跳过部分冗余计算。这样可以显著减少模型的前向计算次数,提高推理速度。 + +实际效果上,TeaCache 在保证生成质量的前提下,实现了明显的加速。在单卡H200上,加速前后的用时与视频对比如下: + + + + + + + + + + +
+ 加速前:58s + + 加速后:17.9s +
+ + + +
+ + +- 加速比为:**3.24** +- config:[wan_t2v_1_3b_tea_480p.json](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json) +- 参考论文:[https://arxiv.org/abs/2411.19108](https://arxiv.org/abs/2411.19108) + +### TaylorSeer Cache +`TaylorSeer Cache`的核心在于利用泰勒公式对缓存内容进行再次计算,作为缓存复用时间步的残差补偿。 +- 具体做法是在缓存复用的时间步,不仅简单地复用历史缓存,还通过泰勒展开对当前输出进行近似重构。这样可以在减少计算量的同时,进一步提升输出的准确性。 +- 泰勒展开能够有效捕捉模型状态的微小变化,使得缓存复用带来的误差得到补偿,从而在加速的同时保证生成质量。 + +`TaylorSeer Cache`适用于对输出精度要求较高的场景,能够在缓存复用的基础上进一步提升模型推理的表现。 + + + + + + + + + + +
+ 加速前:57.7s + + 加速后:41.3s +
+ + + +
+ + +- 加速比为:**1.39** +- config:[wan_t2v_taylorseer](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/taylorseer/wan_t2v_taylorseer.json) +- 参考论文:[https://arxiv.org/abs/2503.06923](https://arxiv.org/abs/2503.06923) + +### AdaCache +`AdaCache`的核心思想是根据指定block块中的部分缓存内容,动态调整缓存复用的步长。 +- 算法会分析相邻两个时间步在特定 block 内的特征差异,根据差异大小自适应地决定下一个缓存复用的时间步间隔。 +- 当模型状态变化较小时,步长自动加大,减少缓存更新频率;当状态变化较大时,步长缩小,保证输出质量。 + +这样可以根据实际推理过程中的动态变化,灵活调整缓存策略,实现更高效的加速和更优的生成效果。AdaCache 适合对推理速度和生成质量都有较高要求的应用场景。 + + + + + + + + + + +
+ 加速前:227s + + 加速后:83s +
+ + + +
+ + +- 加速比为:**2.73** +- config:[wan_i2v_ada](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/adacache/wan_i2v_ada.json) +- 参考论文:[https://arxiv.org/abs/2411.02397](https://arxiv.org/abs/2411.02397) + +### CustomCache +`CustomCache`综合了`TeaCache`和`TaylorSeer Cache`的优势。 +- 它结合了`TeaCache`在缓存决策上的实时性和合理性,通过动态阈值判断何时进行缓存复用. +- 同时利用`TaylorSeer`的泰勒展开方法对已缓存内容进行利用。 + +这样不仅能够高效地决定缓存复用的时机,还能最大程度地利用缓存内容,提升输出的准确性和生成质量。实际测试表明,`CustomCache`在多个内容生成任务上,生成的视频质量优于单独使用`TeaCache、TaylorSeer Cache`或`AdaCache`的方案,是目前综合性能最优的缓存加速算法之一。 + + + + + + + + + + +
+ 加速前:57.9s + + 加速后:16.6s +
+ + + +
+ + +- 加速比为:**3.49** +- config:[wan_t2v_custom_1_3b](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/custom/wan_t2v_custom_1_3b.json) + + +## 使用方式 + +特征缓存的config文件在[这里](https://github.com/ModelTC/lightx2v/tree/main/configs/caching) + +通过指定--config_json到具体的config文件,即可以测试不同的cache算法 + +[这里](https://github.com/ModelTC/lightx2v/tree/main/scripts/cache)有一些运行脚本供使用。 diff --git a/docs/ZH_CN/source/method_tutorials/changing_resolution.md b/docs/ZH_CN/source/method_tutorials/changing_resolution.md new file mode 100644 index 0000000000000000000000000000000000000000..e4a51490d032bcf556d3ca1cae9f06b7648bbc9d --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/changing_resolution.md @@ -0,0 +1,68 @@ +# 变分辨率推理 + +## 概述 + +变分辨率推理是一种优化去噪过程的技术策略,通过在不同的去噪阶段采用不同的分辨率来提升计算效率并保持生成质量。该方法的核心思想是:在去噪过程的前期使用较低分辨率进行粗略去噪,在后期切换到正常分辨率进行精细化处理。 + +## 技术原理 + +### 分阶段去噪策略 + +变分辨率推理基于以下观察: + +- **前期去噪**:主要处理粗糙的噪声和整体结构,不需要过多的细节信息 +- **后期去噪**:专注于细节优化和高频信息恢复,需要完整的分辨率信息 + +### 分辨率切换机制 + +1. **低分辨率阶段**(前期) + - 将输入降采样到较低分辨率(如原尺寸的0.75) + - 执行初始的去噪步骤 + - 快速移除大部分噪声,建立基本结构 + +2. **正常分辨率阶段**(后期) + - 将第一步的去噪结果上采样回原始分辨率 + - 继续执行剩余的去噪步骤 + - 恢复细节信息,完成精细化处理 + + +### U型分辨率策略 + +如果在刚开始的去噪步,降低分辨率,可能会导致最后的生成的视频和正常推理的生成的视频,整体差异偏大。因此可以采用U型的分辨率策略,即最一开始保持几步的原始分辨率,再降低分辨率推理。 + +## 使用方式 + +变分辨率推理的config文件在[这里](https://github.com/ModelTC/LightX2V/tree/main/configs/changing_resolution) + +通过指定--config_json到具体的config文件,即可以测试变分辨率推理。 + +可以参考[这里](https://github.com/ModelTC/LightX2V/blob/main/scripts/changing_resolution)的脚本运行。 + + +### 举例1: +``` +{ + "infer_steps": 50, + "changing_resolution": true, + "resolution_rate": [0.75], + "changing_resolution_steps": [25] +} +``` + +表示总步数为50,1到25步的分辨率为0.75倍原始分辨率,26到最后一步的分辨率为原始分辨率。 + +### 举例2: +``` +{ + "infer_steps": 50, + "changing_resolution": true, + "resolution_rate": [1.0, 0.75], + "changing_resolution_steps": [10, 35] +} +``` + +表示总步数为50,1到10步的分辨率为原始分辨率,11到35步的分辨率为0.75倍原始分辨率,36到最后一步的分辨率为原始分辨率。 + +通常来说,假设`changing_resolution_steps`为[A, B, C],去噪的起始步为1,总步数为X,则推理会被分成4个部分。 + +分别是,(0, A], (A, B]. (B, C], (C, X],每个部分是左开右闭集合。 diff --git a/docs/ZH_CN/source/method_tutorials/offload.md b/docs/ZH_CN/source/method_tutorials/offload.md new file mode 100644 index 0000000000000000000000000000000000000000..2f30f47d8496efb08581bc1bf5fc80dab8c3657b --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/offload.md @@ -0,0 +1,177 @@ +# 参数卸载 + +## 📖 概述 + +Lightx2v 实现了先进的参数卸载机制,专为在有限硬件资源下处理大型模型推理而设计。该系统通过智能管理不同内存层次中的模型权重,提供了优秀的速度-内存平衡。 + +**核心特性:** +- **分block/phase卸载**:高效地以block/phase为单位管理模型权重,实现最优内存使用 + - **Block**:Transformer模型的基本计算单元,包含完整的Transformer层(自注意力、交叉注意力、前馈网络等),是较大的内存管理单位 + - **Phase**:Block内部的更细粒度计算阶段,包含单个计算组件(如自注意力、交叉注意力、前馈网络等),提供更精细的内存控制 +- **多级存储支持**:GPU → CPU → 磁盘层次结构,配合智能缓存 +- **异步操作**:使用 CUDA 流实现计算和数据传输的重叠 +- **磁盘/NVMe 序列化**:当内存不足时支持二级存储 + +## 🎯 卸载策略 + +### 策略一:GPU-CPU 分block/phase卸载 + +**适用场景**:GPU 显存不足但系统内存充足 + +**工作原理**:在 GPU 和 CPU 内存之间以block或phase为单位管理模型权重,利用 CUDA 流实现计算和数据传输的重叠。Block包含完整的Transformer层,而Phase则是Block内部的单个计算组件。 + +
+GPU-CPU block/phase卸载流程图 +
+ +
+Swap操作 +
+ +
+Swap思想 +
+ + +**Block vs Phase 说明**: +- **Block粒度**:较大的内存管理单位,包含完整的Transformer层(自注意力、交叉注意力、前馈网络等),适合内存充足的情况,减少管理开销 +- **Phase粒度**:更细粒度的内存管理,包含单个计算组件(如自注意力、交叉注意力、前馈网络等),适合内存受限的情况,提供更灵活的内存控制 + +**关键特性:** +- **异步传输**:使用三个不同优先级的CUDA流实现计算和传输的并行 + - 计算流(priority=-1):高优先级,负责当前计算 + - GPU加载流(priority=0):中优先级,负责从CPU到GPU的预取 + - CPU加载流(priority=0):中优先级,负责从GPU到CPU的卸载 +- **预取机制**:提前将下一个block/phase加载到 GPU +- **智能缓存**:在 CPU 内存中维护权重缓存 +- **流同步**:确保数据传输和计算的正确性 +- **Swap操作**:计算完成后轮换block/phase位置,实现连续计算 + + + + +### 策略二:磁盘-CPU-GPU 分block/phase卸载(延迟加载) + +**适用场景**:GPU 显存和系统内存都不足 + +**工作原理**:在策略一的基础上引入磁盘存储,实现三级存储层次(磁盘 → CPU → GPU)。CPU继续作为缓存池,但大小可配置,适用于CPU内存受限的设备。 + + +
+磁盘-CPU-GPU 分block/phase卸载工作流程 +
+ + +
+工作步骤 +
+ +**关键特性:** +- **延迟加载**:模型权重按需从磁盘加载,避免一次性加载全部模型 +- **智能缓存**:CPU内存缓冲区使用FIFO策略管理,可配置大小 +- **多线程预取**:使用多个磁盘工作线程并行加载 +- **异步传输**:使用CUDA流实现计算和数据传输的重叠 +- **Swap轮换**:通过位置轮换实现连续计算,避免重复加载/卸载 + +**工作步骤**: +- **磁盘存储**:模型权重按block存储在SSD/NVMe上,每个block一个.safetensors文件 +- **任务调度**:当需要某个block/phase时,优先级任务队列分配磁盘工作线程 +- **异步加载**:多个磁盘线程并行从磁盘读取权重文件到CPU内存缓冲区 +- **智能缓存**:CPU内存缓冲区使用FIFO策略管理缓存,可配置大小 +- **缓存命中**:如果权重已在缓存中,直接传输到GPU,无需磁盘读取 +- **预取传输**:缓存中的权重异步传输到GPU内存(使用GPU加载流) +- **计算执行**:GPU上的权重进行计算(使用计算流),同时后台继续预取下一个block/phase +- **Swap轮换**:计算完成后轮换block/phase位置,实现连续计算 +- **内存管理**:当CPU缓存满时,自动淘汰最早使用的权重block/phase + + + +## ⚙️ 配置参数 + +### GPU-CPU 卸载配置 + +```python +config = { + "cpu_offload": True, + "offload_ratio": 1.0, # 卸载比例(0.0-1.0) + "offload_granularity": "block", # 卸载粒度:"block"或"phase" + "lazy_load": False, # 禁用延迟加载 +} +``` + +### 磁盘-CPU-GPU 卸载配置 + +```python +config = { + "cpu_offload": True, + "lazy_load": True, # 启用延迟加载 + "offload_ratio": 1.0, # 卸载比例 + "offload_granularity": "phase", # 推荐使用phase粒度 + "num_disk_workers": 2, # 磁盘工作线程数 + "offload_to_disk": True, # 启用磁盘卸载 +} +``` + +**智能缓存关键参数:** +- `max_memory`:控制CPU缓存大小,影响缓存命中率和内存使用 +- `num_disk_workers`:控制磁盘加载线程数,影响预取速度 +- `offload_granularity`:控制缓存粒度(block或phase),影响缓存效率 + - `"block"`:以完整的Transformer层为单位进行缓存管理 + - `"phase"`:以单个计算组件为单位进行缓存管理 + +**非 DIT 模型组件(T5、CLIP、VAE)的卸载配置:** + +这些组件的卸载行为遵循以下规则: +- **默认行为**:如果没有单独指定,T5、CLIP、VAE 会跟随 `cpu_offload` 的设置 +- **独立配置**:可以为每个组件单独设置卸载策略,实现精细控制 + +**配置示例**: +```json +{ + "cpu_offload": true, // DIT 模型卸载开关 + "t5_cpu_offload": false, // T5 编码器独立设置 + "clip_cpu_offload": false, // CLIP 编码器独立设置 + "vae_cpu_offload": false // VAE 编码器独立设置 +} +``` + +在显存受限的设备上,建议采用渐进式卸载策略: + +1. **第一步**:仅开启 `cpu_offload`,关闭 `t5_cpu_offload`、`clip_cpu_offload`、`vae_cpu_offload` +2. **第二步**:如果显存仍不足,逐步开启 T5、CLIP、VAE 的 CPU 卸载 +3. **第三步**:如果显存仍然不够,考虑使用量化 + CPU 卸载或启用 `lazy_load` + +**实践经验**: +- **RTX 4090 24GB + 14B 模型**:通常只需开启 `cpu_offload`,其他组件卸载需要手动设为 `false`,同时使用 FP8 量化版本 +- **更小显存的 GPU**:需要组合使用量化、CPU 卸载和延迟加载 +- **量化方案**:建议参考[量化技术文档](../method_tutorials/quantization.md)选择合适的量化策略 + + +**配置文件参考**: +- **Wan2.1 系列模型**:参考 [offload 配置文件](https://github.com/ModelTC/lightx2v/tree/main/configs/offload) +- **Wan2.2 系列模型**:参考 [wan22 配置文件](https://github.com/ModelTC/lightx2v/tree/main/configs/wan22) 中以 `4090` 结尾的配置文件 + +## 🎯 使用建议 +- 🔄 GPU-CPU分block/phase卸载:适合GPU显存不足(RTX 3090/4090 24G)但系统内存(>64/128G)充足 + +- 💾 磁盘-CPU-GPU分block/phase卸载:适合GPU显存(RTX 3060/4090 8G)和系统内存(16/32G)都不足 + +- 🚫 无Offload:适合高端硬件配置,追求最佳性能 + + +## 🔍 故障排除 + +### 常见问题及解决方案 + +1. **磁盘I/O瓶颈** + - 解决方案:使用NVMe SSD,增加num_disk_workers + + +2. **内存缓冲区溢出** + - 解决方案:增加max_memory或减少num_disk_workers + +3. **加载超时** + - 解决方案:检查磁盘性能,优化文件系统 + + +**注意**:本卸载机制专为Lightx2v设计,充分利用了现代硬件的异步计算能力,能够显著降低大模型推理的硬件门槛。 diff --git a/docs/ZH_CN/source/method_tutorials/parallel.md b/docs/ZH_CN/source/method_tutorials/parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..a6ff0dfe2c2a65322a2434c5365c4143d4ff6764 --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/parallel.md @@ -0,0 +1,55 @@ +# 并行推理 + +LightX2V 支持分布式并行推理,能够利用多个 GPU 进行推理。DiT部分支持两种并行注意力机制:**Ulysses** 和 **Ring**,同时还支持 **Cfg 并行推理**。并行推理,显著降低推理耗时和减轻每个GPU的显存开销。 + +## DiT 并行配置 + +### 1. Ulysses 并行 + +**配置方式:** +```json + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +``` + +### 2. Ring 并行 + + +**配置方式:** +```json + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ring" + } +``` + +## Cfg 并行配置 + +**配置方式:** +```json + "parallel": { + "cfg_p_size": 2 + } +``` + +## 混合并行配置 + +**配置方式:** +```json + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses", + "cfg_p_size": 2 + } +``` + + +## 使用方式 + +并行推理的config文件在[这里](https://github.com/ModelTC/lightx2v/tree/main/configs/dist_infer) + +通过指定--config_json到具体的config文件,即可以测试并行推理 + +[这里](https://github.com/ModelTC/lightx2v/tree/main/scripts/dist_infer)有一些运行脚本供使用。 diff --git a/docs/ZH_CN/source/method_tutorials/quantization.md b/docs/ZH_CN/source/method_tutorials/quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..311367cc612f29744974dba4345d9d7c10ac10b9 --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/quantization.md @@ -0,0 +1,158 @@ +# 模型量化技术 + +## 📖 概述 + +LightX2V 支持对 DIT、T5 和 CLIP 模型进行量化推理,通过降低模型精度来减少显存占用并提升推理速度。 + +--- + +## 🔧 量化模式 + +| 量化模式 | 权重量化 | 激活量化 | 计算内核 | 适用硬件 | +|--------------|----------|----------|----------|----------| +| `fp8-vllm` | FP8 通道对称 | FP8 通道动态对称 | [VLLM](https://github.com/vllm-project/vllm) | H100/H200/H800, RTX 40系等 | +| `int8-vllm` | INT8 通道对称 | INT8 通道动态对称 | [VLLM](https://github.com/vllm-project/vllm) | A100/A800, RTX 30/40系等 | +| `fp8-sgl` | FP8 通道对称 | FP8 通道动态对称 | [SGL](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) | H100/H200/H800, RTX 40系等 | +| `int8-sgl` | INT8 通道对称 | INT8 通道动态对称 | [SGL](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) | A100/A800, RTX 30/40系等 | +| `fp8-q8f` | FP8 通道对称 | FP8 通道动态对称 | [Q8-Kernels](https://github.com/KONAKONA666/q8_kernels) | RTX 40系, L40S等 | +| `int8-q8f` | INT8 通道对称 | INT8 通道动态对称 | [Q8-Kernels](https://github.com/KONAKONA666/q8_kernels) | RTX 40系, L40S等 | +| `int8-torchao` | INT8 通道对称 | INT8 通道动态对称 | [TorchAO](https://github.com/pytorch/ao) | A100/A800, RTX 30/40系等 | +| `int4-g128-marlin` | INT4 分组对称 | FP16 | [Marlin](https://github.com/IST-DASLab/marlin) | H200/H800/A100/A800, RTX 30/40系等 | +| `fp8-b128-deepgemm` | FP8 分块对称 | FP8 分组对称 | [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) | H100/H200/H800, RTX 40系等| + +--- + +## 🔧 量化模型获取 + +### 方式一:下载预量化模型 + +从 LightX2V 模型仓库下载预量化的模型: + +**DIT 模型** + +从 [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) 下载预量化的 DIT 模型: + +```bash +# 下载 DIT FP8 量化模型 +huggingface-cli download lightx2v/Wan2.1-Distill-Models \ + --local-dir ./models \ + --include "wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors" +``` + +**Encoder 模型** + +从 [Encoders-LightX2V](https://huggingface.co/lightx2v/Encoders-Lightx2v) 下载预量化的 T5 和 CLIP 模型: + +```bash +# 下载 T5 FP8 量化模型 +huggingface-cli download lightx2v/Encoders-Lightx2v \ + --local-dir ./models \ + --include "models_t5_umt5-xxl-enc-fp8.pth" + +# 下载 CLIP FP8 量化模型 +huggingface-cli download lightx2v/Encoders-Lightx2v \ + --local-dir ./models \ + --include "models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8.pth" +``` + +### 方式二:自行量化模型 + +详细量化工具使用方法请参考:[模型转换文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) + +--- + +## 🚀 量化模型使用 + +### DIT 模型量化 + +#### 支持的量化模式 + +DIT 量化模式(`dit_quant_scheme`)支持:`fp8-vllm`、`int8-vllm`、`fp8-sgl`、`int8-sgl`、`fp8-q8f`、`int8-q8f`、`int8-torchao`、`int4-g128-marlin`、`fp8-b128-deepgemm` + +#### 配置示例 + +```json +{ + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "dit_quantized_ckpt": "/path/to/dit_quantized_model" // 可选 +} +``` + +> 💡 **提示**:当运行脚本的 `model_path` 中只有一个 DIT 模型时,`dit_quantized_ckpt` 可以不用单独指定。 + +### T5 模型量化 + +#### 支持的量化模式 + +T5 量化模式(`t5_quant_scheme`)支持:`int8-vllm`、`fp8-sgl`、`int8-q8f`、`fp8-q8f`、`int8-torchao` + +#### 配置示例 + +```json +{ + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl", + "t5_quantized_ckpt": "/path/to/t5_quantized_model" // 可选 +} +``` + +> 💡 **提示**:当运行脚本指定的 `model_path` 中存在 T5 量化模型(如 `models_t5_umt5-xxl-enc-fp8.pth` 或 `models_t5_umt5-xxl-enc-int8.pth`)时,`t5_quantized_ckpt` 可以不用单独指定。 + +### CLIP 模型量化 + +#### 支持的量化模式 + +CLIP 量化模式(`clip_quant_scheme`)支持:`int8-vllm`、`fp8-sgl`、`int8-q8f`、`fp8-q8f`、`int8-torchao` + +#### 配置示例 + +```json +{ + "clip_quantized": true, + "clip_quant_scheme": "fp8-sgl", + "clip_quantized_ckpt": "/path/to/clip_quantized_model" // 可选 +} +``` + +> 💡 **提示**:当运行脚本指定的 `model_path` 中存在 CLIP 量化模型(如 `models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8.pth` 或 `models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8.pth`)时,`clip_quantized_ckpt` 可以不用单独指定。 + +### 性能优化策略 + +如果显存不够,可以结合参数卸载来进一步减少显存占用,参考[参数卸载文档](../method_tutorials/offload.md): + +> - **Wan2.1 配置**:参考 [offload 配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/offload) +> - **Wan2.2 配置**:参考 [wan22 配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/wan22) 中以 `4090` 结尾的配置 + +--- + +## 📚 相关资源 + +### 配置文件示例 +- [INT8 量化配置](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v.json) +- [Q8F 量化配置](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v_q8f.json) +- [TorchAO 量化配置](https://github.com/ModelTC/LightX2V/blob/main/configs/quantization/wan_i2v_torchao.json) + +### 运行脚本 +- [量化推理脚本](https://github.com/ModelTC/LightX2V/tree/main/scripts/quantization) + +### 工具文档 +- [量化工具文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) +- [LightCompress 量化文档](https://github.com/ModelTC/llmc/blob/main/docs/zh_cn/source/backend/lightx2v.md) + +### 模型仓库 +- [Wan2.1-LightX2V 量化模型](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- [Wan2.2-LightX2V 量化模型](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) +- [Encoders 量化模型](https://huggingface.co/lightx2v/Encoders-Lightx2v) + +--- + +通过本文档,您应该能够: + +✅ 理解 LightX2V 支持的量化方案 +✅ 根据硬件选择合适的量化策略 +✅ 正确配置量化参数 +✅ 获取和使用量化模型 +✅ 优化推理性能和显存使用 + +如有其他问题,欢迎在 [GitHub Issues](https://github.com/ModelTC/LightX2V/issues) 中提问。 diff --git a/docs/ZH_CN/source/method_tutorials/step_distill.md b/docs/ZH_CN/source/method_tutorials/step_distill.md new file mode 100644 index 0000000000000000000000000000000000000000..2ed48c810dedee47a5b7a657e155ce0f6b4e6a34 --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/step_distill.md @@ -0,0 +1,183 @@ +# 步数蒸馏 + +步数蒸馏是 LightX2V 中的一项重要优化技术,通过训练蒸馏模型将推理步数从原始的 40-50 步大幅减少到 **4 步**,在保持视频质量的同时显著提升推理速度。LightX2V 在实现步数蒸馏的同时也加入了 CFG 蒸馏,进一步提升推理速度。 + +## 🔍 技术原理 + +### DMD 蒸馏 + +步数蒸馏的核心技术是 [DMD 蒸馏](https://arxiv.org/abs/2311.18828)。DMD 蒸馏的框架如下图所示: + +
+DMD 蒸馏框架 +
+ +DMD蒸馏的核心思想是最小化蒸馏模型与原始模型输出分布的 KL 散度: + +$$ +\begin{aligned} +D_{KL}\left(p_{\text{fake}} \; \| \; p_{\text{real}} \right) &= \mathbb{E}{x\sim p\text{fake}}\left(\log\left(\frac{p_\text{fake}(x)}{p_\text{real}(x)}\right)\right)\\ +&= \mathbb{E}{\substack{ +z \sim \mathcal{N}(0; \mathbf{I}) \\ +x = G_\theta(z) +}}-\big(\log~p_\text{real}(x) - \log~p_\text{fake}(x)\big). +\end{aligned} +$$ + +由于直接计算概率密度几乎是不可能的,因此 DMD 蒸馏改为计算这个 KL 散度的梯度: + +$$ +\begin{aligned} +\nabla_\theta D_{KL} +&= \mathbb{E}{\substack{ +z \sim \mathcal{N}(0; \mathbf{I}) \\ +x = G_\theta(z) +} } \Big[- +\big( +s_\text{real}(x) - s_\text{fake}(x)\big) +\hspace{.5mm} \frac{dG}{d\theta} +\Big], +\end{aligned} +$$ + +其中 $s_\text{real}(x) =\nabla_{x} \text{log}~p_\text{real}(x)$ 和 $s_\text{fake}(x) =\nabla_{x} \text{log}~p_\text{fake}(x)$ 为得分函数。得分函数可以由模型进行计算。因此,DMD 蒸馏一共维护三个模型: + +- `real_score`,计算真实分布的得分;由于真实分布是固定的,因此 DMD 蒸馏使用固定权重的原始模型作为其得分函数; +- `fake_score`,计算伪分布的得分;由于伪分布是不断更新的,因此 DMD 蒸馏使用原始模型对其初始化,并对其进行微调以学习生成器的输出分布; +- `generator`,学生模型,通过计算 `real_score` 与 `fake_score` KL 散度的梯度指导其优化方向。 + +> 参考文献: +> 1. [DMD (One-step Diffusion with Distribution Matching Distillation)](https://arxiv.org/abs/2311.18828) +> 2. [DMD2 (Improved Distribution Matching Distillation for Fast Image Synthesis)](https://arxiv.org/abs/2405.14867) + +### Self-Forcing + +DMD 蒸馏技术是针对图像生成的。Lightx2v 中的步数蒸馏基于 [Self-Forcing](https://github.com/guandeh17/Self-Forcing) 技术实现。Self-Forcing 的整体实现与 DMD 类似,但是仿照 DMD2,去掉了它的回归损失,而是使用了 ODE 初始化。此外,Self-Forcing 针对视频生成任务加入了一个重要优化: + +目前基于 DMD 蒸馏的方法难以一步生成视频。Self-Forcing 每次选择一个时间步进行优化,generator 仅仅在这一步计算梯度。这种方法使得 Self-Forcing 的训练速度显著提升,并且提升了中间时间步的去噪质量,其效果亦有所提升。 + +> 参考文献: +> 1. [Self-Forcing (Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion)](https://arxiv.org/abs/2506.08009) + +### Lightx2v + +Self-Forcing 针对 1.3B 的自回归模型进行步数蒸馏、CFG蒸馏。LightX2V 在其基础上,进行了一系列扩展: + +1. **更大的模型**:支持 14B 模型的步数蒸馏训练; +2. **更多的模型**:支持标准的双向模型,以及 I2V 模型的步数蒸馏训练; +3. **更好的效果**:Lightx2v 使用了约 50,000 条数据的高质量 prompt 进行训练; + +具体实现可参考 [Self-Forcing-Plus](https://github.com/GoatWu/Self-Forcing-Plus)。 + +## 🎯 技术特性 + +- **推理加速**:推理步数从 40-50 步减少到 4 步且无需 CFG,速度提升约 **20-24x** +- **质量保持**:通过蒸馏技术保持原有的视频生成质量 +- **兼容性强**:支持 T2V 和 I2V 任务 +- **使用灵活**:支持加载完整步数蒸馏模型,或者在原生模型的基础上加载步数蒸馏LoRA;支持与 int8/fp8 模型量化相兼容 + +## 🛠️ 配置文件说明 + +### 基础配置文件 + +在 [configs/distill/](https://github.com/ModelTC/lightx2v/tree/main/configs/distill) 目录下提供了多种配置选项: + +| 配置文件 | 用途 | 模型地址 | +|----------|------|------------| +| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | 加载 T2V 4步蒸馏完整模型 | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) | +| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | 加载 I2V 4步蒸馏完整模型 | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) | +| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | 加载 Wan-T2V 模型和步数蒸馏 LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) | +| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | 加载 Wan-I2V 模型和步数蒸馏 LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) | + +### 关键配置参数 + +- 由于 DMD 蒸馏仅训练几个固定的时间步,因此我们推荐使用 `LCM Scheduler` 进行推理。[WanStepDistillScheduler](https://github.com/ModelTC/LightX2V/blob/main/lightx2v/models/schedulers/wan/step_distill/scheduler.py) 中,已经固定使用 `LCM Scheduler`,无需用户进行配置。 +- `infer_steps`, `denoising_step_list` 和 `sample_shift` 设置为与训练时相匹配的参数,一般不建议用户修改。 +- `enable_cfg` 一定设置为 `false`(等价于设置 `sample_guide_scale = 1`),否则可能出现视频完全模糊的现象。 +- `lora_configs` 支持融合不同强度的多个 lora。当 `lora_configs` 不为空时,默认加载原始的 `Wan2.1` 模型。因此使用 `lora_config` 并且想要使用步数蒸馏时,请设置步数蒸馏lora的路径与强度。 + +```json +{ + "infer_steps": 4, // 推理步数 + "denoising_step_list": [1000, 750, 500, 250], // 去噪时间步列表 + "sample_shift": 5, // 调度器 timestep shift + "enable_cfg": false, // 关闭CFG以提升速度 + "lora_configs": [ // LoRA权重路径(可选) + { + "path": "path/to/distill_lora.safetensors", + "strength": 1.0 + } + ] +} +``` + +## 📜 使用方法 + +### 模型准备 + +**完整模型:** +将下载好的模型(`distill_model.pt` 或者 `distill_model.safetensors`)放到 Wan 模型根目录的 `distill_models/` 文件夹下即可 + +- 对于 T2V:`Wan2.1-T2V-14B/distill_models/` +- 对于 I2V-480P:`Wan2.1-I2V-14B-480P/distill_models/` + +**LoRA:** + +1. 将下载好的 LoRA 放到任意位置 +2. 修改配置文件中的 `lora_path` 参数为 LoRA 存放路径即可 + +### 推理脚本 + +**T2V 完整模型:** + +```bash +bash scripts/wan/run_wan_t2v_distill_4step_cfg.sh +``` + +**I2V 完整模型:** + +```bash +bash scripts/wan/run_wan_i2v_distill_4step_cfg.sh +``` + +### 步数蒸馏 LoRA 推理脚本 + +**T2V LoRA:** + +```bash +bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh +``` + +**I2V LoRA:** + +```bash +bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh +``` + +## 🔧 服务化部署 + +### 启动蒸馏模型服务 + +对 [scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh) 中的启动命令进行修改: + +```bash +python -m lightx2v.api_server \ + --model_cls wan2.1_distill \ + --task t2v \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg.json \ + --port 8000 \ + --nproc_per_node 1 +``` + +运行服务启动脚本: + +```bash +scripts/server/start_server.sh +``` + +更多详细信息见[服务化部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_service.html)。 + +### 在 Gradio 界面中使用 + +见 [Gradio 文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html) diff --git a/docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md b/docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md new file mode 100644 index 0000000000000000000000000000000000000000..7df2b6b4edc0cd16d97a57a3d222639c3ca3b2ae --- /dev/null +++ b/docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md @@ -0,0 +1,246 @@ +# 视频帧插值 (VFI) + +> **重要说明**: 视频帧插值功能通过配置文件启用,而不是通过命令行参数。请在配置 JSON 文件中添加 `video_frame_interpolation` 配置块来启用此功能。 + +## 概述 + +视频帧插值(VFI)是一种在现有帧之间生成中间帧的技术,用于提高帧率并创建更流畅的视频播放效果。LightX2V 集成了 RIFE(Real-Time Intermediate Flow Estimation)模型,提供高质量的帧插值能力。 + +## 什么是 RIFE? + +RIFE 是一种最先进的视频帧插值方法,使用光流估计来生成中间帧。它能够有效地: + +- 提高视频帧率(例如,从 16 FPS 提升到 32 FPS) +- 创建平滑的运动过渡 +- 保持高视觉质量,最少伪影 +- 实时处理视频 + +## 安装和设置 + +### 下载 RIFE 模型 + +首先,使用提供的脚本下载 RIFE 模型权重: + +```bash +python tools/download_rife.py <目标目录> +``` + +例如,下载到指定位置: +```bash +python tools/download_rife.py /path/to/rife/train_log +``` + +此脚本将: +- 从 HuggingFace 下载 RIFEv4.26 模型 +- 提取并将模型文件放置在正确的目录中 +- 清理临时文件 + +## 使用方法 + +### 配置文件设置 + +视频帧插值功能通过配置文件启用。在你的配置 JSON 文件中添加 `video_frame_interpolation` 配置块: + +```json +{ + "infer_steps": 50, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "fps": 16, + "video_frame_interpolation": { + "algo": "rife", + "target_fps": 32, + "model_path": "/path/to/rife/train_log" + } +} +``` + +### 命令行使用 + +使用包含 VFI 配置的配置文件运行推理: + +```bash +python lightx2v/infer.py \ + --model_cls wan2.1 \ + --task t2v \ + --model_path /path/to/model \ + --config_json ./configs/video_frame_interpolation/wan_t2v.json \ + --prompt "美丽的海上日落" \ + --save_result_path ./output.mp4 +``` + +### 配置参数说明 + +在 `video_frame_interpolation` 配置块中: + +- `algo`: 帧插值算法,目前支持 "rife" +- `target_fps`: 输出视频的目标帧率 +- `model_path`: RIFE 模型路径,通常为 "train_log" + +其他相关配置: +- `fps`: 源视频帧率(默认 16) + +### 配置优先级 + +系统会自动处理视频帧率配置,优先级如下: +1. `video_frame_interpolation.target_fps` - 如果启用视频帧插值,使用此帧率作为输出帧率 +2. `fps`(默认 16)- 如果未启用视频帧插值,使用此帧率;同时总是用作源帧率 + + +## 工作原理 + +### 帧插值过程 + +1. **源视频生成**: 基础模型以源 FPS 生成视频帧 +2. **帧分析**: RIFE 分析相邻帧以估计光流 +3. **中间帧生成**: 在现有帧之间生成新帧 +4. **时序平滑**: 插值帧创建平滑的运动过渡 + +### 技术细节 + +- **输入格式**: ComfyUI 图像张量 [N, H, W, C],范围 [0, 1] +- **输出格式**: 插值后的 ComfyUI 图像张量 [M, H, W, C],范围 [0, 1] +- **处理**: 自动填充和分辨率处理 +- **内存优化**: 高效的 GPU 内存管理 + +## 示例配置 + +### 基础帧率翻倍 + +创建配置文件 `wan_t2v_vfi_32fps.json`: + +```json +{ + "infer_steps": 50, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "seed": 42, + "sample_guide_scale": 6, + "enable_cfg": true, + "fps": 16, + "video_frame_interpolation": { + "algo": "rife", + "target_fps": 32, + "model_path": "/path/to/rife/train_log" + } +} +``` + +运行命令: +```bash +python lightx2v/infer.py \ + --model_cls wan2.1 \ + --task t2v \ + --model_path ./models/wan2.1 \ + --config_json ./wan_t2v_vfi_32fps.json \ + --prompt "一只小猫在花园里玩耍" \ + --save_result_path ./output_32fps.mp4 +``` + +### 更高帧率增强 + +创建配置文件 `wan_i2v_vfi_60fps.json`: + +```json +{ + "infer_steps": 30, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "seed": 42, + "sample_guide_scale": 6, + "enable_cfg": true, + "fps": 16, + "video_frame_interpolation": { + "algo": "rife", + "target_fps": 60, + "model_path": "/path/to/rife/train_log" + } +} +``` + +运行命令: +```bash +python lightx2v/infer.py \ + --model_cls wan2.1 \ + --task i2v \ + --model_path ./models/wan2.1 \ + --config_json ./wan_i2v_vfi_60fps.json \ + --image_path ./input.jpg \ + --prompt "平滑的相机运动" \ + --save_result_path ./output_60fps.mp4 +``` + +## 性能考虑 + +### 内存使用 + +- RIFE 处理需要额外的 GPU 内存 +- 内存使用量与视频分辨率和长度成正比 +- 对于较长的视频,考虑使用较低的分辨率 + +### 处理时间 + +- 帧插值会增加处理开销 +- 更高的目标帧率需要更多计算 +- 处理时间大致与插值帧数成正比 + +### 质量与速度权衡 + +- 更高的插值比率可能引入伪影 +- 最佳范围:2x 到 4x 帧率增加 +- 对于极端插值(>4x),考虑多次处理 + +## 最佳实践 + +### 最佳使用场景 + +- **运动密集视频**: 从帧插值中受益最多 +- **相机运动**: 更平滑的平移和缩放 +- **动作序列**: 减少运动模糊感知 +- **慢动作效果**: 创建流畅的慢动作视频 + +### 推荐设置 + +- **源 FPS**: 16-24 FPS(基础模型生成) +- **目标 FPS**: 32-60 FPS(2x 到 4x 增加) +- **分辨率**: 最高 720p 以获得最佳性能 + +### 故障排除 + +#### 常见问题 + +1. **内存不足**: 减少视频分辨率或目标 FPS +2. **输出中有伪影**: 降低插值比率 +3. **处理缓慢**: 检查 GPU 内存并考虑使用 CPU 卸载 + +#### 解决方案 + +通过修改配置文件来解决问题: + +```json +{ + // 内存问题解决:使用较低分辨率 + "target_height": 480, + "target_width": 832, + + // 质量问题解决:使用适中的插值 + "video_frame_interpolation": { + "target_fps": 24 // 而不是 60 + }, + + // 性能问题解决:启用卸载 + "cpu_offload": true +} +``` + +## 技术实现 + +LightX2V 中的 RIFE 集成包括: + +- **RIFEWrapper**: 与 ComfyUI 兼容的 RIFE 模型包装器 +- **自动模型加载**: 与推理管道的无缝集成 +- **内存优化**: 高效的张量管理和 GPU 内存使用 +- **质量保持**: 在添加帧的同时保持原始视频质量 diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8bee25d4f8f8c09f4192273cb99938badadd429e --- /dev/null +++ b/examples/README.md @@ -0,0 +1,284 @@ +# LightX2V Usage Examples + +This document introduces how to use LightX2V for video generation, including basic usage and advanced configurations. + +## 📋 Table of Contents + +- [Environment Setup](#environment-setup) +- [Basic Usage Examples](#basic-usage-examples) +- [Model Path Configuration](#model-path-configuration) +- [Creating Generator](#creating-generator) +- [Advanced Configurations](#advanced-configurations) + - [Parameter Offloading](#parameter-offloading) + - [Model Quantization](#model-quantization) + - [Parallel Inference](#parallel-inference) + - [Feature Caching](#feature-caching) + - [Lightweight VAE](#lightweight-vae) + +## 🔧 Environment Setup + +Please refer to the main project's [Quick Start Guide](../docs/EN/source/getting_started/quickstart.md) for environment setup. + +## 🚀 Basic Usage Examples + +A minimal code example can be found in `examples/wan_t2v.py`: + +```python +from lightx2v import LightX2VPipeline + +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.1-T2V-14B", + model_cls="wan2.1", + task="t2v", +) + +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=50, + height=480, + width=832, + num_frames=81, + guidance_scale=5.0, + sample_shift=5.0, +) + +seed = 42 +prompt = "Your prompt here" +negative_prompt = "" +save_result_path="/path/to/save_results/output.mp4" + +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) +``` + +## 📁 Model Path Configuration + +### Basic Configuration + +Pass the model path to `LightX2VPipeline`: + +```python +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.2-I2V-A14B", + model_cls="wan2.2_moe", # For wan2.1, use "wan2.1" + task="i2v", +) +``` + +### Specifying Multiple Model Weight Versions + +When there are multiple versions of bf16 precision DIT model safetensors files in the `model_path` directory, you need to use the following parameters to specify which weights to use: + +- **`dit_original_ckpt`**: Used to specify the original DIT weight path for models like wan2.1 and hunyuan15 +- **`low_noise_original_ckpt`**: Used to specify the low noise branch weight path for wan2.2 models +- **`high_noise_original_ckpt`**: Used to specify the high noise branch weight path for wan2.2 models + +**Usage Example:** + +```python +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.2-I2V-A14B", + model_cls="wan2.2_moe", + task="i2v", + low_noise_original_ckpt="/path/to/low_noise_model.safetensors", + high_noise_original_ckpt="/path/to/high_noise_model.safetensors", +) +``` + +## 🎛️ Creating Generator + +### Loading from Configuration File + +The generator can be loaded directly from a JSON configuration file. Configuration files are located in the `configs` directory: + +```python +pipe.create_generator(config_json="../configs/wan/wan_t2v.json") +``` + +### Creating Generator Manually + +You can also create the generator manually and configure multiple parameters: + +```python +pipe.create_generator( + attn_mode="flash_attn2", # Options: flash_attn2, flash_attn3, sage_attn2, sage_attn3 (B-architecture GPUs) + infer_steps=50, # Number of inference steps + num_frames=81, # Number of video frames + height=480, # Video height + width=832, # Video width + guidance_scale=5.0, # CFG guidance strength (CFG disabled when =1) + sample_shift=5.0, # Sample shift + fps=16, # Frame rate + aspect_ratio="16:9", # Aspect ratio + boundary=0.900, # Boundary value + boundary_step_index=2, # Boundary step index + denoising_step_list=[1000, 750, 500, 250], # Denoising step list +) +``` + +**Parameter Description:** +- **Resolution**: Specified via `height` and `width` +- **CFG**: Specified via `guidance_scale` (set to 1 to disable CFG) +- **FPS**: Specified via `fps` +- **Video Length**: Specified via `num_frames` +- **Inference Steps**: Specified via `infer_steps` +- **Sample Shift**: Specified via `sample_shift` +- **Attention Mode**: Specified via `attn_mode`, options include `flash_attn2`, `flash_attn3`, `sage_attn2`, `sage_attn3` (for B-architecture GPUs) + +## ⚙️ Advanced Configurations + +**⚠️ Important: When manually creating a generator, you can configure some advanced options. All advanced configurations must be specified before `create_generator()`, otherwise they will not take effect!** + +### Parameter Offloading + +Significantly reduces memory usage with almost no impact on inference speed. Suitable for RTX 30/40/50 series GPUs. + +```python +pipe.enable_offload( + cpu_offload=True, # Enable CPU offloading + offload_granularity="block", # Offload granularity: "block" or "phase" + text_encoder_offload=False, # Whether to offload text encoder + image_encoder_offload=False, # Whether to offload image encoder + vae_offload=False, # Whether to offload VAE +) +``` + +**Notes:** +- For Wan models, `offload_granularity` supports both `"block"` and `"phase"` +- For HunyuanVideo-1.5, only `"block"` is currently supported + +### Model Quantization + +Quantization can significantly reduce memory usage and accelerate inference. + +```python +pipe.enable_quantize( + dit_quantized=False, # Whether to use quantized DIT model + text_encoder_quantized=False, # Whether to use quantized text encoder + image_encoder_quantized=False, # Whether to use quantized image encoder + dit_quantized_ckpt=None, # DIT quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files) + low_noise_quantized_ckpt=None, # Wan2.2 low noise branch quantized weight path + high_noise_quantized_ckpt=None, # Wan2.2 high noise branch quantized weight path + text_encoder_quantized_ckpt=None, # Text encoder quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files) + image_encoder_quantized_ckpt=None, # Image encoder quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files) + quant_scheme="fp8-sgl", # Quantization scheme +) +``` + +**Parameter Description:** +- **`dit_quantized_ckpt`**: When the `model_path` directory doesn't contain quantized weights, or has multiple weight files, you need to specify the specific DIT quantized weight path +- **`text_encoder_quantized_ckpt`** and **`image_encoder_quantized_ckpt`**: Similarly, used to specify encoder quantized weight paths +- **`low_noise_quantized_ckpt`** and **`high_noise_quantized_ckpt`**: Used to specify dual-branch quantized weights for Wan2.2 models + +**Quantized Model Downloads:** + +- **Wan-2.1 Quantized Models**: Download from [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) +- **Wan-2.2 Quantized Models**: Download from [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) +- **HunyuanVideo-1.5 Quantized Models**: Download from [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models) + - `hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors` is the quantized weight for the text encoder + +**Usage Examples:** + +```python +# HunyuanVideo-1.5 Quantization Example +pipe.enable_quantize( + quant_scheme='fp8-sgl', + dit_quantized=True, + dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors", + text_encoder_quantized=True, + image_encoder_quantized=False, + text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors", +) + +# Wan2.1 Quantization Example +pipe.enable_quantize( + dit_quantized=True, + dit_quantized_ckpt="/path/to/wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors", +) + +# Wan2.2 Quantization Example +pipe.enable_quantize( + dit_quantized=True, + low_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors", + high_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step_1030.safetensors", +) +``` + +**Quantization Scheme Reference:** For detailed information, please refer to the [Quantization Documentation](../docs/EN/source/method_tutorials/quantization.md) + +### Parallel Inference + +Supports multi-GPU parallel inference. Requires running with `torchrun`: + +```python +pipe.enable_parallel( + seq_p_size=4, # Sequence parallel size + seq_p_attn_type="ulysses", # Sequence parallel attention type +) +``` + +**Running Method:** +```bash +torchrun --nproc_per_node=4 your_script.py +``` + +### Feature Caching + +You can specify the cache method as Mag or Tea, using MagCache and TeaCache methods: + +```python +pipe.enable_cache( + cache_method='Tea', # Cache method: 'Tea' or 'Mag' + coefficients=[-3.08907507e+04, 1.67786188e+04, -3.19178643e+03, + 2.60740519e+02, -8.19205881e+00, 1.07913775e-01], # Coefficients + teacache_thresh=0.15, # TeaCache threshold +) +``` + +**Coefficient Reference:** Refer to configuration files in `configs/caching` or `configs/hunyuan_video_15/cache` directories + +### Lightweight VAE + +Using lightweight VAE can accelerate decoding and reduce memory usage. + +```python +pipe.enable_lightvae( + use_lightvae=False, # Whether to use LightVAE + use_tae=False, # Whether to use LightTAE + vae_path=None, # Path to LightVAE + tae_path=None, # Path to LightTAE +) +``` + +**Support Status:** +- **LightVAE**: Currently only supports wan2.1, wan2.2 moe +- **LightTAE**: Currently only supports wan2.1, wan2.2-ti2v, wan2.2 moe, HunyuanVideo-1.5 + +**Model Downloads:** Lightweight VAE models can be downloaded from [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) + +- LightVAE for Wan-2.1: [lightvaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lightvaew2_1.safetensors) +- LightTAE for Wan-2.1: [lighttaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_1.safetensors) +- LightTAE for Wan-2.2-ti2v: [lighttaew2_2.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_2.safetensors) +- LightTAE for HunyuanVideo-1.5: [lighttaehy1_5.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors) + +**Usage Example:** + +```python +# Using LightTAE for HunyuanVideo-1.5 +pipe.enable_lightvae( + use_tae=True, + tae_path="/path/to/lighttaehy1_5.safetensors", + use_lightvae=False, + vae_path=None +) +``` + +## 📚 More Resources + +- [Full Documentation](https://lightx2v-en.readthedocs.io/en/latest/) +- [GitHub Repository](https://github.com/ModelTC/LightX2V) +- [HuggingFace Model Hub](https://huggingface.co/lightx2v) diff --git a/examples/README_zh.md b/examples/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..124b527f98f59bf8af4d897c354cab6e460d3fda --- /dev/null +++ b/examples/README_zh.md @@ -0,0 +1,284 @@ +# LightX2V 使用示例 + +本文档介绍如何使用 LightX2V 进行视频生成,包括基础使用和进阶配置。 + +## 📋 目录 + +- [环境安装](#环境安装) +- [基础运行示例](#基础运行示例) +- [模型路径配置](#模型路径配置) +- [创建生成器](#创建生成器) +- [进阶配置](#进阶配置) + - [参数卸载 (Offload)](#参数卸载-offload) + - [模型量化 (Quantization)](#模型量化-quantization) + - [并行推理 (Parallel Inference)](#并行推理-parallel-inference) + - [特征缓存 (Cache)](#特征缓存-cache) + - [轻量 VAE (Light VAE)](#轻量-vae-light-vae) + +## 🔧 环境安装 + +请参考主项目的[快速入门文档](../docs/ZH_CN/source/getting_started/quickstart.md)进行环境安装。 + +## 🚀 基础运行示例 + +最小化代码示例可参考 `examples/wan_t2v.py`: + +```python +from lightx2v import LightX2VPipeline + +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.1-T2V-14B", + model_cls="wan2.1", + task="t2v", +) + +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=50, + height=480, + width=832, + num_frames=81, + guidance_scale=5.0, + sample_shift=5.0, +) + +seed = 42 +prompt = "Your prompt here" +negative_prompt = "" +save_result_path="/path/to/save_results/output.mp4" + +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) +``` + +## 📁 模型路径配置 + +### 基础配置 + +将模型路径传入 `LightX2VPipeline`: + +```python +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.2-I2V-A14B", + model_cls="wan2.2_moe", # 对于 wan2.1,使用 "wan2.1" + task="i2v", +) +``` + +### 多版本模型权重指定 + +当 `model_path` 目录下存在多个不同版本的 bf16 精度 DIT 模型 safetensors 文件时,需要使用以下参数指定具体使用哪个权重: + +- **`dit_original_ckpt`**: 用于指定 wan2.1 和 hunyuan15 等模型的原始 DIT 权重路径 +- **`low_noise_original_ckpt`**: 用于指定 wan2.2 模型的低噪声分支权重路径 +- **`high_noise_original_ckpt`**: 用于指定 wan2.2 模型的高噪声分支权重路径 + +**使用示例:** + +```python +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.2-I2V-A14B", + model_cls="wan2.2_moe", + task="i2v", + low_noise_original_ckpt="/path/to/low_noise_model.safetensors", + high_noise_original_ckpt="/path/to/high_noise_model.safetensors", +) +``` + +## 🎛️ 创建生成器 + +### 从配置文件加载 + +生成器可以从 JSON 配置文件直接加载,配置文件位于 `configs` 目录: + +```python +pipe.create_generator(config_json="../configs/wan/wan_t2v.json") +``` + +### 手动创建生成器 + +也可以手动创建生成器,并配置多个参数: + +```python +pipe.create_generator( + attn_mode="flash_attn2", # 可选: flash_attn2, flash_attn3, sage_attn2, sage_attn3 (B架构显卡适用) + infer_steps=50, # 推理步数 + num_frames=81, # 视频帧数 + height=480, # 视频高度 + width=832, # 视频宽度 + guidance_scale=5.0, # CFG引导强度 (=1时弃用CFG) + sample_shift=5.0, # 采样偏移 + fps=16, # 帧率 + aspect_ratio="16:9", # 宽高比 + boundary=0.900, # 边界值 + boundary_step_index=2, # 边界步索引 + denoising_step_list=[1000, 750, 500, 250], # 去噪步列表 +) +``` + +**参数说明:** +- **分辨率**: 通过 `height` 和 `width` 指定 +- **CFG**: 通过 `guidance_scale` 指定(设置为 1 时禁用 CFG) +- **FPS**: 通过 `fps` 指定帧率 +- **视频长度**: 通过 `num_frames` 指定帧数 +- **推理步数**: 通过 `infer_steps` 指定 +- **采样偏移**: 通过 `sample_shift` 指定 +- **注意力模式**: 通过 `attn_mode` 指定,可选 `flash_attn2`, `flash_attn3`, `sage_attn2`, `sage_attn3`(B架构显卡适用) + +## ⚙️ 进阶配置 + +**⚠️ 重要提示:手动创建生成器时,可以配置一些进阶选项,所有进阶配置必须在 `create_generator()` 之前指定,否则会失效!** + +### 参数卸载 (Offload) + +显著降低显存占用,几乎不影响推理速度,适用于 RTX 30/40/50 系列显卡。 + +```python +pipe.enable_offload( + cpu_offload=True, # 启用 CPU 卸载 + offload_granularity="block", # 卸载粒度: "block" 或 "phase" + text_encoder_offload=False, # 文本编码器是否卸载 + image_encoder_offload=False, # 图像编码器是否卸载 + vae_offload=False, # VAE 是否卸载 +) +``` + +**说明:** +- 对于 Wan 模型,`offload_granularity` 支持 `"block"` 和 `"phase"` +- 对于 HunyuanVideo-1.5,目前只支持 `"block"` + +### 模型量化 (Quantization) + +量化可以显著降低显存占用并加速推理。 + +```python +pipe.enable_quantize( + dit_quantized=False, # 是否使用量化的 DIT 模型 + text_encoder_quantized=False, # 是否使用量化的文本编码器 + image_encoder_quantized=False, # 是否使用量化的图像编码器 + dit_quantized_ckpt=None, # DIT 量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定) + low_noise_quantized_ckpt=None, # Wan2.2 低噪声分支量化权重路径 + high_noise_quantized_ckpt=None, # Wan2.2 高噪声分支量化权重路径 + text_encoder_quantized_ckpt=None, # 文本编码器量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定) + image_encoder_quantized_ckpt=None, # 图像编码器量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定) + quant_scheme="fp8-sgl", # 量化方案 +) +``` + +**参数说明:** +- **`dit_quantized_ckpt`**: 当 `model_path` 目录下没有量化权重,或存在多个权重文件时,需要指定具体的 DIT 量化权重路径 +- **`text_encoder_quantized_ckpt`** 和 **`image_encoder_quantized_ckpt`**: 类似地,用于指定编码器的量化权重路径 +- **`low_noise_quantized_ckpt`** 和 **`high_noise_quantized_ckpt`**: 用于指定 Wan2.2 模型的双分支量化权重 + +**量化模型下载:** + +- **Wan-2.1 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) 下载 +- **Wan-2.2 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) 下载 +- **HunyuanVideo-1.5 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models) 下载 + - `hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors` 是文本编码器的量化权重 + +**使用示例:** + +```python +# HunyuanVideo-1.5 量化示例 +pipe.enable_quantize( + quant_scheme='fp8-sgl', + dit_quantized=True, + dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors", + text_encoder_quantized=True, + image_encoder_quantized=False, + text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors", +) + +# Wan2.1 量化示例 +pipe.enable_quantize( + dit_quantized=True, + dit_quantized_ckpt="/path/to/wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors", +) + +# Wan2.2 量化示例 +pipe.enable_quantize( + dit_quantized=True, + low_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors", + high_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step_1030.safetensors", +) +``` + +**量化方案参考:** 详细说明请参考 [量化文档](../docs/ZH_CN/source/method_tutorials/quantization.md) + +### 并行推理 (Parallel Inference) + +支持多 GPU 并行推理,需要使用 `torchrun` 运行: + +```python +pipe.enable_parallel( + seq_p_size=4, # 序列并行大小 + seq_p_attn_type="ulysses", # 序列并行注意力类型 +) +``` + +**运行方式:** +```bash +torchrun --nproc_per_node=4 your_script.py +``` + +### 特征缓存 (Cache) + +可以指定缓存方法为 Mag 或 Tea,使用 MagCache 和 TeaCache 方法: + +```python +pipe.enable_cache( + cache_method='Tea', # 缓存方法: 'Tea' 或 'Mag' + coefficients=[-3.08907507e+04, 1.67786188e+04, -3.19178643e+03, + 2.60740519e+02, -8.19205881e+00, 1.07913775e-01], # 系数 + teacache_thresh=0.15, # TeaCache 阈值 +) +``` + +**系数参考:** 可参考 `configs/caching` 或 `configs/hunyuan_video_15/cache` 目录下的配置文件 + +### 轻量 VAE (Light VAE) + +使用轻量 VAE 可以加速解码并降低显存占用。 + +```python +pipe.enable_lightvae( + use_lightvae=False, # 是否使用 LightVAE + use_tae=False, # 是否使用 LightTAE + vae_path=None, # LightVAE 的路径 + tae_path=None, # LightTAE 的路径 +) +``` + +**支持情况:** +- **LightVAE**: 目前只支持 wan2.1、wan2.2 moe +- **LightTAE**: 目前只支持 wan2.1、wan2.2-ti2v、wan2.2 moe、HunyuanVideo-1.5 + +**模型下载:** 轻量 VAE 模型可从 [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) 下载 + +- Wan-2.1 的 LightVAE: [lightvaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lightvaew2_1.safetensors) +- Wan-2.1 的 LightTAE: [lighttaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_1.safetensors) +- Wan-2.2-ti2v 的 LightTAE: [lighttaew2_2.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_2.safetensors) +- HunyuanVideo-1.5 的 LightTAE: [lighttaehy1_5.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors) + +**使用示例:** + +```python +# 使用 HunyuanVideo-1.5 的 LightTAE +pipe.enable_lightvae( + use_tae=True, + tae_path="/path/to/lighttaehy1_5.safetensors", + use_lightvae=False, + vae_path=None +) +``` + +## 📚 更多资源 + +- [完整文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/) +- [GitHub 仓库](https://github.com/ModelTC/LightX2V) +- [HuggingFace 模型库](https://huggingface.co/lightx2v) diff --git a/examples/hunyuan_video/hunyuan_i2v.py b/examples/hunyuan_video/hunyuan_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..654b1f75333dda87cae3e135a7a5d22ae0301ce2 --- /dev/null +++ b/examples/hunyuan_video/hunyuan_i2v.py @@ -0,0 +1,63 @@ +""" +HunyuanVideo-1.5 image-to-video generation example with quantization. +This example demonstrates how to use LightX2V with HunyuanVideo-1.5 model for I2V generation, +including quantized model usage for reduced memory consumption. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for HunyuanVideo-1.5 I2V task +pipe = LightX2VPipeline( + model_path="/path/to/ckpts/hunyuanvideo-1.5/", + model_cls="hunyuan_video_1.5", + transformer_model_name="720p_i2v", + task="i2v", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_i2v_720p.json") + +# Enable offloading to significantly reduce VRAM usage with minimal speed impact +# Suitable for RTX 30/40/50 consumer GPUs +pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, +) + +# Enable quantization for reduced memory usage +# Quantized models can be downloaded from: https://huggingface.co/lightx2v/Hy1.5-Quantized-Models +pipe.enable_quantize( + quant_scheme="fp8-sgl", + dit_quantized=True, + dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors", + text_encoder_quantized=True, + image_encoder_quantized=False, + text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors", +) + +# Create generator with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=50, + num_frames=121, + guidance_scale=6.0, + sample_shift=7.0, + fps=24, +) + +# Generation parameters +seed = 42 +prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +negative_prompt = "" +save_result_path = "/path/to/save_results/output2.mp4" + +# Generate video +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/hunyuan_video/hunyuan_t2v.py b/examples/hunyuan_video/hunyuan_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..7283bdede10cf9a7597b7c7551aa63935c283795 --- /dev/null +++ b/examples/hunyuan_video/hunyuan_t2v.py @@ -0,0 +1,60 @@ +""" +HunyuanVideo-1.5 text-to-video generation example. +This example demonstrates how to use LightX2V with HunyuanVideo-1.5 model for T2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for HunyuanVideo-1.5 +pipe = LightX2VPipeline( + model_path="/path/to/ckpts/hunyuanvideo-1.5/", + model_cls="hunyuan_video_1.5", + transformer_model_name="720p_t2v", + task="t2v", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_t2v_720p.json") + +# Enable offloading to significantly reduce VRAM usage with minimal speed impact +# Suitable for RTX 30/40/50 consumer GPUs +pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, +) + +# Use lighttae +pipe.enable_lightvae( + use_tae=True, + tae_path="/path/to/lighttaehy1_5.safetensors", + use_lightvae=False, + vae_path=None, +) + +# Create generator with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=50, + num_frames=121, + guidance_scale=6.0, + sample_shift=9.0, + aspect_ratio="16:9", + fps=24, +) + +# Generation parameters +seed = 123 +prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." +negative_prompt = "" +save_result_path = "/path/to/save_results/output.mp4" + +# Generate video +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/hunyuan_video/hunyuan_t2v_distill.py b/examples/hunyuan_video/hunyuan_t2v_distill.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c7cdd53d73d9e0c20e6eb94d396c4db00b1078 --- /dev/null +++ b/examples/hunyuan_video/hunyuan_t2v_distill.py @@ -0,0 +1,55 @@ +""" +HunyuanVideo-1.5 text-to-video generation example. +This example demonstrates how to use LightX2V with HunyuanVideo-1.5 4-step distilled model for T2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for HunyuanVideo-1.5 +pipe = LightX2VPipeline( + model_path="/path/to/ckpts/hunyuanvideo-1.5/", + model_cls="hunyuan_video_1.5", + transformer_model_name="480p_t2v", + task="t2v", + # 4-step distilled model ckpt + dit_original_ckpt="/path/to/hy1.5_t2v_480p_lightx2v_4step.safetensors", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_t2v_720p.json") + +# Enable offloading to significantly reduce VRAM usage with minimal speed impact +# Suitable for RTX 30/40/50 consumer GPUs +pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, +) + +# Use lighttae +pipe.enable_lightvae( + use_tae=True, + tae_path="/path/to/lighttaehy1_5.safetensors", + use_lightvae=False, + vae_path=None, +) + +# Create generator with specified parameters +pipe.create_generator(attn_mode="sage_attn2", infer_steps=4, num_frames=81, guidance_scale=1, sample_shift=9.0, aspect_ratio="16:9", fps=16, denoising_step_list=[1000, 750, 500, 250]) + + +# Generation parameters +seed = 123 +prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." +negative_prompt = "" +save_result_path = "/data/nvme0/gushiqiao/LightX2V/save_results/output.mp4" + +# Generate video +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/wan/wan_animate.py b/examples/wan/wan_animate.py new file mode 100644 index 0000000000000000000000000000000000000000..317f34d71d2c10c6490c42279fd3ef60a58a98aa --- /dev/null +++ b/examples/wan/wan_animate.py @@ -0,0 +1,72 @@ +""" +Wan2.2 animate video generation example. +This example demonstrates how to use LightX2V with Wan2.2 model for animate video generation. + +First, run preprocessing: +1. Set up environment: pip install -r ../requirements_animate.txt +2. For animate mode: + python ../tools/preprocess/preprocess_data.py \ + --ckpt_path /path/to/Wan2.1-FLF2V-14B-720P/process_checkpoint \ + --video_path /path/to/video \ + --refer_path /path/to/ref_img \ + --save_path ../save_results/animate/process_results \ + --resolution_area 1280 720 \ + --retarget_flag +3. For replace mode: + python ../tools/preprocess/preprocess_data.py \ + --ckpt_path /path/to/Wan2.1-FLF2V-14B-720P/process_checkpoint \ + --video_path /path/to/video \ + --refer_path /path/to/ref_img \ + --save_path ../save_results/replace/process_results \ + --resolution_area 1280 720 \ + --iterations 3 \ + --k 7 \ + --w_len 1 \ + --h_len 1 \ + --replace_flag +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for animate task +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.1-FLF2V-14B-720P", + model_cls="wan2.2_animate", + task="animate", +) +pipe.replace_flag = True # Set to True for replace mode, False for animate mode + +# Alternative: create generator from config JSON file +# pipe.create_generator( +# config_json="../configs/wan/wan_animate_replace.json" +# ) + +# Create generator with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=20, + height=480, # Can be set to 720 for higher resolution + width=832, # Can be set to 1280 for higher resolution + num_frames=77, + guidance_scale=1, + sample_shift=5.0, + fps=30, +) + +seed = 42 +prompt = "视频中的人在做动作" +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +src_pose_path = "../save_results/animate/process_results/src_pose.mp4" +src_face_path = "../save_results/animate/process_results/src_face.mp4" +src_ref_images = "../save_results/animate/process_results/src_ref.png" +save_result_path = "/path/to/save_results/output.mp4" + +pipe.generate( + seed=seed, + src_pose_path=src_pose_path, + src_face_path=src_face_path, + src_ref_images=src_ref_images, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/wan/wan_flf2v.py b/examples/wan/wan_flf2v.py new file mode 100644 index 0000000000000000000000000000000000000000..1f18ccf4a0b440a07f594b1f4345f2ca470d8841 --- /dev/null +++ b/examples/wan/wan_flf2v.py @@ -0,0 +1,55 @@ +""" +Wan2.1 first-last-frame-to-video generation example. +This example demonstrates how to use LightX2V with Wan2.1 model for FLF2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for FLF2V task +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.1-FLF2V-14B-720P", + model_cls="wan2.1", + task="flf2v", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator( +# config_json="../configs/wan/wan_flf2v.json" +# ) + +# Optional: enable offloading to significantly reduce VRAM usage +# Suitable for RTX 30/40/50 consumer GPUs +# pipe.enable_offload( +# cpu_offload=True, +# offload_granularity="block", +# text_encoder_offload=True, +# image_encoder_offload=False, +# vae_offload=False, +# ) + +# Create generator with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=40, + height=480, # Can be set to 720 for higher resolution + width=832, # Can be set to 1280 for higher resolution + num_frames=81, + guidance_scale=5, + sample_shift=5.0, +) + +seed = 42 +prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +image_path = "../assets/inputs/imgs/flf2v_input_first_frame-fs8.png" +last_frame_path = "../assets/inputs/imgs/flf2v_input_last_frame-fs8.png" +save_result_path = "/path/to/save_results/output.mp4" + +pipe.generate( + image_path=image_path, + last_frame_path=last_frame_path, + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/wan/wan_i2v.py b/examples/wan/wan_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..632566b4be0447ee8f442882b67e40f2a848ec66 --- /dev/null +++ b/examples/wan/wan_i2v.py @@ -0,0 +1,56 @@ +""" +Wan2.2 image-to-video generation example. +This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for Wan2.2 I2V task +# For wan2.1, use model_cls="wan2.1" +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.2-I2V-A14B", + model_cls="wan2.2_moe", + task="i2v", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator( +# config_json="../configs/wan22/wan_moe_i2v.json" +# ) + +# Enable offloading to significantly reduce VRAM usage with minimal speed impact +# Suitable for RTX 30/40/50 consumer GPUs +pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", # For Wan models, supports both "block" and "phase" + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, +) + +# Create generator manually with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=40, + height=480, # Can be set to 720 for higher resolution + width=832, # Can be set to 1280 for higher resolution + num_frames=81, + guidance_scale=[3.5, 3.5], # For wan2.1, guidance_scale is a scalar (e.g., 5.0) + sample_shift=5.0, +) + +# Generation parameters +seed = 42 +prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +image_path = "/path/to/img_0.jpg" +save_result_path = "/path/to/save_results/output.mp4" + +# Generate video +pipe.generate( + seed=seed, + image_path=image_path, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/wan/wan_i2v_distilled.py b/examples/wan/wan_i2v_distilled.py new file mode 100644 index 0000000000000000000000000000000000000000..be89488b03f3f3507bf734be6c326d801c7e5637 --- /dev/null +++ b/examples/wan/wan_i2v_distilled.py @@ -0,0 +1,57 @@ +""" +Wan2.2 distilled model image-to-video generation example. +This example demonstrates how to use LightX2V with Wan2.2 distilled model for I2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for Wan2.2 distilled I2V task +# For wan2.1, use model_cls="wan2.1_distill" +pipe = LightX2VPipeline( + model_path="/path/to/wan2.2/Wan2.2-I2V-A14B", + model_cls="wan2.2_moe_distill", + task="i2v", + # Distilled weights: For wan2.1, only need to specify dit_original_ckpt="/path/to/wan2.1_i2v_720p_lightx2v_4step.safetensors" + low_noise_original_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors", + high_noise_original_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator( +# config_json="../configs/wan22/wan_moe_i2v_distill.json" +# ) + +# Enable offloading to significantly reduce VRAM usage +# Suitable for RTX 30/40/50 consumer GPUs +pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, +) + +# Create generator manually with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=4, + height=480, # Can be set to 720 for higher resolution + width=832, # Can be set to 1280 for higher resolution + num_frames=81, + guidance_scale=1, + sample_shift=5.0, +) + +seed = 42 +prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +save_result_path = "/path/to/save_results/output.mp4" +image_path = "/path/to/img_0.jpg" + +pipe.generate( + seed=seed, + image_path=image_path, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/wan/wan_i2v_with_distill_loras.py b/examples/wan/wan_i2v_with_distill_loras.py new file mode 100644 index 0000000000000000000000000000000000000000..0d629d52b94cbfb6226a7db36e734b2658c27132 --- /dev/null +++ b/examples/wan/wan_i2v_with_distill_loras.py @@ -0,0 +1,62 @@ +""" +Wan2.2 distilled model with LoRA image-to-video generation example. +This example demonstrates how to use LightX2V with Wan2.2 distilled model and LoRA for I2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for Wan2.2 distilled I2V task with LoRA +# For wan2.1, use model_cls="wan2.1_distill" +pipe = LightX2VPipeline( + model_path="/path/to/wan2.2/Wan2.2-I2V-A14B", + model_cls="wan2.2_moe_distill", + task="i2v", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator( +# config_json="../configs/wan22/wan_moe_i2v_distill_with_lora.json" +# ) + +# Enable offloading to significantly reduce VRAM usage +# Suitable for RTX 30/40/50 consumer GPUs +pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, +) + +# Load distilled LoRA weights +pipe.enable_lora( + [ + {"name": "high_noise_model", "path": "/path/to/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "strength": 1.0}, + {"name": "low_noise_model", "path": "/path/to/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", "strength": 1.0}, + ] +) + +# Create generator with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=4, + height=480, # Can be set to 720 for higher resolution + width=832, # Can be set to 1280 for higher resolution + num_frames=81, + guidance_scale=1, + sample_shift=5.0, +) + +seed = 42 +prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +save_result_path = "/path/to/save_results/output.mp4" +image_path = "/path/to/img_0.jpg" + +pipe.generate( + seed=seed, + image_path=image_path, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/wan/wan_t2v.py b/examples/wan/wan_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2e71b27e3658f86dae33eb8d3907a1e831a8db --- /dev/null +++ b/examples/wan/wan_t2v.py @@ -0,0 +1,39 @@ +""" +Wan2.1 text-to-video generation example. +This example demonstrates how to use LightX2V with Wan2.1 model for T2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for Wan2.1 T2V task +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.1-T2V-14B", + model_cls="wan2.1", + task="t2v", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator(config_json="../configs/wan/wan_t2v.json") + +# Create generator with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=50, + height=480, # Can be set to 720 for higher resolution + width=832, # Can be set to 1280 for higher resolution + num_frames=81, + guidance_scale=5.0, + sample_shift=5.0, +) + +seed = 42 +prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +save_result_path = "/path/to/save_results/output.mp4" + +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/examples/wan/wan_vace.py b/examples/wan/wan_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..8bbf88c2db6848fc9fab6aaa70487226c3ec4a87 --- /dev/null +++ b/examples/wan/wan_vace.py @@ -0,0 +1,52 @@ +""" +Wan2.1 VACE (Video Animate Character Exchange) generation example. +This example demonstrates how to use LightX2V with Wan2.1 VACE model for character exchange in videos. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for VACE task +pipe = LightX2VPipeline( + model_path="/path/to/Wan2.1-VACE-1.3B", + src_ref_images="../assets/inputs/imgs/girl.png,../assets/inputs/imgs/snake.png", + model_cls="wan2.1_vace", + task="vace", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator( +# config_json="../configs/wan/wan_vace.json" +# ) + +# Optional: enable offloading to significantly reduce VRAM usage +# Suitable for RTX 30/40/50 consumer GPUs +# pipe.enable_offload( +# cpu_offload=True, +# offload_granularity="block", +# text_encoder_offload=True, +# image_encoder_offload=False, +# vae_offload=False, +# ) + +# Create generator with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=40, + height=480, # Can be set to 720 for higher resolution + width=832, # Can be set to 1280 for higher resolution + num_frames=81, + guidance_scale=5, + sample_shift=16, +) + +seed = 42 +prompt = "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +save_result_path = "/path/to/save_results/output.mp4" + +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/lightx2v/__init__.py b/lightx2v/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7298b043cb7a1a768579e0639eb0a2c17c82451b --- /dev/null +++ b/lightx2v/__init__.py @@ -0,0 +1,18 @@ +__version__ = "0.1.0" +__author__ = "LightX2V Contributors" +__license__ = "Apache 2.0" + +import lightx2v_platform.set_ai_device +from lightx2v import common, deploy, models, utils +from lightx2v.pipeline import LightX2VPipeline + +__all__ = [ + "__version__", + "__author__", + "__license__", + "models", + "common", + "deploy", + "utils", + "LightX2VPipeline", +] diff --git a/lightx2v/common/__init__.py b/lightx2v/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/common/modules/__init__.py b/lightx2v/common/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/common/modules/weight_module.py b/lightx2v/common/modules/weight_module.py new file mode 100644 index 0000000000000000000000000000000000000000..59d646cc57df3ccbdb72b687f357c2ced68bceb5 --- /dev/null +++ b/lightx2v/common/modules/weight_module.py @@ -0,0 +1,175 @@ +class WeightModule: + def __init__(self): + self._modules = {} + self._parameters = {} + + def is_empty(self): + return len(self._modules) == 0 and len(self._parameters) == 0 + + def add_module(self, name, module): + self._modules[name] = module + setattr(self, name, module) + + def register_parameter(self, name, param): + self._parameters[name] = param + setattr(self, name, param) + + def load(self, weight_dict): + for _, module in self._modules.items(): + if hasattr(module, "load"): + module.load(weight_dict) + + for _, parameter in self._parameters.items(): + if hasattr(parameter, "load"): + parameter.load(weight_dict) + + def state_dict(self, destination=None): + if destination is None: + destination = {} + for _, param in self._parameters.items(): + if param is not None: + param.state_dict(destination) + for _, module in self._modules.items(): + if module is not None: + module.state_dict(destination) + return destination + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if destination is None: + destination = {} + for _, param in self._parameters.items(): + if param is not None: + param.load_state_dict(destination, block_index, adapter_block_index) + for _, module in self._modules.items(): + if module is not None: + module.load_state_dict(destination, block_index, adapter_block_index) + return destination + + def load_state_dict_from_disk(self, block_index, adapter_block_index=None): + for _, param in self._parameters.items(): + if param is not None: + param.load_state_dict_from_disk(block_index, adapter_block_index) + for _, module in self._modules.items(): + if module is not None: + module.load_state_dict_from_disk(block_index, adapter_block_index) + + def named_parameters(self, prefix=""): + for name, param in self._parameters.items(): + if param is not None: + yield prefix + name, param + for name, module in self._modules.items(): + if module is not None: + yield from module.named_parameters(prefix + name + ".") + + def to_cpu(self): + for name, param in self._parameters.items(): + if param is not None: + if hasattr(param, "cpu"): + self._parameters[name] = param.cpu() + setattr(self, name, self._parameters[name]) + elif hasattr(param, "to_cpu"): + self._parameters[name].to_cpu() + setattr(self, name, self._parameters[name]) + for module in self._modules.values(): + if isinstance(module, WeightModuleList): + for i in range(len(module)): + for m in module[i]._modules.values(): + if m is not None and hasattr(m, "to_cpu"): + m.to_cpu() + for m in module[i]._parameters.values(): + if m is not None and hasattr(m, "to_cpu"): + m.to_cpu() + else: + if module is not None and hasattr(module, "to_cpu"): + module.to_cpu() + + def to_cuda(self): + for name, param in self._parameters.items(): + if param is not None: + if hasattr(param, "cuda"): + self._parameters[name] = param.cuda() + elif hasattr(param, "to_cuda"): + self._parameters[name].to_cuda() + setattr(self, name, self._parameters[name]) + for module in self._modules.values(): + if isinstance(module, WeightModuleList): + for i in range(len(module)): + for m in module[i]._modules.values(): + if m is not None and hasattr(m, "to_cuda"): + m.to_cuda() + for m in module[i]._parameters.values(): + if m is not None and hasattr(m, "to_cuda"): + m.to_cuda() + else: + if module is not None and hasattr(module, "to_cuda"): + module.to_cuda() + + def to_cpu_async(self): + for name, param in self._parameters.items(): + if param is not None: + if hasattr(param, "cpu"): + self._parameters[name] = param.cpu(non_blocking=True) + setattr(self, name, self._parameters[name]) + elif hasattr(param, "to_cpu"): + self._parameters[name].to_cpu(non_blocking=True) + setattr(self, name, self._parameters[name]) + for module in self._modules.values(): + if isinstance(module, WeightModuleList): + for i in range(len(module)): + for m in module[i]._modules.values(): + if m is not None and hasattr(m, "to_cpu"): + m.to_cpu(non_blocking=True) + for m in module[i]._parameters.values(): + if m is not None and hasattr(m, "to_cpu"): + m.to_cpu(non_blocking=True) + else: + if module is not None and hasattr(module, "to_cpu"): + module.to_cpu(non_blocking=True) + + def to_cuda_async(self): + for name, param in self._parameters.items(): + if param is not None: + if hasattr(param, "cuda"): + self._parameters[name] = param.cuda(non_blocking=True) + elif hasattr(param, "to_cuda"): + self._parameters[name].to_cuda(non_blocking=True) + setattr(self, name, self._parameters[name]) + for module in self._modules.values(): + if isinstance(module, WeightModuleList): + for i in range(len(module)): + for m in module[i]._modules.values(): + if m is not None and hasattr(m, "to_cuda"): + m.to_cuda(non_blocking=True) + for m in module[i]._parameters.values(): + if m is not None and hasattr(m, "to_cuda"): + m.to_cuda(non_blocking=True) + else: + if module is not None and hasattr(module, "to_cuda"): + module.to_cuda(non_blocking=True) + + +class WeightModuleList(WeightModule): + def __init__(self, modules=None): + super().__init__() + self._list = [] + if modules is not None: + for idx, module in enumerate(modules): + self.append(module) + + def append(self, module): + idx = len(self._list) + self._list.append(module) + self.add_module(str(idx), module) + + def __getitem__(self, idx): + return self._list[idx] + + def __setitem__(self, idx, module): + self._list[idx] = module + self.add_module(str(idx), module) + + def __len__(self): + return len(self._list) + + def __iter__(self): + return iter(self._list) diff --git a/lightx2v/common/offload/manager.py b/lightx2v/common/offload/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..c1671ebec66cf8a17f4c53c1b4191e836e0c3936 --- /dev/null +++ b/lightx2v/common/offload/manager.py @@ -0,0 +1,133 @@ +from concurrent.futures import ThreadPoolExecutor + +import torch +from loguru import logger +from packaging.version import parse +from tqdm import tqdm + +from lightx2v.utils.profiler import ExcludedProfilingContext +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +class WeightAsyncStreamManager(object): + def __init__(self, offload_granularity): + self.offload_granularity = offload_granularity + self.init_stream = torch_device_module.Stream(priority=0) + self.need_init_first_buffer = True + self.lazy_load = False + torch_version = parse(torch.__version__.split("+")[0]) + if AI_DEVICE == "cuda" and torch_version >= parse("2.7"): + self.cuda_load_stream = torch_device_module.Stream(priority=1) + self.compute_stream = torch_device_module.Stream(priority=1) + else: + self.cuda_load_stream = torch_device_module.Stream(priority=0) + self.compute_stream = torch_device_module.Stream(priority=-1) + + def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None): + self.need_init_first_buffer = True + if self.offload_granularity == "block": + assert blocks_cpu_buffer is not None + self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))] + elif self.offload_granularity == "phase": + assert phases_cpu_buffer is not None + self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))] + else: + raise NotImplementedError + + def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None): + self.need_init_first_buffer = True + if self.offload_granularity == "block": + assert blocks_cuda_buffer is not None + self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))] + elif self.offload_granularity == "phase": + assert phases_cuda_buffer is not None + self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))] + else: + raise NotImplementedError + + def init_first_buffer(self, blocks, adapter_block_idx=None): + with torch_device_module.stream(self.init_stream): + if hasattr(self, "cpu_buffers"): + self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx) + else: + if self.offload_granularity == "block": + self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx) + else: + self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx) + self.init_stream.synchronize() + self.need_init_first_buffer = False + + def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None): + with torch_device_module.stream(self.cuda_load_stream): + if hasattr(self, "cpu_buffers"): + self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx) + self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx) + else: + self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx) + + def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None): + with torch_device_module.stream(self.cuda_load_stream): + if hasattr(self, "cpu_buffers"): + self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx) + else: + self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx) + + def swap_blocks(self): + self.cuda_load_stream.synchronize() + self.compute_stream.synchronize() + self.cuda_buffers[0], self.cuda_buffers[1] = ( + self.cuda_buffers[1], + self.cuda_buffers[0], + ) + + def swap_phases(self): + self.cuda_load_stream.synchronize() + self.compute_stream.synchronize() + + @ExcludedProfilingContext("🔥 warm_up_cpu_buffers") + def warm_up_cpu_buffers(self, blocks_num): + logger.info("🔥 Warming up cpu buffers...") + for i in tqdm(range(blocks_num)): + for phase in self.cpu_buffers[0]: + phase.load_state_dict_from_disk(i, None) + for phase in self.cpu_buffers[1]: + phase.load_state_dict_from_disk(i, None) + + for phase in self.cpu_buffers[0]: + phase.load_state_dict_from_disk(0, None) + for phase in self.cpu_buffers[1]: + phase.load_state_dict_from_disk(1, None) + logger.info("✅ CPU buffers warm-up completed.") + + def init_lazy_load(self, num_workers=6): + self.lazy_load = True + self.executor = ThreadPoolExecutor(max_workers=num_workers) + self.prefetch_futures = [] + self.prefetch_block_idx = -1 + + def start_prefetch_block(self, block_idx, adapter_block_idx=None): + self.prefetch_block_idx = block_idx + self.prefetch_futures = [] + for phase in self.cpu_buffers[1]: + future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx) + self.prefetch_futures.append(future) + + def swap_cpu_buffers(self): + # wait_start = time.time() + # already_done = all(f.done() for f in self.prefetch_futures) + for f in self.prefetch_futures: + f.result() + # wait_time = time.time() - wait_start + # logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}") + self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]] + + def __del__(self): + if hasattr(self, "executor") and self.executor is not None: + for f in self.prefetch_futures: + if not f.done(): + f.result() + self.executor.shutdown(wait=False) + self.executor = None + logger.debug("ThreadPoolExecutor shut down successfully.") diff --git a/lightx2v/common/ops/__init__.py b/lightx2v/common/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aab2983ef2b2649da7a004b59ef96993f4fd0d54 --- /dev/null +++ b/lightx2v/common/ops/__init__.py @@ -0,0 +1,6 @@ +from .attn import * +from .conv import * +from .embedding import * +from .mm import * +from .norm import * +from .tensor import * diff --git a/lightx2v/common/ops/attn/__init__.py b/lightx2v/common/ops/attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65e23ddef56a45665ed795b3e56ef6516e7d09f4 --- /dev/null +++ b/lightx2v/common/ops/attn/__init__.py @@ -0,0 +1,10 @@ +from .flash_attn import FlashAttn2Weight, FlashAttn3Weight +from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer +from .radial_attn import RadialAttnWeight +from .ring_attn import RingAttnWeight +from .sage_attn import SageAttn2Weight, SageAttn3Weight +from .spassage_attn import SageAttnWeight +from .svg2_attn import Svg2AttnWeight +from .svg_attn import SvgAttnWeight +from .torch_sdpa import TorchSDPAWeight +from .ulysses_attn import Ulysses4090AttnWeight, UlyssesAttnWeight diff --git a/lightx2v/common/ops/attn/flash_attn.py b/lightx2v/common/ops/attn/flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8150bb034ec42ebb57d43fa31a1b045eee468e --- /dev/null +++ b/lightx2v/common/ops/attn/flash_attn.py @@ -0,0 +1,89 @@ +from loguru import logger + +try: + import flash_attn # noqa: F401 + from flash_attn.flash_attn_interface import flash_attn_varlen_func +except ImportError: + logger.info("flash_attn_varlen_func not found, please install flash_attn2 first") + flash_attn_varlen_func = None + +try: + from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 +except ImportError: + logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") + flash_attn_varlen_func_v3 = None + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .template import AttnWeightTemplate + + +@ATTN_WEIGHT_REGISTER("flash_attn2") +class FlashAttn2Weight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + if len(q.shape) == 3: + bs = 1 + elif len(q.shape) == 4: + bs = q.shape[0] + q = q.reshape(-1, q.shape[-2], q.shape[-1]) + k = k.reshape(-1, k.shape[-2], k.shape[-1]) + v = v.reshape(-1, v.shape[-2], v.shape[-1]) + x = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ).reshape(bs * max_seqlen_q, -1) + return x + + +@ATTN_WEIGHT_REGISTER("flash_attn3") +class FlashAttn3Weight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + if len(q.shape) == 3: + bs = 1 + elif len(q.shape) == 4: + bs = q.shape[0] + if model_cls is not None and model_cls in ["hunyuan_video_1.5"]: + q = q.reshape(-1, q.shape[-2], q.shape[-1]) + k = k.reshape(-1, k.shape[-2], k.shape[-1]) + v = v.reshape(-1, v.shape[-2], v.shape[-1]) + x = flash_attn_varlen_func_v3( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ).reshape(bs * max_seqlen_q, -1) + return x diff --git a/lightx2v/common/ops/attn/nbhd_attn.py b/lightx2v/common/ops/attn/nbhd_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f32b88aac2dfc5d2a36554721a297c8767d3bc --- /dev/null +++ b/lightx2v/common/ops/attn/nbhd_attn.py @@ -0,0 +1,196 @@ +import torch +from loguru import logger + +try: + from magi_attention.functional import flex_flash_attn_func as magi_ffa_func +except ImportError: + magi_ffa_func = None + +try: + import flashinfer +except ImportError: + flashinfer = None + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .template import AttnWeightTemplate + + +def generate_nbhd_mask(a, block_num, attnmap_frame_num, coefficient=[1.0, 0.5, 0.056], min_width=1.0, device="cpu"): + """ + a : block num per frame + block_num : block num per col/row + attnmap_frame_num : total frame num + """ + i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1] + j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num] + + assert len(coefficient) <= attnmap_frame_num, f"coefficient length {len(coefficient)} should <= attnmap_frame_num {attnmap_frame_num}" + width_list = [max(min_width, coefficient[i] * a) for i in range(len(coefficient))] + [min_width] * (attnmap_frame_num - len(coefficient)) + logger.info(f"nbhd_attn width_list: {width_list}, len={len(width_list)}") + + # attention sink frame: j <= a + mask_sink = j_indices <= a + + mask_sparse = torch.zeros((block_num, block_num), dtype=torch.bool, device=device) + for interval in range(0, attnmap_frame_num): + n = i_indices // a + mask_sparse_base_1 = (j_indices >= (n + interval) * a) & (j_indices <= (n + interval + 1) * a) + n = j_indices // a + mask_sparse_base_2 = (i_indices >= (n + interval) * a) & (i_indices <= (n + interval + 1) * a) + + width = width_list[interval] + + mask_1 = mask_sparse_base_1 & (i_indices - j_indices + (interval * a + width) >= 0) & (i_indices - j_indices + (interval * a - width) <= 0) + mask_2 = mask_sparse_base_2 & (i_indices - j_indices - (interval * a - width) >= 0) & (i_indices - j_indices - (interval * a + width) <= 0) + + mask_sparse = mask_sparse | mask_1 | mask_2 + + mask = mask_sink | mask_sparse + return mask + + +def generate_qk_ranges(mask, block_size, seqlen): + indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2] + + i_indices = indices[:, 0] # [N] + j_indices = indices[:, 1] # [N] + + q_start = i_indices * block_size # [N] + q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N] + + k_start = j_indices * block_size # [N] + k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N] + + q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2] + k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2] + + return q_ranges, k_ranges + + +@ATTN_WEIGHT_REGISTER("nbhd_attn") +class NbhdAttnWeight(AttnWeightTemplate): + block_size = 128 + seqlen = None + attnmap_frame_num = None + q_ranges = None + k_ranges = None + attn_type_map = None + coefficient = [1.0, 0.5, 0.056] + min_width = 1.0 + + def __init__(self): + self.config = {} + + @classmethod + @torch.compiler.disable + def prepare_mask(cls, seqlen): + if seqlen == cls.seqlen: + return + block_num = (seqlen + cls.block_size - 1) // cls.block_size + block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size + mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu") + q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen) + attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda") + q_ranges = q_ranges.to(torch.int32).to("cuda") + k_ranges = k_ranges.to(torch.int32).to("cuda") + cls.seqlen = seqlen + cls.q_ranges = q_ranges + cls.k_ranges = k_ranges + cls.attn_type_map = attn_type_map + logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}") + sparsity = 1 - mask.sum().item() / mask.numel() + logger.info(f"Attention sparsity: {sparsity}") + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + """ + q: [seqlen, head_num, head_dim] + k: [seqlen, head_num, head_dim] + v: [seqlen, head_num, head_dim] + """ + self.prepare_mask(seqlen=q.shape[0]) + out = magi_ffa_func( + q, + k, + v, + q_ranges=self.q_ranges, + k_ranges=self.k_ranges, + attn_type_map=self.attn_type_map, + auto_range_merge=True, + )[0] + return out.reshape(out.shape[0], -1) + + +@ATTN_WEIGHT_REGISTER("nbhd_attn_flashinfer") +class NbhdAttnWeightFlashInfer(AttnWeightTemplate): + block_size = 128 + seqlen = None + attnmap_frame_num = None + coefficient = [1.0, 0.5, 0.056] + min_width = 1.0 + sparse_wrapper = None + + def __init__(self): + self.config = {} + + @classmethod + @torch.compiler.disable + def prepare_mask(cls, seqlen, head_num, head_dim): + if seqlen == cls.seqlen: + return + block_num = (seqlen + cls.block_size - 1) // cls.block_size + block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size + mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu") + mask = mask.unsqueeze(0).repeat(head_num, 1, 1) + block_rowcol_size = torch.ones(block_num, dtype=torch.int32) * cls.block_size + block_rowcol_size[-1] = seqlen - cls.block_size * (block_num - 1) + block_rowcol_size = block_rowcol_size.unsqueeze(0).repeat(head_num, 1) + float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") + cls.sparse_wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="fa2") + cls.sparse_wrapper.plan( + block_mask_map=mask, + block_row_sz=block_rowcol_size, + block_col_sz=block_rowcol_size, + num_qo_heads=head_num, + num_kv_heads=head_num, + head_dim=head_dim, + q_data_type=torch.bfloat16, + ) + cls.seqlen = seqlen + logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}") + sparsity = 1 - mask.sum().item() / mask.numel() + logger.info(f"Attention sparsity: {sparsity}") + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + """ + q: [seqlen, head_num, head_dim] + k: [seqlen, head_num, head_dim] + v: [seqlen, head_num, head_dim] + """ + self.prepare_mask(seqlen=q.shape[0], head_num=q.shape[1], head_dim=q.shape[2]) + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + out = self.sparse_wrapper.run(q, k, v) + out = out.transpose(0, 1) + return out.reshape(out.shape[0], -1) diff --git a/lightx2v/common/ops/attn/radial_attn.py b/lightx2v/common/ops/attn/radial_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..002f149b75f17f29ec9af6b5d47de05364e94252 --- /dev/null +++ b/lightx2v/common/ops/attn/radial_attn.py @@ -0,0 +1,185 @@ +import torch +from loguru import logger + +try: + from magi_attention.functional import flex_flash_attn_func as magi_ffa_func +except ImportError: + magi_ffa_func = None + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .template import AttnWeightTemplate + + +def shrinkMaskStrict(mask, block_size=128): + seqlen = mask.shape[0] + block_num = seqlen // block_size + mask = mask[: block_num * block_size, : block_num * block_size].view(block_num, block_size, block_num, block_size) + col_densities = mask.sum(dim=1) / block_size + # we want the minimum non-zero column density in the block + non_zero_densities = col_densities > 0 + high_density_cols = col_densities > 1 / 3 + frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9) + block_mask = frac_high_density_cols > 0.6 + block_mask[0:0] = True + block_mask[-1:-1] = True + return block_mask + + +def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None): + assert sparse_type in ["radial"] + dist = abs(i - j) + if model_type == "wan": + if dist < 1: + return token_per_frame + if dist == 1: + return token_per_frame // 2 + elif model_type == "hunyuan": + if dist <= 1: + return token_per_frame + else: + raise ValueError(f"Unknown model type: {model_type}") + group = dist.bit_length() + decay_length = 2 ** token_per_frame.bit_length() / 2**group * decay_factor + threshold = block_size + if decay_length >= threshold: + return decay_length + else: + return threshold + + +def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device): + assert sparse_type in ["radial"] + dist = abs(i - j) + group = dist.bit_length() + threshold = 128 # hardcoded threshold for now, which is equal to block-size + decay_length = 2 ** token_per_frame.bit_length() / 2**group + if decay_length >= threshold: + return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool) + + split_factor = int(threshold / decay_length) + modular = dist % split_factor + if modular == 0: + return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool) + else: + return torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool) + + +def gen_log_mask_shrinked(device, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None): + """ + A more memory friendly version, we generate the attention mask of each frame pair at a time, + shrinks it, and stores it into the final result + """ + final_log_mask = torch.zeros(((s + block_size - 1) // block_size, (s + block_size - 1) // block_size), device=device, dtype=torch.bool) + token_per_frame = video_token_num // num_frame + video_text_border = video_token_num // block_size + + col_indices = torch.arange(0, token_per_frame, device=device).view(1, -1) + row_indices = torch.arange(0, token_per_frame, device=device).view(-1, 1) + final_log_mask[video_text_border:] = True + final_log_mask[:, video_text_border:] = True + for i in range(num_frame): + for j in range(num_frame): + local_mask = torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool) + if j == 0 and model_type == "wan": # this is attention sink + local_mask = torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool) + else: + window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type) + local_mask = torch.abs(col_indices - row_indices) <= window_width + split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device) + local_mask = torch.logical_and(local_mask, split_mask) + + remainder_row = (i * token_per_frame) % block_size + remainder_col = (j * token_per_frame) % block_size + # get the padded size + all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size + all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size + padded_local_mask = torch.zeros((all_length_row, all_length_col), device=device, dtype=torch.bool) + padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask + # shrink the mask + block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size) + # set the block mask to the final log mask + block_row_start = (i * token_per_frame) // block_size + block_col_start = (j * token_per_frame) // block_size + block_row_end = block_row_start + block_mask.shape[0] + block_col_end = block_col_start + block_mask.shape[1] + final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask) + return final_log_mask + + +def generate_qk_ranges(mask, block_size, seqlen): + indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2] + + i_indices = indices[:, 0] # [N] + j_indices = indices[:, 1] # [N] + + q_start = i_indices * block_size # [N] + q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N] + + k_start = j_indices * block_size # [N] + k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N] + + q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2] + k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2] + + return q_ranges, k_ranges + + +@ATTN_WEIGHT_REGISTER("radial_attn") +class RadialAttnWeight(AttnWeightTemplate): + block_size = 128 + seqlen = None + attnmap_frame_num = None + q_ranges = None + k_ranges = None + attn_type_map = None + + def __init__(self): + self.config = {} + + @classmethod + def prepare_mask(cls, seqlen): + if seqlen == cls.seqlen: + return + mask = gen_log_mask_shrinked( + device="cuda", s=seqlen, video_token_num=seqlen, num_frame=cls.attnmap_frame_num, block_size=cls.block_size, sparse_type="radial", decay_factor=0.2, model_type="wan" + ) + q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen) + attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda") + q_ranges = q_ranges.to(torch.int32).to("cuda") + k_ranges = k_ranges.to(torch.int32).to("cuda") + cls.seqlen = seqlen + cls.q_ranges = q_ranges + cls.k_ranges = k_ranges + cls.attn_type_map = attn_type_map + logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}") + sparsity = 1 - mask.sum().item() / mask.numel() + logger.info(f"Attention sparsity: {sparsity}") + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + """ + q: [seqlen, head_num, head_dim] + k: [seqlen, head_num, head_dim] + v: [seqlen, head_num, head_dim] + """ + self.prepare_mask(seqlen=q.shape[0]) + out = magi_ffa_func( + q, + k, + v, + q_ranges=self.q_ranges, + k_ranges=self.k_ranges, + attn_type_map=self.attn_type_map, + auto_range_merge=True, + )[0] + return out.reshape(out.shape[0], -1) diff --git a/lightx2v/common/ops/attn/ring_attn.py b/lightx2v/common/ops/attn/ring_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..8e96117f8eabea7d2ac06c2073140383c4979962 --- /dev/null +++ b/lightx2v/common/ops/attn/ring_attn.py @@ -0,0 +1,179 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +from loguru import logger + +from lightx2v.utils.envs import * +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .template import AttnWeightTemplate +from .utils.ring_comm import RingComm + +try: + import flash_attn + from flash_attn.flash_attn_interface import flash_attn_varlen_func +except ImportError: + logger.info("flash_attn_varlen_func not found, please install flash_attn2 first") + flash_attn_varlen_func = None + + +@torch.jit.script +def _update_out_and_lse( + out, + lse, + block_out, + block_lse, +): + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + return out, lse + + +@ATTN_WEIGHT_REGISTER("ring") +class RingAttnWeight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False): + """ + 执行 Ring 注意力机制,结合图像和文本的查询、键和值。 + + 参数: + q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims] + k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims] + v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims] + img_qkv_len (int): 图像查询、键和值的长度 + cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息 + attention_type (str): 注意力类型,默认为 "flash_attn2" + + 返回: + torch.Tensor: 计算得到的注意力结果 + """ + assert not use_fp8_comm, "RingAttn can't support fp8 comm now." + + # 获取当前进程的排名和全局进程数 + cur_rank = dist.get_rank(seq_p_group) + world_size = dist.get_world_size(seq_p_group) + + if len(cu_seqlens_qkv) == 3: + txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 + txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 + elif len(cu_seqlens_qkv) == 2: + txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 + txt_mask_len = 0 + + # if RING_COMM is None: + # init_ring_comm() + + RING_COMM = RingComm(seq_p_group) + + # if len(cu_seqlens_qkv) == 3: + # txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 + # txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 + # elif len(cu_seqlens_qkv) == 2: + # txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 + # txt_mask_len = None + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous() + txt_q, txt_k, txt_v = ( + q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), + k[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), + v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), + ) + + out, lse, next_k, next_v = None, None, None, None + + if len(cu_seqlens_qkv) == 3: + q = torch.cat((img_q, txt_q), dim=1) + k = img_k + v = img_v + + for step in range(world_size): + if step + 1 != world_size: + next_k = RING_COMM.send_recv(k) + next_v = RING_COMM.send_recv(v) + RING_COMM.commit() + + if step + 1 == world_size: + k = torch.cat((k, txt_k), dim=1) + v = torch.cat((v, txt_v), dim=1) + + block_out, block_lse = self.ring_attn_sub(q, k, v) + + out, lse = self.update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != world_size: + RING_COMM.wait() + k = next_k + v = next_v + + attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1) + + if txt_mask_len > 0: + attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward( + q[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(), + k[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(), + v[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(), + dropout_p=0.0, + softmax_scale=q.shape[-1] ** (-0.5), + causal=False, + window_size_left=-1, + window_size_right=-1, + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + ) + + attn2 = attn2.to(GET_DTYPE()).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1) + attn1 = torch.cat([attn1, attn2], dim=0) + + return attn1 + + def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax, + ) + return block_out, block_lse + + def update_out_and_lse( + self, + out, + lse, + block_out, + block_lse, + slice_=None, + ): + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse diff --git a/lightx2v/common/ops/attn/sage_attn.py b/lightx2v/common/ops/attn/sage_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..7cde1eaf6f8b311863a9cf3850b6d04843523431 --- /dev/null +++ b/lightx2v/common/ops/attn/sage_attn.py @@ -0,0 +1,83 @@ +import torch +from loguru import logger + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .template import AttnWeightTemplate + +if torch.cuda.is_available() and torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]: + try: + from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn + except ImportError: + logger.info("sageattn not found, please install sageattention first") + sageattn = None +else: + try: + from sageattention import sageattn + except ImportError: + logger.info("sageattn not found, please install sageattention first") + sageattn = None + +try: + from sageattn3 import sageattn3_blackwell +except ImportError: + logger.info("sageattn3 not found, please install sageattention first") + sageattn3_blackwell = None + + +@ATTN_WEIGHT_REGISTER("sage_attn2") +class SageAttn2Weight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + q, k, v = q.contiguous(), k.contiguous(), v.contiguous() + if len(q.shape) == 3: + bs = 1 + q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) + elif len(q.shape) == 4: + bs = q.shape[0] + x = sageattn( + q, + k, + v, + tensor_layout="NHD", + ).view(bs * max_seqlen_q, -1) + return x + + +@ATTN_WEIGHT_REGISTER("sage_attn3") +class SageAttn3Weight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + q, k, v = q.contiguous(), k.contiguous(), v.contiguous() + if len(q.shape) == 3: + bs = 1 + q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) + elif len(q.shape) == 4: + bs = q.shape[0] + + x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1) + return x diff --git a/lightx2v/common/ops/attn/spassage_attn.py b/lightx2v/common/ops/attn/spassage_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..b7fbf4644a5a7be634a2be957845f4cbcf3a8a3d --- /dev/null +++ b/lightx2v/common/ops/attn/spassage_attn.py @@ -0,0 +1,76 @@ +import os + +import torch + +try: + import spas_sage_attn +except ImportError: + spas_sage_attn = None + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .template import AttnWeightTemplate + + +@ATTN_WEIGHT_REGISTER("spas_sage_attn") +class SageAttnWeight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + @classmethod + def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, tensor_layout="HND"): + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + attn_out = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout) + _, H, N, D = attn_out.shape + attn_out = attn_out.permute(2, 1, 3, 0).contiguous().view(N, H * D) + return attn_out + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + # 1. 构造输入 + q = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda() + k = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda() + v = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda() + + # 2. 直接用PyTorch计算注意力 + q_ = q.float() + k_ = k.float() + v_ = v.float() + attn_weights = torch.matmul(q_, k_.transpose(-2, -1)) / (128**0.5) + attn_weights = torch.softmax(attn_weights, dim=-1) + output_pt = torch.matmul(attn_weights, v_) + + # 3. 用spas_sage2_attn_meansim_cuda计算注意力 + q = q.unsqueeze(0) # shape: (1, 32760, 12, 128) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + q = q.transpose(1, 2) # shape: (1, 12, 32760, 128) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + output_cuda = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout="HND") + output_cuda = output_cuda.float() + + # 4. 取左上角[3000, 3000],只取第一个head + output_pt_crop = output_pt[0, :3000, :3000].cpu().detach().numpy() + output_cuda_crop = output_cuda[0, 0, :3000, :3000].cpu().detach().numpy() + + # 5. 保存图片 + save_dir = os.path.expanduser("~/Log/10-22/") + os.makedirs(save_dir, exist_ok=True) + + plt.imshow(output_pt_crop, aspect="auto") + plt.title("PyTorch Attention (left-top 3000x3000)") + plt.savefig(os.path.join(save_dir, "attn.png")) + plt.close() + + plt.imshow(output_cuda_crop, aspect="auto") + plt.title("spas_sage2_attn_meansim_cuda (left-top 3000x3000)") + plt.savefig(os.path.join(save_dir, "spas_attn.png")) + plt.close() diff --git a/lightx2v/common/ops/attn/svg2_attn.py b/lightx2v/common/ops/attn/svg2_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..2fed193a32eb6de338ea30be9ef0578c325f6802 --- /dev/null +++ b/lightx2v/common/ops/attn/svg2_attn.py @@ -0,0 +1,355 @@ +from typing import Optional + +# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen +try: + import flashinfer +except ImportError: + flashinfer = None + +import torch +import triton +import triton.language as tl + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .svg2_attn_utils import ( + batch_kmeans_Euclid, + identify_dynamic_map, +) +from .template import AttnWeightTemplate + + +@triton.jit +def _permute_kernel( + X_ptr, + IDX_ptr, + Y_ptr, + S: tl.constexpr, + D: tl.constexpr, + BLOCK_S: tl.constexpr, +): + """Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop.""" + + pid_bh = tl.program_id(0) + tile_s = tl.program_id(1) + + # Offsets along sequence + s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S) + token_mask = s_offsets < S + + # Gather source indices for these tokens + idx_ptrs = IDX_ptr + pid_bh * S + s_offsets + src_row_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32) + + # Broadcast to create 2-D pointer matrix (BLOCK_S, D) + d_offsets = tl.arange(0, D) + + src_ptrs = X_ptr + (pid_bh * S + src_row_idx[:, None]) * D + d_offsets[None, :] + dst_ptrs = Y_ptr + (pid_bh * S + s_offsets[:, None]) * D + d_offsets[None, :] + + full_mask = token_mask[:, None] + + values = tl.load(src_ptrs, mask=full_mask, other=0.0) + tl.store(dst_ptrs, values, mask=full_mask) + + +def permute_tensor_by_labels_triton( + tensor: torch.Tensor, + labels: Optional[torch.Tensor], + dim: int, + *, + sorted_indices: Optional[torch.Tensor] = None, +): + """ + Permute `tensor` along `dim` according to ascending order of `labels`. + + This is a Triton-accelerated replacement for the original implementation. + It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`. + If these conditions are not met or the tensors reside on CPU, we fall back + to the reference PyTorch implementation. + """ + + # Assertions – we only support the optimized CUDA path. + assert dim == 2, "permute_tensor_by_labels currently only supports dim==2 (sequence dimension)" + assert tensor.dim() == 4, "Expected tensor shape [B,H,S,D]" + assert tensor.is_cuda, "permute_tensor_by_labels requires CUDA tensors" + + B, H, S, D = tensor.shape + BH = B * H + + # Determine sorted indices + if sorted_indices is not None: + sorted_indices = sorted_indices.to(torch.int32).contiguous() + else: + assert labels is not None, "Either `labels` or `sorted_indices` must be provided." + labels = labels.to(tensor.device) + sorted_indices = torch.argsort(labels, dim=-1).to(torch.int32).contiguous() + + # Flatten tensor and allocate output + inp_flat = tensor.reshape(BH, S, D).contiguous() + out_flat = torch.empty_like(inp_flat) + + # Triton kernel tile size + BLOCK_S = 64 # number of tokens per program, tunable + + n_s_tiles = triton.cdiv(S, BLOCK_S) + grid = (BH, n_s_tiles) + + _permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4) + + permuted_tensor = out_flat.reshape(B, H, S, D) + return permuted_tensor, sorted_indices + + +@triton.jit +def _inverse_permute_kernel( + X_ptr, + IDX_ptr, + Y_ptr, + S: tl.constexpr, + D: tl.constexpr, + BLOCK_S: tl.constexpr, +): + """Inverse permutation: scatter BLOCK_S tokens back in one shot.""" + + pid_bh = tl.program_id(0) + tile_s = tl.program_id(1) + + s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S) + token_mask = s_offsets < S + + idx_ptrs = IDX_ptr + pid_bh * S + s_offsets + src_pos_idx = s_offsets.to(tl.int32) + dst_pos_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32) + + d_offsets = tl.arange(0, D) + + src_ptrs = X_ptr + (pid_bh * S + src_pos_idx[:, None]) * D + d_offsets[None, :] + dst_ptrs = Y_ptr + (pid_bh * S + dst_pos_idx[:, None]) * D + d_offsets[None, :] + + full_mask = token_mask[:, None] + + values = tl.load(src_ptrs, mask=full_mask, other=0.0) + tl.store(dst_ptrs, values, mask=full_mask) + + +def apply_inverse_permutation_triton( + permuted_tensor: torch.Tensor, + sorted_indices: torch.Tensor, + dim: int, +): + """ + Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`. + + Args: + permuted_tensor: (B, H, S, D). + sorted_indices: (B, H, S). + dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension. + + Returns: + Tensor of shape (B, H, S, D). + """ + + assert dim == 2, "apply_inverse_permutation currently only supports dim==2" + assert permuted_tensor.dim() == 4, "Expected tensor shape [B,H,S,D]" + assert permuted_tensor.is_cuda, "apply_inverse_permutation requires CUDA tensors" + + B, H, S, D = permuted_tensor.shape + BH = B * H + + # Ensure index dtype + sorted_indices = sorted_indices.to(torch.int32).contiguous() + + # Flatten inputs + inp_flat = permuted_tensor.reshape(BH, S, D).contiguous() + out_flat = torch.empty_like(inp_flat) + + BLOCK_S = 64 + n_s_tiles = triton.cdiv(S, BLOCK_S) + grid = (BH, n_s_tiles) + + _inverse_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4) + + original_tensor = out_flat.reshape(B, H, S, D) + return original_tensor + + +@ATTN_WEIGHT_REGISTER("svg2_attn") +class Svg2AttnWeight(AttnWeightTemplate): + centroids_init = False + num_q_centroids = 300 + num_k_centroids = 1000 + kmeans_iter_init = 50 + top_p_kmeans = 0.9 + min_kc_ratio = 0.10 + kmeans_iter_step = 2 + + def __init__(self): + self.config = {} + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + q = q.unsqueeze(0).transpose(1, 2) + k = k.unsqueeze(0).transpose(1, 2) + v = v.unsqueeze(0).transpose(1, 2) + bs, num_heads, seq_len, dim = q.size() + q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, q_sorted_indices = self.semantic_aware_permutation(q, k, v) + + output_permuted = self.dynamic_block_sparse_fwd_flashinfer(q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, is_cpu=False) + + attn_output = apply_inverse_permutation_triton(output_permuted, q_sorted_indices, dim=2) + + return attn_output.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1) + + def dynamic_block_sparse_fwd_flashinfer( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_mask_map: torch.Tensor, + block_row_sz: torch.Tensor, + block_col_sz: torch.Tensor, + is_cpu: bool = True, + ): + """ + Launcher for the Flashinfer dynamic block sparse attention kernel. + + Args: + q (torch.Tensor): Query tensor, shape [B, H, S, D]. + k (torch.Tensor): Key tensor, shape [B, H, S, D]. + v (torch.Tensor): Value tensor, shape [B, H, S, D]. + block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU. + block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU. + block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU. + is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True. + """ + # Input shape check + B, H, S, D = q.shape + qc_num = block_row_sz.shape[-1] + kc_num = block_col_sz.shape[-1] + assert block_mask_map.shape == (B, H, qc_num, kc_num) + + assert all(t.device == torch.device("cpu") for t in [block_mask_map, block_row_sz, block_col_sz]) if is_cpu else True + + # Check if block_col_sz and block_row_sz are the same for each head + assert torch.all(block_col_sz.sum(dim=2) == block_col_sz.sum(dim=2)[0, 0]) + assert torch.all(block_row_sz.sum(dim=2) == block_row_sz.sum(dim=2)[0, 0]) + + # Prepare flashinfer wrapper + float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=q.device) + vector_sparse_indices_buffer = torch.empty(1024 * 1024 * 1024, device=q.device) + wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="auto") + wrapper.reset_workspace_buffer( + float_workspace_buffer=wrapper._float_workspace_buffer, + int_workspace_buffer=wrapper._int_workspace_buffer, + vector_sparse_indices_buffer=vector_sparse_indices_buffer, # Only reset this buffer size + vector_sparse_indptr_buffer=wrapper._vector_sparse_indptr_buffer, + ) + + block_mask_map = block_mask_map.reshape(B * H, qc_num, kc_num) + block_row_sz = block_row_sz.reshape(B * H, qc_num) + block_col_sz = block_col_sz.reshape(B * H, kc_num) + + wrapper.plan( + block_mask_map=block_mask_map, + block_row_sz=block_row_sz, + block_col_sz=block_col_sz, + num_qo_heads=B * H, + num_kv_heads=B * H, + head_dim=D, + q_data_type=q.dtype, + kv_data_type=k.dtype, + ) + + # print_memory_usage("After plan") + + q = q.reshape(B * H, S, D) + k = k.reshape(B * H, S, D) + v = v.reshape(B * H, S, D) + o = wrapper.run(q, k, v) # [num_qo_heads, qo_len, head_dim] + o = o.reshape(B, H, S, D) + return o + + def semantic_aware_permutation(self, query, key, value): + cfg, num_heads, seq_len, dim = query.size() + + # 1. Kmeans clustering + qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_clustering(query, key) + + # 2. Identify dynamic map + q_cluster_sizes = qcluster_sizes.view(cfg, num_heads, self.num_q_centroids) + k_cluster_sizes = kcluster_sizes.view(cfg, num_heads, self.num_k_centroids) + + dynamic_map = identify_dynamic_map( + qcentroids.view(cfg, num_heads, self.num_q_centroids, dim), + kcentroids.view(cfg, num_heads, self.num_k_centroids, dim), + q_cluster_sizes, + k_cluster_sizes, + self.top_p_kmeans, + self.min_kc_ratio, + ) + + # 3. Permute the query, key, value + q_permuted, q_sorted_indices = permute_tensor_by_labels_triton(query, qlabels, dim=2) + k_permuted, k_sorted_indices = permute_tensor_by_labels_triton(key, klabels, dim=2) + v_permuted, v_sorted_indices = permute_tensor_by_labels_triton(value, klabels, dim=2, sorted_indices=k_sorted_indices) + + return q_permuted, k_permuted, v_permuted, dynamic_map, q_cluster_sizes, k_cluster_sizes, q_sorted_indices + + def kmeans_clustering(self, query, key): + if not self.centroids_init: + qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_init(query, key) + self.centroids_init = True + else: + qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_step(query, key) + + return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter + + def kmeans_init(self, query, key): + cfg, num_heads, seq_len, dim = query.size() + qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(query.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_q_centroids, max_iters=self.kmeans_iter_init) + klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(key.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_k_centroids, max_iters=self.kmeans_iter_init) + + self.q_centroids = qcentroids + self.k_centroids = kcentroids + + return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter + + def kmeans_step(self, query, key): + cfg, num_heads, seq_len, dim = query.size() + qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid( + query.view(cfg * num_heads, seq_len, dim), + n_clusters=self.num_q_centroids, + max_iters=self.kmeans_iter_step, + init_centroids=self.q_centroids, + ) + klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid( + key.view(cfg * num_heads, seq_len, dim), + n_clusters=self.num_k_centroids, + max_iters=self.kmeans_iter_step, + init_centroids=self.k_centroids, + ) + + self.q_centroids = qcentroids + self.k_centroids = kcentroids + + return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter + + +if __name__ == "__main__": + q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda() + + svg2_attn = Svg2AttnWeight() + print("Svg2AttnWeight initialized.") + + out = svg2_attn.apply(q, k, v) + print(f"out: {out.shape}, {out.dtype}, {out.device}") diff --git a/lightx2v/common/ops/attn/svg2_attn_utils.py b/lightx2v/common/ops/attn/svg2_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3e65815e455daeec11375af2822ffffb8297c12f --- /dev/null +++ b/lightx2v/common/ops/attn/svg2_attn_utils.py @@ -0,0 +1,1359 @@ +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +try: + from cuvs.cluster.kmeans import KMeansParams, fit +except ImportError: + KMeansParams = None + fit = None + +# --- New functions --- + + +def density_calculation(dynamic_map, q_cluster_sizes, k_cluster_sizes): + """ + Calculate the density of the dynamic map. Currently only batch size = 1 and head size = 1 are supported. + + Input: + dynamic_map: [cfg, num_heads, qc_num, kc_num] + q_cluster_sizes: [cfg, num_heads, qc_num] + k_cluster_sizes: [cfg, num_heads, kc_num] + """ + cfg, num_heads, qc_num, kc_num = dynamic_map.shape + + # Calculate the block size of each block + clustered_block_size = q_cluster_sizes[:, :, :, None] * k_cluster_sizes[:, :, None, :] + masked_block_size = clustered_block_size * dynamic_map + + # Calculate the density of each block + density = torch.sum(masked_block_size, dim=(2, 3)) / torch.sum(clustered_block_size, dim=(2, 3)) + return density + + +# --- Functions from analyze/kmeans_rapidai.py --- + + +def pairwise_distance(x, y): + """ + Computes pairwise squared Euclidean distance between two sets of points. + """ + x_norm = (x**2).sum(1).view(-1, 1) + y_norm = (y**2).sum(1).view(1, -1) + dist = torch.clamp(x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)), min=0.0) + return dist + + +def kmeans_predict(centroids, input_tensor): # Removed unused params argument + """ + Predict the labels for the input tensor using the centroids. + """ + input_tensor = input_tensor.to(torch.float32) + dist = pairwise_distance(input_tensor, centroids) + labels = torch.argmin(dist, dim=1) + return labels + + +def kmeans_rapidai(tensor, k, max_iter=5, tol=1e-4, init_method="Array", centroids_init=None): # Renamed centroids to centroids_init + """ + Performs K-means clustering using cuVS. + """ + + assert tensor.dtype == torch.float32, "Tensor must be float32 for cuVS KMeans" + assert tensor.ndim == 2, f"Tensor must be 2D, but got {tensor.shape}" + # assert init_method == "Array", "init_method must be 'Array' for now" + + L, D = tensor.shape + + # cuVS KMeans in RAPIDS >=23.10 uses 'centroids_init' for initial centroids + current_centroids = centroids_init + if current_centroids is None: + # Default init: cuVS handles KMeansPlusPlus if centroids_init is None and init_method is KMeansPlusPlus + # If you need to pass an empty tensor for cuVS to initialize: + current_centroids = torch.empty(k, D, device=tensor.device, dtype=torch.float32) # Or pass None + else: + assert current_centroids.dtype == torch.float32, "Initial centroids must be float32" + assert current_centroids.shape == ( + k, + D, + ), f"Initial centroids shape mismatch, got {current_centroids.shape}, expected ({k}, {D})" + # cuVS uses 'init_method="Array"' when 'centroids_init' is provided. + + # import IPython; IPython.embed() + + params = KMeansParams(n_clusters=k, max_iter=max_iter, tol=tol, init_method=init_method) # Changed init_method to init + + # Call fit with centroids_init (can be None) + new_centroids, inertia, n_iter_ = fit(params, tensor, current_centroids) # Added handle=None + + labels = kmeans_predict(new_centroids, tensor) + return labels, new_centroids, n_iter_ + + +@triton.jit +def _centroid_update_kernel( + x_ptr, # *f16 [B, N, D] + cluster_ptr, # *i32 [B, N] + sum_ptr, # *f32 [B, K, D] + count_ptr, # *i32 [B, K] + B: tl.constexpr, + N: tl.constexpr, + D: tl.constexpr, + K: tl.constexpr, + BLOCK_D: tl.constexpr, # number of dims processed per program +): + """Each program processes 1 point (token) across BLOCK_D dimensions with atomics.""" + pid = tl.program_id(axis=0) + token_idx = pid # range: [0, B * N) + + # Derive (b, n) indices + b = token_idx // N + n = token_idx % N + + # Pointers to the token features and its cluster id + x_offset = (b * N + n) * D + x_ptr = x_ptr + x_offset + + cluster_idx = tl.load(cluster_ptr + b * N + n) # int32 + + # Guard for invalid cluster ids (should not happen) + cluster_idx = tl.where(cluster_idx < K, cluster_idx, 0) + + # Base pointer for this centroid in the output sum tensor + centroid_base = (b * K + cluster_idx) * D + + # Process feature vector in chunks of BLOCK_D + offs = tl.arange(0, BLOCK_D) + for d_start in range(0, D, BLOCK_D): + mask = offs + d_start < D + feats = tl.load(x_ptr + d_start + offs, mask=mask, other=0.0) + feats = feats.to(tl.float32) + + dest_ptr = sum_ptr + centroid_base + d_start + offs + tl.atomic_add(dest_ptr, feats, mask=mask) + + # Update counts (only once per point) + tl.atomic_add(count_ptr + b * K + cluster_idx, 1) + + +def triton_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor): + """Compute centroids using custom Triton kernel. + + Args: + x_norm (Tensor): (B, N, D) normalized input vectors (float16/float32) + cluster_ids (LongTensor): (B, N) cluster assignment per point + old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x_norm) + + Returns: + Tensor: (B, K, D) updated and L2-normalized centroids (dtype == x_norm.dtype) + """ + assert x_norm.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device" + B, N, D = x_norm.shape + K = old_centroids.shape[1] + assert cluster_ids.shape == (B, N) + + # Allocate accumulation buffers + centroid_sums = torch.zeros((B, K, D), device=x_norm.device, dtype=torch.float32) + centroid_counts = torch.zeros((B, K), device=x_norm.device, dtype=torch.int32) + + # Launch Triton kernel – one program per token + total_tokens = B * N + BLOCK_D = 128 # tuneable + + grid = (total_tokens,) + _centroid_update_kernel[grid]( + x_norm, + cluster_ids.to(torch.int32), + centroid_sums, + centroid_counts, + B, + N, + D, + K, + BLOCK_D=BLOCK_D, + ) + + # Compute means; keep old centroid if empty cluster + counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0) + centroids = centroid_sums / counts_f + + # For clusters with zero count, revert to old centroids + zero_mask = (centroid_counts == 0).unsqueeze(-1) + centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids) + + centroids = centroids.to(x_norm.dtype) + centroids = F.normalize(centroids, p=2, dim=-1) + return centroids + + +def torch_loop_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor): + """Reference Python implementation (double for-loop)""" + B, N, D = x_norm.shape + K = old_centroids.shape[1] + new_centroids = torch.zeros_like(old_centroids) + for b in range(B): + for k in range(K): + mask = cluster_ids[b] == k + if mask.any(): + new_centroids[b, k] = F.normalize(x_norm[b][mask].mean(dim=0, dtype=x_norm.dtype), p=2, dim=0) + else: + new_centroids[b, k] = old_centroids[b, k] + return new_centroids + + +def triton_centroid_update_euclid(x: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor): + """Compute centroids for Euclidean KMeans using Triton. + + Args: + x (Tensor): (B, N, D) input vectors (float16/float32) + cluster_ids (LongTensor): (B, N) cluster assignment per point + old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x) + + Returns: + Tensor: (B, K, D) updated centroids (dtype == x.dtype) + """ + assert x.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device" + B, N, D = x.shape + K = old_centroids.shape[1] + assert cluster_ids.shape == (B, N) + + # Allocate accumulation buffers + centroid_sums = torch.zeros((B, K, D), device=x.device, dtype=torch.float32) + centroid_counts = torch.zeros((B, K), device=x.device, dtype=torch.int32) + + total_tokens = B * N + BLOCK_D = 128 # tuneable + grid = (total_tokens,) + + _centroid_update_kernel[grid]( + x, + cluster_ids.to(torch.int32), + centroid_sums, + centroid_counts, + B, + N, + D, + K, + BLOCK_D=BLOCK_D, + ) + + # Compute means; keep old centroid if empty cluster + counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0) + centroids = centroid_sums / counts_f + + # For clusters with zero count, revert to old centroids + zero_mask = (centroid_counts == 0).unsqueeze(-1) + centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids) + + return centroids.to(x.dtype) + + +# ------------------------------ NEW: chunk-wise centroid update (sorted ids) ------------------------------ + + +@triton.jit +def _centroid_update_chunk_kernel( + x_ptr, # *f16 / *f32 [B, N, D] – ORIGINAL ORDER + sorted_idx_ptr, # *i32 [B, N] – indices after sort + sorted_cluster_ptr, # *i32 [B, N] – cluster ids in sorted order + sum_ptr, # *f32 [B, K, D] + count_ptr, # *i32 [B, K] + B: tl.constexpr, + N: tl.constexpr, + D: tl.constexpr, + K: tl.constexpr, + BLOCK_N: tl.constexpr, # how many tokens (points) each program processes +): + """Each program processes **BLOCK_N consecutive, already-sorted tokens**. + + Because the tokens are sorted by cluster id, identical ids appear in + contiguous runs. We therefore accumulate a local sum/count for the + current run and perform **a single atomic update per run**, instead of + per-token. + """ + # program indices – 2-D launch grid: (chunk_id, batch_id) + pid_chunk = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + + b = pid_b + chunk_start = pid_chunk * BLOCK_N # position of the first token handled by this program + + # Nothing to do – out of range + if chunk_start >= N: + return + + # base pointers for this batch + idx_batch_base = sorted_idx_ptr + b * N + cid_batch_base = sorted_cluster_ptr + b * N + x_batch_base = x_ptr + b * N * D # for pointer arithmetic + + # helper aranges + offs_token = tl.arange(0, BLOCK_N) + offs_dim = tl.arange(0, D) + + # first token index & validity mask + token_idx = chunk_start + offs_token + valid_tok = token_idx < N + first_token_idx = chunk_start + last_token_idx = tl.minimum(chunk_start + BLOCK_N, N) - 1 + + # Load first cluster id to initialise the running accumulator + first_id = tl.load(cid_batch_base + first_token_idx) + last_id = tl.load(cid_batch_base + last_token_idx) + all_ids = tl.load(cid_batch_base + token_idx, mask=valid_tok, other=-1) + + all_tokens_idxs = tl.load(idx_batch_base + token_idx, mask=valid_tok, other=-1) # [BLOCK_N] + + load_mask = all_tokens_idxs[:, None] * D + offs_dim[None, :] + + for cid in range(first_id, last_id + 1): + cluster_mask = all_ids == cid + cluster_size = tl.sum(cluster_mask.to(tl.int32)) + if cluster_size != 0: + cluster_feats = tl.load(x_batch_base + load_mask, mask=cluster_mask[:, None], other=0.0) # [BLOCK_N, D] + cluster_feats = cluster_feats.to(tl.float32) + sum_feats = tl.sum(cluster_feats, axis=0) + dest_ptr = sum_ptr + (b * K + cid) * D + offs_dim + tl.atomic_add(dest_ptr, sum_feats) + tl.atomic_add(count_ptr + b * K + cid, cluster_size) + + +# --------------------------------------------------------------------------------------------- + + +def triton_centroid_update_sorted_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor, *, BLOCK_N: int = 256): + """Fast centroid update assuming **cluster_ids are sorted along N**. + + This helper will sort the assignments (together with `x_norm`) and launch the + chunk kernel above. Compared to the naive per-token kernel it performs *one + atomic add per run of identical ids* instead of per token, providing large + speed-ups when clusters are reasonably sized. + """ + assert x_norm.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA" + B, N, D = x_norm.shape + K = old_centroids.shape[1] + assert cluster_ids.shape == (B, N) + + # -------- sort per-batch -------- + sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids, dim=-1) + sorted_idx_int = sorted_idx.to(torch.int32) + + # accumulation buffers + centroid_sums = torch.zeros((B, K, D), device=x_norm.device, dtype=torch.float32) + centroid_cnts = torch.zeros((B, K), device=x_norm.device, dtype=torch.int32) + + grid = (triton.cdiv(N, BLOCK_N), B) + _centroid_update_chunk_kernel[grid]( + x_norm, + sorted_idx_int, + sorted_cluster_ids.to(torch.int32), + centroid_sums, + centroid_cnts, + B, + N, + D, + K, + BLOCK_N=BLOCK_N, + ) + + # finalise – convert to means, handle empty clusters, renormalise + counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0) + centroids = centroid_sums / counts_f + empty_mask = (centroid_cnts == 0).unsqueeze(-1) + centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids) + centroids = centroids.to(x_norm.dtype) + centroids = F.normalize(centroids, p=2, dim=-1) + return centroids + + +def triton_centroid_update_sorted_euclid(x: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor, *, BLOCK_N: int = 256): + """Fast centroid update for *Euclidean* KMeans assuming cluster IDs are pre-sorted. + + Parameters + ---------- + x : Tensor [B, N, D] + Input feature vectors (no normalization assumed). + cluster_ids : LongTensor [B, N] + Cluster assignment for each point. + old_centroids : Tensor [B, K, D] + Previous centroids (used to fill empty clusters). + BLOCK_N : int, optional + Tokens per Triton program (affects occupancy/perf). + """ + assert x.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA device" + B, N, D = x.shape + K = old_centroids.shape[1] + + # Batch-wise sort of cluster assignments + sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids, dim=-1) + sorted_idx_int = sorted_idx.to(torch.int32) + + centroid_sums = torch.zeros((B, K, D), device=x.device, dtype=torch.float32) + centroid_cnts = torch.zeros((B, K), device=x.device, dtype=torch.int32) + + grid = (triton.cdiv(N, BLOCK_N), B) + _centroid_update_chunk_kernel[grid]( + x, # original features + sorted_idx_int, # gather indices + sorted_cluster_ids.to(torch.int32), + centroid_sums, + centroid_cnts, + B, + N, + D, + K, + BLOCK_N=BLOCK_N, + ) + + # Convert sums to means; replace empty clusters with old centroids + counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0) + centroids = centroid_sums / counts_f + empty_mask = (centroid_cnts == 0).unsqueeze(-1) + centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids) + return centroids.to(x.dtype), centroid_cnts + + +# =============================================================== +# Triton kernel: compute nearest-centroid IDs (Euclidean distance) +# Inputs: +# x : (B, N, D) float16 / float32 +# centroids : (B, K, D) same dtype as x +# x_sq : (B, N) float32 – pre-computed ||x||^2 per point +# Output: +# cluster_ids : (B, N) int32 – nearest centroid index per point +# =============================================================== + + +def _ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +# ----------------------------------------------------------------------------- +# Auto-tuning setup – explore various tile sizes / warp counts +# ----------------------------------------------------------------------------- + +_TUNE_CONFIGS = [triton.Config({"BLOCK_N": BN, "BLOCK_K": BK}, num_stages=4, num_warps=wp) for BN in [32, 64, 128] for BK in [32, 64, 128] for wp in [4, 8]] + + +def _cfg_keep(conf): + """Basic heuristic to prune unbalanced configs.""" + BN = conf.kwargs["BLOCK_N"] + BK = conf.kwargs["BLOCK_K"] + # Avoid tiny tiles on many warps + if BN * BK < 32 * 32 and conf.num_warps > 4: + return False + return True + + +_TUNE_CONFIGS = list(filter(_cfg_keep, _TUNE_CONFIGS)) + + +@triton.autotune(_TUNE_CONFIGS, key=["N", "K"]) +@triton.jit +def _euclid_assign_kernel( + x_ptr, # *f16 / *f32 [B, N, D] + c_ptr, # *f16 / *f32 [B, K, D] + x_sq_ptr, # *f32 [B, N] + out_ptr, # *i32 [B, N] + B: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + D: tl.constexpr, + stride_x_b: tl.constexpr, + stride_x_n: tl.constexpr, + stride_x_d: tl.constexpr, + stride_c_b: tl.constexpr, + stride_c_k: tl.constexpr, + stride_c_d: tl.constexpr, + stride_xsq_b: tl.constexpr, + stride_xsq_n: tl.constexpr, + stride_out_b: tl.constexpr, + stride_out_n: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Each program handles a tile of BLOCK_N points for a given batch element. + + The kernel iterates over the centroid dimension K in chunks of BLOCK_K and + maintains the running minimum distance as well as the corresponding index + for every point in the tile. + """ + pid_n = tl.program_id(0) # tile index along N dimension + pid_b = tl.program_id(1) # batch index + + n_start = pid_n * BLOCK_N + n_offsets = n_start + tl.arange(0, BLOCK_N) + n_mask = n_offsets < N + + # ------------------------------------------------------------------ + # Load x tile (BLOCK_N, D) + # ------------------------------------------------------------------ + offs_d = tl.arange(0, D) + # Compute pointer for x block: base + b*stride_x_b + n*stride_x_n + d*stride_x_d + x_ptrs = x_ptr + pid_b * stride_x_b + n_offsets[:, None] * stride_x_n + offs_d[None, :] * stride_x_d + x_tile = tl.load(x_ptrs, mask=n_mask[:, None], other=0.0) + x_tile = x_tile # compute in f32 + + # Pre-load x_sq for the tile (BLOCK_N,) + xsq_ptrs = x_sq_ptr + pid_b * stride_xsq_b + n_offsets * stride_xsq_n + x_sq_tile = tl.load(xsq_ptrs, mask=n_mask, other=0.0).to(tl.float32) + + # Init best distance / index + best_dist = tl.full((BLOCK_N,), 3.4e38, tl.float32) # large number + best_idx = tl.zeros((BLOCK_N,), tl.int32) + + # ------------------------------------------------------------------ + # Iterate over centroids in chunks of BLOCK_K + # ------------------------------------------------------------------ + for k_start in range(0, K, BLOCK_K): + k_offsets = k_start + tl.arange(0, BLOCK_K) + k_mask = k_offsets < K + + # Load centroid tile (D, BLOCK_K) + c_ptrs = c_ptr + pid_b * stride_c_b + k_offsets[None, :] * stride_c_k + offs_d[:, None] * stride_c_d + c_tile = tl.load(c_ptrs, mask=k_mask[None, :], other=0.0) + c_tile = c_tile + + # Compute centroid squared norms (BLOCK_K,) + cent_sq = tl.sum(c_tile * c_tile, axis=0).to(tl.float32) + + # Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile + cross = tl.dot(x_tile, c_tile).to(tl.float32) # float32 + + # Squared Euclidean distance + dist = x_sq_tile[:, None] + cent_sq[None, :] - 2.0 * cross + dist = tl.maximum(dist, 0.0) + + # Mask out invalid centroid columns before reduction + dist = tl.where(k_mask[None, :], dist, 3.4e38) + + curr_min = tl.min(dist, axis=1) + curr_idx = tl.argmin(dist, axis=1) + + update = curr_min < best_dist + best_dist = tl.where(update, curr_min, best_dist) + best_idx = tl.where(update, k_start + curr_idx, best_idx) + + # ------------------------------------------------------------------ + # Write results + # ------------------------------------------------------------------ + out_ptrs = out_ptr + pid_b * stride_out_b + n_offsets * stride_out_n + tl.store(out_ptrs, best_idx, mask=n_mask) + + +# --------------------------------------------------------------- +# Python wrapper +# --------------------------------------------------------------- + + +def euclid_assign_triton( + x: torch.Tensor, + centroids: torch.Tensor, + x_sq: torch.Tensor, + out: torch.Tensor = None, + *, + BLOCK_N: int = 128, + BLOCK_K: int = 128, +) -> torch.Tensor: + """Return nearest-centroid indices using Triton kernel. + + Args: + x : (B, N, D) float16 / float32 (on CUDA) + centroids : (B, K, D) same dtype/device as x + x_sq : (B, N) float32 – ||x||^2 per point (on CUDA) + + Returns: + cluster_ids (B, N) int32 (callers can cast to int64 if desired) + """ + assert x.is_cuda and centroids.is_cuda and x_sq.is_cuda, "All tensors must be on CUDA" + # assert x.dtype in (torch.float16, torch.float32), "x must be fp16/fp32" + assert centroids.dtype == x.dtype, "centroids dtype mismatch" + + B, N, D = x.shape + K = centroids.shape[1] + assert centroids.shape == (B, K, D), "centroids shape mismatch" + assert x_sq.shape == (B, N), "x_sq shape mismatch" + + # x = x.contiguous() + # centroids = centroids.contiguous() + # x_sq = x_sq.contiguous() + + if out is None: + out = torch.empty((B, N), device=x.device, dtype=torch.int64) + + # Strides (in elements) + stride_x_b, stride_x_n, stride_x_d = x.stride() + stride_c_b, stride_c_k, stride_c_d = centroids.stride() + stride_xsq_b, stride_xsq_n = x_sq.stride() + stride_out_b, stride_out_n = out.stride() + + grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]), B) # noqa + + _euclid_assign_kernel[grid]( + x, + centroids, + x_sq, + out, + B, + N, + K, + D, + stride_x_b, + stride_x_n, + stride_x_d, + stride_c_b, + stride_c_k, + stride_c_d, + stride_xsq_b, + stride_xsq_n, + stride_out_b, + stride_out_n, + ) + return out + + +# 1. Euclidean +def _euclid_iter(x, x_sq, centroids): + # cent_sq = (centroids ** 2).sum(dim=-1) + # cross = torch.einsum('bnd,bkd->bnk', x, centroids) + # dist_sq = (x_sq[:,:,None] + cent_sq[:,None,:] - 2.0 * cross).clamp_min_(0.0) + + # cluster_ids = dist_sq.argmin(dim=-1) + cluster_ids = euclid_assign_triton(x, centroids, x_sq) + centroids_new, cluster_sizes = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids) + # centroids_new = triton_centroid_update_euclid(x, cluster_ids, centroids) + + # centroids_new = centroids_new.clone() # avoid CUDA graphs aliasing + + shift = (centroids_new - centroids).norm(dim=-1).max() + return centroids_new, shift, cluster_ids, cluster_sizes + + +# 2. Cosine +def _cosine_iter(x_norm, centroids): + cos_sim = torch.einsum("bnd,bkd->bnk", x_norm, centroids) + cluster_ids = cos_sim.argmax(dim=-1) + centroids_new = triton_centroid_update_cosine(x_norm, cluster_ids, centroids) + # centroids_new = centroids_new.clone() + shift = (centroids_new - centroids).norm(dim=-1).max() + return centroids_new, shift, cluster_ids + + +# 3. Dot-product +def _dot_iter(x, centroids): + sim = torch.einsum("bnd,bkd->bnk", x, centroids) + cluster_ids = sim.argmax(dim=-1) + centroids_new = triton_centroid_update_cosine(x, cluster_ids, centroids) + # centroids_new = centroids_new.clone() + shift = (centroids_new - centroids).norm(dim=-1).max() + return centroids_new, shift, cluster_ids + + +COMPILE_FLAG = False + +# Try to compile; if PyTorch < 2.0 or compile is not available, fallback to original function +try: + if COMPILE_FLAG: + _euclid_iter_compiled = torch.compile(_euclid_iter, dynamic=True, mode="reduce-overhead") + _cosine_iter_compiled = torch.compile(_cosine_iter, dynamic=True, mode="reduce-overhead") + _dot_iter_compiled = torch.compile(_dot_iter, dynamic=True, mode="reduce-overhead") + else: + _euclid_iter_compiled = _euclid_iter + _cosine_iter_compiled = _cosine_iter + _dot_iter_compiled = _dot_iter +except Exception: # pragma: no cover + _euclid_iter_compiled = _euclid_iter + _cosine_iter_compiled = _cosine_iter + _dot_iter_compiled = _dot_iter + + +def batch_kmeans_Euclid(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False): + """ + Batched KMeans clustering in PyTorch using Euclidean distance. + + Args: + x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims. + n_clusters: Number of clusters. + max_iters: Max number of iterations. + tol: Relative tolerance for center movement. + verbose: Print loss for each iter. + Returns: + cluster_ids: (B, N) LongTensor, cluster assignment for each point. + centroids: (B, n_clusters, D) final cluster centers. + cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster. + n_iters: actual number of iterations executed (int) + """ + B, N, D = x.shape + + # Pre-compute squared L2 norm of all points (constant during iterations) + x_sq = (x**2).sum(dim=-1) # (B, N) + + if init_centroids is None: + # Randomly select initial centers from x + indices = torch.randint(0, N, (B, n_clusters), device=x.device) + centroids = torch.gather(x, dim=1, index=indices[..., None].expand(-1, -1, D)) # (B, n_clusters, D) + else: + # centroids = init_centroids.clone() + centroids = init_centroids + + centroids = centroids.view(B, n_clusters, D) + + for it in range(max_iters): + # ---- compiled single iteration ---- + centroids_new, center_shift, cluster_ids, cluster_sizes = _euclid_iter_compiled(x, x_sq, centroids) + + # 4. Check for convergence + if verbose: + print(f"Iter {it}, center shift: {center_shift.item():.6f}") + if center_shift < tol: + break + # centroids = centroids_new.clone() + centroids = centroids_new + + # # --- compute cluster sizes --- + # ones = torch.ones_like(cluster_ids, dtype=torch.int64) + # cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device) + # cluster_sizes.scatter_add_(1, cluster_ids, ones) + + return cluster_ids, centroids, cluster_sizes, it + 1 + # return cluster_ids.clone(), centroids.clone(), cluster_sizes.clone(), it + 1 + + +# batch_kmeans_Euclid = torch.compile(batch_kmeans_Euclid, dynamic=True, mode="reduce-overhead") + + +def batch_kmeans_Cosine(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False): + """ + Batched KMeans clustering in PyTorch using Cosine similarity. + + Args: + x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims. + n_clusters: Number of clusters. + max_iters: Max number of iterations. + tol: Relative tolerance for center movement. + verbose: Print loss for each iter. + Returns: + cluster_ids: (B, N) LongTensor, cluster assignment for each point. + centroids: (B, n_clusters, D) final cluster centers. + cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster. + n_iters: actual number of iterations executed (int) + """ + B, N, D = x.shape + + # Normalize input vectors for cosine similarity + x_norm = F.normalize(x, p=2, dim=-1) # (B, N, D) + + if init_centroids is None: + # Randomly select initial centers from x_norm + indices = torch.randint(0, N, (B, n_clusters), device=x.device) + centroids = torch.gather(x_norm, dim=1, index=indices[..., None].expand(-1, -1, D)) # (B, n_clusters, D) + else: + centroids = init_centroids + + centroids = centroids.view(B, n_clusters, D) + centroids = F.normalize(centroids, p=2, dim=-1) # Ensure centroids are normalized + + for it in range(max_iters): + # ---- compiled single iteration ---- + centroids_new, center_shift, cluster_ids = _cosine_iter_compiled(x_norm, centroids) + + # 4. Check for convergence + if verbose: + print(f"Iter {it}, center shift: {center_shift.item():.6f}") + if center_shift < tol: + break + centroids = centroids_new.clone() + + # --- compute cluster sizes --- + ones = torch.ones_like(cluster_ids, dtype=torch.int64) + cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device) + cluster_sizes.scatter_add_(1, cluster_ids, ones) + + return cluster_ids, centroids, cluster_sizes, it + 1 + + +def batch_kmeans_Dot(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False): + """ + Batched KMeans clustering in PyTorch using raw dot-product as similarity. + + """ + B, N, D = x.shape + + if init_centroids is None: + # Randomly initialize centroids + indices = torch.randint(0, N, (B, n_clusters), device=x.device) + centroids = torch.gather(x, dim=1, index=indices[..., None].expand(-1, -1, D)) + else: + centroids = init_centroids + + centroids = centroids.view(B, n_clusters, D) + + for it in range(max_iters): + # ---- compiled single iteration ---- + centroids_new, center_shift, cluster_ids = _dot_iter_compiled(x, centroids) + + # 4. Check for convergence + if verbose: + print(f"Iter {it} (dot), center shift: {center_shift.item():.6f}") + if center_shift < tol: + break + centroids = centroids_new.clone() + + # --- compute cluster sizes --- + ones = torch.ones_like(cluster_ids, dtype=torch.int64) + cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device) + cluster_sizes.scatter_add_(1, cluster_ids, ones) + + return cluster_ids, centroids, cluster_sizes, it + 1 + + +# --- Functions from analyze/kmeans_block_sparse_attention.py (helpers) --- + + +def permute_tensor_by_labels(tensor, labels, dim): + labels = labels.to(tensor.device) + sorted_indices = torch.argsort(labels, dim=-1) + gather_indices = sorted_indices + for i in range(dim + 1, tensor.dim()): + gather_indices = gather_indices.unsqueeze(-1) + expand_shape = list(tensor.shape) + gather_indices = gather_indices.expand(expand_shape) + permuted_tensor = torch.gather(tensor, dim, gather_indices) + return permuted_tensor, sorted_indices + + +def apply_inverse_permutation(permuted_tensor, sorted_indices, dim): + inverse_indices = torch.argsort(sorted_indices, dim=-1) + gather_indices = inverse_indices + for i in range(dim + 1, permuted_tensor.dim()): + gather_indices = gather_indices.unsqueeze(-1) + gather_indices = gather_indices.expand(permuted_tensor.shape) + original_tensor = torch.gather(permuted_tensor, dim, gather_indices) + return original_tensor + + +def weighted_softmax(scores, weights): + input_dtype = scores.dtype + scores = scores.float() + weights = weights.float() + max_score = torch.max(scores, dim=-1, keepdim=True)[0] + exp_scores = torch.exp(scores - max_score) + weighted_exp = weights * exp_scores + softmax_out = weighted_exp / torch.sum(weighted_exp, dim=-1, keepdim=True).clamp(min=1e-12) + return softmax_out.to(input_dtype) + + +def identify_dynamic_map( + query_centroids, + key_centroids, + q_cluster_sizes, + k_cluster_sizes, + p, + min_kc_ratio=0, +): + B, H, qc_num, D = query_centroids.shape + kc_num = key_centroids.shape[2] + device = query_centroids.device + + attn_scores = torch.matmul(query_centroids, key_centroids.transpose(-2, -1)) / (D**0.5) + k_weights = k_cluster_sizes.unsqueeze(-2).float() + + weighted_attn_probs = weighted_softmax(attn_scores, k_weights) + sorted_probs, sorted_indices = torch.sort(weighted_attn_probs, dim=-1, descending=True) + + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) + remove_indices = cumsum_probs > p + remove_indices[..., 1:] = remove_indices[..., :-1].clone() + remove_indices[..., 0] = False + + if min_kc_ratio > 0: + preserve_length = int(min_kc_ratio * kc_num) + remove_indices[..., :preserve_length] = False + + sorted_clusters_to_keep = ~remove_indices + + dynamic_map = torch.zeros(B, H, qc_num, kc_num, dtype=torch.bool, device=device) + dynamic_map.scatter_(-1, sorted_indices, sorted_clusters_to_keep) + return dynamic_map + + +# --- Functions from analyze/dynamic_block_sparse_attention.py --- + + +def dynamic_block_sparse_fwd_torch(q, k, v, dynamic_map, qc_size, kc_size): + """ + Computes dynamic block sparse attention using pure PyTorch. + + Args: + q (torch.Tensor): Query tensor, shape [B, H, S, D]. + k (torch.Tensor): Key tensor, shape [B, H, S, D]. + v (torch.Tensor): Value tensor, shape [B, H, S, D]. + dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. + qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num]. + kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num]. + + Returns: + torch.Tensor: Output tensor, shape [B, H, S, D]. + """ + B, H, S, D = q.shape + qc_num = qc_size.shape[-1] + kc_num = kc_size.shape[-1] + device = q.device + dtype = q.dtype + + # Ensure sequence lengths match sum of block sizes + assert S == torch.sum(qc_size[0, 0, :]), "Sum of qc_size must equal S" + assert S == torch.sum(kc_size[0, 0, :]), "Sum of kc_size must equal S" + + # Precompute cumulative sizes for block indexing + # Add a 0 at the beginning for easier slicing + qc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(qc_size[..., :1]), qc_size], dim=-1), dim=-1) + kc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(kc_size[..., :1]), kc_size], dim=-1), dim=-1) + + out = torch.zeros_like(q) + scale = D**-0.5 + + # Naive implementation: Iterate through batch, head, and blocks + for b in range(B): + for h in range(H): + # Precompute start/end indices for this batch/head + q_starts = qc_cum_size[b, h, :-1] + q_ends = qc_cum_size[b, h, 1:] + k_starts = kc_cum_size[b, h, :-1] + k_ends = kc_cum_size[b, h, 1:] + + # Iterate through query blocks + for i in range(qc_num): + q_start, q_end = q_starts[i], q_ends[i] + q_block = q[b, h, q_start:q_end, :] # Shape: [qc_i, D] + + if q_block.shape[0] == 0: + continue # Skip empty blocks + + m_i = torch.full((q_block.shape[0], 1), -float("inf"), device=device, dtype=dtype) + l_i = torch.zeros((q_block.shape[0], 1), device=device, dtype=dtype) + acc_o_i = torch.zeros_like(q_block) # Shape: [qc_i, D] + + # Iterate through key/value blocks for the current query block + for j in range(kc_num): + # Check if this block needs computation + if dynamic_map[b, h, i, j]: + k_start, k_end = k_starts[j], k_ends[j] + k_block = k[b, h, k_start:k_end, :] # Shape: [kc_j, D] + v_block = v[b, h, k_start:k_end, :] # Shape: [kc_j, D] + + if k_block.shape[0] == 0: + continue # Skip empty blocks + + # Compute attention scores for the block + # QK^T: [qc_i, D] @ [D, kc_j] -> [qc_i, kc_j] + s_ij = (q_block @ k_block.transpose(-1, -2)) * scale + + # --- Online Softmax --- + # Find max score per query token in this block + m_ij = torch.max(s_ij, dim=-1, keepdim=True)[0] # Shape: [qc_i, 1] + + # Update overall max score (m_i) + m_new = torch.maximum(m_i, m_ij) # Shape: [qc_i, 1] + + # Calculate scaling factors for previous accumulator and current block + p_ij = torch.exp(s_ij - m_new) # Shape: [qc_i, kc_j] + exp_m_diff = torch.exp(m_i - m_new) # Shape: [qc_i, 1] + + # Update softmax denominator (l_i) + l_i = (l_i * exp_m_diff) + torch.sum(p_ij, dim=-1, keepdim=True) # Shape: [qc_i, 1] + + # Update output accumulator (acc_o_i) + # P_ij @ V_j: [qc_i, kc_j] @ [kc_j, D] -> [qc_i, D] + acc_o_i = (acc_o_i * exp_m_diff) + (p_ij @ v_block) # Shape: [qc_i, D] + + # Update max score for next iteration + m_i = m_new + + # Normalize the accumulated output + out[b, h, q_start:q_end, :] = acc_o_i / l_i.clamp(min=1e-12) # Avoid division by zero + + return out + + +# --- Triton Implementation --- + + +@triton.jit +def _dynamic_block_sparse_fwd_kernel( + Q, + K, + V, + Out, + dynamic_map, + qc_cum_size, + kc_cum_size, + stride_qb, + stride_qh, + stride_qs, + stride_qd, + stride_kb, + stride_kh, + stride_ks, + stride_kd, + stride_vb, + stride_vh, + stride_vs, + stride_vd, + stride_ob, + stride_oh, + stride_os, + stride_od, + stride_dmap_b, + stride_dmap_h, + stride_dmap_qc, + stride_dmap_kc, + stride_qcs_b, + stride_qcs_h, + stride_qcs_qc, + stride_kcs_b, + stride_kcs_h, + stride_kcs_kc, + B, + H, + S, + D, + scale, + QC_NUM: tl.constexpr, + KC_NUM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + """ + Triton kernel for dynamic block sparse attention. + Each program computes attention for one query block within a batch/head. + Processes query block in chunks of BLOCK_M. + Iterates through key blocks, checking dynamic_map. + Processes key/value blocks in chunks of BLOCK_N. + Uses online softmax. + """ + # --- Grid Calculation --- + # Each program instance handles one query block for a specific batch and head + pid = tl.program_id(axis=0) + B * H * QC_NUM + + # Calculate batch, head, and query block index + pid_q_block_global = pid # 0 to B*H*QC_NUM - 1 + # pid_bh = pid // QC_NUM # Deprecated: Causes issues if QC_NUM is not constant across BH + # pid_q_block_idx = pid % QC_NUM + + # Need to map pid (0.. B*H*QC_NUM-1) back to (b, h, q_block_idx) + # q_block_idx changes fastest, then h, then b + q_block_idx = pid_q_block_global % QC_NUM + pid_h_temp = pid_q_block_global // QC_NUM + h = pid_h_temp % H + b = pid_h_temp // H + + # --- Load Q block info (start/end offsets) --- + qcs_offset = b * stride_qcs_b + h * stride_qcs_h + q_start_offset = tl.load(qc_cum_size + qcs_offset + q_block_idx * stride_qcs_qc) + q_end_offset = tl.load(qc_cum_size + qcs_offset + (q_block_idx + 1) * stride_qcs_qc) + q_block_size = q_end_offset - q_start_offset + + # Early exit if the query block is empty + if q_block_size == 0: + return + + # --- Pointers setup --- + q_ptr_base = Q + b * stride_qb + h * stride_qh + q_start_offset * stride_qs + k_ptr_base = K + b * stride_kb + h * stride_kh + v_ptr_base = V + b * stride_vb + h * stride_vh + out_ptr_base = Out + b * stride_ob + h * stride_oh + q_start_offset * stride_os + dmap_ptr = dynamic_map + b * stride_dmap_b + h * stride_dmap_h + q_block_idx * stride_dmap_qc + kcs_ptr = kc_cum_size + b * stride_kcs_b + h * stride_kcs_h + + # --- Iterate over the query block rows in chunks of BLOCK_M --- + offs_qm = tl.arange(0, BLOCK_M) # Query block row offsets [0, 1, ..., BLOCK_M-1] + offs_d = tl.arange(0, BLOCK_D) # Dimension offsets [0, 1, ..., BLOCK_D-1] + + for q_chunk_start in range(0, q_block_size, BLOCK_M): + q_chunk_rows = offs_qm + q_chunk_start + q_rows_mask = q_chunk_rows < q_block_size # Mask for valid rows in this Q chunk [BLOCK_M] + + # --- Initialize accumulators for this Q chunk --- + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # Max score + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # Sum of exp(scores - max) + acc_o = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) # Accumulated output + + # --- Load Q chunk --- + q_ptr = q_ptr_base + q_chunk_rows[:, None] * stride_qs + offs_d[None, :] + # Mask ensures we don't read out of bounds for the query block or dimension D + mask_q = q_rows_mask[:, None] & (offs_d[None, :] < D) + q_chunk = tl.load(q_ptr, mask=mask_q, other=0.0) # Shape: [BLOCK_M, BLOCK_D] + + # --- Inner loop over K blocks (columns in the block sparse map) --- + for k_block_idx in range(KC_NUM): + # --- Check dynamic_map: Is this block active? --- + is_active = tl.load(dmap_ptr + k_block_idx * stride_dmap_kc) + if is_active: # Process block only if it's active + # --- Load K block info (start/end offsets) --- + k_start_offset = tl.load(kcs_ptr + k_block_idx * stride_kcs_kc) + k_end_offset = tl.load(kcs_ptr + (k_block_idx + 1) * stride_kcs_kc) + k_block_size = k_end_offset - k_start_offset + + # Skip if the key block is empty (inside the active block check) + if k_block_size > 0: + k_block_ptr_base = k_ptr_base + k_start_offset * stride_ks + v_block_ptr_base = v_ptr_base + k_start_offset * stride_vs + + # --- Loop over K block chunks (size BLOCK_N) --- + offs_kn = tl.arange(0, BLOCK_N) # Key block row offsets [0, ..., BLOCK_N-1] + for k_chunk_start in range(0, k_block_size, BLOCK_N): + k_chunk_rows = offs_kn + k_chunk_start + k_rows_mask = k_chunk_rows < k_block_size # Mask for valid rows in this K/V chunk [BLOCK_N] + + # --- Load K, V chunks --- + k_ptr = k_block_ptr_base + k_chunk_rows[:, None] * stride_ks + offs_d[None, :] + v_ptr = v_block_ptr_base + k_chunk_rows[:, None] * stride_vs + offs_d[None, :] + + # Mask ensures we don't read out of bounds for the key block or dimension D + mask_kv = k_rows_mask[:, None] & (offs_d[None, :] < D) + k_chunk = tl.load(k_ptr, mask=mask_kv, other=0.0) # Shape: [BLOCK_N, BLOCK_D] + v_chunk = tl.load(v_ptr, mask=mask_kv, other=0.0) # Shape: [BLOCK_N, BLOCK_D] + + # --- Compute Scores (Attention) --- + # QK^T: [BLOCK_M, BLOCK_D] @ [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N] + s_ij_chunk = tl.dot(q_chunk, k_chunk.T) * scale + + # IMPORTANT: Mask out scores corresponding to padding in K before max/softmax + # Set scores for invalid K elements to -inf + s_ij_chunk = tl.where(k_rows_mask[None, :], s_ij_chunk, -float("inf")) + # Mask out scores for invalid Q elements as well (although q_chunk elements are 0, avoid potential issues) + s_ij_chunk = tl.where(q_rows_mask[:, None], s_ij_chunk, -float("inf")) + + # --- Online Softmax Update --- + # Current max for this Q-K chunk interaction + m_ij_chunk = tl.max(s_ij_chunk, axis=1) # Shape: [BLOCK_M] + + # Update overall max (across K chunks seen so far for this Q chunk) + m_new = tl.maximum(m_i, m_ij_chunk) # Shape: [BLOCK_M] + + # Calculate scaled probabilities P_ij = exp(S_ij - m_new) + p_ij_chunk = tl.exp(s_ij_chunk - m_new[:, None]) # Shape: [BLOCK_M, BLOCK_N] + # Zero out probabilities for masked K elements before summing + p_ij_chunk = tl.where(k_rows_mask[None, :], p_ij_chunk, 0.0) + + # Calculate scaling factor for previous accumulator state + exp_m_diff = tl.exp(m_i - m_new) # Shape: [BLOCK_M] + + # Update sum accumulator (denominator L) + l_i_chunk = tl.sum(p_ij_chunk, axis=1) # Sum probabilities for this chunk, shape [BLOCK_M] + l_i = (l_i * exp_m_diff) + l_i_chunk # Shape: [BLOCK_M] + + # Update output accumulator O + # P_ij @ V_j: [BLOCK_M, BLOCK_N] @ [BLOCK_N, BLOCK_D] -> [BLOCK_M, BLOCK_D] + # Ensure p_ij_chunk is the correct dtype for dot product + p_ij_chunk_casted = p_ij_chunk.to(V.dtype.element_ty) + o_chunk = tl.dot(p_ij_chunk_casted, v_chunk) # Shape: [BLOCK_M, BLOCK_D] + + acc_o = (acc_o * exp_m_diff[:, None]) + o_chunk # Shape: [BLOCK_M, BLOCK_D] + + # Update max for the next K chunk/block + m_i = m_new + # End of 'if is_active:' block + # --- End of loop over K blocks --- + + # --- Finalize output for this Q chunk --- + # Normalize the accumulated output: O = acc_o / l_i + # Add epsilon to l_i to avoid division by zero + l_i_safe = tl.where(l_i == 0, 1.0, l_i) # Avoid 0/0 -> NaN + o_final_chunk = acc_o / (l_i_safe[:, None]) + o_final_chunk = tl.where(l_i[:, None] == 0, 0.0, o_final_chunk) # Ensure output is 0 if l_i was 0 + + # --- Write output chunk to global memory --- + out_ptr = out_ptr_base + q_chunk_rows[:, None] * stride_os + offs_d[None, :] + # Mask ensures we don't write out of bounds for the query block or dimension D + mask_out = q_rows_mask[:, None] & (offs_d[None, :] < D) + tl.store(out_ptr, o_final_chunk.to(Out.dtype.element_ty), mask=mask_out) + + # --- (Optional: Write L and M stats if needed) --- + # Example: + # l_ptr = L + b * stride_lb + h * stride_lh + (q_start_offset + q_chunk_rows) * stride_ls + # tl.store(l_ptr, l_i, mask=q_rows_mask) + # m_ptr = M + ... + # tl.store(m_ptr, m_i, mask=q_rows_mask) + + # --- End of loop over Q chunks --- + + +def dynamic_block_sparse_fwd_triton(q, k, v, dynamic_map, qc_size, kc_size): + """ + Launcher for the Triton dynamic block sparse attention kernel. + + Args: + q (torch.Tensor): Query tensor, shape [B, H, S, D]. + k (torch.Tensor): Key tensor, shape [B, H, S, D]. + v (torch.Tensor): Value tensor, shape [B, H, S, D]. + dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. + qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num]. + kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num]. + + Returns: + torch.Tensor: Output tensor, shape [B, H, S, D]. + """ + B, H, S, D = q.shape + qc_num = qc_size.shape[-1] + kc_num = kc_size.shape[-1] + dtype = q.dtype + + # Assertions and checks + assert q.is_cuda and k.is_cuda and v.is_cuda, "Inputs must be CUDA tensors" + assert dynamic_map.is_cuda and qc_size.is_cuda and kc_size.is_cuda + assert q.dtype == k.dtype == v.dtype, "Input dtypes must match" + assert dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype" + assert D in [16, 32, 64, 128], "Head dimension D must be 16, 32, 64, or 128 for efficient Triton dot" + # Ensure sequence lengths match sum of block sizes (check on one batch/head for simplicity) + assert S == torch.sum(qc_size[0, 0, :]), "Sum of qc_size must equal S" + assert S == torch.sum(kc_size[0, 0, :]), "Sum of kc_size must equal S" + # Ensure dynamic_map is boolean + assert dynamic_map.dtype == torch.bool + + # Calculate scale factor (using float32 for stability) + scale = D**-0.5 + + # Precompute cumulative sizes (on CPU/GPU, keep on device) + qc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(qc_size[..., :1]), qc_size], dim=-1), dim=-1).int() + kc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(kc_size[..., :1]), kc_size], dim=-1), dim=-1).int() + + # Output tensor + out = torch.empty_like(q) + + # Triton kernel config + # BLOCK_M/N can be tuned. Larger blocks may increase occupancy but need more shared memory. + # Let's start with reasonably sized blocks. + BLOCK_D = D + if S <= 512: # Smaller sequence, smaller blocks might be ok + BLOCK_M = 64 + BLOCK_N = 64 + elif S <= 1024: + BLOCK_M = 64 + BLOCK_N = 64 + else: # Larger sequence, potentially larger blocks + BLOCK_M = 128 # Or keep 64? Test + BLOCK_N = 64 + + # Adjust block size if sequence length is smaller + BLOCK_M = min(BLOCK_M, S) + BLOCK_N = min(BLOCK_N, S) + + # Launch grid: One program per query block per batch/head + grid = (B * H * qc_num,) + + # Call the kernel + _dynamic_block_sparse_fwd_kernel[grid]( + q, + k, + v, + out, + dynamic_map, + qc_cum_size, + kc_cum_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dynamic_map.stride(0), + dynamic_map.stride(1), + dynamic_map.stride(2), + dynamic_map.stride(3), + qc_cum_size.stride(0), + qc_cum_size.stride(1), + qc_cum_size.stride(2), + kc_cum_size.stride(0), + kc_cum_size.stride(1), + kc_cum_size.stride(2), + B, + H, + S, + D, + scale, + QC_NUM=qc_num, + KC_NUM=kc_num, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + # num_warps=4 # Can tune this + ) + + return out + + +# ---------------- Batch wrapper for cuVS KMeans ----------------- + + +def batch_kmeans_rapidai(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False): + """Batched K-Means using RAPIDS cuVS implementation. + + Args: + x (Tensor): (B, N, D) float32 tensor on CUDA. + n_clusters (int): K. + max_iters (int): maximum iterations. + tol (float): tolerance. + init_centroids (Tensor|None): optional initial centroids (B,K,D) float32. + verbose (bool): print per-batch info. + + Returns: + cluster_ids (B, N) LongTensor + centroids (B, K, D) float32 + cluster_sizes (B, K) LongTensor + n_iters_list (List[int]) iterations per batch + """ + B, N, D = x.shape + if init_centroids is not None: + assert init_centroids.shape == (B, n_clusters, D) + + cluster_ids_list = [] + centroids_list = [] + # cluster_sizes_list = [] + n_iters_list = [] + + x_float = x.float() + if init_centroids is not None: + init_centroids_float = init_centroids.float() + + for b in range(B): + xb = x_float[b] + if init_centroids is None: + centroids_init_b = None + init_method = "KMeansPlusPlus" + else: + centroids_init_b = init_centroids_float[b] + init_method = "Array" + labels_b, centroids_b, n_iter_b = kmeans_rapidai(xb, n_clusters, max_iter=max_iters, tol=tol, init_method=init_method, centroids_init=centroids_init_b) + + cluster_ids_list.append(labels_b.to(torch.int64)) # (N,) + centroids_list.append(centroids_b) + # cluster_sizes_b = torch.bincount(labels_b, minlength=n_clusters).to(torch.int64) + # cluster_sizes_list.append(cluster_sizes_b) + # n_iters_list.append(n_iter_b) + # if verbose: + # print(f"Batch {b}: iters={n_iter_b}, cluster sizes min={cluster_sizes_b.min().item()} max={cluster_sizes_b.max().item()}") + + cluster_ids = torch.stack(cluster_ids_list, dim=0) # (B,N) + centroids = torch.stack(centroids_list, dim=0).to(x.dtype) # (B,K,D) + # cluster_sizes = torch.stack(cluster_sizes_list, dim=0) # (B,K) + # --- compute cluster sizes --- + ones = torch.ones_like(cluster_ids, dtype=torch.int64) + cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device) + cluster_sizes.scatter_add_(1, cluster_ids, ones) + + return cluster_ids, centroids, cluster_sizes, n_iters_list diff --git a/lightx2v/common/ops/attn/svg_attn.py b/lightx2v/common/ops/attn/svg_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..66db7d1b2b34c7e7e7eec4fc5155adebf2b27c8a --- /dev/null +++ b/lightx2v/common/ops/attn/svg_attn.py @@ -0,0 +1,409 @@ +import math +from functools import lru_cache +from math import ceil + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from loguru import logger +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .template import AttnWeightTemplate + + +@triton.jit +def wan_hidden_states_placement_kernel( + hidden_states_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size + hidden_states_out_ptr, # [cfg, num_heads, seq_len, head_dim] + best_mask_idx_ptr, # [cfg, num_heads] + hidden_states_stride_b, + hidden_states_stride_h, + hidden_states_stride_s, + hidden_states_stride_d, + mask_idx_stride_b, + mask_idx_stride_h, + seq_len: tl.constexpr, + head_dim: tl.constexpr, + context_length: tl.constexpr, + num_frame: tl.constexpr, + frame_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + # Copy hidden_states to output + # range: [b, h, block_id * block_size: block_id * block_size + block_size, :] + cfg = tl.program_id(0) + head = tl.program_id(1) + block_id = tl.program_id(2) + + start_id = block_id * BLOCK_SIZE + end_id = start_id + BLOCK_SIZE + end_id = tl.where(end_id > seq_len, seq_len, end_id) + + # Load best mask idx (0 is spatial, 1 is temporal) + is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h) + + offset_token = tl.arange(0, BLOCK_SIZE) + start_id + offset_mask = offset_token < seq_len + offset_d = tl.arange(0, head_dim) + + if is_temporal: + patch_id = offset_token // num_frame + frame_id = offset_token - patch_id * num_frame + offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, frame_id * frame_size + patch_id) + + offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d + offset_hidden_states = hidden_states_ptr + offset_load + + offset_store = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_store_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d + offset_hidden_states_out = hidden_states_out_ptr + offset_store + + # Maybe tune the pipeline here + hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None]) + tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None]) + else: + offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d + offset_hidden_states = hidden_states_ptr + offset_load + + offset_store = offset_load + offset_hidden_states_out = hidden_states_out_ptr + offset_store + + # Maybe tune the pipeline here + hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None]) + tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None]) + + +def wan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size): + cfg, num_heads, seq_len, head_dim = hidden_states.shape + BLOCK_SIZE = 128 + assert seq_len == context_length + num_frame * frame_size + + grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) + + wan_hidden_states_placement_kernel[grid]( + hidden_states, + hidden_states_out, + best_mask_idx, + hidden_states.stride(0), + hidden_states.stride(1), + hidden_states.stride(2), + hidden_states.stride(3), + best_mask_idx.stride(0), + best_mask_idx.stride(1), + seq_len, + head_dim, + context_length, + num_frame, + frame_size, + BLOCK_SIZE, + ) + + return hidden_states_out + + +@triton.jit +def wan_sparse_head_placement_kernel( + query_ptr, + key_ptr, + value_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size + query_out_ptr, + key_out_ptr, + value_out_ptr, # [cfg, num_heads, seq_len, head_dim] + best_mask_idx_ptr, # [cfg, num_heads] + query_stride_b, + query_stride_h, + query_stride_s, + query_stride_d, + mask_idx_stride_b, + mask_idx_stride_h, + seq_len: tl.constexpr, + head_dim: tl.constexpr, + context_length: tl.constexpr, + num_frame: tl.constexpr, + frame_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + # Copy query, key, value to output + # range: [b, h, block_id * block_size: block_id * block_size + block_size, :] + cfg = tl.program_id(0) + head = tl.program_id(1) + block_id = tl.program_id(2) + + start_id = block_id * BLOCK_SIZE + end_id = start_id + BLOCK_SIZE + end_id = tl.where(end_id > seq_len, seq_len, end_id) + + # Load best mask idx (0 is spatial, 1 is temporal) + is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h) + + offset_token = tl.arange(0, BLOCK_SIZE) + start_id + offset_mask = offset_token < seq_len + offset_d = tl.arange(0, head_dim) + + if is_temporal: + frame_id = offset_token // frame_size + patch_id = offset_token - frame_id * frame_size + offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, patch_id * num_frame + frame_id) + + offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d + offset_query = query_ptr + offset_load + offset_key = key_ptr + offset_load + offset_value = value_ptr + offset_load + + offset_store = (cfg * query_stride_b + head * query_stride_h + offset_store_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d + offset_query_out = query_out_ptr + offset_store + offset_key_out = key_out_ptr + offset_store + offset_value_out = value_out_ptr + offset_store + + # Maybe tune the pipeline here + query = tl.load(offset_query, mask=offset_mask[:, None]) + tl.store(offset_query_out, query, mask=offset_mask[:, None]) + key = tl.load(offset_key, mask=offset_mask[:, None]) + tl.store(offset_key_out, key, mask=offset_mask[:, None]) + value = tl.load(offset_value, mask=offset_mask[:, None]) + tl.store(offset_value_out, value, mask=offset_mask[:, None]) + + else: + offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d + offset_query = query_ptr + offset_load + offset_key = key_ptr + offset_load + offset_value = value_ptr + offset_load + + offset_store = offset_load + offset_query_out = query_out_ptr + offset_store + offset_key_out = key_out_ptr + offset_store + offset_value_out = value_out_ptr + offset_store + + # Maybe tune the pipeline here + query = tl.load(offset_query, mask=offset_mask[:, None]) + tl.store(offset_query_out, query, mask=offset_mask[:, None]) + key = tl.load(offset_key, mask=offset_mask[:, None]) + tl.store(offset_key_out, key, mask=offset_mask[:, None]) + value = tl.load(offset_value, mask=offset_mask[:, None]) + tl.store(offset_value_out, value, mask=offset_mask[:, None]) + + +def wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size): + cfg, num_heads, seq_len, head_dim = query.shape + BLOCK_SIZE = 128 + assert seq_len == context_length + num_frame * frame_size + + grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) + + wan_sparse_head_placement_kernel[grid]( + query, + key, + value, + query_out, + key_out, + value_out, + best_mask_idx, + query.stride(0), + query.stride(1), + query.stride(2), + query.stride(3), + best_mask_idx.stride(0), + best_mask_idx.stride(1), + seq_len, + head_dim, + context_length, + num_frame, + frame_size, + BLOCK_SIZE, + ) + + +def generate_temporal_head_mask_mod(context_length: int = 226, prompt_length: int = 226, num_frames: int = 13, token_per_frame: int = 1350, mul: int = 2): + def round_to_multiple(idx): + return ceil(idx / 128) * 128 + + def temporal_mask_mod(b, h, q_idx, kv_idx): + two_frame = round_to_multiple(mul * token_per_frame) + temporal_head_mask = torch.abs(q_idx - kv_idx) <= two_frame + + # return temporal_head_mask + first_frame_mask = kv_idx < token_per_frame + video_mask = first_frame_mask | temporal_head_mask + return video_mask + + return temporal_mask_mod + + +@lru_cache +def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False): + block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile) + return block_mask + + +def prepare_flexattention(cfg_size, num_head, head_dim, dtype, device, context_length, prompt_length, num_frame, frame_size, diag_width=1, multiplier=2): + assert diag_width == multiplier, f"{diag_width} is not equivalent to {multiplier}" + seq_len = context_length + num_frame * frame_size + mask_mod = generate_temporal_head_mask_mod(context_length, prompt_length, num_frame, frame_size, mul=multiplier) + block_mask = create_block_mask_cached(mask_mod, None, None, seq_len, seq_len, device=device, _compile=True) + return block_mask + + +def sparsity_to_width(sparsity, context_length, num_frame, frame_size): + seq_len = context_length + num_frame * frame_size + total_elements = seq_len**2 + + sparsity = (sparsity * total_elements - 2 * seq_len * context_length) / total_elements + + width = seq_len * (1 - math.sqrt(1 - sparsity)) + width_frame = width / frame_size + + return width_frame + + +def get_attention_mask(mask_name, sample_mse_max_row, context_length, num_frame, frame_size): + attention_mask = torch.zeros((context_length + num_frame * frame_size, context_length + num_frame * frame_size), device="cpu") + + # TODO: fix hard coded mask + if mask_name == "spatial": + pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu") + + pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink + + block_size, block_thres = 128, frame_size * 2 + num_block = math.ceil(num_frame * frame_size / block_size) + for i in range(num_block): + for j in range(num_block): + if abs(i - j) < block_thres // block_size: + pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1 + attention_mask = pixel_attn_mask + else: + pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu") + + pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink + + block_size, block_thres = 128, frame_size * 2 + num_block = math.ceil(num_frame * frame_size / block_size) + for i in range(num_block): + for j in range(num_block): + if abs(i - j) < block_thres // block_size: + pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1 + + pixel_attn_mask = pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame).permute(1, 0, 3, 2).reshape(frame_size * num_frame, frame_size * num_frame) + attention_mask = pixel_attn_mask + + attention_mask = attention_mask[:sample_mse_max_row].cuda() + return attention_mask + + +@ATTN_WEIGHT_REGISTER("svg_attn") +class SvgAttnWeight(AttnWeightTemplate): + head_num = None + head_dim = None + sample_mse_max_row = None + num_sampled_rows = None + context_length = None + attnmap_frame_num = None + seqlen = None + sparsity = None + mask_name_list = ["spatial", "temporal"] + attention_masks = None + block_mask = None + + @classmethod + def prepare(cls, head_num, head_dim, sample_mse_max_row, num_sampled_rows, context_length, sparsity): + cls.head_num = head_num + cls.head_dim = head_dim + cls.sample_mse_max_row = sample_mse_max_row + cls.num_sampled_rows = num_sampled_rows + cls.context_length = context_length + cls.sparsity = sparsity + torch._dynamo.config.cache_size_limit = 192 * 3 + torch._dynamo.config.accumulated_cache_size_limit = 192 * 3 + logger.info( + f"SvgAttnWeight Prepare: head_num={head_num}, head_dim={head_dim}, sample_mse_max_row={sample_mse_max_row}, num_sampled_rows={num_sampled_rows}, context_length={context_length}, sparsity={sparsity}" + ) + + def __init__(self): + self.config = {} + self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs") + + @classmethod + def prepare_mask(cls, seqlen): + # Use class attributes so updates affect all instances of this class + if seqlen == cls.seqlen: + return + frame_size = seqlen // cls.attnmap_frame_num + cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, cls.attnmap_frame_num, frame_size) for mask_name in cls.mask_name_list] + multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, cls.attnmap_frame_num, frame_size) + cls.block_mask = prepare_flexattention( + 1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, cls.attnmap_frame_num, frame_size, diag_width=diag_width, multiplier=multiplier + ) + cls.seqlen = seqlen + logger.info(f"SvgAttnWeight Update: seqlen={seqlen}") + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + q = q.unsqueeze(0).transpose(1, 2) + k = k.unsqueeze(0).transpose(1, 2) + v = v.unsqueeze(0).transpose(1, 2) + bs, num_heads, seq_len, dim = q.size() + + self.prepare_mask(seq_len) + sampled_mses = self.sample_mse(q, k, v) + best_mask_idx = torch.argmin(sampled_mses, dim=0) + + output_hidden_states = torch.zeros_like(q) + query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) + + query_out, key_out, value_out = self.fast_sparse_head_placement( + q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num + ) + + hidden_states = self.sparse_attention(query_out, key_out, value_out) + wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num) + + return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1) + + def fast_sparse_head_placement(self, query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size): + wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size) + return query_out, key_out, value_out + + def sample_mse(self, query, key, value): + cfg, num_heads, seq_len, dim = query.size() + num_sampled_rows = min(self.num_sampled_rows, seq_len) + sampled_rows = torch.randint(low=0, high=self.sample_mse_max_row, size=(num_sampled_rows,)) + sampled_q = query[:, :, sampled_rows, :] + sampled_qk_scores = torch.matmul(sampled_q, key.transpose(-2, -1)) / (dim**0.5) + + sampled_attn_weights = F.softmax(sampled_qk_scores, dim=-1) + sampled_golden_hidden_states = torch.matmul(sampled_attn_weights, value) # (1, seq_len, dim) + + sampled_mses = torch.zeros(len(self.attention_masks), cfg, num_heads, device=query.device, dtype=query.dtype) + + # Only have Tri-diagonal and Striped + for mask_idx, attn_mask in enumerate(self.attention_masks): + sampled_attention_mask = attn_mask[sampled_rows, :] + sampled_attention_scores = sampled_qk_scores.masked_fill(sampled_attention_mask == 0, float("-inf")) + sampled_attn_weights = F.softmax(sampled_attention_scores, dim=-1) + sampled_hidden_states = torch.matmul(sampled_attn_weights, value) + mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3)) + sampled_mses[mask_idx] = mse + + return sampled_mses + + +if __name__ == "__main__": + q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda() + + SvgAttnWeight.prepare(head_num=40, head_dim=128, sample_mse_max_row=10000, num_sampled_rows=64, context_length=0, sparsity=0.25) + svg_attn = SvgAttnWeight() + print("SvgAttnWeight initialized.") + + out = svg_attn.apply(q, k, v) + print(f"out: {out.shape}, {out.dtype}, {out.device}") diff --git a/lightx2v/common/ops/attn/template.py b/lightx2v/common/ops/attn/template.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1f447e9f7b693b530743cad8e3aa72459685ff --- /dev/null +++ b/lightx2v/common/ops/attn/template.py @@ -0,0 +1,35 @@ +from abc import ABCMeta, abstractmethod + + +class AttnWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name): + self.weight_name = weight_name + self.config = {} + + def load(self, weight_dict): + pass + + @abstractmethod + def apply(self, input_tensor): + pass + + def set_config(self, config=None): + if config is not None: + self.config = config + + def to_cpu(self, non_blocking=False): + pass + + def to_cuda(self, non_blocking=False): + pass + + def state_dict(self, destination=None): + if destination is None: + destination = {} + return destination + + def load_state_dict(self, destination, block_index, adapter_block_inde=None): + return {} + + def load_state_dict_from_disk(self, block_index, adapter_block_inde=None): + pass diff --git a/lightx2v/common/ops/attn/torch_sdpa.py b/lightx2v/common/ops/attn/torch_sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..65a002b6ae16f52975e0bf476fa99ab26dfdcce0 --- /dev/null +++ b/lightx2v/common/ops/attn/torch_sdpa.py @@ -0,0 +1,39 @@ +import torch +import torch.nn.functional as F + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER + +from .template import AttnWeightTemplate + + +@ATTN_WEIGHT_REGISTER("torch_sdpa") +class TorchSDPAWeight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + def apply( + self, + q, + k, + v, + drop_rate=0, + attn_mask=None, + causal=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + model_cls=None, + ): + if q.ndim == 3: + q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + x = x.transpose(1, 2) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out.squeeze(0) diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..197998b6e75055c24b028b43ec891decc8f0825a --- /dev/null +++ b/lightx2v/common/ops/attn/ulysses_attn.py @@ -0,0 +1,415 @@ +import torch +import torch.distributed as dist +from loguru import logger + +from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + +from .template import AttnWeightTemplate +from .utils.all2all import all2all_head2seq, all2all_seq2head + + +@ATTN_WEIGHT_REGISTER("ulysses") +class UlyssesAttnWeight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False): + """ + 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。 + + 参数: + q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims] + k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims] + v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims] + img_qkv_len (int): 图像查询、键和值的长度 + cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息 + attention_type (str): 注意力类型,默认为 "flash_attn2" + + 返回: + torch.Tensor: 计算得到的注意力结果 + """ + if len(q.shape) == 4: + q = q.reshape(-1, q.shape[-2], q.shape[-1]) + k = k.reshape(-1, k.shape[-2], k.shape[-1]) + v = v.reshape(-1, v.shape[-2], v.shape[-1]) + + # 获取当前进程的排名和全局进程数 + world_size = dist.get_world_size(seq_p_group) + cur_rank = dist.get_rank(seq_p_group) + + # 获取序列长度和文本相关的长度 + seq_len = q.shape[0] + if len(cu_seqlens_qkv) == 3: + txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 + txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 + elif len(cu_seqlens_qkv) == 2: + txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 + txt_mask_len = None + + # 获取查询张量的头数和隐藏维度 + _, heads, hidden_dims = q.shape + shard_heads = heads // world_size # 每个进程处理的头数 + shard_seqlen = img_qkv_len # 每个进程处理的序列长度 + + # 分割图像和文本的查询、键和值 + img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous() + txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous() + + # 将图像的查询、键和值转换为头的格式 + if use_fp8_comm: + original_dtype = img_q.dtype + original_shape = img_q.shape + img_q_fp8, q_scale = quant_fp8_vllm(img_q.reshape(-1, original_shape[-1])) + img_k_fp8, k_scale = quant_fp8_vllm(img_k.reshape(-1, original_shape[-1])) + img_v_fp8, v_scale = quant_fp8_vllm(img_v.reshape(-1, original_shape[-1])) + img_q_fp8 = all2all_seq2head(img_q_fp8.reshape(original_shape), group=seq_p_group) + img_k_fp8 = all2all_seq2head(img_k_fp8.reshape(original_shape), group=seq_p_group) + img_v_fp8 = all2all_seq2head(img_v_fp8.reshape(original_shape), group=seq_p_group) + q_scale = all2all_seq2head(q_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group) + k_scale = all2all_seq2head(k_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group) + v_scale = all2all_seq2head(v_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group) + img_q = dequant_fp8_vllm(img_q_fp8, q_scale, original_dtype) + img_k = dequant_fp8_vllm(img_k_fp8, k_scale, original_dtype) + img_v = dequant_fp8_vllm(img_v_fp8, v_scale, original_dtype) + else: + img_q = all2all_seq2head(img_q, group=seq_p_group) + img_k = all2all_seq2head(img_k, group=seq_p_group) + img_v = all2all_seq2head(img_v, group=seq_p_group) + + # 处理文本的查询、键和值,选择当前进程的头 + txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] + txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] + txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] + + # 合并图像和文本的查询、键和值 + q = torch.cat((img_q, txt_q), dim=0) + k = torch.cat((img_k, txt_k), dim=0) + v = torch.cat((img_v, txt_v), dim=0) + + # 初始化累积序列长度张量 + cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=AI_DEVICE) + s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度 + s1 = s # 当前样本的结束位置 + cu_seqlens_qkv[1] = s1 # 设置累积序列长度 + if txt_mask_len: + s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置 + cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2) + max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度 + + # 调用注意力函数计算注意力结果 + # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv) + attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls) + + # 分割图像和文本的注意力结果 + img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,] + + # 收集所有进程的文本注意力结果 + gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)] + dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) + + img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm) + + txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果 + + # 合并图像和文本的注意力结果 + attn = torch.cat([img_attn, txt_attn], dim=0) + + return attn # 返回最终的注意力结果 + + @torch.compiler.disable + def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm): + img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果 + + # 将头的格式转换回序列格式 + if use_fp8_comm: + original_dtype = img_attn.dtype + original_shape = img_attn.shape + img_attn_fp8, attn_scale = quant_fp8_vllm(img_attn.reshape(-1, original_shape[-1])) + img_attn_fp8 = all2all_head2seq(img_attn_fp8.reshape(original_shape), group=seq_p_group) + attn_scale = all2all_head2seq(attn_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group) + img_attn = dequant_fp8_vllm(img_attn_fp8, attn_scale, original_dtype) + else: + img_attn = all2all_head2seq(img_attn, group=seq_p_group) + + img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状 + return img_attn + + +@ATTN_WEIGHT_REGISTER("ulysses-4090") +class Ulysses4090AttnWeight(AttnWeightTemplate): + def __init__(self): + self.config = {} + self.rounds = [] + + def generate_round_robin_pairs(self, seq_p_group=None): + """ + 生成循环赛配对表,并确保每个配对中的第一个元素小于第二个 + 这样我们可以用简单的规则确定通信顺序 + """ + cur_rank = dist.get_rank(seq_p_group) + world_size = dist.get_world_size(seq_p_group) + if world_size % 2 != 0: + raise ValueError("world_size必须是偶数,奇数情况需要特殊处理") + + teams = list(range(world_size)) + for _ in range(world_size - 1): + round_schedule = {} + for i in range(world_size // 2): + team1, team2 = teams[i], teams[world_size - 1 - i] + smaller, larger = min(team1, team2), max(team1, team2) + round_schedule[smaller] = (larger, True) + round_schedule[larger] = (smaller, False) + self.rounds.append(round_schedule) + # 旋转列表(固定第一个元素) + teams = [teams[0]] + [teams[-1]] + teams[1:-1] + + # if cur_rank == 0: + # self.print_pairing_schedule(seq_p_group) + + def print_pairing_schedule(self, seq_p_group): + """打印通信调度表""" + world_size = dist.get_world_size(seq_p_group) + logger.info("循环赛通信调度表:") + logger.info("=" * 50) + for i, round_schedule in enumerate(self.rounds): + logger.info(f"第 {i + 1} 轮:") + for cur_rank in range(world_size): + partner, is_smaller_in_pair = round_schedule[cur_rank] + logger.info(f" 进程 {cur_rank} ←→ 进程 {partner}") + logger.info("=" * 50) + + def load_balanced_all_to_all(self, shards, seq_p_group=None): + """ + 负载均衡all-to-all通信实现 + """ + world_size = dist.get_world_size(seq_p_group) + cur_rank = dist.get_rank(seq_p_group) + global_rank = dist.get_global_rank(seq_p_group, cur_rank) + cfg_p_group_index = global_rank // world_size + + # 准备接收缓冲区 + gathered_shards = [None] * world_size + for target_rank in range(world_size): + if target_rank != cur_rank: + gathered_shards[target_rank] = torch.empty_like(shards[target_rank]) + else: + gathered_shards[cur_rank] = shards[cur_rank] + + for i, round_schedule in enumerate(self.rounds): + # 查找当前进程在本轮的配对 + partner = None + is_smaller_in_pair = False + if cur_rank in round_schedule: + partner, is_smaller_in_pair = round_schedule[cur_rank] + + # 如果没有找到配对,说明本轮当前进程空闲 + if partner is None: + continue + + # 计算全局rank + partner_global_rank = cfg_p_group_index * world_size + partner + + if is_smaller_in_pair: + # 当前进程是配对中的较小者,先发送后接收 + send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group) + recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group) + send_req.wait() + recv_req.wait() + else: + # 当前进程是配对中的较大者,先接收后发送 + recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group) + send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group) + recv_req.wait() + send_req.wait() + + return gathered_shards + + def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False): + """ + 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。 + + 参数: + q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims] + k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims] + v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims] + img_qkv_len (int): 图像查询、键和值的长度 + cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息 + attention_type (str): 注意力类型,默认为 "flash_attn2" + + 返回: + torch.Tensor: 计算得到的注意力结果 + """ + if len(self.rounds) == 0: + self.generate_round_robin_pairs(seq_p_group) + + if len(q.shape) == 4: + q = q.reshape(-1, q.shape[-2], q.shape[-1]) + k = k.reshape(-1, k.shape[-2], k.shape[-1]) + v = v.reshape(-1, v.shape[-2], v.shape[-1]) + # 获取当前进程的排名和全局进程数 + world_size = dist.get_world_size(seq_p_group) + cur_rank = dist.get_rank(seq_p_group) + global_world_size = dist.get_world_size() + global_rank = dist.get_global_rank(seq_p_group, cur_rank) + cfg_p_group_index = global_rank // world_size + + # 获取序列长度和文本相关的长度 + seq_len = q.shape[0] + if len(cu_seqlens_qkv) == 3: + txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 + txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 + elif len(cu_seqlens_qkv) == 2: + txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 + txt_mask_len = None + + # 获取查询张量的头数和隐藏维度 + _, heads, hidden_dims = q.shape + shard_heads = heads // world_size # 每个进程处理的头数 + shard_seqlen = img_qkv_len # 每个进程处理的序列长度 + + # 分割图像和文本的查询、键和值 + img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous() + txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous() + + # 计算每个进程应该持有的头数分片 + num_heads = img_q.shape[1] + shard_heads = num_heads // world_size + + # 将 image QKV 拼接后,按头维度切分成 N 份,每份大小为 D/N + img_qkv = torch.stack([img_q, img_k, img_v], dim=0) + qkv_shards = [img_qkv[:, :, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)] + qkv_dtype = img_qkv.dtype + + if use_fp8_comm: + qkv_fp8_byte_tensors = [] + qkv_fp8_bytes = 0 + qkv_fp8_dtype = None + qkv_scale_dtype = None + for i in range(world_size): + qkv_fp8, qkv_scale = quant_fp8_vllm(qkv_shards[i].reshape(-1, hidden_dims)) + if i == 0: + qkv_fp8_bytes = qkv_fp8.numel() * qkv_fp8.element_size() + qkv_fp8_dtype = qkv_fp8.dtype + qkv_scale_dtype = qkv_scale.dtype + qkv_fp8_byte_tensors.append(torch.cat([qkv_fp8.contiguous().reshape(-1).view(torch.uint8), qkv_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0)) + + gathered_qkv_fp8_byte_tensors = self.load_balanced_all_to_all(qkv_fp8_byte_tensors, seq_p_group) + + gathered_q_shards = [] + gathered_k_shards = [] + gathered_v_shards = [] + for i in range(world_size): + qkv_fp8_byte_tensor = gathered_qkv_fp8_byte_tensors[i] + qkv_fp8 = qkv_fp8_byte_tensor[:qkv_fp8_bytes].view(qkv_fp8_dtype).reshape(3, -1, hidden_dims) + qkv_scale = qkv_fp8_byte_tensor[qkv_fp8_bytes:].view(qkv_scale_dtype).reshape(3, -1, 1) + q_shards_new = dequant_fp8_vllm(qkv_fp8[0], qkv_scale[0], qkv_dtype).reshape(-1, shard_heads, hidden_dims) + k_shards_new = dequant_fp8_vllm(qkv_fp8[1], qkv_scale[1], qkv_dtype).reshape(-1, shard_heads, hidden_dims) + v_shards_new = dequant_fp8_vllm(qkv_fp8[2], qkv_scale[2], qkv_dtype).reshape(-1, shard_heads, hidden_dims) + gathered_q_shards.append(q_shards_new) + gathered_k_shards.append(k_shards_new) + gathered_v_shards.append(v_shards_new) + else: + gathered_qkv_byte_tensors = self.load_balanced_all_to_all(qkv_shards, seq_p_group) + + gathered_q_shards = [] + gathered_k_shards = [] + gathered_v_shards = [] + for i in range(world_size): + qkv_tensor = gathered_qkv_byte_tensors[i].view(qkv_dtype).reshape(3, -1, shard_heads, hidden_dims) + gathered_q_shards.append(qkv_tensor[0]) + gathered_k_shards.append(qkv_tensor[1]) + gathered_v_shards.append(qkv_tensor[2]) + + # 拼接所有分片 (在序列维度上) + # 每个 gathered_*_shards[i] 的形状是 (seq_len/N, num_heads/N, head_dim) + # 拼接后形状是 (seq_len, num_heads/N, head_dim) + img_q = torch.cat(gathered_q_shards, dim=0) + img_k = torch.cat(gathered_k_shards, dim=0) + img_v = torch.cat(gathered_v_shards, dim=0) + + # 处理文本的查询、键和值,选择当前进程的头 + txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] + txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] + txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] + + # 合并图像和文本的查询、键和值 + q = torch.cat((img_q, txt_q), dim=0) + k = torch.cat((img_k, txt_k), dim=0) + v = torch.cat((img_v, txt_v), dim=0) + + # 初始化累积序列长度张量 + cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device="cuda") + s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度 + s1 = s # 当前样本的结束位置 + cu_seqlens_qkv[1] = s1 # 设置累积序列长度 + if txt_mask_len: + s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置 + cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2) + max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度 + + # 调用注意力函数计算注意力结果 + # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv) + attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls) + + # 分割图像和文本的注意力结果 + img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,] + + # 收集所有进程的文本注意力结果 + gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)] + dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) + + img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm) + + txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果 + + # 合并图像和文本的注意力结果 + attn = torch.cat([img_attn, txt_attn], dim=0) + + return attn # 返回最终的注意力结果 + + @torch.compiler.disable + def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm): + cur_rank = dist.get_rank(seq_p_group) + global_world_size = dist.get_world_size() + global_rank = dist.get_global_rank(seq_p_group, cur_rank) + cfg_p_group_index = global_rank // world_size + + img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果 + attn_dtype = img_attn.dtype + + # 按序列维度切分成 N 份 + attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)] + + if use_fp8_comm: + attn_fp8_byte_tensors = [] + attn_fp8_bytes = 0 + attn_fp8_dtype = None + attn_scale_dtype = None + for i in range(world_size): + attn_fp8, attn_scale = quant_fp8_vllm(attn_shards[i].reshape(-1, hidden_dims)) + if i == 0: + attn_fp8_bytes = attn_fp8.numel() * attn_fp8.element_size() + attn_fp8_dtype = attn_fp8.dtype + attn_scale_dtype = attn_scale.dtype + attn_fp8_byte_tensors.append(torch.cat([attn_fp8.contiguous().reshape(-1).view(torch.uint8), attn_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0)) + + gathered_attn_fp8_byte_tensors = self.load_balanced_all_to_all(attn_fp8_byte_tensors, seq_p_group) + + gathered_attn_shards = [] + for i in range(world_size): + attn_fp8_byte_tensor = gathered_attn_fp8_byte_tensors[i] + attn_fp8 = attn_fp8_byte_tensor[:attn_fp8_bytes].view(attn_fp8_dtype).reshape(-1, hidden_dims) + attn_scale = attn_fp8_byte_tensor[attn_fp8_bytes:].view(attn_scale_dtype).reshape(-1, 1) + attn_shards_new = dequant_fp8_vllm(attn_fp8, attn_scale, attn_dtype).reshape(-1, shard_heads, hidden_dims) + gathered_attn_shards.append(attn_shards_new) + + else: + gathered_attn_shards = self.load_balanced_all_to_all(attn_shards, seq_p_group) + + # 拼接所有分片 (在头维度上) + img_attn = torch.cat(gathered_attn_shards, dim=1) + img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状 + + return img_attn diff --git a/lightx2v/common/ops/attn/utils/all2all.py b/lightx2v/common/ops/attn/utils/all2all.py new file mode 100644 index 0000000000000000000000000000000000000000..757ce74e8abaa1fae6c44d0b7251106ab8823137 --- /dev/null +++ b/lightx2v/common/ops/attn/utils/all2all.py @@ -0,0 +1,89 @@ +import torch +import torch._dynamo as dynamo +import torch.distributed as dist + + +@dynamo.disable +def all2all_seq2head(input, group=None): + """ + 将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。 + + 参数: + input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims] + + 返回: + torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims] + """ + # 确保输入是一个3D张量 + assert input.dim() == 3, f"input must be 3D tensor" + + # 获取当前进程的世界大小 + world_size = dist.get_world_size(group=group) + + # 获取输入张量的形状 + shard_seq_len, heads, hidden_dims = input.shape + seq_len = shard_seq_len * world_size # 计算总序列长度 + shard_heads = heads // world_size # 计算每个进程处理的头数 + + # 重塑输入张量以便进行 all-to-all 操作 + input_t = ( + input.reshape(shard_seq_len, world_size, shard_heads, hidden_dims) # 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims] + .transpose(0, 1) # 转置以便进行 all-to-all 操作 + .contiguous() # 确保内存连续 + ) + + # 创建一个与输入张量相同形状的输出张量 + output = torch.empty_like(input_t) + + # 执行 all-to-all 操作,将输入张量的内容分发到所有进程 + dist.all_to_all_single(output, input_t, group=group) + + # 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状 + output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous() + + return output # 返回转换后的输出张量 + + +@dynamo.disable +def all2all_head2seq(input, group=None): + """ + 将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。 + + 参数: + input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims] + + 返回: + torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims] + """ + # 确保输入是一个3D张量 + assert input.dim() == 3, f"input must be 3D tensor" + + # 获取当前进程的世界大小 + world_size = dist.get_world_size(group=group) + + # 获取输入张量的形状 + seq_len, shard_heads, hidden_dims = input.shape + heads = shard_heads * world_size # 计算总头数 + shard_seq_len = seq_len // world_size # 计算每个进程处理的序列长度 + + # 重塑输入张量以便进行 all-to-all 操作 + input_t = ( + input.reshape(world_size, shard_seq_len, shard_heads, hidden_dims) # 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims] + .transpose(1, 2) # 转置以便进行 all-to-all 操作 + .contiguous() # 确保内存连续 + .reshape(world_size, shard_heads, shard_seq_len, hidden_dims) # 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims] + ) + + # 创建一个与输入张量相同形状的输出张量 + output = torch.empty_like(input_t) + + # 执行 all-to-all 操作,将输入张量的内容分发到所有进程 + dist.all_to_all_single(output, input_t, group=group) + + # 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状 + output = output.reshape(heads, shard_seq_len, hidden_dims) + + # 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状 + output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims) + + return output # 返回转换后的输出张量 diff --git a/lightx2v/common/ops/attn/utils/ring_comm.py b/lightx2v/common/ops/attn/utils/ring_comm.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0f30a463b53eecc7e7d5598bf424256c0a8aa9 --- /dev/null +++ b/lightx2v/common/ops/attn/utils/ring_comm.py @@ -0,0 +1,46 @@ +from typing import Optional + +import torch +import torch.distributed as dist + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup = None): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + # logger.info(f"send_recv: empty_like {to_send.shape}") + else: + res = recv_tensor + + send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] diff --git a/lightx2v/common/ops/conv/__init__.py b/lightx2v/common/ops/conv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4cf739e6027771a21e11d8a629cf444546e1163 --- /dev/null +++ b/lightx2v/common/ops/conv/__init__.py @@ -0,0 +1,2 @@ +from .conv2d import * +from .conv3d import * diff --git a/lightx2v/common/ops/conv/conv2d.py b/lightx2v/common/ops/conv/conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d6665087ad03468d532aa1b42810903e5c773562 --- /dev/null +++ b/lightx2v/common/ops/conv/conv2d.py @@ -0,0 +1,61 @@ +from abc import ABCMeta, abstractmethod + +import torch + +from lightx2v.utils.registry_factory import CONV2D_WEIGHT_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + + +class Conv2dWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name, bias_name, stride, padding, dilation, groups): + self.weight_name = weight_name + self.bias_name = bias_name + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.config = {} + + @abstractmethod + def load(self, weight_dict): + pass + + @abstractmethod + def apply(self, input_tensor): + pass + + def set_config(self, config=None): + if config is not None: + self.config = config + + +@CONV2D_WEIGHT_REGISTER("Default") +class Conv2dWeight(Conv2dWeightTemplate): + def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1): + super().__init__(weight_name, bias_name, stride, padding, dilation, groups) + + def load(self, weight_dict): + self.weight = weight_dict[self.weight_name].to(AI_DEVICE) + self.bias = weight_dict[self.bias_name].to(AI_DEVICE) if self.bias_name is not None else None + + def apply(self, input_tensor): + input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + return input_tensor + + def to_cpu(self, non_blocking=False): + self.weight = self.weight.cpu(non_blocking=non_blocking) + if self.bias is not None: + self.bias = self.bias.cpu(non_blocking=non_blocking) + + def to_cuda(self, non_blocking=False): + self.weight = self.weight.to(AI_DEVICE, non_blocking=non_blocking) + if self.bias is not None: + self.bias = self.bias.to(AI_DEVICE, non_blocking=non_blocking) + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.weight_name] = self.weight.cpu().detach().clone() + if self.bias is not None: + destination[self.bias_name] = self.bias.cpu().detach().clone() + return destination diff --git a/lightx2v/common/ops/conv/conv3d.py b/lightx2v/common/ops/conv/conv3d.py new file mode 100644 index 0000000000000000000000000000000000000000..4e45f3391ec48cf67008ba5748c9065677328ec7 --- /dev/null +++ b/lightx2v/common/ops/conv/conv3d.py @@ -0,0 +1,94 @@ +from abc import ABCMeta, abstractmethod + +import torch + +from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + + +class Conv3dWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1): + self.weight_name = weight_name + self.bias_name = bias_name + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.config = {} + + @abstractmethod + def load(self, weight_dict): + pass + + @abstractmethod + def apply(self, input_tensor): + pass + + def set_config(self, config=None): + if config is not None: + self.config = config + + +@CONV3D_WEIGHT_REGISTER("Default") +class Conv3dWeight(Conv3dWeightTemplate): + def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1): + super().__init__(weight_name, bias_name, stride, padding, dilation, groups) + + def load(self, weight_dict): + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + if self.bias_name is not None: + bias_shape = weight_dict[self.bias_name].shape + bias_dtype = weight_dict[self.bias_name].dtype + self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) + self.pin_bias.copy_(weight_dict[self.bias_name]) + else: + self.bias = None + self.pin_bias = None + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + if self.bias_name is not None: + self.bias = weight_dict[self.bias_name] + else: + self.bias = None + + def apply(self, input_tensor): + input_tensor = torch.nn.functional.conv3d( + input_tensor, + weight=self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + return input_tensor + + def to_cuda(self, non_blocking=False): + self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking) + if hasattr(self, "pin_bias") and self.pin_bias is not None: + self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking) + + def to_cpu(self, non_blocking=False): + if hasattr(self, "pin_weight"): + self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu() + if self.bias is not None: + self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu() + else: + self.weight = self.weight.to("cpu", non_blocking=non_blocking) + if hasattr(self, "bias") and self.bias is not None: + self.bias = self.bias.to("cpu", non_blocking=non_blocking) + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight # .cpu().detach().clone().contiguous() + if self.bias_name is not None: + destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias # .cpu().detach().clone() + return destination diff --git a/lightx2v/common/ops/embedding/__init__.py b/lightx2v/common/ops/embedding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df915a383076749252e8c22370e58537918c5f36 --- /dev/null +++ b/lightx2v/common/ops/embedding/__init__.py @@ -0,0 +1 @@ +from .embedding_weight import * diff --git a/lightx2v/common/ops/embedding/embedding_weight.py b/lightx2v/common/ops/embedding/embedding_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..8268d18188176f6bd17ed1a31eec324c9075dbb6 --- /dev/null +++ b/lightx2v/common/ops/embedding/embedding_weight.py @@ -0,0 +1,72 @@ +import re +from abc import ABCMeta + +import torch +import torch.nn.functional as F + +from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + + +class EmbeddingWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + self.weight_name = weight_name + self.create_cuda_buffer = create_cuda_buffer + self.create_cpu_buffer = create_cpu_buffer + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.is_post_adapter = is_post_adapter + self.config = {} + + def load(self, weight_dict): + if not self.lazy_load: + if self.create_cuda_buffer: + self.weight_cuda_buffer = weight_dict[self.weight_name].to(AI_DEVICE) + else: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + + def to_cuda(self, non_blocking=False): + self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking) + + def to_cpu(self, non_blocking=False): + if hasattr(self, "pin_weight"): + self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu() + else: + self.weight = self.weight.to("cpu", non_blocking=non_blocking) + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight + return destination + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if self.is_post_adapter: + assert adapter_block_index is not None + weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + else: + weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + + if weight_name not in destination: + self.weight = None + return + self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True) + + +@EMBEDDING_WEIGHT_REGISTER("Default") +class EmbeddingWeight(EmbeddingWeightTemplate): + def __init__(self, weight_name=None, lazy_load=False, lazy_load_file=None): + super().__init__(weight_name, lazy_load, lazy_load_file) + + def apply(self, input_indices): + output = F.embedding(input=input_indices, weight=self.weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False) + + return output diff --git a/lightx2v/common/ops/mm/__init__.py b/lightx2v/common/ops/mm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9898101b00e91590a78d636e2066ff1234a4d8 --- /dev/null +++ b/lightx2v/common/ops/mm/__init__.py @@ -0,0 +1 @@ +from .mm_weight import * diff --git a/lightx2v/common/ops/mm/mm_weight.py b/lightx2v/common/ops/mm/mm_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..330a50e0252432f12c5c6dad9414ad167f1f445f --- /dev/null +++ b/lightx2v/common/ops/mm/mm_weight.py @@ -0,0 +1,1325 @@ +import os +import re +from abc import ABCMeta, abstractmethod +from pathlib import Path + +import torch +from safetensors import safe_open + +from lightx2v.utils.envs import * +from lightx2v.utils.ggml_tensor import GGMLTensor +from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tensor +from lightx2v.utils.global_paras import CALIB +from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer +from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + +try: + from lightx2v_kernel.gemm import ( + cutlass_scaled_mxfp4_mm, + cutlass_scaled_mxfp6_mxfp8_mm, + cutlass_scaled_mxfp8_mm, + cutlass_scaled_nvfp4_mm, + scaled_mxfp4_quant, + scaled_mxfp6_quant, + scaled_mxfp8_quant, + scaled_nvfp4_quant, + ) +except ImportError: + scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm = None, None + scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm = None, None + scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm = None, None + scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None + +try: + from vllm import _custom_ops as ops +except ImportError: + ops = None + +try: + import sgl_kernel +except ImportError: + sgl_kernel = None + +try: + from q8_kernels.functional.linear import q8_linear +except ImportError: + q8_linear = None + +try: + from q8_kernels.functional.linear import fp8_linear +except ImportError: + fp8_linear = None + +try: + import deep_gemm +except ImportError: + deep_gemm = None + +try: + from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax +except ImportError: + quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None + +try: + import gguf +except ImportError: + gguf = None + +try: + import marlin_cuda_quant +except ImportError: + marlin_cuda_quant = None + + +class MMWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + self.weight_name = weight_name + self.bias_name = bias_name + self.create_cuda_buffer = create_cuda_buffer + self.create_cpu_buffer = create_cpu_buffer + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.is_post_adapter = is_post_adapter + self.config = {} + + @abstractmethod + def load(self, weight_dict): + pass + + @abstractmethod + def apply(self): + pass + + def set_config(self, config={}): + self.config = config + + def to_cuda(self, non_blocking=False): + self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking) + if hasattr(self, "pin_weight_scale"): + self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking) + if hasattr(self, "pin_bias") and self.pin_bias is not None: + self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking) + + def to_cpu(self, non_blocking=False): + if hasattr(self, "pin_weight"): + self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu() + if hasattr(self, "weight_scale_name"): + self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu() + if self.bias is not None: + self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu() + else: + self.weight = self.weight.to("cpu", non_blocking=non_blocking) + if hasattr(self, "weight_scale"): + self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking) + if hasattr(self, "bias") and self.bias is not None: + self.bias = self.bias.to("cpu", non_blocking=non_blocking) + + +@MM_WEIGHT_REGISTER("Default") +class MMWeight(MMWeightTemplate): + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + + def load(self, weight_dict): + if self.create_cuda_buffer: + self._load_cuda_buffers(weight_dict) + elif self.create_cpu_buffer: + self._load_cpu_pin_buffers() + else: + self._load_default_tensors(weight_dict) + + def _get_source_tensor(self, source_name, weight_dict=None): + if self.lazy_load: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{source_name.split('.')[1]}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + return lazy_load_file.get_tensor(source_name) + return weight_dict[source_name] + + def _create_pin_tensor(self, tensor, transpose=False): + pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype) + pin_tensor = pin_tensor.copy_(tensor) + if transpose: + pin_tensor = pin_tensor.t() + del tensor + return pin_tensor + + def _load_cuda_buffers(self, weight_dict): + self.weight_cuda_buffer = self._get_source_tensor(self.weight_name, weight_dict).t().to(AI_DEVICE) + if self.bias_name is not None: + self.bias_cuda_buffer = self._get_source_tensor(self.bias_name, weight_dict).to(AI_DEVICE) + + def _load_cpu_pin_buffers(self): + if self.lazy_load: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + weight_tensor = lazy_load_file.get_tensor(self.weight_name) + self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True) + + if self.bias_name is not None: + bias_tensor = lazy_load_file.get_tensor(self.bias_name) + self.pin_bias = self._create_pin_tensor(bias_tensor) + else: + self.bias = None + self.pin_bias = None + + def _load_default_tensors(self, weight_dict): + if not self.lazy_load: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_tensor = weight_dict[self.weight_name] + self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True) + + if self.bias_name is not None: + bias_tensor = weight_dict[self.bias_name] + self.pin_bias = self._create_pin_tensor(bias_tensor) + else: + self.bias = None + self.pin_bias = None + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name].t() + self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None + + def apply(self, input_tensor): + shape = (input_tensor.shape[0], self.weight.shape[1]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + if self.bias is None: + return torch.mm(input_tensor, self.weight, out=output_tensor) + return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor) + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight + if self.bias_name is not None: + destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias + return destination + + def load_state_dict_from_disk(self, block_index, adapter_block_index=None): + if self.is_post_adapter: + assert adapter_block_index is not None + self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + else: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + + if self.bias_name is not None: + if self.is_post_adapter: + assert adapter_block_index is not None + self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1) + else: + self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + weight_tensor = lazy_load_file.get_tensor(self.weight_name).t() + self.pin_weight = self.pin_weight.copy_(weight_tensor) + del weight_tensor + + if self.bias_name is not None: + bias_tensor = lazy_load_file.get_tensor(self.bias_name) + self.pin_bias.copy_(bias_tensor) + del bias_tensor + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if self.is_post_adapter: + assert adapter_block_index is not None + weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + else: + weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + + if weight_name not in destination: + self.weight = None + return + + self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True) + + if self.bias_name is not None: + if self.is_post_adapter: + assert adapter_block_index is not None + bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1) + else: + bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True) + else: + self.bias = None + + +@MM_WEIGHT_REGISTER("Default-Force-FP32") +class MMWeightForceFP32(MMWeight): + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + + def load(self, weight_dict): + if not self.lazy_load: + super().load(weight_dict) + self.weight = self.weight.to(torch.float32) + if hasattr(self, "bias") and self.bias is not None: + self.bias = self.bias.to(torch.float32) + + +class MMWeightQuantTemplate(MMWeightTemplate): + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale" + self.load_func = None + self.weight_need_transpose = True + self.act_quant_func = None + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.infer_dtype = GET_DTYPE() + self.bias_force_fp32 = False + + # ========================= + # weight load functions + # ========================= + def load(self, weight_dict): + self.load_quantized(weight_dict) + if self.weight_need_transpose: + if hasattr(self, "weight") and self.weight is not None: + self.weight = self.weight.t() + if hasattr(self, "pin_weight") and self.pin_weight is not None: + self.pin_weight = self.pin_weight.t() + if hasattr(self, "weight_cuda_buffer") and self.weight_cuda_buffer is not None: + self.weight_cuda_buffer = self.weight_cuda_buffer.t() + + def load_quantized(self, weight_dict): + if self.create_cuda_buffer: + self._load_cuda_buffers(weight_dict) + elif self.create_cpu_buffer: + self._load_cpu_pin_buffers() + else: + self._load_default_tensors(weight_dict) + + def _load_cuda_buffers(self, weight_dict): + if self.lazy_load: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source: + self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load) + self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load) + else: + source = weight_dict + self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load) + self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load) + + def _get_cuda_tensor_pair(self, source, is_lazy): + if is_lazy: + weight = source.get_tensor(self.weight_name).to(AI_DEVICE) + scale = source.get_tensor(self.weight_scale_name).float().to(AI_DEVICE) + else: + weight = source[self.weight_name].to(AI_DEVICE) + scale = source[self.weight_scale_name].float().to(AI_DEVICE) + return weight, scale + + def _get_cuda_bias_tensor(self, source, is_lazy): + if self.bias_name is None: + return None + if is_lazy: + bias = source.get_tensor(self.bias_name) + dtype = self.infer_dtype + else: + bias = source[self.bias_name] + dtype = bias.dtype + if self.bias_force_fp32: + bias = bias.to(torch.float32) + else: + bias = bias.to(dtype) + return bias.to(AI_DEVICE) + + def _load_cpu_pin_buffers(self): + self.pin_weight, self.pin_weight_scale = self._get_cpu_pin_tensor_pair(self.lazy_load_file, is_lazy=True) + self.pin_bias = self._get_cpu_pin_bias_tensor(self.lazy_load_file, is_lazy=True) + self.bias = None + + def _get_cpu_pin_tensor_pair(self, source, is_lazy): + if is_lazy: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source: + weight_tensor = source.get_tensor(self.weight_name) + scale_tensor = source.get_tensor(self.weight_scale_name) + scale_dtype = torch.float + pin_weight = self._create_pin_tensor(weight_tensor) + pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype) + else: + weight_tensor = source[self.weight_name] + scale_tensor = source[self.weight_scale_name] + scale_dtype = torch.float + pin_weight = self._create_pin_tensor(weight_tensor) + pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype) + return pin_weight, pin_scale + + def _get_cpu_pin_bias_tensor(self, source, is_lazy): + if self.bias_name is None: + return None + if is_lazy: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source: + bias_tensor = source.get_tensor(self.bias_name) + if not self.bias_force_fp32: + bias_tensor = bias_tensor.to(self.infer_dtype) + if self.bias_force_fp32: + bias_tensor = bias_tensor.to(torch.float32) + return self._create_pin_tensor(bias_tensor) + else: + bias_tensor = source[self.bias_name] + if self.bias_force_fp32: + bias_tensor = bias_tensor.to(torch.float32) + return self._create_pin_tensor(bias_tensor) + + def _create_pin_tensor(self, tensor, dtype=None): + dtype = dtype or tensor.dtype + pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype) + pin_tensor.copy_(tensor) + del tensor + return pin_tensor + + def _load_default_tensors(self, weight_dict): + if not self.lazy_load: + self.weight, self.weight_scale, self.pin_weight, self.pin_weight_scale = self._get_device_tensor_pair(weight_dict) + self._load_default_bias(weight_dict) + else: + self.bias = None + self.pin_bias = None + + def _get_device_tensor_pair(self, source): + device = source[self.weight_name].device + if device.type == "cpu": + pin_weight, pin_scale = self._get_cpu_pin_tensor_pair(source, is_lazy=False) + return None, None, pin_weight, pin_scale + else: + return source[self.weight_name], source[self.weight_scale_name].float(), None, None + + def _load_default_bias(self, source): + if self.bias_name is None: + self.bias = None + self.pin_bias = None + self.bias_cuda_buffer = None + return + + if self.create_cuda_buffer: + self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, is_lazy=False) + self.bias = None + self.pin_bias = None + else: + bias_tensor = source[self.bias_name].float() if self.bias_force_fp32 else source[self.bias_name] + device = bias_tensor.device + if device.type == "cpu": + self.pin_bias = self._get_cpu_pin_bias_tensor(source, is_lazy=False) + self.bias = None + else: + self.bias = bias_tensor + self.pin_bias = None + + def load_fp8_perchannel_sym(self, weight_dict): + if self.config.get("weight_auto_quant", False): + self.weight = weight_dict[self.weight_name].to(torch.float32) + w_quantizer = FloatQuantizer("e4m3", True, "per_channel") + self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) + self.weight = self.weight.to(torch.float8_e4m3fn) + self.weight_scale = self.weight_scale.to(torch.float32) + else: + self.load_quantized(weight_dict) + + def load_int8_perchannel_sym(self, weight_dict): + if self.config.get("weight_auto_quant", False): + self.weight = weight_dict[self.weight_name].to(torch.float32) + w_quantizer = IntegerQuantizer(8, True, "per_channel") + self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) + self.weight = self.weight.to(torch.int8) + self.weight_scale = self.weight_scale.to(torch.float32) + else: + self.load_quantized(weight_dict) + + def load_mxfp4(self, weight_dict): + if self.config.get("weight_auto_quant", False): + device = weight_dict[self.weight_name].device + self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16) + self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight) + self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) + else: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + weight_scale_shape = weight_dict[self.weight_scale_name].shape + weight_scale_dtype = weight_dict[self.weight_scale_name].dtype + self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) + self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.weight_scale = weight_dict[self.weight_scale_name] + + def load_mxfp6(self, weight_dict): + if self.config.get("weight_auto_quant", False): + device = weight_dict[self.weight_name].device + self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16) + self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight) + self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) + else: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + weight_scale_shape = weight_dict[self.weight_scale_name].shape + weight_scale_dtype = weight_dict[self.weight_scale_name].dtype + self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) + self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.weight_scale = weight_dict[self.weight_scale_name] + + def load_mxfp8(self, weight_dict): + if self.config.get("weight_auto_quant", False): + device = weight_dict[self.weight_name].device + self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16) + self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight) + self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) + else: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + weight_scale_shape = weight_dict[self.weight_scale_name].shape + weight_scale_dtype = weight_dict[self.weight_scale_name].dtype + self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) + self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.weight_scale = weight_dict[self.weight_scale_name] + + def load_nvfp4(self, weight_dict): + device = weight_dict[self.weight_name].device + + input_absmax = weight_dict[self.weight_name.replace(".weight", ".input_absmax")] + input_global_scale = (2688.0 / input_absmax).to(torch.float32) + weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"] + alpha = 1.0 / (input_global_scale * weight_global_scale) + + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + weight_scale_shape = weight_dict[self.weight_scale_name].shape + weight_scale_dtype = weight_dict[self.weight_scale_name].dtype + self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) + self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) + + input_global_scale_shape = input_global_scale.shape + input_global_scale_dtype = input_global_scale.dtype + self.pin_input_global_scale = torch.empty(input_global_scale_shape, pin_memory=True, dtype=input_global_scale_dtype) + self.pin_input_global_scale.copy_(input_global_scale) + + alpha_shape = alpha.shape + alpha_dtype = alpha.dtype + self.pin_alpha = torch.empty(alpha_shape, pin_memory=True, dtype=alpha_dtype) + self.pin_alpha.copy_(alpha) + + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.weight_scale = weight_dict[self.weight_scale_name] + self.input_global_scale = input_global_scale + self.alpha = alpha + + if self.bias_name is not None: + if self.create_cuda_buffer: + self.bias_cuda_buffer = weight_dict[self.bias_name].to(AI_DEVICE) + else: + device = weight_dict[self.bias_name].device + if device.type == "cpu": + bias_shape = weight_dict[self.bias_name].shape + bias_dtype = weight_dict[self.bias_name].dtype + self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) + self.pin_bias.copy_(weight_dict[self.bias_name]) + else: + self.bias = weight_dict[self.bias_name] + else: + self.bias = None + self.pin_bias = None + + def load_fp8_perblock128_sym(self, weight_dict): + if self.config.get("weight_auto_quant", False): + self.weight = weight_dict[self.weight_name] + self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight) + else: + self.load_quantized(weight_dict) + + def per_block_cast_to_fp8(self, x): + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128), + dtype=x.dtype, + device=x.device, + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + # ========================= + # act quant kernels + # ========================= + def act_quant_int8_perchannel_sym_torchao(self, x): + input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x) + return input_tensor_quant, input_tensor_scale + + def act_quant_fp8_perchannel_sym_vllm(self, x): + input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) + return input_tensor_quant, input_tensor_scale + + def act_quant_fp8_perchannel_sym_sgl(self, x): + m, k = x.shape + input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False) + input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False) + sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale) + return input_tensor_quant, input_tensor_scale + + def act_quant_int8_perchannel_sym_vllm(self, x): + input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) + return input_tensor_quant, input_tensor_scale + + def act_quant_nvfp4(self, x): + input_tensor_quant, input_tensor_scale = scaled_nvfp4_quant(x, self.input_global_scale) + return input_tensor_quant, input_tensor_scale + + def act_quant_mxfp4(self, x): + input_tensor_quant, input_tensor_scale = scaled_mxfp4_quant(x) + return input_tensor_quant, input_tensor_scale + + def act_quant_mxfp8(self, x): + input_tensor_quant, input_tensor_scale = scaled_mxfp8_quant(x) + return input_tensor_quant, input_tensor_scale + + def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x): + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + def act_quant_fp8_perchannelgroup128_sym_sgl(self, x): + m, k = x.shape + input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False) + input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False) + sgl_kernel.sgl_per_token_group_quant_fp8( + x, + input_tensor_quant, + input_tensor_scale, + group_size=128, + eps=1e-10, + fp8_min=-448.0, + fp8_max=448.0, + ) + return input_tensor_quant, input_tensor_scale + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight + if self.bias_name is not None: + destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias + destination[self.weight_scale_name] = self.pin_weight_scale if hasattr(self, "pin_weight_scale") else self.weight_scale + return destination + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if self.is_post_adapter: + weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1) + else: + weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1) + + if weight_name not in destination: + self.weight = None + return + + self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True) + self.weight_scale = self.weight_scale_cuda_buffer.copy_(destination[weight_scale_name], non_blocking=True) + + if self.bias_name is not None: + bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True) + else: + self.bias = None + + def load_state_dict_from_disk(self, block_index, adapter_block_index=None): + if self.is_post_adapter: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1) + else: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1) + + if self.bias_name is not None: + if self.is_post_adapter: + assert adapter_block_index is not None + self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1) + else: + self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + if self.weight_need_transpose: + weight_tensor = lazy_load_file.get_tensor(self.weight_name).t() + else: + weight_tensor = lazy_load_file.get_tensor(self.weight_name) + + self.pin_weight = self.pin_weight.copy_(weight_tensor) + del weight_tensor + + weight_scale_tensor = lazy_load_file.get_tensor(self.weight_scale_name) + self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor) + del weight_scale_tensor + + if self.bias_name is not None: + bias_tensor = lazy_load_file.get_tensor(self.bias_name) + self.pin_bias.copy_(bias_tensor) + del bias_tensor + + +@MM_WEIGHT_REGISTER("fp8-vllm") +class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate): + """ + Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm + + Quant MM: + Weight: fp8 perchannel sym + Act: fp8 perchannel dynamic sym + Kernel: vllm + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_fp8_perchannel_sym + self.weight_need_transpose = True + self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm + + def apply(self, input_tensor): + shape = (input_tensor.shape[0], self.weight.shape[1]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + torch.ops._C.cutlass_scaled_mm( + output_tensor, + input_tensor_quant, + self.weight, + input_tensor_scale, + self.weight_scale, + self.bias if self.bias is not None else None, + ) + return output_tensor + + +@MM_WEIGHT_REGISTER("int8-vllm") +class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate): + """ + Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm + + Quant MM: + Weight: int8 perchannel sym + Act: int8 perchannel dynamic sym + Kernel: vllm + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_int8_perchannel_sym + self.weight_need_transpose = True + self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + + def apply(self, input_tensor): + shape = (input_tensor.shape[0], self.weight.shape[1]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + torch.ops._C.cutlass_scaled_mm( + output_tensor, + input_tensor_quant, + self.weight, + input_tensor_scale, + self.weight_scale, + self.bias if self.bias is not None else None, + ) + return output_tensor + + +@MM_WEIGHT_REGISTER("mxfp4") +class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate): + """ + Name: W-mxfp4-A-mxfp4-dynamic + + Quant MM: + Weight: mxfp4 + Act: mxfp4 + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_mxfp4 + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_mxfp4 + self.set_alpha() + + def set_alpha(self): + self.alpha = torch.tensor(1.0, dtype=torch.float32) + + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + self.alpha = self.alpha.to(self.weight.device) + output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias) + return output_tensor + + +@MM_WEIGHT_REGISTER("mxfp6-mxfp8") +class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate): + """ + Name: W-mxfp6-A-nvfp8-dynamic + + Quant MM: + Weight: mxfp6 + Act: mxfp8 + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_mxfp6 + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_mxfp8 + self.set_alpha() + + def set_alpha(self): + self.alpha = torch.tensor(1.0, dtype=torch.float32) + + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + self.alpha = self.alpha.to(self.weight.device) + output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias) + return output_tensor + + +@MM_WEIGHT_REGISTER("mxfp8") +class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate): + """ + Name: W-mxfp8-A-nvfp8-dynamic + + Quant MM: + Weight: mxfp8 + Act: mxfp8 + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_mxfp8 + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_mxfp8 + self.set_alpha() + + def set_alpha(self): + self.alpha = torch.tensor(1.0, dtype=torch.float32) + + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + self.alpha = self.alpha.to(self.weight.device) + output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias) + return output_tensor + + +@MM_WEIGHT_REGISTER("nvfp4") +class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate): + """ + Name: W-nvfp4-A-nvfp4-dynamic + + Quant MM: + Weight: nvfp4 + Act: nvfp4 + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_nvfp4 + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_nvfp4 + + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias) + return output_tensor + + def to_cuda(self, non_blocking=False): + self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking) + if hasattr(self, "pin_weight_scale"): + self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking) + self.input_global_scale = self.pin_input_global_scale.to(AI_DEVICE, non_blocking=non_blocking) + self.alpha = self.pin_alpha.to(AI_DEVICE, non_blocking=non_blocking) + if hasattr(self, "pin_bias") and self.pin_bias is not None: + self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking) + + def to_cpu(self, non_blocking=False): + if hasattr(self, "pin_weight"): + self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu() + if hasattr(self, "weight_scale_name"): + self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu() + self.input_global_scale = self.pin_input_global_scale.copy_(self.input_global_scale, non_blocking=non_blocking).cpu() + self.alpha = self.pin_alpha.copy_(self.alpha, non_blocking=non_blocking).cpu() + if self.bias is not None: + self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu() + else: + self.weight = self.weight.to("cpu", non_blocking=non_blocking) + if hasattr(self, "weight_scale"): + self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking) + self.input_global_scale = self.input_global_scale.to("cpu", non_blocking=non_blocking) + self.alpha = self.alpha.to("cpu", non_blocking=non_blocking) + if hasattr(self, "bias") and self.bias is not None: + self.bias = self.bias.to("cpu", non_blocking=non_blocking) + + +@MM_WEIGHT_REGISTER("Calib") +class MMCalibNvfp4(MMWeight): + """ + Name: calib + + Calib: + absmax: torch.max(torch.abs(input_tensor)) + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.running_absmax = None + self.count = 0 + self.decay = 0.9 + + def apply(self, input_tensor): + shape = (input_tensor.shape[0], self.weight.shape[1]) + dtype, device = input_tensor.dtype, input_tensor.device + + current_absmax = torch.max(torch.abs(input_tensor)).to("cpu") + if self.count % 2 == 0: + if self.running_absmax is None: + self.running_absmax = current_absmax + else: + self.running_absmax = self.decay * self.running_absmax + (1 - self.decay) * current_absmax + CALIB["absmax"][self.weight_name] = self.running_absmax + self.count = self.count + 1 + + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + if self.bias is None: + return torch.mm(input_tensor, self.weight, out=output_tensor) + return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor) + + +@MM_WEIGHT_REGISTER("fp8-q8f") +class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate): + """ + Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F + + Quant MM: + Weight: fp8 perchannel sym + Act: fp8 perchannel dynamic sym + Kernel: Q8F + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_fp8_perchannel_sym + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm + self.bias_force_fp32 = True + + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = fp8_linear( + input_tensor_quant, + self.weight, + self.bias.float() if self.bias is not None else None, + input_tensor_scale, + self.weight_scale, + out_dtype=self.infer_dtype, + ) + return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor + + +@MM_WEIGHT_REGISTER("int8-q8f") +class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate): + """ + Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F + + Quant MM: + Weight: int8 perchannel sym + Act: int8 perchannel dynamic sym + Kernel: Q8F + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_int8_perchannel_sym + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = q8_linear( + input_tensor_quant, + self.weight, + self.bias.float() if self.bias is not None else None, + input_tensor_scale, + self.weight_scale, + fuse_gelu=False, + out_dtype=self.infer_dtype, + ) + return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor + + +@MM_WEIGHT_REGISTER("fp8-b128-deepgemm") +class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate): + """ + Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl + + Quant MM: + Weight: fp8 perblock 128x128 sym + Act: fp8 pertoken-pergroup group=128 dynamic sym + Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_fp8_perblock128_sym + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl + + def apply(self, input_tensor): + shape = (input_tensor.shape[0], self.weight.shape[0]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + deep_gemm.gemm_fp8_fp8_bf16_nt( + (input_tensor_quant, input_tensor_scale), + (self.weight, self.weight_scale), + output_tensor, + ) + if hasattr(self, "bias") and self.bias is not None: + output_tensor.add_(self.bias) + return output_tensor + + +@MM_WEIGHT_REGISTER("fp8-sgl") +class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate): + """ + Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl + + Quant MM: + Weight: fp8 perchannel sym + Act: fp8 perchannel dynamic sym + Kernel: Sgl-kernel + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_fp8_perchannel_sym + self.weight_need_transpose = True + self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl + + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = sgl_kernel.fp8_scaled_mm( + input_tensor_quant, + self.weight, + input_tensor_scale, + self.weight_scale, + self.infer_dtype, + self.bias if self.bias is not None else None, + ) + return output_tensor + + +@MM_WEIGHT_REGISTER("int8-sgl") +class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): + """ + Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm + + Quant MM: + Weight: int8 perchannel sym + Act: int8 perchannel dynamic sym + Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_int8_perchannel_sym + self.weight_need_transpose = True + self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + + def apply(self, input_tensor): + shape = (input_tensor.shape[0], self.weight.shape[1]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = sgl_kernel.int8_scaled_mm( + input_tensor_quant, + self.weight, + input_tensor_scale, + self.weight_scale, + self.infer_dtype, + self.bias if self.bias is not None else None, + ) + return output_tensor + + +@MM_WEIGHT_REGISTER("int8-torchao") +class MMWeightWint8channelAint8channeldynamicTorchao(MMWeightQuantTemplate): + """ + Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao + + Quant MM: + Weight: int8 perchannel sym + Act: int8 perchannel dynamic sym + Kernel: Torchao + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_int8_perchannel_sym + self.weight_need_transpose = True + self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao + + def apply(self, input_tensor): + input_tensor = input_tensor + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype) + if self.bias is not None: + output_tensor = output_tensor + self.bias + + return output_tensor + + +class MMWeightGGUFTemplate(MMWeightTemplate): + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + + def load(self, weight_dict): + if not self.lazy_load: + assert not self.create_cuda_buffer, "GGUF Unsupported offload block" + self.weight = weight_dict[self.weight_name] + + weight_shape = self.weight.shape + weight_dtype = self.weight.dtype + + if isinstance(self.weight, GGMLTensor): + self.pin_weight = GGMLTensor.empty_pinned(weight_shape, orig_shape=self.weight.orig_shape, dtype=weight_dtype, gguf_type=self.weight.gguf_type) + self.pin_weight.copy_from(self.weight) + else: + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + if self.bias_name is not None: + self.bias = weight_dict[self.bias_name] + if isinstance(self.bias, GGMLTensor): + self.pin_bias = GGMLTensor.empty_pinned(self.bias.shape, orig_shape=self.bias.orig_shape, dtype=self.bias.dtype, gguf_type=self.bias.gguf_type) + self.pin_bias.copy_from(self.bias) + else: + self.pin_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) + self.pin_bias.copy_(weight_dict[self.bias_name]) + else: + self.bias = None + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if self.is_post_adapter: + assert adapter_block_index is not None + weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + else: + weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + + if weight_name not in destination: + self.weight = None + return + + self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True) + + if self.bias_name is not None: + if self.is_post_adapter: + assert adapter_block_index is not None + bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1) + else: + bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True) + else: + self.bias = None + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight + if self.bias_name is not None: + destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias + + return destination + + def get_weight(self, tensor, dtype): + if tensor is None: + return + + weight = gguf_dequantize_tensor(tensor, dtype) + if isinstance(weight, GGMLTensor): + weight = torch.Tensor(weight) + + return weight + + def cast_bias_weight(self, input_tensor=None, dtype=None, device=None, bias_dtype=None): + if input_tensor is not None: + if dtype is None: + dtype = getattr(input_tensor, "dtype", torch.float32) + + bias = None + if self.bias is not None: + bias = self.get_weight(self.bias, dtype) + + weight = self.get_weight(self.weight, dtype) + return weight, bias + + def apply(self, input_tensor): + weight, bias = self.cast_bias_weight(input_tensor) + return torch.nn.functional.linear(input_tensor, weight, bias) + + +@MM_WEIGHT_REGISTER("gguf-BF16") +class MMWeightGGUFBF16(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.BF16 + + +@MM_WEIGHT_REGISTER("gguf-Q8_0") +class MMWeightGGUFQ80(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q8_0 + + +@MM_WEIGHT_REGISTER("gguf-Q6_K") +class MMWeightGGUFQ6K(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q6_K + + +@MM_WEIGHT_REGISTER("gguf-Q5_K_S") +class MMWeightGGUFQ5KS(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q6_K + + +@MM_WEIGHT_REGISTER("gguf-Q5_K_M") +class MMWeightGGUFQ5KM(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q6_K + + +@MM_WEIGHT_REGISTER("gguf-Q5_1") +class MMWeightGGUFQ51(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q5_1 + + +@MM_WEIGHT_REGISTER("gguf-Q5_0") +class MMWeightGGUFQ50(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q5_0 + + +@MM_WEIGHT_REGISTER("gguf-Q4_K_M") +class MMWeightGGUFQ4KM(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q5_0 + + +@MM_WEIGHT_REGISTER("gguf-Q4_K_S") +class MMWeightGGUFQ4KS(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q4_K + + +@MM_WEIGHT_REGISTER("gguf-Q4_1") +class MMWeightGGUFQ41(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q4_1 + + +@MM_WEIGHT_REGISTER("gguf-Q4_0") +class MMWeightGGUFQ40(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q4_0 + + +@MM_WEIGHT_REGISTER("gguf-Q3_K_M") +class MMWeightGGUFQ3KM(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q3_K + + +@MM_WEIGHT_REGISTER("gguf-Q3_K_S") +class MMWeightGGUFQ3KS(MMWeightGGUFTemplate): + qtype = gguf.GGMLQuantizationType.Q2_K + + +@MM_WEIGHT_REGISTER("int4-g128-marlin") +class MMWeightWint4group128Marlin(MMWeightQuantTemplate): + """ + Name: "W-int4-group128-sym-Marlin + + Quant int4 x FP16: + Weight: int4 pergroup sym + Kernel: Marlin + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_quantized + + def load(self, weight_dict): + assert not self.lazy_load + self.load_func(weight_dict) + self.workspace = weight_dict[f"{self.weight_name}_workspace"] + + if self.bias_name is not None: + bias_shape = weight_dict[self.bias_name].shape + bias_dtype = weight_dict[self.bias_name].dtype + self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) + self.bias.copy_(weight_dict[self.bias_name]) + else: + self.bias = None + + def apply(self, input_tensor): + output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device) + marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1) + if hasattr(self, "bias") and self.bias is not None: + output_tensor.add_(self.bias) + return output_tensor diff --git a/lightx2v/common/ops/norm/__init__.py b/lightx2v/common/ops/norm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e724797374c92b0ca99f38778067216867174e0 --- /dev/null +++ b/lightx2v/common/ops/norm/__init__.py @@ -0,0 +1,2 @@ +from .layer_norm_weight import * +from .rms_norm_weight import * diff --git a/lightx2v/common/ops/norm/layer_norm_weight.py b/lightx2v/common/ops/norm/layer_norm_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..38db511751673035761c339d10b151dc555e1305 --- /dev/null +++ b/lightx2v/common/ops/norm/layer_norm_weight.py @@ -0,0 +1,220 @@ +import os +import re +from abc import ABCMeta, abstractmethod +from pathlib import Path + +import torch +from safetensors import safe_open + +from lightx2v.utils.envs import * +from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + +from .triton_ops import norm_infer + + +class LNWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6): + self.weight_name = weight_name + self.bias_name = bias_name + self.eps = eps + self.create_cuda_buffer = create_cuda_buffer + self.create_cpu_buffer = create_cpu_buffer + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.is_post_adapter = is_post_adapter + self.config = {} + self.infer_dtype = GET_DTYPE() + self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() + + def load(self, weight_dict): + if self.create_cuda_buffer: + self._load_cuda_buffers(weight_dict) + elif self.create_cpu_buffer: + self._load_cpu_pin_buffers() + else: + self._load_default_tensors(weight_dict) + + def _load_default_tensors(self, weight_dict): + if not self.lazy_load and self.weight_name is not None: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_tensor = weight_dict[self.weight_name] + self.pin_weight = self._create_cpu_pin_tensor(weight_tensor) + bias_tensor = weight_dict[self.bias_name] if self.bias_name is not None else None + self.pin_bias = self._create_cpu_pin_tensor(bias_tensor) if bias_tensor is not None else None + self.bias = None + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None + else: + self.weight = None + self.bias = None + + def _get_tensor(self, name, weight_dict=None, use_infer_dtype=False): + if name is None: + return None + if self.lazy_load: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{name.split('.')[1]}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + tensor = lazy_load_file.get_tensor(name) + if use_infer_dtype: + tensor = tensor.to(self.infer_dtype) + else: + tensor = weight_dict[name] + return tensor + + def _create_cpu_pin_tensor(self, tensor): + if tensor is None: + return None + pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype) + pin_tensor.copy_(tensor) + del tensor + return pin_tensor + + def _load_cuda_buffers(self, weight_dict): + weight_tensor = self._get_tensor(self.weight_name, weight_dict, use_infer_dtype=self.lazy_load) + if weight_tensor is not None: + self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE) + + bias_tensor = self._get_tensor(self.bias_name, weight_dict, use_infer_dtype=self.lazy_load) + if bias_tensor is not None: + self.bias_cuda_buffer = bias_tensor.to(AI_DEVICE) + + def _load_cpu_pin_buffers(self): + weight_tensor = self._get_tensor(self.weight_name, use_infer_dtype=True) + if weight_tensor is not None: + self.pin_weight = self._create_cpu_pin_tensor(weight_tensor) + else: + self.weight = None + + bias_tensor = self._get_tensor(self.bias_name, use_infer_dtype=True) + if bias_tensor is not None: + self.pin_bias = self._create_cpu_pin_tensor(bias_tensor) + else: + self.bias = None + self.pin_bias = None + + @abstractmethod + def apply(self, input_tensor): + pass + + def set_config(self, config=None): + if config is not None: + self.config = config + + def to_cuda(self, non_blocking=False): + if hasattr(self, "pin_weight") and self.pin_weight is not None: + self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking) + else: + self.weight = None + if hasattr(self, "pin_bias") and self.pin_bias is not None: + self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking) + else: + self.bias = None + + def to_cpu(self, non_blocking=False): + if hasattr(self, "pin_weight") and self.pin_weight is not None: + self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu() + if self.bias is not None: + self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu() + elif hasattr(self, "weight") and self.weight is not None: + self.weight = self.weight.to("cpu", non_blocking=non_blocking) + if hasattr(self, "bias") and self.bias is not None: + self.bias = self.bias.to("cpu", non_blocking=non_blocking) + + def state_dict(self, destination=None): + if destination is None: + destination = {} + if self.weight_name is not None: + destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight + if self.bias_name is not None: + destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias + return destination + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if self.weight_name is not None: + if self.is_post_adapter: + assert adapter_block_index is not None + weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + else: + weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + + if weight_name not in destination: + self.weight = None + return + self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True) + else: + self.weight = None + + if self.bias_name is not None: + bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True) + else: + self.bias = None + + def load_state_dict_from_disk(self, block_index, adapter_block_index=None): + if self.weight_name is not None: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors") + if self.is_post_adapter: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + else: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype) + self.pin_weight = self.pin_weight.copy_(weight_tensor) + del weight_tensor + + if self.bias_name is not None: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors") + if self.is_post_adapter: + assert adapter_block_index is not None + self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1) + else: + self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + bias_tensor = lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype) + self.pin_bias.copy_(bias_tensor) + del bias_tensor + + +@LN_WEIGHT_REGISTER("Default") +class LNWeight(LNWeightTemplate): + def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) + + def apply(self, input_tensor): + if self.sensitive_layer_dtype != self.infer_dtype: + input_tensor = torch.nn.functional.layer_norm( + input_tensor.float(), + (input_tensor.shape[-1],), + self.weight, + self.bias, + self.eps, + ).to(self.infer_dtype) + else: + input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps) + + return input_tensor + + +@LN_WEIGHT_REGISTER("Triton") +class LNWeight(LNWeightTemplate): + def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) + + def apply(self, input_tensor): + input_tensor = norm_infer(input_tensor, self.weight, self.bias, self.eps) + return input_tensor diff --git a/lightx2v/common/ops/norm/rms_norm_weight.py b/lightx2v/common/ops/norm/rms_norm_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..98d60ef2cba0aab9495e7f690f17d301eb82a81d --- /dev/null +++ b/lightx2v/common/ops/norm/rms_norm_weight.py @@ -0,0 +1,204 @@ +import os +import re +from abc import ABCMeta, abstractmethod +from pathlib import Path + +import torch +from safetensors import safe_open + +from lightx2v.utils.envs import * +from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + +try: + import sgl_kernel +except ImportError: + sgl_kernel = None + + +class RMSWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6): + self.weight_name = weight_name + self.eps = eps + self.create_cuda_buffer = create_cuda_buffer + self.create_cpu_buffer = create_cpu_buffer + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.is_post_adapter = is_post_adapter + self.infer_dtype = GET_DTYPE() + self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() + self.config = {} + + def load(self, weight_dict): + if self.create_cuda_buffer: + self._load_cuda_buffer(weight_dict) + elif self.create_cpu_buffer: + self._load_cpu_pin_buffer() + else: + self._load_default_tensors(weight_dict) + + def _load_default_tensors(self, weight_dict): + if not self.lazy_load: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_tensor = weight_dict[self.weight_name] + self.pin_weight = self._create_cpu_pin_weight(weight_tensor) + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + + def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False): + if self.lazy_load: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + tensor = lazy_load_file.get_tensor(self.weight_name) + if use_infer_dtype: + tensor = tensor.to(self.infer_dtype) + else: + tensor = weight_dict[self.weight_name] + return tensor + + def _create_cpu_pin_weight(self, tensor): + pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype) + pin_tensor.copy_(tensor) + del tensor + return pin_tensor + + def _load_cuda_buffer(self, weight_dict): + weight_tensor = self._get_weight_tensor(weight_dict, use_infer_dtype=self.lazy_load) + self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE) + + def _load_cpu_pin_buffer(self): + weight_tensor = self._get_weight_tensor(use_infer_dtype=True) + self.pin_weight = self._create_cpu_pin_weight(weight_tensor) + + @abstractmethod + def apply(self, input_tensor): + pass + + def set_config(self, config=None): + if config is not None: + self.config = config + + def to_cuda(self, non_blocking=False): + self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking) + + def to_cpu(self, non_blocking=False): + if hasattr(self, "pin_weight"): + self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu() + else: + self.weight = self.weight.to("cpu", non_blocking=non_blocking) + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight + return destination + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if self.is_post_adapter: + assert adapter_block_index is not None + weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + else: + weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + + if weight_name not in destination: + self.weight = None + return + self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True) + + def load_state_dict_from_disk(self, block_index, adapter_block_index=None): + if self.is_post_adapter: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + else: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype) + self.pin_weight = self.pin_weight.copy_(weight_tensor) + del weight_tensor + + +@RMS_WEIGHT_REGISTER("Default") +class RMSWeight(RMSWeightTemplate): + def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6): + super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def apply(self, input_tensor): + if GET_SENSITIVE_DTYPE() != GET_DTYPE(): + input_tensor = self._norm(input_tensor).type_as(input_tensor) * self.weight + else: + input_tensor = self._norm(input_tensor.float()).type_as(input_tensor) * self.weight + return input_tensor + + +@RMS_WEIGHT_REGISTER("sgl-kernel") +class RMSWeightSgl(RMSWeight): + def __init__( + self, + weight_name, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + is_post_adapter=False, + eps=1e-6, + ): + super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) + + def apply(self, input_tensor): + if sgl_kernel is not None and self.sensitive_layer_dtype == self.infer_dtype: + input_tensor = input_tensor.contiguous() + orig_shape = input_tensor.shape + input_tensor = input_tensor.view(-1, orig_shape[-1]) + input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape) + else: + # sgl_kernel is not available or dtype!=torch.bfloat16/float16, fallback to default implementation + if self.sensitive_layer_dtype != self.infer_dtype: + input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).to(self.infer_dtype) + input_tensor = (input_tensor * self.weight).to(self.infer_dtype) + else: + input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps) + input_tensor = input_tensor * self.weight + + return input_tensor + + +@RMS_WEIGHT_REGISTER("fp32_variance") +class RMSWeightFP32(RMSWeight): + def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6): + super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) + + def apply(self, input_tensor): + input_dtype = input_tensor.dtype + variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = input_tensor * torch.rsqrt(variance + self.eps) + + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + if self.weight is not None: + hidden_states = hidden_states * self.weight + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +@RMS_WEIGHT_REGISTER("self_forcing") +class RMSWeightSF(RMSWeight): + def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6): + super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def apply(self, x): + return self._norm(x.float()).type_as(x) * self.weight diff --git a/lightx2v/common/ops/norm/triton_ops.py b/lightx2v/common/ops/norm/triton_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0ad507bc467600547dd491a012e5ca22fc6cb4 --- /dev/null +++ b/lightx2v/common/ops/norm/triton_ops.py @@ -0,0 +1,900 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang + +# TODO: for temporary usage, expecting a refactor +from typing import Optional + +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from torch import Tensor + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64}, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_N": 1024}, num_warps=8), + ], + key=["inner_dim"], +) +@triton.jit +def _fused_scale_shift_4d_kernel( + output_ptr, + normalized_ptr, + scale_ptr, + shift_ptr, + rows, + inner_dim, + seq_len, + num_frames, + frame_seqlen, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N) + mask = col_offsets < inner_dim + + # Pointers for normalized and output + row_base = pid_row * inner_dim + norm_ptrs = normalized_ptr + row_base + col_offsets + out_ptrs = output_ptr + row_base + col_offsets + + # Pointers for scale and shift for 4D + b_idx = pid_row // seq_len + t_idx = pid_row % seq_len + frame_idx_in_batch = t_idx // frame_seqlen + + scale_row_idx = b_idx * num_frames + frame_idx_in_batch + scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets + shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets + + normalized = tl.load(norm_ptrs, mask=mask, other=0.0) + scale = tl.load(scale_ptrs, mask=mask, other=0.0) + shift = tl.load(shift_ptrs, mask=mask, other=0.0) + + one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype) + output = normalized * (one + scale) + shift + + tl.store(out_ptrs, output, mask=mask) + + +@triton.jit +def fuse_scale_shift_kernel_blc_opt( + x_ptr, + shift_ptr, + scale_ptr, + y_ptr, + B, + L, + C, + stride_x_b, + stride_x_l, + stride_x_c, + stride_s_b, + stride_s_l, + stride_s_c, + stride_sc_b, + stride_sc_l, + stride_sc_c, + SCALE_IS_SCALAR: tl.constexpr, + SHIFT_IS_SCALAR: tl.constexpr, + BLOCK_L: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_l = tl.program_id(0) + pid_c = tl.program_id(1) + pid_b = tl.program_id(2) + + l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_l = l_offsets < L + mask_c = c_offsets < C + mask = mask_l[:, None] & mask_c[None, :] + + x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c + x = tl.load(x_ptr + x_off, mask=mask, other=0) + + if SHIFT_IS_SCALAR: + shift_val = tl.load(shift_ptr) + shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype) + else: + s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c + shift = tl.load(shift_ptr + s_off, mask=mask, other=0) + + if SCALE_IS_SCALAR: + scale_val = tl.load(scale_ptr) + scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype) + else: + sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c + scale = tl.load(scale_ptr + sc_off, mask=mask, other=0) + + y = x * (1 + scale) + shift + tl.store(y_ptr + x_off, y, mask=mask) + + +def fuse_scale_shift_kernel( + x: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + block_l: int = 128, + block_c: int = 128, +): + assert x.is_cuda and scale.is_cuda + assert x.is_contiguous() + + B, L, C = x.shape + output = torch.empty_like(x) + + if scale.dim() == 4: + # scale/shift: [B, F, 1, C] + rows = B * L + x_2d = x.view(rows, C) + output_2d = output.view(rows, C) + grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa + num_frames = scale.shape[1] + assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift" + frame_seqlen = L // num_frames + + # Compact [B, F, C] without the singleton dim into [B*F, C] + scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous() + shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous() + + _fused_scale_shift_4d_kernel[grid]( + output_2d, + x_2d, + scale_reshaped, + shift_reshaped, + rows, + C, + L, + num_frames, + frame_seqlen, + ) + else: + # 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L + # 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C]) + # Also support scalar (0D or 1-element) + if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1): + scale_blc = scale.reshape(1) + elif scale.dim() == 2: + scale_blc = scale[:, None, :] + elif scale.dim() == 3: + scale_blc = scale + else: + raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D") + + if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1): + shift_blc = shift.reshape(1) + elif shift.dim() == 2: + shift_blc = shift[:, None, :] + elif shift.dim() == 3: + shift_blc = shift + else: + # broadcast later via expand if possible + shift_blc = shift + + need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1 + need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1 + + if not need_scale_scalar: + scale_exp = scale_blc.expand(B, L, C) + s_sb, s_sl, s_sc = scale_exp.stride() + else: + s_sb = s_sl = s_sc = 0 + + if not need_shift_scalar: + shift_exp = shift_blc.expand(B, L, C) + sh_sb, sh_sl, sh_sc = shift_exp.stride() + else: + sh_sb = sh_sl = sh_sc = 0 + + # If both scalars and both zero, copy fast-path + if need_scale_scalar and need_shift_scalar: + if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0): + output.copy_(x) + return output + + grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) + fuse_scale_shift_kernel_blc_opt[grid]( + x, + shift_blc if need_shift_scalar else shift_exp, + scale_blc if need_scale_scalar else scale_exp, + output, + B, + L, + C, + x.stride(0), + x.stride(1), + x.stride(2), + sh_sb, + sh_sl, + sh_sc, + s_sb, + s_sl, + s_sc, + SCALE_IS_SCALAR=need_scale_scalar, + SHIFT_IS_SCALAR=need_shift_scalar, + BLOCK_L=block_l, + BLOCK_C=block_c, + num_warps=4, + num_stages=2, + ) + return output + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2), + triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8), + ], + key=["head_size", "interleaved"], +) +@triton.jit +def _rotary_embedding_kernel( + output_ptr, + x_ptr, + cos_ptr, + sin_ptr, + num_heads, + head_size, + num_tokens, + stride_x_row, + stride_cos_row, + stride_sin_row, + interleaved: tl.constexpr, + BLOCK_HS_HALF: tl.constexpr, +): + row_idx = tl.program_id(0) + token_idx = (row_idx // num_heads) % num_tokens + + x_row_ptr = x_ptr + row_idx * stride_x_row + cos_row_ptr = cos_ptr + token_idx * stride_cos_row + sin_row_ptr = sin_ptr + token_idx * stride_sin_row + output_row_ptr = output_ptr + row_idx * stride_x_row + + # half size for x1 and x2 + head_size_half = head_size // 2 + + for block_start in range(0, head_size_half, BLOCK_HS_HALF): + offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF) + mask = offsets_half < head_size_half + + cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0) + sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0) + + offsets_x1 = 2 * offsets_half + offsets_x2 = 2 * offsets_half + 1 + + x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0) + x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0) + + x1_fp32 = x1_vals.to(tl.float32) + x2_fp32 = x2_vals.to(tl.float32) + cos_fp32 = cos_vals.to(tl.float32) + sin_fp32 = sin_vals.to(tl.float32) + o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32) + o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32) + + tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask) + tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask) + + +def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + output = torch.empty_like(x) + + if x.dim() > 3: + bsz, num_tokens, num_heads, head_size = x.shape + else: + num_tokens, num_heads, head_size = x.shape + bsz = 1 + + assert head_size % 2 == 0, "head_size must be divisible by 2" + + x_reshaped = x.view(-1, head_size) + output_reshaped = output.view(-1, head_size) + + # num_tokens per head, 1 token per block + grid = (bsz * num_tokens * num_heads,) + + if interleaved and cos.shape[-1] == head_size: + cos = cos[..., ::2].contiguous() + sin = sin[..., ::2].contiguous() + else: + cos = cos.contiguous() + sin = sin.contiguous() + + _rotary_embedding_kernel[grid]( + output_reshaped, + x_reshaped, + cos, + sin, + num_heads, + head_size, + num_tokens, + x_reshaped.stride(0), + cos.stride(0), + sin.stride(0), + interleaved, + ) + + return output + + +# RMSNorm-fp32 +def maybe_contiguous_lastdim(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def maybe_contiguous(x): + return x.contiguous() if x is not None else None + + +def triton_autotune_configs(): + if not torch.cuda.is_available(): + return [] + # Return configs with a valid warp count for the current device + configs = [] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block = 1024 + # Default to warp size 32 if not defined by device + warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block] + # return [triton.Config({}, num_warps=8)] + + +# Copied from flash-attn +@triton.autotune( + configs=triton_autotune_configs(), + key=[ + "N", + "HAS_RESIDUAL", + "STORE_RESIDUAL_OUT", + "IS_RMS_NORM", + "HAS_BIAS", + "HAS_WEIGHT", + "HAS_X1", + "HAS_W1", + "HAS_B1", + ], +) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + DROPOUT_MASK1, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + y = x_hat * w + b if HAS_BIAS else x_hat * w + else: + y = x_hat + b if HAS_BIAS else x_hat + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + out: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None: + residual_dtype = residual.dtype + if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None): + residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype) + else: + residual_out = None + y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( + x, + weight, + bias, + eps, + out, + residual=residual, + x1=x1, + weight1=weight1, + bias1=bias1, + dropout_p=dropout_p, + rowscale=rowscale, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + residual_out=residual_out, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 + + +# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema +# since we're returning a tuple of tensors +def _layer_norm_fwd_impl( + x: Tensor, + weight: Optional[Tensor], + bias: Tensor, + eps: float, + out: Tensor, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + assert out.shape == x.shape + assert out.stride(-1) == 1 + if residual_out is not None: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) + if x1 is not None: + dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask1 = None + else: + dropout_mask, dropout_mask1 = None, None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( + x, + out, + weight if weight is not None else x, # unused when HAS_WEIGHT == False + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + dropout_mask1, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + HAS_X1=x1 is not None, + HAS_W1=weight1 is not None, + HAS_B1=bias1 is not None, + ) + return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 + + +class LayerNormFn: + @staticmethod + def forward( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) + if residual is not None: + assert residual.shape == x_shape_og + residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) + # weight can be None when elementwise_affine=False for LayerNorm + if weight is not None: + weight = weight.contiguous() + bias = maybe_contiguous(bias) + weight1 = maybe_contiguous(weight1) + bias1 = maybe_contiguous(bias1) + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + out_dtype=out_dtype, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out, + ) + y = y.reshape(x_shape_og) + return y + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) + + +@triton.jit +def _norm_infer_kernel( + X, + Y, + W, + B, + stride_x_row, + stride_y_row, + M, + N, + eps, + IS_RMS_NORM: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_WEIGHT: + W += 0 + if HAS_BIAS: + B += 0 + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32) + y = x_hat * w + else: + y = x_hat + if HAS_BIAS: + b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32) + y += b + tl.store(Y + cols, y, mask=cols < N) + + +def norm_infer( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, + is_rms_norm: bool = False, + out: Optional[Tensor] = None, +): + M, N = x.shape + assert x.stride(-1) == 1 + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.shape == (N,) + assert bias.stride(-1) == 1 + if out is None: + out = torch.empty_like(x) + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + num_warps = min(max(BLOCK_N // 256, 1), 8) + _norm_infer_kernel[(M,)]( + x, + out, + weight if weight is not None else x, # dummy when HAS_WEIGHT=False + bias if bias is not None else x, # dummy when HAS_BIAS=False + x.stride(0), + out.stride(0), + M, + N, + eps, + IS_RMS_NORM=is_rms_norm, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + ) + return out + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) diff --git a/lightx2v/common/ops/tensor/__init__.py b/lightx2v/common/ops/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa13faea725555a9f71d9a73a5a18bc4d1a97b0d --- /dev/null +++ b/lightx2v/common/ops/tensor/__init__.py @@ -0,0 +1 @@ +from .tensor import DefaultTensor diff --git a/lightx2v/common/ops/tensor/tensor.py b/lightx2v/common/ops/tensor/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..069d5e3ac9b13360b381b7029d40e0b3c8984d75 --- /dev/null +++ b/lightx2v/common/ops/tensor/tensor.py @@ -0,0 +1,110 @@ +import os +import re +from pathlib import Path + +import torch +from safetensors import safe_open + +from lightx2v.utils.envs import * +from lightx2v.utils.registry_factory import TENSOR_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + + +@TENSOR_REGISTER("Default") +class DefaultTensor: + def __init__(self, tensor_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + self.tensor_name = tensor_name + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.is_post_adapter = is_post_adapter + self.create_cuda_buffer = create_cuda_buffer + self.create_cpu_buffer = create_cpu_buffer + self.infer_dtype = GET_DTYPE() + self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() + + def load(self, weight_dict): + if self.create_cuda_buffer: + self._load_cuda_buffer(weight_dict) + elif self.create_cpu_buffer: + self._load_cpu_pin_buffer() + else: + self._load_default_tensors(weight_dict) + + def _load_default_tensors(self, weight_dict): + if not self.lazy_load: + device = weight_dict[self.tensor_name].device + if device.type == "cpu": + tensor = weight_dict[self.tensor_name] + self.pin_tensor = self._create_cpu_pin_tensor(tensor) + del weight_dict[self.tensor_name] + else: + self.tensor = weight_dict[self.tensor_name] + + def _get_tensor(self, weight_dict=None, use_infer_dtype=False): + if self.lazy_load: + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + tensor = lazy_load_file.get_tensor(self.tensor_name) + if use_infer_dtype: + tensor = tensor.to(self.infer_dtype) + else: + tensor = weight_dict[self.tensor_name] + return tensor + + def _create_cpu_pin_tensor(self, tensor): + pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype) + pin_tensor.copy_(tensor) + del tensor + return pin_tensor + + def _load_cuda_buffer(self, weight_dict): + tensor = self._get_tensor(weight_dict, use_infer_dtype=self.lazy_load) + self.tensor_cuda_buffer = tensor.to(AI_DEVICE) + + def _load_cpu_pin_buffer(self): + tensor = self._get_tensor(use_infer_dtype=True) + self.pin_tensor = self._create_cpu_pin_tensor(tensor) + + def to_cuda(self, non_blocking=False): + self.tensor = self.pin_tensor.to(AI_DEVICE, non_blocking=non_blocking) + + def to_cpu(self, non_blocking=False): + if hasattr(self, "pin_tensor"): + self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu() + else: + self.tensor = self.tensor.to("cpu", non_blocking=non_blocking) + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor + return destination + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if self.is_post_adapter: + assert adapter_block_index is not None + tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1) + else: + tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1) + if tensor_name not in destination: + self.tensor = None + return + self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True) + + def load_state_dict_from_disk(self, block_index, adapter_block_index=None): + if self.is_post_adapter: + assert adapter_block_index is not None + self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1) + else: + self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1) + if Path(self.lazy_load_file).is_file(): + lazy_load_file_path = self.lazy_load_file + else: + lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors") + with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file: + tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype) + self.pin_tensor = self.pin_tensor.copy_(tensor) + del tensor diff --git a/lightx2v/common/transformer_infer/transformer_infer.py b/lightx2v/common/transformer_infer/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..424f79b00568c9f19ccf532a11fff01b55bd97b1 --- /dev/null +++ b/lightx2v/common/transformer_infer/transformer_infer.py @@ -0,0 +1,46 @@ +import math +from abc import ABC, abstractmethod + + +class BaseTransformerInfer(ABC): + @abstractmethod + def infer(self): + pass + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + self.scheduler.transformer_infer = self + + +class BaseTaylorCachingTransformerInfer(BaseTransformerInfer): + @abstractmethod + def infer_calculating(self): + pass + + @abstractmethod + def infer_using_cache(self): + pass + + @abstractmethod + def get_taylor_step_diff(self): + pass + + # 1. when fully calcualted, stored in cache + def derivative_approximation(self, block_cache, module_name, out): + if module_name not in block_cache: + block_cache[module_name] = {0: out} + else: + step_diff = self.get_taylor_step_diff() + + previous_out = block_cache[module_name][0] + block_cache[module_name][0] = out + block_cache[module_name][1] = (out - previous_out) / step_diff + + def taylor_formula(self, tensor_dict): + x = self.get_taylor_step_diff() + + output = 0 + for i in range(len(tensor_dict)): + output += (1 / math.factorial(i)) * tensor_dict[i] * (x**i) + + return output diff --git a/lightx2v/deploy/__init__.py b/lightx2v/deploy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/deploy/common/__init__.py b/lightx2v/deploy/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/deploy/common/aliyun.py b/lightx2v/deploy/common/aliyun.py new file mode 100644 index 0000000000000000000000000000000000000000..6aaff6b47d27964d0cc05d2daea0d7c60fa20772 --- /dev/null +++ b/lightx2v/deploy/common/aliyun.py @@ -0,0 +1,81 @@ +import asyncio +import json +import os +import sys + +from alibabacloud_dypnsapi20170525 import models as dypnsapi_models +from alibabacloud_dypnsapi20170525.client import Client +from alibabacloud_tea_openapi import models as openapi_models +from alibabacloud_tea_util import models as util_models +from loguru import logger + + +class AlibabaCloudClient: + def __init__(self): + config = openapi_models.Config( + access_key_id=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_ID"), + access_key_secret=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_SECRET"), + https_proxy=os.getenv("auth_https_proxy", None), + ) + self.client = Client(config) + self.runtime = util_models.RuntimeOptions() + + def check_ok(self, res, prefix): + logger.info(f"{prefix}: {res}") + if not isinstance(res, dict) or "statusCode" not in res or res["statusCode"] != 200: + logger.warning(f"{prefix}: error response: {res}") + return False + if "body" not in res or "Code" not in res["body"] or "Success" not in res["body"]: + logger.warning(f"{prefix}: error body: {res}") + return False + if res["body"]["Code"] != "OK" or res["body"]["Success"] is not True: + logger.warning(f"{prefix}: sms error: {res}") + return False + return True + + async def send_sms(self, phone_number): + try: + req = dypnsapi_models.SendSmsVerifyCodeRequest( + phone_number=phone_number, + sign_name="速通互联验证服务", + template_code="100001", + template_param=json.dumps({"code": "##code##", "min": "5"}), + valid_time=300, + ) + res = await self.client.send_sms_verify_code_with_options_async(req, self.runtime) + ok = self.check_ok(res.to_map(), "AlibabaCloudClient send sms") + logger.info(f"AlibabaCloudClient send sms for {phone_number}: {ok}") + return ok + + except Exception as e: + logger.warning(f"AlibabaCloudClient send sms for {phone_number}: {e}") + return False + + async def check_sms(self, phone_number, verify_code): + try: + req = dypnsapi_models.CheckSmsVerifyCodeRequest( + phone_number=phone_number, + verify_code=verify_code, + ) + res = await self.client.check_sms_verify_code_with_options_async(req, self.runtime) + ok = self.check_ok(res.to_map(), "AlibabaCloudClient check sms") + logger.info(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {ok}") + return ok + + except Exception as e: + logger.warning(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {e}") + return False + + +async def test(args): + assert len(args) in [1, 2], "Usage: python aliyun_sms.py [verify_code]" + phone_number = args[0] + client = AlibabaCloudClient() + if len(args) == 1: + await client.send_sms(phone_number) + else: + await client.check_sms(phone_number, args[1]) + + +if __name__ == "__main__": + asyncio.run(test(sys.argv[1:])) diff --git a/lightx2v/deploy/common/audio_separator.py b/lightx2v/deploy/common/audio_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..36f06d6ef8257259b517549185256aaee44b0d25 --- /dev/null +++ b/lightx2v/deploy/common/audio_separator.py @@ -0,0 +1,376 @@ +# -*- coding: utf-8 -*- +""" +Audio Source Separation Module +Separates different voice tracks in audio, supports multi-person audio separation +""" + +import base64 +import io +import os +import tempfile +import traceback +from collections import defaultdict +from typing import Dict, Optional, Union + +import torch +import torchaudio +from loguru import logger + +# Import pyannote.audio for speaker diarization +from pyannote.audio import Audio, Pipeline + +_origin_torch_load = torch.load + + +def our_torch_load(checkpoint_file, *args, **kwargs): + kwargs["weights_only"] = False + return _origin_torch_load(checkpoint_file, *args, **kwargs) + + +class AudioSeparator: + """ + Audio separator for separating different voice tracks in audio using pyannote.audio + Supports multi-person conversation separation, maintains duration (other speakers' tracks are empty) + """ + + def __init__( + self, + model_path: str = None, + device: str = None, + sample_rate: int = 16000, + ): + """ + Initialize audio separator + + Args: + model_path: Model path (if using custom model), default uses pyannote/speaker-diarization-community-1 + device: Device ('cpu', 'cuda', etc.), None for auto selection + sample_rate: Target sample rate, default 16000 + """ + self.sample_rate = sample_rate + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") + self._init_pyannote(model_path) + + def _init_pyannote(self, model_path: str = None): + """Initialize pyannote.audio pipeline""" + try: + huggingface_token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN") + model_name = model_path or "pyannote/speaker-diarization-community-1" + + try: + torch.load = our_torch_load + # Try loading with token if available + if huggingface_token: + self.pipeline = Pipeline.from_pretrained(model_name, token=huggingface_token) + else: + # Try without token (may work for public models) + self.pipeline = Pipeline.from_pretrained(model_name) + except Exception as e: + if "gated" in str(e).lower() or "token" in str(e).lower(): + raise RuntimeError(f"Model requires authentication. Set HUGGINGFACE_TOKEN or HF_TOKEN environment variable: {e}") + raise RuntimeError(f"Failed to load pyannote model: {e}") + finally: + torch.load = _origin_torch_load + + # Move pipeline to specified device + if self.device: + self.pipeline.to(torch.device(self.device)) + + # Initialize Audio helper for waveform loading + self.pyannote_audio = Audio() + + logger.info("Initialized pyannote.audio speaker diarization pipeline") + except Exception as e: + logger.error(f"Failed to initialize pyannote: {e}") + raise RuntimeError(f"Failed to initialize pyannote.audio pipeline: {e}") + + def separate_speakers( + self, + audio_path: Union[str, bytes], + num_speakers: Optional[int] = None, + min_speakers: int = 1, + max_speakers: int = 5, + ) -> Dict: + """ + Separate different speakers in audio + + Args: + audio_path: Audio file path or bytes data + num_speakers: Specified number of speakers, None for auto detection + min_speakers: Minimum number of speakers + max_speakers: Maximum number of speakers + + Returns: + Dict containing: + - speakers: List of speaker audio segments, each containing: + - speaker_id: Speaker ID (0, 1, 2, ...) + - audio: torch.Tensor audio data [channels, samples] + - segments: List of (start_time, end_time) tuples + - sample_rate: Sample rate + """ + try: + # Load audio + if isinstance(audio_path, bytes): + # 尝试从字节数据推断音频格式 + # 检查是否是 WAV 格式(RIFF 头) + is_wav = audio_path[:4] == b"RIFF" and audio_path[8:12] == b"WAVE" + # 检查是否是 MP3 格式(ID3 或 MPEG 头) + is_mp3 = audio_path[:3] == b"ID3" or audio_path[:2] == b"\xff\xfb" or audio_path[:2] == b"\xff\xf3" + + # 根据格式选择后缀 + if is_wav: + suffix = ".wav" + elif is_mp3: + suffix = ".mp3" + else: + # 默认尝试 WAV,如果失败会抛出错误 + suffix = ".wav" + + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_file: + tmp_file.write(audio_path) + tmp_audio_path = tmp_file.name + try: + result = self._separate_speakers_internal(tmp_audio_path, num_speakers, min_speakers, max_speakers) + finally: + # 确保临时文件被删除 + try: + os.unlink(tmp_audio_path) + except Exception as e: + logger.warning(f"Failed to delete temp file {tmp_audio_path}: {e}") + return result + else: + return self._separate_speakers_internal(audio_path, num_speakers, min_speakers, max_speakers) + + except Exception as e: + logger.error(f"Speaker separation failed: {traceback.format_exc()}") + raise RuntimeError(f"Audio separation error: {e}") + + def _separate_speakers_internal( + self, + audio_path: str, + num_speakers: Optional[int] = None, + min_speakers: int = 1, + max_speakers: int = 5, + ) -> Dict: + """Internal method: execute speaker separation""" + + # Load audio + waveform, original_sr = torchaudio.load(audio_path) + if original_sr != self.sample_rate: + resampler = torchaudio.transforms.Resample(original_sr, self.sample_rate) + waveform = resampler(waveform) + + # Convert to mono if stereo + if waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + # Ensure waveform is float32 and normalized (pyannote expects this format) + if waveform.dtype != torch.float32: + waveform = waveform.float() + + # Ensure waveform is in range [-1, 1] (normalize if needed) + if waveform.abs().max() > 1.0: + waveform = waveform / waveform.abs().max() + + if self.pipeline is None: + raise RuntimeError("Pyannote pipeline not initialized") + + return self._separate_with_pyannote(audio_path, waveform, num_speakers, min_speakers, max_speakers) + + def _separate_with_pyannote( + self, + audio_path: str, + waveform: torch.Tensor, + num_speakers: Optional[int], + min_speakers: int, + max_speakers: int, + ) -> Dict: + """Use pyannote.audio for speaker diarization""" + try: + # Use waveform dict to avoid AudioDecoder dependency issues + # Pipeline can accept either file path or waveform dict + # Using waveform dict is more reliable when torchcodec is not properly installed + audio_input = { + "waveform": waveform, + "sample_rate": self.sample_rate, + } + + # Run speaker diarization + output = self.pipeline( + audio_input, + min_speakers=min_speakers if num_speakers is None else num_speakers, + max_speakers=max_speakers if num_speakers is None else num_speakers, + ) + + # Extract audio segments for each speaker + speakers_dict = defaultdict(list) + for turn, speaker in output.speaker_diarization: + print(f"Speaker: {speaker}, Start time: {turn.start}, End time: {turn.end}") + start_time = turn.start + end_time = turn.end + start_sample = int(start_time * self.sample_rate) + end_sample = int(end_time * self.sample_rate) + + # Extract audio segment for this time period + segment_audio = waveform[:, start_sample:end_sample] + speakers_dict[speaker].append((start_time, end_time, segment_audio)) + + # Generate complete audio for each speaker (other speakers' segments are empty) + speakers = [] + audio_duration = waveform.shape[1] / self.sample_rate + num_samples = waveform.shape[1] + + for speaker_id, segments in speakers_dict.items(): + # Create zero-filled audio + speaker_audio = torch.zeros_like(waveform) + + # Fill in this speaker's segments + for start_time, end_time, segment_audio in segments: + start_sample = int(start_time * self.sample_rate) + end_sample = int(end_time * self.sample_rate) + # Ensure no out-of-bounds + end_sample = min(end_sample, num_samples) + segment_len = end_sample - start_sample + if segment_len > 0 and segment_audio.shape[1] > 0: + actual_len = min(segment_len, segment_audio.shape[1]) + speaker_audio[:, start_sample : start_sample + actual_len] = segment_audio[:, :actual_len] + + speakers.append( + { + "speaker_id": speaker_id, + "audio": speaker_audio, + "segments": [(s[0], s[1]) for s in segments], + "sample_rate": self.sample_rate, + } + ) + + logger.info(f"Separated audio into {len(speakers)} speakers using pyannote") + return {"speakers": speakers, "method": "pyannote"} + + except Exception as e: + logger.error(f"Pyannote separation failed: {e}") + raise RuntimeError(f"Audio separation failed: {e}") + + def save_speaker_audio(self, speaker_audio: torch.Tensor, output_path: str, sample_rate: int = None): + """ + Save speaker audio to file + + Args: + speaker_audio: Audio tensor [channels, samples] + output_path: Output path + sample_rate: Sample rate, if None uses self.sample_rate + """ + sr = sample_rate if sample_rate else self.sample_rate + torchaudio.save(output_path, speaker_audio, sr) + logger.info(f"Saved speaker audio to {output_path}") + + def speaker_audio_to_base64(self, speaker_audio: torch.Tensor, sample_rate: int = None, format: str = "wav") -> str: + """ + Convert speaker audio tensor to base64 encoded string without saving to file + + Args: + speaker_audio: Audio tensor [channels, samples] + sample_rate: Sample rate, if None uses self.sample_rate + format: Audio format (default: "wav") + + Returns: + Base64 encoded audio string + """ + sr = sample_rate if sample_rate else self.sample_rate + + # Use BytesIO to save audio to memory + buffer = io.BytesIO() + torchaudio.save(buffer, speaker_audio, sr, format=format) + + # Get the audio bytes + audio_bytes = buffer.getvalue() + + # Encode to base64 + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + + logger.debug(f"Converted speaker audio to base64, size: {len(audio_bytes)} bytes") + return audio_base64 + + def separate_and_save( + self, + audio_path: Union[str, bytes], + output_dir: str, + num_speakers: Optional[int] = None, + min_speakers: int = 1, + max_speakers: int = 5, + ) -> Dict: + """ + Separate audio and save to files + + Args: + audio_path: Input audio path or bytes data + output_dir: Output directory + num_speakers: Specified number of speakers + min_speakers: Minimum number of speakers + max_speakers: Maximum number of speakers + + Returns: + Separation result dictionary, containing output file paths + """ + os.makedirs(output_dir, exist_ok=True) + + result = self.separate_speakers(audio_path, num_speakers, min_speakers, max_speakers) + + output_paths = [] + for speaker in result["speakers"]: + speaker_id = speaker["speaker_id"] + output_path = os.path.join(output_dir, f"{speaker_id}.wav") + self.save_speaker_audio(speaker["audio"], output_path, speaker["sample_rate"]) + output_paths.append(output_path) + speaker["output_path"] = output_path + + result["output_paths"] = output_paths + return result + + +def separate_audio_tracks( + audio_path: str, + output_dir: str = None, + num_speakers: int = None, + model_path: str = None, +) -> Dict: + """ + Convenience function: separate different audio tracks + + Args: + audio_path: Audio file path + output_dir: Output directory, if None does not save files + num_speakers: Number of speakers + model_path: Model path (optional) + + Returns: + Separation result dictionary + """ + separator = AudioSeparator(model_path=model_path) + + if output_dir: + return separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers) + else: + return separator.separate_speakers(audio_path, num_speakers=num_speakers) + + +if __name__ == "__main__": + # Test code + import sys + + if len(sys.argv) < 2: + print("Usage: python audio_separator.py [output_dir] [num_speakers]") + sys.exit(1) + + audio_path = sys.argv[1] + output_dir = sys.argv[2] if len(sys.argv) > 2 else "./separated_audio" + num_speakers = int(sys.argv[3]) if len(sys.argv) > 3 else None + + separator = AudioSeparator() + result = separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers) + + print(f"Separated audio into {len(result['speakers'])} speakers:") + for speaker in result["speakers"]: + print(f" Speaker {speaker['speaker_id']}: {len(speaker['segments'])} segments") + if "output_path" in speaker: + print(f" Saved to: {speaker['output_path']}") diff --git a/lightx2v/deploy/common/face_detector.py b/lightx2v/deploy/common/face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..24240e330979994711df70767cf8c1a828ab0502 --- /dev/null +++ b/lightx2v/deploy/common/face_detector.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +""" +Face Detection Module using YOLO +Supports detecting faces in images, including human faces, animal faces, anime faces, sketches, etc. +""" + +import io +import traceback +from typing import Dict, List, Union + +import numpy as np +from PIL import Image, ImageDraw +from loguru import logger +from ultralytics import YOLO + + +class FaceDetector: + """ + Face detection using YOLO models + Supports detecting: human faces, animal faces, anime faces, sketch faces, etc. + """ + + def __init__(self, model_path: str = None, conf_threshold: float = 0.25, device: str = None): + """ + Initialize face detector + + Args: + model_path: YOLO model path, if None uses default pretrained model + conf_threshold: Confidence threshold, default 0.25 + device: Device ('cpu', 'cuda', '0', '1', etc.), None for auto selection + """ + + self.conf_threshold = conf_threshold + self.device = device + + if model_path is None: + # Use YOLO11 pretrained model, can detect COCO dataset classes (including person) + # Or use dedicated face detection model + logger.info("Loading default YOLO11n model for face detection") + try: + self.model = YOLO("yolo11n.pt") # Lightweight model + except Exception as e: + logger.warning(f"Failed to load default model, trying yolov8n: {e}") + self.model = YOLO("yolov8n.pt") + else: + logger.info(f"Loading YOLO model from {model_path}") + self.model = YOLO(model_path) + + # Person class ID in COCO dataset is 0 + # YOLO can detect person, for more precise face detection, recommend using dedicated face detection models + # Such as YOLOv8-face or RetinaFace, can be specified via model_path parameter + # First use YOLO to detect person region, then can further detect faces within + self.target_classes = { + "person": 0, # Face (by detecting person class) + # Can be extended to detect animal faces (cat, dog, etc.) and other classes + } + + def detect_faces( + self, + image: Union[str, Image.Image, bytes, np.ndarray], + return_image: bool = False, + ) -> Dict: + """ + Detect faces in image + + Args: + image: Input image, can be path, PIL Image, bytes or numpy array + return_image: Whether to return annotated image with detection boxes + return_boxes: Whether to return detection box information + + Returns: + Dict containing: + - faces: List of face detection results, each containing: + - bbox: [x1, y1, x2, y2] bounding box coordinates (absolute pixel coordinates) + - confidence: Confidence score (0.0-1.0) + - class_id: Class ID + - class_name: Class name + - image (optional): PIL Image with detection boxes drawn (if return_image=True) + """ + try: + # Load image + if isinstance(image, str): + img = Image.open(image).convert("RGB") + elif isinstance(image, bytes): + img = Image.open(io.BytesIO(image)).convert("RGB") + elif isinstance(image, np.ndarray): + img = Image.fromarray(image).convert("RGB") + elif isinstance(image, Image.Image): + img = image.convert("RGB") + else: + raise ValueError(f"Unsupported image type: {type(image)}") + + # Use YOLO for detection + # Note: YOLO by default detects person, we focus on person detection + # For more precise face detection, can train or use dedicated face detection models + results = self.model.predict( + source=img, + conf=self.conf_threshold, + device=self.device, + verbose=False, + ) + + faces = [] + annotated_img = img.copy() if return_image else None + + if len(results) > 0: + result = results[0] + boxes = result.boxes + + if boxes is not None and len(boxes) > 0: + for i in range(len(boxes)): + # Get bounding box coordinates (xyxy format) + bbox = boxes.xyxy[i].cpu().numpy().tolist() + confidence = float(boxes.conf[i].cpu().numpy()) + class_id = int(boxes.cls[i].cpu().numpy()) + + # Get class name + class_name = result.names.get(class_id, "unknown") + + # Process target classes (person, etc.) + # For person, the entire body box contains face region + # For more precise face detection, can: + # 1. Use dedicated face detection models (RetinaFace, YOLOv8-face) + # 2. Further use face detection model within current person box + # 3. Use specifically trained multi-class detection models (faces, animal faces, anime faces, etc.) + if class_id in self.target_classes.values(): + face_info = { + "bbox": bbox, # [x1, y1, x2, y2] - absolute pixel coordinates + "confidence": confidence, + "class_id": class_id, + "class_name": class_name, + } + faces.append(face_info) + + # Draw annotations on image if needed + if return_image and annotated_img is not None: + draw = ImageDraw.Draw(annotated_img) + x1, y1, x2, y2 = bbox + # Draw bounding box + draw.rectangle( + [x1, y1, x2, y2], + outline="red", + width=2, + ) + # Draw label + label = f"{class_name} {confidence:.2f}" + draw.text((x1, y1 - 15), label, fill="red") + + result_dict = {"faces": faces} + + if return_image and annotated_img is not None: + result_dict["image"] = annotated_img + + logger.info(f"Detected {len(faces)} faces in image") + return result_dict + + except Exception as e: + logger.error(f"Face detection failed: {traceback.format_exc()}") + raise RuntimeError(f"Face detection error: {e}") + + def detect_faces_from_bytes(self, image_bytes: bytes, **kwargs) -> Dict: + """ + Detect faces from byte data + + Args: + image_bytes: Image byte data + **kwargs: Additional parameters passed to detect_faces + + Returns: + Detection result dictionary + """ + return self.detect_faces(image_bytes, **kwargs) + + def extract_face_regions(self, image: Union[str, Image.Image, bytes], expand_ratio: float = 0.1) -> List[Image.Image]: + """ + Extract detected face regions + + Args: + image: Input image + expand_ratio: Bounding box expansion ratio to include more context + + Returns: + List of extracted face region images + """ + result = self.detect_faces(image) + faces = result["faces"] + + # Load original image + if isinstance(image, str): + img = Image.open(image).convert("RGB") + elif isinstance(image, bytes): + img = Image.open(io.BytesIO(image)).convert("RGB") + elif isinstance(image, Image.Image): + img = image.convert("RGB") + else: + raise ValueError(f"Unsupported image type: {type(image)}") + + face_regions = [] + img_width, img_height = img.size + + for face in faces: + x1, y1, x2, y2 = face["bbox"] + + # Expand bounding box + width = x2 - x1 + height = y2 - y1 + expand_x = width * expand_ratio + expand_y = height * expand_ratio + + x1 = max(0, int(x1 - expand_x)) + y1 = max(0, int(y1 - expand_y)) + x2 = min(img_width, int(x2 + expand_x)) + y2 = min(img_height, int(y2 + expand_y)) + + # Crop region + face_region = img.crop((x1, y1, x2, y2)) + face_regions.append(face_region) + + return face_regions + + def count_faces(self, image: Union[str, Image.Image, bytes]) -> int: + """ + Count number of faces in image + + Args: + image: Input image + + Returns: + Number of detected faces + """ + result = self.detect_faces(image, return_image=False) + return len(result["faces"]) + + +def detect_faces_in_image( + image_path: str, + model_path: str = None, + conf_threshold: float = 0.25, + return_image: bool = False, +) -> Dict: + """ + Convenience function: detect faces in image + + Args: + image_path: Image path + model_path: YOLO model path + conf_threshold: Confidence threshold + return_image: Whether to return annotated image + + Returns: + Detection result dictionary containing: + - faces: List of face detection results with bbox coordinates [x1, y1, x2, y2] + - image (optional): Annotated image with detection boxes + """ + detector = FaceDetector(model_path=model_path, conf_threshold=conf_threshold) + return detector.detect_faces(image_path, return_image=return_image) + + +if __name__ == "__main__": + # Test code + import sys + + if len(sys.argv) < 2: + print("Usage: python face_detector.py ") + sys.exit(1) + + image_path = sys.argv[1] + detector = FaceDetector() + result = detector.detect_faces(image_path, return_image=True) + + print(f"Detected {len(result['faces'])} faces:") + for i, face in enumerate(result["faces"]): + print(f" Face {i + 1}: {face}") + + output_path = "detected_faces.png" + result["image"].save(output_path) + print(f"Annotated image saved to: {output_path}") diff --git a/lightx2v/deploy/common/pipeline.py b/lightx2v/deploy/common/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4868ddbe8ab719b722a430a5762d7cc8c26b6c --- /dev/null +++ b/lightx2v/deploy/common/pipeline.py @@ -0,0 +1,167 @@ +import json +import sys + +from loguru import logger + + +class Pipeline: + def __init__(self, pipeline_json_file): + self.pipeline_json_file = pipeline_json_file + x = json.load(open(pipeline_json_file)) + self.data = x["data"] + self.meta = x["meta"] + self.inputs = {} + self.outputs = {} + self.temps = {} + self.model_lists = [] + self.types = {} + self.queues = set() + self.model_name_inner_to_outer = self.meta.get("model_name_inner_to_outer", {}) + self.model_name_outer_to_inner = self.meta.get("model_name_outer_to_inner", {}) + self.tidy_pipeline() + + def init_dict(self, base, task, model_cls): + if task not in base: + base[task] = {} + if model_cls not in base[task]: + base[task][model_cls] = {} + + # tidy each task item eg, ['t2v', 'wan2.1', 'multi_stage'] + def tidy_task(self, task, model_cls, stage, v3): + out2worker = {} + out2num = {} + cur_inps = set() + cur_temps = set() + cur_types = {} + for worker_name, worker_item in v3.items(): + prevs = [] + for inp in worker_item["inputs"]: + cur_types[inp] = self.get_type(inp) + if inp in out2worker: + prevs.append(out2worker[inp]) + out2num[inp] -= 1 + if out2num[inp] <= 0: + cur_temps.add(inp) + else: + cur_inps.add(inp) + worker_item["previous"] = prevs + + for out in worker_item["outputs"]: + cur_types[out] = self.get_type(out) + out2worker[out] = worker_name + if out not in out2num: + out2num[out] = 0 + out2num[out] += 1 + + if "queue" not in worker_item: + worker_item["queue"] = "-".join([task, model_cls, stage, worker_name]) + self.queues.add(worker_item["queue"]) + + cur_outs = [out for out, num in out2num.items() if num > 0] + self.inputs[task][model_cls][stage] = list(cur_inps) + self.outputs[task][model_cls][stage] = cur_outs + self.temps[task][model_cls][stage] = list(cur_temps) + self.types[task][model_cls][stage] = cur_types + + # tidy previous dependence workers and queue name + def tidy_pipeline(self): + for task, v1 in self.data.items(): + for model_cls, v2 in v1.items(): + for stage, v3 in v2.items(): + self.init_dict(self.inputs, task, model_cls) + self.init_dict(self.outputs, task, model_cls) + self.init_dict(self.temps, task, model_cls) + self.init_dict(self.types, task, model_cls) + self.tidy_task(task, model_cls, stage, v3) + self.model_lists.append({"task": task, "model_cls": model_cls, "stage": stage}) + logger.info(f"pipelines: {json.dumps(self.data, indent=4)}") + logger.info(f"inputs: {self.inputs}") + logger.info(f"outputs: {self.outputs}") + logger.info(f"temps: {self.temps}") + logger.info(f"types: {self.types}") + logger.info(f"model_lists: {self.model_lists}") + logger.info(f"queues: {self.queues}") + + def get_item_by_keys(self, keys): + item = self.data + for k in keys: + if k not in item: + raise Exception(f"{keys} are not in {self.pipeline_json_file}!") + item = item[k] + return item + + # eg. keys: ['t2v', 'wan2.1', 'multi_stage', 'text_encoder'] + def get_worker(self, keys): + return self.get_item_by_keys(keys) + + # eg. keys: ['t2v', 'wan2.1', 'multi_stage'] + def get_workers(self, keys): + return self.get_item_by_keys(keys) + + # eg. keys: ['t2v', 'wan2.1', 'multi_stage'] + def get_inputs(self, keys): + item = self.inputs + for k in keys: + if k not in item: + raise Exception(f"{keys} are not in inputs!") + item = item[k] + return item + + # eg. keys: ['t2v', 'wan2.1', 'multi_stage'] + def get_outputs(self, keys): + item = self.outputs + for k in keys: + if k not in item: + raise Exception(f"{keys} are not in outputs!") + item = item[k] + return item + + # eg. keys: ['t2v', 'wan2.1', 'multi_stage'] + def get_temps(self, keys): + item = self.temps + for k in keys: + if k not in item: + raise Exception(f"{keys} are not in temps!") + item = item[k] + return item + + # eg. keys: ['t2v', 'wan2.1', 'multi_stage'] + def get_types(self, keys): + item = self.types + for k in keys: + if k not in item: + raise Exception(f"{keys} are not in types!") + item = item[k] + return item + + def check_item_by_keys(self, keys): + item = self.data + for k in keys: + if k not in item: + return False + item = item[k] + return True + + def get_model_lists(self): + return self.model_lists + + def get_type(self, name): + return self.meta["special_types"].get(name, "OBJECT") + + def get_monitor_config(self): + return self.meta["monitor"] + + def get_queues(self): + return self.queues + + def inner_model_name(self, name): + return self.model_name_outer_to_inner.get(name, name) + + def outer_model_name(self, name): + return self.model_name_inner_to_outer.get(name, name) + + +if __name__ == "__main__": + pipeline = Pipeline(sys.argv[1]) + print(pipeline.get_workers(["t2v", "wan2.1", "multi_stage"])) + print(pipeline.get_worker(["i2v", "wan2.1", "multi_stage", "dit"])) diff --git a/lightx2v/deploy/common/podcasts.py b/lightx2v/deploy/common/podcasts.py new file mode 100644 index 0000000000000000000000000000000000000000..52d255dc23afa864b90cc496faeb175c4a8be998 --- /dev/null +++ b/lightx2v/deploy/common/podcasts.py @@ -0,0 +1,696 @@ +# -*- coding: utf-8 -*- + +import asyncio +import io +import json +import os +import struct +import uuid +from dataclasses import dataclass +from enum import IntEnum +from typing import Callable, List, Optional + +import websockets +from loguru import logger +from pydub import AudioSegment + + +# Protocol definitions (from podcasts_protocols) +class MsgType(IntEnum): + """Message type enumeration""" + + Invalid = 0 + FullClientRequest = 0b1 + AudioOnlyClient = 0b10 + FullServerResponse = 0b1001 + AudioOnlyServer = 0b1011 + FrontEndResultServer = 0b1100 + Error = 0b1111 + ServerACK = AudioOnlyServer + + +class MsgTypeFlagBits(IntEnum): + """Message type flag bits""" + + NoSeq = 0 + PositiveSeq = 0b1 + LastNoSeq = 0b10 + NegativeSeq = 0b11 + WithEvent = 0b100 + + +class VersionBits(IntEnum): + """Version bits""" + + Version1 = 1 + + +class HeaderSizeBits(IntEnum): + """Header size bits""" + + HeaderSize4 = 1 + HeaderSize8 = 2 + HeaderSize12 = 3 + HeaderSize16 = 4 + + +class SerializationBits(IntEnum): + """Serialization method bits""" + + Raw = 0 + JSON = 0b1 + Thrift = 0b11 + Custom = 0b1111 + + +class CompressionBits(IntEnum): + """Compression method bits""" + + None_ = 0 + Gzip = 0b1 + Custom = 0b1111 + + +class EventType(IntEnum): + """Event type enumeration""" + + None_ = 0 + StartConnection = 1 + StartTask = 1 + FinishConnection = 2 + FinishTask = 2 + ConnectionStarted = 50 + TaskStarted = 50 + ConnectionFailed = 51 + TaskFailed = 51 + ConnectionFinished = 52 + TaskFinished = 52 + StartSession = 100 + CancelSession = 101 + FinishSession = 102 + SessionStarted = 150 + SessionCanceled = 151 + SessionFinished = 152 + SessionFailed = 153 + UsageResponse = 154 + ChargeData = 154 + TaskRequest = 200 + UpdateConfig = 201 + AudioMuted = 250 + SayHello = 300 + TTSSentenceStart = 350 + TTSSentenceEnd = 351 + TTSResponse = 352 + TTSEnded = 359 + PodcastRoundStart = 360 + PodcastRoundResponse = 361 + PodcastRoundEnd = 362 + PodcastEnd = 363 + + +@dataclass +class Message: + """Message object""" + + version: VersionBits = VersionBits.Version1 + header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4 + type: MsgType = MsgType.Invalid + flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq + serialization: SerializationBits = SerializationBits.JSON + compression: CompressionBits = CompressionBits.None_ + event: EventType = EventType.None_ + session_id: str = "" + connect_id: str = "" + sequence: int = 0 + error_code: int = 0 + payload: bytes = b"" + + @classmethod + def from_bytes(cls, data: bytes) -> "Message": + """Create message object from bytes""" + if len(data) < 3: + raise ValueError(f"Data too short: expected at least 3 bytes, got {len(data)}") + type_and_flag = data[1] + msg_type = MsgType(type_and_flag >> 4) + flag = MsgTypeFlagBits(type_and_flag & 0b00001111) + msg = cls(type=msg_type, flag=flag) + msg.unmarshal(data) + return msg + + def marshal(self) -> bytes: + """Serialize message to bytes""" + buffer = io.BytesIO() + header = [ + (self.version << 4) | self.header_size, + (self.type << 4) | self.flag, + (self.serialization << 4) | self.compression, + ] + header_size = 4 * self.header_size + if padding := header_size - len(header): + header.extend([0] * padding) + buffer.write(bytes(header)) + writers = self._get_writers() + for writer in writers: + writer(buffer) + return buffer.getvalue() + + def unmarshal(self, data: bytes) -> None: + """Deserialize message from bytes""" + buffer = io.BytesIO(data) + version_and_header_size = buffer.read(1)[0] + self.version = VersionBits(version_and_header_size >> 4) + self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111) + buffer.read(1) + serialization_compression = buffer.read(1)[0] + self.serialization = SerializationBits(serialization_compression >> 4) + self.compression = CompressionBits(serialization_compression & 0b00001111) + header_size = 4 * self.header_size + read_size = 3 + if padding_size := header_size - read_size: + buffer.read(padding_size) + readers = self._get_readers() + for reader in readers: + reader(buffer) + remaining = buffer.read() + if remaining: + raise ValueError(f"Unexpected data after message: {remaining}") + + def _get_writers(self) -> List[Callable[[io.BytesIO], None]]: + """Get list of writer functions""" + writers = [] + if self.flag == MsgTypeFlagBits.WithEvent: + writers.extend([self._write_event, self._write_session_id]) + if self.type in [MsgType.FullClientRequest, MsgType.FullServerResponse, MsgType.FrontEndResultServer, MsgType.AudioOnlyClient, MsgType.AudioOnlyServer]: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + writers.append(self._write_sequence) + elif self.type == MsgType.Error: + writers.append(self._write_error_code) + else: + raise ValueError(f"Unsupported message type: {self.type}") + writers.append(self._write_payload) + return writers + + def _get_readers(self) -> List[Callable[[io.BytesIO], None]]: + """Get list of reader functions""" + readers = [] + if self.type in [MsgType.FullClientRequest, MsgType.FullServerResponse, MsgType.FrontEndResultServer, MsgType.AudioOnlyClient, MsgType.AudioOnlyServer]: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + readers.append(self._read_sequence) + elif self.type == MsgType.Error: + readers.append(self._read_error_code) + if self.flag == MsgTypeFlagBits.WithEvent: + readers.extend([self._read_event, self._read_session_id, self._read_connect_id]) + readers.append(self._read_payload) + return readers + + def _write_event(self, buffer: io.BytesIO) -> None: + buffer.write(struct.pack(">i", self.event)) + + def _write_session_id(self, buffer: io.BytesIO) -> None: + if self.event in [EventType.StartConnection, EventType.FinishConnection, EventType.ConnectionStarted, EventType.ConnectionFailed]: + return + session_id_bytes = self.session_id.encode("utf-8") + size = len(session_id_bytes) + if size > 0xFFFFFFFF: + raise ValueError(f"Session ID size ({size}) exceeds max(uint32)") + buffer.write(struct.pack(">I", size)) + if size > 0: + buffer.write(session_id_bytes) + + def _write_sequence(self, buffer: io.BytesIO) -> None: + buffer.write(struct.pack(">i", self.sequence)) + + def _write_error_code(self, buffer: io.BytesIO) -> None: + buffer.write(struct.pack(">I", self.error_code)) + + def _write_payload(self, buffer: io.BytesIO) -> None: + size = len(self.payload) + if size > 0xFFFFFFFF: + raise ValueError(f"Payload size ({size}) exceeds max(uint32)") + buffer.write(struct.pack(">I", size)) + buffer.write(self.payload) + + def _read_event(self, buffer: io.BytesIO) -> None: + event_bytes = buffer.read(4) + if event_bytes: + self.event = EventType(struct.unpack(">i", event_bytes)[0]) + + def _read_session_id(self, buffer: io.BytesIO) -> None: + if self.event in [EventType.StartConnection, EventType.FinishConnection, EventType.ConnectionStarted, EventType.ConnectionFailed, EventType.ConnectionFinished]: + return + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + session_id_bytes = buffer.read(size) + if len(session_id_bytes) == size: + self.session_id = session_id_bytes.decode("utf-8") + + def _read_connect_id(self, buffer: io.BytesIO) -> None: + if self.event in [EventType.ConnectionStarted, EventType.ConnectionFailed, EventType.ConnectionFinished]: + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + self.connect_id = buffer.read(size).decode("utf-8") + + def _read_sequence(self, buffer: io.BytesIO) -> None: + sequence_bytes = buffer.read(4) + if sequence_bytes: + self.sequence = struct.unpack(">i", sequence_bytes)[0] + + def _read_error_code(self, buffer: io.BytesIO) -> None: + error_code_bytes = buffer.read(4) + if error_code_bytes: + self.error_code = struct.unpack(">I", error_code_bytes)[0] + + def _read_payload(self, buffer: io.BytesIO) -> None: + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + self.payload = buffer.read(size) + + +async def receive_message(websocket: websockets.WebSocketClientProtocol) -> Message: + """Receive message from websocket""" + try: + data = await websocket.recv() + if isinstance(data, str): + raise ValueError(f"Unexpected text message: {data}") + elif isinstance(data, bytes): + msg = Message.from_bytes(data) + # logger.debug(f"Received: {msg}") + return msg + else: + raise ValueError(f"Unexpected message type: {type(data)}") + except Exception as e: + logger.error(f"Failed to receive message: {e}") + raise + + +async def wait_for_event(websocket: websockets.WebSocketClientProtocol, msg_type: MsgType, event_type: EventType) -> Message: + """Wait for specific event""" + while True: + msg = await receive_message(websocket) + if msg.type != msg_type or msg.event != event_type: + raise ValueError(f"Unexpected message: {msg}") + if msg.type == msg_type and msg.event == event_type: + return msg + + +async def start_connection(websocket: websockets.WebSocketClientProtocol) -> None: + """Start connection""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.StartConnection + msg.payload = b"{}" + logger.debug(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def finish_connection(websocket: websockets.WebSocketClientProtocol) -> None: + """Finish connection""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.FinishConnection + msg.payload = b"{}" + logger.debug(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def start_session(websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str) -> None: + """Start session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.StartSession + msg.session_id = session_id + msg.payload = payload + logger.debug(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def finish_session(websocket: websockets.WebSocketClientProtocol, session_id: str) -> None: + """Finish session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.FinishSession + msg.session_id = session_id + msg.payload = b"{}" + logger.debug(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +class PodcastRoundPostProcessor: + def __init__(self, session_id, data_manager): + self.session_id = session_id + self.data_manager = data_manager + + self.temp_merged_audio_name = "merged_audio.mp3" + self.output_merged_audio_name = f"{session_id}-merged_audio.mp3" + self.subtitle_timestamps = [] # 记录字幕时间戳 + self.current_audio_duration = 0.0 # 当前音频时长 + self.merged_audio = None # 用于存储合并的音频对象 + self.merged_audio_bytes = None + + async def init(self): + if self.data_manager: + await self.data_manager.create_podcast_temp_session_dir(self.session_id) + + async def postprocess_round(self, current_round, voice, audio, podcast_texts): + text = "" + if podcast_texts: + text = podcast_texts[-1].get("text", "") + logger.debug(f"Processing round: {current_round}, voice: {voice}, text: {text}, audio: {len(audio)} bytes") + + new_segment = AudioSegment.from_mp3(io.BytesIO(bytes(audio))) + round_duration = len(new_segment) / 1000.0 + + if self.merged_audio is None: + self.merged_audio = new_segment + else: + self.merged_audio = self.merged_audio + new_segment + + # 保存合并后的音频到临时文件(用于前端实时访问) + merged_io = io.BytesIO() + self.merged_audio.export(merged_io, format="mp3") + self.merged_audio_bytes = merged_io.getvalue() + if self.data_manager: + await self.data_manager.save_podcast_temp_session_file(self.session_id, self.temp_merged_audio_name, self.merged_audio_bytes) + merged_file_size = len(self.merged_audio_bytes) + + # 记录字幕时间戳 + self.subtitle_timestamps.append( + { + "start": self.current_audio_duration, + "end": self.current_audio_duration + round_duration, + "text": text, + "speaker": voice, + } + ) + self.current_audio_duration += round_duration + logger.debug(f"Merged audio updated: {merged_file_size} bytes, duration: {self.current_audio_duration:.2f}s") + + return { + "url": f"/api/v1/podcast/audio?session_id={self.session_id}&filename={self.temp_merged_audio_name}", + "size": merged_file_size, + "duration": self.current_audio_duration, + "round": current_round, + "text": text, + "speaker": voice, + } + + async def postprocess_final(self): + if self.data_manager: + await self.data_manager.save_podcast_output_file(self.output_merged_audio_name, self.merged_audio_bytes) + return { + "subtitles": self.subtitle_timestamps, + "audio_name": self.output_merged_audio_name, + } + + async def cleanup(self): + if self.data_manager: + await self.data_manager.clear_podcast_temp_session_dir(self.session_id) + self.data_manager = None + + +class VolcEnginePodcastClient: + """ + VolcEngine Podcast客户端 + + 支持多种播客类型: + - action=0: 文本转播客 + - action=3: NLP文本转播客 + - action=4: 提示词生成播客 + """ + + def __init__(self): + self.endpoint = "wss://openspeech.bytedance.com/api/v3/sami/podcasttts" + self.appid = os.getenv("VOLCENGINE_PODCAST_APPID") + self.access_token = os.getenv("VOLCENGINE_PODCAST_ACCESS_TOKEN") + self.app_key = "aGjiRDfUWi" + self.proxy = os.getenv("HTTPS_PROXY", None) + if self.proxy: + logger.info(f"volcengine podcast use proxy: {self.proxy}") + + async def podcast_request( + self, + session_id: str, + data_manager=None, + text: str = "", + input_url: str = "", + prompt_text: str = "", + nlp_texts: str = "", + action: int = 0, + resource_id: str = "volc.service_type.10050", + encoding: str = "mp3", + input_id: str = "test_podcast", + speaker_info: str = '{"random_order":false}', + use_head_music: bool = False, + use_tail_music: bool = False, + only_nlp_text: bool = False, + return_audio_url: bool = False, + skip_round_audio_save: bool = False, + on_round_complete: Optional[Callable] = None, + ): + """ + 执行播客请求 + + Args: + text: 输入文本 (action=0时使用) + input_url: Web URL或文件URL (action=0时使用) + prompt_text: 提示词文本 (action=4时必须) + nlp_texts: NLP文本 (action=3时必须) + action: 播客类型 (0/3/4) + resource_id: 音频资源ID + encoding: 音频格式 (mp3/wav) + input_id: 唯一输入标识 + speaker_info: 播客说话人信息 + use_head_music: 是否使用开头音乐 + use_tail_music: 是否使用结尾音乐 + only_nlp_text: 是否只返回播客文本 + return_audio_url: 是否返回音频URL + skip_round_audio_save: 是否跳过单轮音频保存 + output_dir: 输出目录 + on_round_complete: 轮次完成回调函数 + """ + if not self.appid or not self.access_token: + logger.error("APP ID or Access Key is required") + return None, None + + headers = { + "X-Api-App-Id": self.appid, + "X-Api-App-Key": self.app_key, + "X-Api-Access-Key": self.access_token, + "X-Api-Resource-Id": resource_id, + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + is_podcast_round_end = True + audio_received = False + last_round_id = -1 + task_id = "" + websocket = None + retry_num = 5 + audio = bytearray() + voice = "" + current_round = 0 + podcast_texts = [] + + post_processor = PodcastRoundPostProcessor(session_id, data_manager) + await post_processor.init() + + try: + while retry_num > 0: + # 建立WebSocket连接 + websocket = await websockets.connect(self.endpoint, additional_headers=headers) + logger.debug(f"WebSocket connected: {websocket.response.headers}") + + # 构建请求参数 + if input_url: + req_params = { + "input_id": input_id, + "nlp_texts": json.loads(nlp_texts) if nlp_texts else None, + "prompt_text": prompt_text, + "action": action, + "use_head_music": use_head_music, + "use_tail_music": use_tail_music, + "input_info": { + "input_url": input_url, + "return_audio_url": return_audio_url, + "only_nlp_text": only_nlp_text, + }, + "speaker_info": json.loads(speaker_info) if speaker_info else None, + "audio_config": {"format": encoding, "sample_rate": 24000, "speech_rate": 0}, + } + else: + req_params = { + "input_id": input_id, + "input_text": text, + "nlp_texts": json.loads(nlp_texts) if nlp_texts else None, + "prompt_text": prompt_text, + "action": action, + "use_head_music": use_head_music, + "use_tail_music": use_tail_music, + "input_info": { + "input_url": input_url, + "return_audio_url": return_audio_url, + "only_nlp_text": only_nlp_text, + }, + "speaker_info": json.loads(speaker_info) if speaker_info else None, + "audio_config": {"format": encoding, "sample_rate": 24000, "speech_rate": 0}, + } + + logger.debug(f"Request params: {json.dumps(req_params, indent=2, ensure_ascii=False)}") + + if not is_podcast_round_end: + req_params["retry_info"] = {"retry_task_id": task_id, "last_finished_round_id": last_round_id} + + # Start connection + await start_connection(websocket) + await wait_for_event(websocket, MsgType.FullServerResponse, EventType.ConnectionStarted) + + session_id = str(uuid.uuid4()) + if not task_id: + task_id = session_id + + # Start session + await start_session(websocket, json.dumps(req_params).encode(), session_id) + await wait_for_event(websocket, MsgType.FullServerResponse, EventType.SessionStarted) + + # Finish session + await finish_session(websocket, session_id) + + while True: + msg = await receive_message(websocket) + + # 音频数据块 + if msg.type == MsgType.AudioOnlyServer and msg.event == EventType.PodcastRoundResponse: + if not audio_received and audio: + audio_received = True + audio.extend(msg.payload) + + # 错误信息 + elif msg.type == MsgType.Error: + raise RuntimeError(f"Server error: {msg.payload.decode()}") + + elif msg.type == MsgType.FullServerResponse: + # 播客 round 开始 + if msg.event == EventType.PodcastRoundStart: + data = json.loads(msg.payload.decode()) + if data.get("text"): + filtered_payload = {"text": data.get("text"), "speaker": data.get("speaker")} + podcast_texts.append(filtered_payload) + voice = data.get("speaker") + current_round = data.get("round_id") + if current_round == -1: + voice = "head_music" + if current_round == 9999: + voice = "tail_music" + is_podcast_round_end = False + logger.debug(f"New round started: {data}") + + # 播客 round 结束 + if msg.event == EventType.PodcastRoundEnd: + data = json.loads(msg.payload.decode()) + logger.debug(f"Podcast round end: {data}") + if data.get("is_error"): + break + is_podcast_round_end = True + last_round_id = current_round + if audio: + round_info = await post_processor.postprocess_round(current_round, voice, audio, podcast_texts) + if on_round_complete: + await on_round_complete(round_info) + audio.clear() + + # 播客结束 + if msg.event == EventType.PodcastEnd: + data = json.loads(msg.payload.decode()) + logger.info(f"Podcast end: {data}") + + # 会话结束 + if msg.event == EventType.SessionFinished: + break + + if not audio_received and not only_nlp_text: + raise RuntimeError("No audio data received") + + # 保持连接 + await finish_connection(websocket) + await wait_for_event(websocket, MsgType.FullServerResponse, EventType.ConnectionFinished) + + # 播客结束, 保存最终音频文件 + if is_podcast_round_end: + podcast_info = await post_processor.postprocess_final() + return podcast_info + else: + logger.error(f"Current podcast not finished, resuming from round {last_round_id}") + retry_num -= 1 + await asyncio.sleep(1) + if websocket: + await websocket.close() + + finally: + await post_processor.cleanup() + if websocket: + await websocket.close() + return None + + +async def test(args): + """ + Podcast测试函数 + + Args: + args: dict, 包含所有podcast参数 + """ + client = VolcEnginePodcastClient() + + # 设置默认参数 + params = { + "text": "", + "input_url": "https://zhuanlan.zhihu.com/p/607822576", + "prompt_text": "", + "nlp_texts": "", + "action": 0, + "resource_id": "volc.service_type.10050", + "encoding": "mp3", + "input_id": "test_podcast", + "speaker_info": '{"random_order":false}', + "use_head_music": False, + "use_tail_music": False, + "only_nlp_text": False, + "return_audio_url": True, + "skip_round_audio_save": False, + "output_dir": "output", + } + + # 覆盖默认参数 + if args: + params.update(args) + + await client.podcast_request(**params) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--text", default="", help="Input text Use when action in [0]") + parser.add_argument("--input_url", default="", help="Web url or file url Use when action in [0]") + parser.add_argument("--prompt_text", default="", help="Input Prompt Text must not empty when action in [4]") + parser.add_argument("--nlp_texts", default="", help="Input NLP Texts must not empty when action in [3]") + parser.add_argument("--resource_id", default="volc.service_type.10050", help="Audio Resource ID") + parser.add_argument("--encoding", default="mp3", choices=["mp3", "wav"], help="Audio format") + parser.add_argument("--input_id", default="test_podcast", help="Unique input identifier") + parser.add_argument("--speaker_info", default='{"random_order":false}', help="Podcast Speaker Info") + parser.add_argument("--use_head_music", default=False, action="store_true", help="Enable head music") + parser.add_argument("--use_tail_music", default=False, action="store_true", help="Enable tail music") + parser.add_argument("--only_nlp_text", default=False, action="store_true", help="Enable only podcast text when action in [0, 4]") + parser.add_argument("--return_audio_url", default=False, action="store_true", help="Enable return audio url that can download") + parser.add_argument("--action", default=0, type=int, choices=[0, 3, 4], help="different podcast type") + parser.add_argument("--skip_round_audio_save", default=False, action="store_true", help="skip round audio save") + parser.add_argument("--output_dir", default="output", help="Output directory") + + args = parser.parse_args() + + kwargs = {k: v for k, v in vars(args).items() if v is not None and not (isinstance(v, bool) and not v)} + + asyncio.run(test(kwargs)) diff --git a/lightx2v/deploy/common/utils.py b/lightx2v/deploy/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6258cc9bc1f4eb53f47793358fb2a552241b91b6 --- /dev/null +++ b/lightx2v/deploy/common/utils.py @@ -0,0 +1,253 @@ +import asyncio +import base64 +import io +import os +import subprocess +import tempfile +import time +import traceback +from datetime import datetime + +import httpx +import torchaudio +from PIL import Image +from loguru import logger + +FMT = "%Y-%m-%d %H:%M:%S" + + +def current_time(): + return datetime.now().timestamp() + + +def time2str(t): + d = datetime.fromtimestamp(t) + return d.strftime(FMT) + + +def str2time(s): + d = datetime.strptime(s, FMT) + return d.timestamp() + + +def try_catch(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception: + logger.error(f"Error in {func.__name__}:") + traceback.print_exc() + return None + + return wrapper + + +def class_try_catch(func): + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except Exception: + logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:") + traceback.print_exc() + return None + + return wrapper + + +def class_try_catch_async(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except Exception: + logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:") + traceback.print_exc() + return None + + return wrapper + + +def data_name(x, task_id): + if x == "input_image": + x = x + ".png" + elif x == "input_video": + x = x + ".mp4" + elif x == "output_video": + x = x + ".mp4" + return f"{task_id}-{x}" + + +async def fetch_resource(url, timeout): + logger.info(f"Begin to download resource from url: {url}") + t0 = time.time() + async with httpx.AsyncClient() as client: + async with client.stream("GET", url, timeout=timeout) as response: + response.raise_for_status() + ans_bytes = [] + async for chunk in response.aiter_bytes(chunk_size=1024 * 1024): + ans_bytes.append(chunk) + if len(ans_bytes) > 128: + raise Exception(f"url {url} recv data is too big") + content = b"".join(ans_bytes) + logger.info(f"Download url {url} resource cost time: {time.time() - t0} seconds") + return content + + +# check, resize, read rotate meta info +def format_image_data(data, max_size=1280): + image = Image.open(io.BytesIO(data)).convert("RGB") + exif = image.getexif() + changed = False + w, h = image.size + assert w > 0 and h > 0, "image is empty" + logger.info(f"load image: {w}x{h}, exif: {exif}") + + if w > max_size or h > max_size: + ratio = max_size / max(w, h) + w = int(w * ratio) + h = int(h * ratio) + image = image.resize((w, h)) + logger.info(f"resize image to: {image.size}") + changed = True + + orientation_key = 274 + if orientation_key and orientation_key in exif: + orientation = exif[orientation_key] + if orientation == 2: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + elif orientation == 3: + image = image.rotate(180, expand=True) + elif orientation == 4: + image = image.transpose(Image.FLIP_TOP_BOTTOM) + elif orientation == 5: + image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(90, expand=True) + elif orientation == 6: + image = image.rotate(270, expand=True) + elif orientation == 7: + image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(270, expand=True) + elif orientation == 8: + image = image.rotate(90, expand=True) + + # reset orientation to 1 + if orientation != 1: + logger.info(f"reset orientation from {orientation} to 1") + exif[orientation_key] = 1 + changed = True + + if not changed: + return data + output = io.BytesIO() + image.save(output, format=image.format or "JPEG", exif=exif.tobytes()) + return output.getvalue() + + +def media_to_wav(data): + with tempfile.NamedTemporaryFile() as fin: + fin.write(data) + fin.flush() + cmd = ["ffmpeg", "-i", fin.name, "-f", "wav", "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2", "pipe:1"] + p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + assert p.returncode == 0, f"media to wav failed: {p.stderr.decode()}" + return p.stdout + + +def format_audio_data(data): + if len(data) < 4: + raise ValueError("Audio file too short") + data = media_to_wav(data) + waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10) + logger.info(f"load audio: {waveform.size()}, {sample_rate}") + assert waveform.numel() > 0, "audio is empty" + assert sample_rate > 0, "audio sample rate is not valid" + return data + + +async def preload_data(inp, inp_type, typ, val): + try: + if typ == "url": + timeout = int(os.getenv("REQUEST_TIMEOUT", "5")) + data = await fetch_resource(val, timeout=timeout) + elif typ == "base64": + # Decode base64 in background thread to avoid blocking event loop + data = await asyncio.to_thread(base64.b64decode, val) + # For multi-person audio directory, val should be a dict with file structure + elif typ == "directory": + data = {} + for fname, b64_data in val.items(): + data[fname] = await asyncio.to_thread(base64.b64decode, b64_data) + return {"type": "directory", "data": data} + elif typ == "stream": + # no bytes data need to be saved by data_manager + data = None + else: + raise ValueError(f"cannot read {inp}[{inp_type}] which type is {typ}!") + + # check if valid image bytes + if inp_type == "IMAGE": + data = await asyncio.to_thread(format_image_data, data) + elif inp_type == "AUDIO": + if typ != "stream" and typ != "directory": + data = await asyncio.to_thread(format_audio_data, data) + elif inp_type == "VIDEO": + # Video data doesn't need special formatting, just validate it's not empty + if len(data) == 0: + raise ValueError("Video file is empty") + logger.info(f"load video: {len(data)} bytes") + else: + raise Exception(f"cannot parse inp_type={inp_type} data") + return data + + except Exception as e: + raise ValueError(f"Failed to read {inp}, type={typ}, val={val[:100]}: {e}!") + + +async def load_inputs(params, raw_inputs, types): + inputs_data = {} + for inp in raw_inputs: + item = params.pop(inp) + bytes_data = await preload_data(inp, types[inp], item["type"], item["data"]) + + # Handle multi-person audio directory + if bytes_data is not None and isinstance(bytes_data, dict) and bytes_data.get("type") == "directory": + fs = [] + for fname, fdata in bytes_data["data"].items(): + inputs_data[f"{inp}/{fname}"] = fdata + fs.append(f"{inp}/{fname}") + params["extra_inputs"] = {inp: fs} + elif bytes_data is not None: + inputs_data[inp] = bytes_data + else: + params[inp] = item + return inputs_data + + +def check_params(params, raw_inputs, raw_outputs, types): + stream_audio = os.getenv("STREAM_AUDIO", "0") == "1" + stream_video = os.getenv("STREAM_VIDEO", "0") == "1" + for x in raw_inputs + raw_outputs: + if x in params and "type" in params[x]: + if params[x]["type"] == "stream": + if types[x] == "AUDIO": + assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1" + elif types[x] == "VIDEO": + assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1" + elif params[x]["type"] == "directory": + # Multi-person audio directory is only supported for AUDIO type + assert types[x] == "AUDIO", f"directory type is only supported for AUDIO input, got {types[x]}" + + +if __name__ == "__main__": + # https://github.com/recurser/exif-orientation-examples + exif_dir = "/data/nvme0/liuliang1/exif-orientation-examples" + out_dir = "/data/nvme0/liuliang1/exif-orientation-examples/outs" + os.makedirs(out_dir, exist_ok=True) + + for base_name in ["Landscape", "Portrait"]: + for i in range(9): + fin_name = os.path.join(exif_dir, f"{base_name}_{i}.jpg") + fout_name = os.path.join(out_dir, f"{base_name}_{i}_formatted.jpg") + logger.info(f"format image: {fin_name} -> {fout_name}") + with open(fin_name, "rb") as f: + data = f.read() + data = format_image_data(data) + with open(fout_name, "wb") as f: + f.write(data) diff --git a/lightx2v/deploy/common/va_controller.py b/lightx2v/deploy/common/va_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..688bce66fa4de2cab9f15121414d53fa66b3c027 --- /dev/null +++ b/lightx2v/deploy/common/va_controller.py @@ -0,0 +1,202 @@ +import math +import os + +import torch +import torch.distributed as dist +from loguru import logger + +from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims +from lightx2v_platform.base.global_var import AI_DEVICE + + +class NextControl: + def __init__(self, action: str, data: any = None): + # action: switch, data: prev_video tensor + # action: wait, data: None + # action: fetch, data: None + self.action = action + self.data = data + + +class VAController: + def __init__(self, model_runner): + self.reader = None + self.recorder = None + self.rank = 0 + self.world_size = 1 + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size + self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size + self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None) + self.init_recorder() + self.init_reader(model_runner) + + def init_base(self, config, input_info, has_vfi_model, has_vsr_model): + self.audio_path = input_info.audio_path + self.output_video_path = input_info.save_result_path + if isinstance(self.output_video_path, dict): + self.output_video_path = self.output_video_path["data"] + + self.audio_sr = config.get("audio_sr", 16000) + self.target_fps = config.get("target_fps", 16) + self.max_num_frames = config.get("target_video_length", 81) + self.prev_frame_length = config.get("prev_frame_length", 5) + + self.record_fps = config.get("target_fps", 16) + if "video_frame_interpolation" in config and has_vfi_model: + self.record_fps = config["video_frame_interpolation"]["target_fps"] + self.record_fps = config.get("record_fps", self.record_fps) + + self.tgt_h = input_info.target_shape[0] + self.tgt_w = input_info.target_shape[1] + self.record_h, self.record_w = self.tgt_h, self.tgt_w + if "video_super_resolution" in config and has_vsr_model: + _, _, self.record_w, self.record_h = compute_scaled_and_target_dims( + self.record_w, + self.record_h, + scale=config["video_super_resolution"]["scale"], + multiple=128, + ) + + # how many frames to publish stream as a batch + self.slice_frame = config.get("slice_frame", self.prev_frame_length) + # estimate the max infer seconds, for immediate switch with local omni + slice_interval = self.slice_frame / self.record_fps + est_max_infer_secs = config.get("est_max_infer_secs", 0.6) + self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval) + self.min_stay_queue_num = self.est_infer_end_idx * 2 + 1 + + def init_recorder(self): + if not self.output_video_path or self.rank != self.target_recorder_rank: + return + logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}") + whip_shared_path = os.getenv("WHIP_SHARED_LIB", None) + if whip_shared_path and self.output_video_path.startswith("http"): + from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder + + self.recorder = X264VARecorder( + whip_shared_path=whip_shared_path, + livestream_url=self.output_video_path, + fps=self.record_fps, + sample_rate=self.audio_sr, + slice_frame=self.slice_frame, + prev_frame=self.prev_frame_length, + ) + else: + from lightx2v.deploy.common.va_recorder import VARecorder + + self.recorder = VARecorder( + livestream_url=self.output_video_path, + fps=self.record_fps, + sample_rate=self.audio_sr, + slice_frame=self.slice_frame, + prev_frame=self.prev_frame_length, + ) + + def init_reader(self, model_runner=None): + if not isinstance(self.audio_path, dict): + return + assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}" + segment_duration = self.max_num_frames / self.target_fps + prev_duration = self.prev_frame_length / self.target_fps + omni_work_dir = os.getenv("OMNI_WORK_DIR", None) + if omni_work_dir: + from lightx2v.deploy.common.va_reader_omni import OmniVAReader + + self.reader = OmniVAReader( + rank=self.rank, + world_size=self.world_size, + stream_url=self.audio_path["data"], + sample_rate=self.audio_sr, + segment_duration=segment_duration, + prev_duration=prev_duration, + target_rank=self.target_reader_rank, + model_runner=model_runner, + huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None), + ) + else: + from lightx2v.deploy.common.va_reader import VAReader + + self.reader = VAReader( + rank=self.rank, + world_size=self.world_size, + stream_url=self.audio_path["data"], + sample_rate=self.audio_sr, + segment_duration=segment_duration, + prev_duration=prev_duration, + target_rank=self.target_reader_rank, + ) + + def start(self): + self.reader.start() + if self.rank == self.target_recorder_rank: + assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}" + self.recorder.start(self.record_w, self.record_h) + if self.world_size > 1: + dist.barrier() + + def next_control(self): + from lightx2v.deploy.common.va_reader_omni import OmniVAReader + + if isinstance(self.reader, OmniVAReader): + return self.omni_reader_next_control() + return NextControl(action="fetch") + + def before_control(self): + from lightx2v.deploy.common.va_reader_omni import OmniVAReader + + if isinstance(self.reader, OmniVAReader): + self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE) + self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE) + self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE) + + def omni_reader_next_control(self): + immediate_switch = self.reader.get_immediate_switch() + if immediate_switch == 1: + # truncate the stream buffer to keep the max infer time length + # and broadcast the prev video tensor to all ranks + if self.rank == self.target_recorder_rank: + logger.warning(f"runner recv immediate switch, truncate stream buffer") + video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx) + if video_tensor is not None: + self.flag_tensor.fill_(1) + self.prev_tensor.copy_(video_tensor) + else: + self.flag_tensor.fill_(0) + dist.broadcast(self.flag_tensor, src=self.target_recorder_rank) + if self.flag_tensor.item() == 1: + dist.broadcast(self.prev_tensor, src=self.target_recorder_rank) + return NextControl(action="switch", data=self.prev_tensor) + else: + # get the length of stream buffer, broadcast to all ranks + if self.rank == self.target_recorder_rank: + stream_buffer_length = self.recorder.get_buffer_stream_size() + self.len_tensor.copy_(stream_buffer_length) + dist.broadcast(self.len_tensor, src=self.target_recorder_rank) + buffer_length = self.len_tensor.item() + # stream buffer is enough, skip infer + if buffer_length >= self.min_stay_queue_num: + return NextControl(action="wait") + return NextControl(action="fetch") + + def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor): + if self.recorder.realtime: + self.recorder.buffer_stream(images, audios, gen_video) + else: + self.recorder.pub_livestream(images, audios) + + def clear(self): + self.len_tensor = None + self.flag_tensor = None + self.prev_tensor = None + if self.reader is not None: + self.reader.stop() + self.reader = None + if self.recorder is not None: + self.recorder.stop() + self.recorder = None + + def __del__(self): + self.clear() diff --git a/lightx2v/deploy/common/va_reader.py b/lightx2v/deploy/common/va_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..6579d51f0bc08272be9a5957734a5685222786d1 --- /dev/null +++ b/lightx2v/deploy/common/va_reader.py @@ -0,0 +1,274 @@ +import os +import queue +import signal +import subprocess +import threading +import time +import traceback + +import numpy as np +import torch +import torch.distributed as dist +from loguru import logger + + +class VAReader: + def __init__( + self, + rank: int, + world_size: int, + stream_url: str, + segment_duration: float = 5.0, + sample_rate: int = 16000, + audio_channels: int = 1, + buffer_size: int = 1, + prev_duration: float = 0.3125, + target_rank: int = 0, + ): + self.rank = rank + self.world_size = world_size + self.stream_url = stream_url + self.segment_duration = segment_duration + self.sample_rate = sample_rate + self.audio_channels = audio_channels + self.prev_duration = prev_duration + # int16 = 2 bytes + self.chunk_size = int(self.segment_duration * self.sample_rate) * 2 + self.prev_size = int(self.prev_duration * self.sample_rate) * 2 + self.prev_chunk = None + self.buffer_size = buffer_size + + self.audio_queue = queue.Queue(maxsize=self.buffer_size) + self.audio_thread = None + self.ffmpeg_process = None + self.bytes_buffer = bytearray() + + self.target_rank = target_rank % self.world_size + + self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda") + self.audio_tensor = torch.zeros(self.chunk_size, dtype=torch.uint8, device="cuda") + + logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}") + logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz") + + def start(self): + if self.rank == self.target_rank: + if self.stream_url.startswith("rtmp://"): + self.start_ffmpeg_process_rtmp() + elif self.stream_url.startswith("http"): + self.start_ffmpeg_process_whep() + else: + raise Exception(f"Unsupported stream URL: {self.stream_url}") + self.audio_thread = threading.Thread(target=self.audio_worker, daemon=True) + self.audio_thread.start() + logger.info(f"VAReader {self.rank}/{self.world_size} started successfully") + else: + logger.info(f"VAReader {self.rank}/{self.world_size} wait only") + if self.world_size > 1: + logger.info(f"VAReader {self.rank}/{self.world_size} wait barrier") + dist.barrier() + logger.info(f"VAReader {self.rank}/{self.world_size} end barrier") + + def start_ffmpeg_process_rtmp(self): + """Start ffmpeg process read audio from stream""" + ffmpeg_cmd = [ + "ffmpeg", + "-i", + self.stream_url, + "-vn", + # "-acodec", + # "pcm_s16le", + "-ar", + str(self.sample_rate), + "-ac", + str(self.audio_channels), + "-f", + "s16le", + "-", + ] + try: + self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0) + logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}") + logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}") + except Exception as e: + logger.error(f"Failed to start FFmpeg process: {e}") + raise + + def start_ffmpeg_process_whep(self): + """Start gstream process read audio from stream""" + ffmpeg_cmd = [ + "gst-launch-1.0", + "-q", + "whepsrc", + f"whep-endpoint={self.stream_url}", + "video-caps=none", + "!rtpopusdepay", + "!opusdec", + "plc=false", + "!audioconvert", + "!audioresample", + f"!audio/x-raw,format=S16LE,channels={self.audio_channels},rate={self.sample_rate}", + "!fdsink", + "fd=1", + ] + try: + self.ffmpeg_process = subprocess.Popen( + ffmpeg_cmd, + stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + bufsize=0, + ) + logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}") + logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}") + except Exception as e: + logger.error(f"Failed to start FFmpeg process: {e}") + raise + + def audio_worker(self): + logger.info("Audio pull worker thread started") + try: + while True: + if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None: + logger.warning("FFmpeg process exited, audio worker thread stopped") + break + self.fetch_audio_data() + time.sleep(0.01) + except: # noqa + logger.error(f"Audio pull worker error: {traceback.format_exc()}") + finally: + logger.warning("Audio pull worker thread stopped") + + def fetch_audio_data(self): + """Fetch audio data from ffmpeg process""" + try: + audio_bytes = self.ffmpeg_process.stdout.read(self.chunk_size) + if not audio_bytes: + return + self.bytes_buffer.extend(audio_bytes) + # logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes") + + if len(self.bytes_buffer) >= self.chunk_size: + audio_data = self.bytes_buffer[: self.chunk_size] + self.bytes_buffer = self.bytes_buffer[self.chunk_size :] + + # first chunk, read original 81 frames + # for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames + if self.prev_chunk is None: + logger.info(f"change chunk_size: from {self.chunk_size} to {self.chunk_size - self.prev_size}") + self.chunk_size -= self.prev_size + else: + audio_data = self.prev_chunk + audio_data + self.prev_chunk = audio_data[-self.prev_size :] + + try: + self.audio_queue.put_nowait(audio_data) + except queue.Full: + logger.warning(f"Audio queue full:{self.audio_queue.qsize()}, discarded oldest chunk") + self.audio_queue.get_nowait() + self.audio_queue.put_nowait(audio_data) + logger.info(f"Put audio data: {len(audio_data)} bytes, audio_queue: {self.audio_queue.qsize()}, chunk_size:{self.chunk_size}") + + except: # noqa + logger.error(f"Fetch audio data error: {traceback.format_exc()}") + + def braodcast_audio_data(self, audio_data): + if self.rank == self.target_rank: + if audio_data is None: + self.flag_tensor.fill_(0) + else: + self.flag_tensor.fill_(1) + self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8)) + logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}") + + dist.broadcast(self.flag_tensor, src=self.target_rank) + if self.flag_tensor.item() == 0: + return None + + dist.broadcast(self.audio_tensor, src=self.target_rank) + if self.rank != self.target_rank: + logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}") + audio_data = self.audio_tensor.cpu().numpy().tobytes() + return audio_data + + def bytes_to_ndarray(self, audio_data): + if audio_data is None: + return None + audio_data = np.frombuffer(audio_data, dtype=np.int16) + audio_data = audio_data.astype(np.float32) / 32768.0 + logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}") + return audio_data + + def get_audio_segment(self, timeout: float = 1.0): + audio_data = None + if self.rank == self.target_rank: + try: + audio_data = self.audio_queue.get(timeout=timeout) + except: # noqa + logger.warning(f"Failed to get audio segment: {traceback.format_exc()}") + if self.world_size > 1: + audio_data = self.braodcast_audio_data(audio_data) + audio_data = self.bytes_to_ndarray(audio_data) + return audio_data + + def stop(self): + # Stop ffmpeg process + if self.ffmpeg_process: + self.ffmpeg_process.send_signal(signal.SIGINT) + try: + self.ffmpeg_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.ffmpeg_process.kill() + logger.warning("FFmpeg reader process stopped") + + # Wait for threads to finish + if self.audio_thread and self.audio_thread.is_alive(): + self.audio_thread.join(timeout=5) + if self.audio_thread.is_alive(): + logger.error("Audio pull thread did not stop gracefully") + + while self.audio_queue and self.audio_queue.qsize() > 0: + self.audio_queue.get_nowait() + self.audio_queue = None + logger.warning("Audio pull queue cleaned") + + def __del__(self): + self.stop() + + +if __name__ == "__main__": + WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) + RANK = int(os.environ.get("RANK", 0)) + if WORLD_SIZE > 1: + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank()) + logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}") + + reader = VAReader( + RANK, + WORLD_SIZE, + # "rtmp://localhost/live/test_audio", + "https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000", + segment_duration=1.0, + sample_rate=16000, + audio_channels=1, + prev_duration=1 / 16, + ) + reader.start() + fail_count = 0 + max_fail_count = 2 + + try: + while True: + audio_data = reader.get_audio_segment(timeout=2) + if audio_data is not None: + # logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]") + fail_count = 0 + else: + fail_count += 1 + if fail_count > max_fail_count: + logger.warning("Failed to get audio chunk, stop reader") + reader.stop() + break + time.sleep(0.95) + finally: + reader.stop() diff --git a/lightx2v/deploy/common/va_reader_omni.py b/lightx2v/deploy/common/va_reader_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..cdec87abd0226df993e65d716b3cdf3cae18bcc9 --- /dev/null +++ b/lightx2v/deploy/common/va_reader_omni.py @@ -0,0 +1,508 @@ +import datetime +import json +import os +import random +import subprocess +import threading +import time +import traceback +from collections import deque +from copy import deepcopy + +import jsonschema +import numpy as np +import torch +import torch.distributed as dist +import zmq +from loguru import logger + +try: + from bson import BSON +except ImportError: + BSON = None + logger.warning("BSON is not installed") +from scipy.signal import resample + + +class AudioInfo: + def __init__(self, info: dict): + self.sample_count = info["sample_count"] + self.sample_rate = info["sample_rate"] + self.channel_count = info["channel_count"] + self.sample_fmt = info["sample_fmt"] + self.pts = info["pts"] + + def is_spec_equal(self, other: "AudioInfo") -> bool: + return self.sample_fmt == other.sample_fmt and self.sample_rate == other.sample_rate and self.channel_count == other.channel_count + + def duration(self) -> datetime.timedelta: + return datetime.timedelta(seconds=self.sample_count / self.sample_rate) + + def __str__(self): + return "AudioInfo(sample_count={}, sample_rate={}, channel_count={}, sample_fmt={}, pts={})".format(self.sample_count, self.sample_rate, self.channel_count, self.sample_fmt, self.pts) + + +class ByteBuffer: + def __init__(self): + self.buffer = deque() + self.current_size = 0 + # is the audio belonging to current turn finished + self.audio_finished = False + + def add(self, byte_data: bytes): + self.buffer.append(byte_data) + self.current_size += len(byte_data) + + def get(self, size=1024): + data = bytearray() + + while size > 0 and len(self.buffer) > 0: + chunk = self.buffer.popleft() + if len(chunk) <= size: + # 如果当前数据小于size,则将当前数据全部添加到data中 + data.extend(chunk) + self.current_size -= len(chunk) + size -= len(chunk) + else: + # 如果当前数据大于size,则将当前数据的一部分添加到data中,剩余部分留在缓冲区 + data.extend(chunk[:size]) + self.buffer.appendleft(chunk[size:]) # 剩余部分留在缓冲区 + self.current_size -= size + size = 0 + + return bytes(data) + + def mark_finished(self): + self.audio_finished = True + + def has_more_voice(self): + return not self.audio_finished + + def __len__(self): + return self.current_size + + +class ChatAdapter: + def __init__( + self, + omni_work_dir: str, + whep_url: str, + session_id: str, + account: str, + config_files: list[str], + config_schema_path: str, + seg_duration: float, + model_runner, + huoshan_tts_voice_type, + ): + assert os.path.exists(omni_work_dir), f"OMNI work directory {omni_work_dir} does not exist" + self.omni_work_dir = omni_work_dir + self.context = zmq.Context() + self.w2f_socket = self.context.socket(zmq.PULL) + self.w2f_url = ChatAdapter.select_and_bind(self.w2f_socket) + self.f2w_socket = self.context.socket(zmq.PUSH) + self.f2w_url = ChatAdapter.select_and_bind(self.f2w_socket) + self.recv_thread = None + self.audio_buffer = ByteBuffer() + self.audio_info = None + self.chat_server_cmd = [ + os.path.join(self.omni_work_dir, "bin", "seko-chatter"), + "--session-id", + session_id, + "--account", + account, + "--whep-server-url", + whep_url, + "--w2f-endpoint", + self.w2f_url, + "--f2w-endpoint", + self.f2w_url, + "--config-files", + *config_files, + ] + override_config = {} + if huoshan_tts_voice_type is not None: + logger.info(f"Use Huoshan TTS voice type: {huoshan_tts_voice_type}") + override_config["TTS"] = { + "default_voice_info": { + "voice_type": huoshan_tts_voice_type, + "provider": "huoshan_stream_tts", + } + } + with open(config_schema_path, "r") as f: + schema = json.load(f) + jsonschema.validate(instance=override_config, schema=schema) + if override_config is not None: + self.chat_server_cmd.extend(["--override-config", json.dumps(override_config)]) + self.chatter_proc = None + + self.seg_duration = seg_duration + self.reset_prev = False + self.status = "blank" + self.immediate_switch = 0 + self.model_runner = model_runner + + def launch_chat_server(self): + env = { + "RUST_LOG": "info,duplex_server=debug,backend_5o=debug", + "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", "") + ":" + os.path.join(self.omni_work_dir, "lib/"), + "PATH": os.environ["PATH"] + ":" + os.path.join(self.omni_work_dir, "bin/"), + } + self.chatter_proc = subprocess.Popen(self.chat_server_cmd, env=env, cwd=self.omni_work_dir) + + @staticmethod + def select_and_bind(socket: zmq.Socket) -> str: + # randomly select a port between 1024 and 6553 + retry_count = 20 + err = None + while retry_count > 0: + try: + port = random.randint(1024, 65535) + # port = 5555 + url = f"tcp://localhost:{port}" + socket.bind(url) + return url + except zmq.error.ZMQError as e: + retry_count -= 1 + err = e + raise err + + # immediate switch to status, discard prev_bytes, set immediate_switch to 1 + def immediate_switch_to(self, status): + logger.warning(f"VA reader immediate switch to {status}") + self.reset_prev = True + self.status = status + self.immediate_switch = 1 + if self.model_runner is not None: + self.model_runner.pause_signal = True + logger.warning(f"Model runner pause signal set to True") + + def recv_loop(self): + while True: + try: + message = self.w2f_socket.recv() + except Exception: + logger.error(f"Error receiving message: {traceback.format_exc()}") + break + try: + message = BSON.decode(message) + msg_type = message["type"] + logger.debug("Received message type: {}".format(msg_type)) + if msg_type == "AgentAudio": + audio = message["audio"] + if audio["type"] != "Pcm": + logger.error("Unsupported audio type: {}".format(audio["type"])) + continue + pcm_data = audio["data"] + audio_info = AudioInfo(audio["info"]) + logger.debug("Received audio with duration: {}".format(audio_info.duration())) + if self.audio_info is None: + self.audio_info = audio_info + else: + # check if the audio info is the same + if not self.audio_info.is_spec_equal(audio_info): + raise ValueError("Audio info mismatch") + self.audio_buffer.add(pcm_data) + # if status is blank and has voice, set immediate switch to 1 + if self.status == "blank" and self.has_voice(self.seg_duration): + self.immediate_switch_to("voice") + elif msg_type == "AgentStartPlay": + logger.debug("Received AgentStartPlay, create new audio buffer") + self.audio_buffer = ByteBuffer() + elif msg_type == "AgentEndPlay": + logger.debug("Received AgentEndPlay, mark audio finished") + self.audio_buffer.mark_finished() + elif msg_type == "ClearAgentAudio": + logger.warning("Received ClearAgentAudio, clear audio buffer") + self.audio_buffer = None + self.audio_info = None + if self.status == "voice": + self.status = "blank" + # self.immediate_switch_to("blank") + except Exception as e: + logger.error("Error decoding message: {}, continue".format(e)) + continue + logger.warning("recv loop interrupted") + + def start(self): + self.launch_chat_server() + self.recv_thread = threading.Thread(target=self.recv_loop) + self.recv_thread.start() + + def has_voice(self, duration) -> bool: + if self.audio_info is None or self.audio_buffer.current_size == 0: + return False + bytes_count = round(duration * self.audio_info.sample_rate) * self.audio_info.channel_count * 2 # S16LE assumed + # if not has enough bytes and maybe has more voice, return False + if self.audio_buffer.current_size < bytes_count and self.audio_buffer.has_more_voice(): + logger.warning(f"Not enough bytes and maybe has more voice, content_size: {self.audio_buffer.current_size}, bytes_count: {bytes_count}") + return False + return bytes_count + + def get_audio(self, fetch_duration) -> (bytes, AudioInfo): + bytes_count = self.has_voice(fetch_duration) + if bytes_count is False: + return None + pcm_data = self.audio_buffer.get(bytes_count) + + # the actual sample count fetched + sample_count = len(pcm_data) // (self.audio_info.channel_count * 2) + logger.debug("Fetched {} bytes audio".format(sample_count)) + logger.debug("After fetch, there are {} bytes left".format(self.audio_buffer.current_size)) + audio_info = deepcopy(self.audio_info) + audio_info.sample_count = sample_count + return (pcm_data, audio_info) + + def stop(self): + self.model_runner = None + if self.chatter_proc is not None: + self.chatter_proc.terminate() + self.chatter_proc.wait() + self.chatter_proc = None + self.w2f_socket.close() + self.f2w_socket.close() + + def __del__(self): + self.stop() + + +class OmniVAReader: + def __init__( + self, + rank: int, + world_size: int, + stream_url: str, + segment_duration: float = 5.0625, + sample_rate: int = 16000, + audio_channels: int = 1, + buffer_size: int = 1, + prev_duration: float = 0.3125, + target_rank: int = 0, + model_runner=None, + huoshan_tts_voice_type=None, + ): + self.rank = rank + self.world_size = world_size + self.stream_url = stream_url + self.segment_duration = segment_duration + self.sample_rate = sample_rate + + self.audio_channels = audio_channels + self.prev_duration = prev_duration + self.all_seg_sample_count = int(self.segment_duration * self.sample_rate) + self.prev_seg_sample_count = int(self.prev_duration * self.sample_rate) + self.prev_seg_chunk = None + + self.target_rank = target_rank % self.world_size + self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda") + self.immediate_switch_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda") + chunk_size = int(self.segment_duration * self.sample_rate) * 2 + self.audio_tensor = torch.zeros(chunk_size, dtype=torch.uint8, device="cuda") + self.chat_adapter = None + self.model_runner = model_runner + self.huoshan_tts_voice_type = huoshan_tts_voice_type + + assert self.audio_channels == 1, "Only mono audio is supported for OmniVAReader" + logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}") + logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz") + + def init_omni_env(self): + self.omni_work_dir = os.getenv("OMNI_WORK_DIR", "/path/of/seko_chatter/") + self.session_id = os.getenv("OMNI_SESSION_ID", "") + self.account = os.getenv("OMNI_ACCOUNT", "") + self.config_files = os.getenv("OMNI_CONFIG_FILES", "").split(",") + self.config_schema_path = os.getenv("OMNI_CONFIG_SCHEMA_PATH", None) + assert os.path.exists(self.omni_work_dir), f"OMNI work directory {self.omni_work_dir} does not exist" + assert self.session_id and self.account, "OMNI_SESSION_ID and OMNI_ACCOUNT are required" + logger.info( + f"OMNI work directory: {self.omni_work_dir}, session_id: {self.session_id}, account: {self.account}, config_files: {self.config_files}, config_schema_path: {self.config_schema_path}" + ) + + def start(self): + if self.rank == self.target_rank: + self.init_omni_env() + assert self.stream_url.startswith("http"), "Only HTTP stream is supported for OmniVAReader" + self.chat_adapter = ChatAdapter( + omni_work_dir=self.omni_work_dir, + whep_url=self.stream_url, + session_id=self.session_id, + account=self.account, + config_files=self.config_files, + config_schema_path=self.config_schema_path, + seg_duration=self.segment_duration, + model_runner=self.model_runner, + huoshan_tts_voice_type=self.huoshan_tts_voice_type, + ) + self.chat_adapter.start() + logger.info(f"OmniVAReader {self.rank}/{self.world_size} started successfully") + else: + logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait only") + if self.world_size > 1: + logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait barrier") + dist.barrier() + logger.info(f"OmniVAReader {self.rank}/{self.world_size} end barrier") + + def braodcast_audio_data(self, audio_data): + if self.rank == self.target_rank: + if audio_data is None: + self.flag_tensor.fill_(0) + else: + self.flag_tensor.fill_(1) + self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8)) + # logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}") + + dist.broadcast(self.flag_tensor, src=self.target_rank) + if self.flag_tensor.item() == 0: + return None + + dist.broadcast(self.audio_tensor, src=self.target_rank) + if self.rank != self.target_rank: + # logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}") + audio_data = self.audio_tensor.cpu().numpy().tobytes() + return audio_data + + def bytes_to_ndarray(self, audio_data): + if audio_data is None: + return None + audio_data = np.frombuffer(audio_data, dtype=np.int16) + audio_data = audio_data.astype(np.float32) / 32768.0 + # logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}") + return audio_data + + def convert_pcm_s16le_to_mono_resampled(self, audio_data, audio_info): + audio = np.frombuffer(audio_data, dtype=np.int16) + sample_count = audio_info.sample_count + assert len(audio) == sample_count * audio_info.channel_count, f"audio length {len(audio)} != sample_count * channel_count {sample_count * audio_info.channel_count}" + # convert to mono + if audio_info.channel_count > 1: + audio = audio.reshape(-1, audio_info.channel_count).mean(axis=1) + + # logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()}") + if audio_info.sample_rate != self.sample_rate: + sample_count = int(len(audio) * self.sample_rate / audio_info.sample_rate) + audio = resample(audio, sample_count).astype(np.int16) + # logger.info(f"resampled audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}") + logger.warning(f"valid audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}") + return audio, sample_count + + def prepare_audio_data(self, chat_audio_result): + sample_count = 0 + audio = np.array([], dtype=np.int16) + + # convert chat audio result to mono and target sample rate + if chat_audio_result is not None: + audio_data, audio_info = chat_audio_result + audio, sample_count = self.convert_pcm_s16le_to_mono_resampled(audio_data, audio_info) + + # if is not the first segment, concat with previous segment + if self.prev_seg_chunk is not None: + audio = np.concatenate([self.prev_seg_chunk, audio]) + sample_count = len(audio) + assert sample_count <= self.all_seg_sample_count, f"audio length {sample_count} > all_seg_sample_count {self.all_seg_sample_count}" + + # pad 0 to the audio to make it the same length as all_seg_sample_count + if sample_count < self.all_seg_sample_count: + pad_count = self.all_seg_sample_count - sample_count + # logger.info(f"pad {pad_count} samples to audio") + audio = np.pad(audio, (0, pad_count), mode="constant", constant_values=0) + sample_count = len(audio) + + # update prev seg chunk + self.prev_seg_chunk = audio[-self.prev_seg_sample_count :] + # logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}, prev seg chunk: {self.prev_seg_chunk.shape}") + return audio.tobytes() + + def get_fetch_duration(self): + fetch_duration = self.segment_duration + # after immediate switch, reset prev seg chunk + if self.chat_adapter.reset_prev: + self.prev_seg_chunk = None + self.chat_adapter.reset_prev = False + logger.warning(f"Reset prev seg chunk") + # first segment, fetch segment_duration, else fetch segment_duration - prev_duration + if self.prev_seg_chunk is not None: + fetch_duration -= self.prev_duration + return fetch_duration + + def get_audio_segment(self): + audio_data = None + if self.rank == self.target_rank: + try: + fetch_duration = self.get_fetch_duration() + # logger.info(f"Get segment, fetch_duration: {fetch_duration}") + if self.chat_adapter.status == "voice": + audio_result = self.chat_adapter.get_audio(fetch_duration) + audio_data = self.prepare_audio_data(audio_result) + # think all voice segments inferred, naturally switch to blank + if audio_result is None: + logger.info(f"Think all voice segments inferred, naturally switch to blank") + self.chat_adapter.status = "blank" + else: + audio_data = self.prepare_audio_data(None) + except Exception as e: + logger.warning(f"Failed to get voice segment: {e}") + return None + if self.world_size > 1: + audio_data = self.braodcast_audio_data(audio_data) + audio_data = self.bytes_to_ndarray(audio_data) + return audio_data + + def get_immediate_switch(self): + if self.rank == self.target_rank: + if self.chat_adapter.immediate_switch == 1: + self.immediate_switch_tensor.fill_(1) + # reset immediate switch + self.chat_adapter.immediate_switch = 0 + else: + self.immediate_switch_tensor.fill_(0) + dist.broadcast(self.immediate_switch_tensor, src=self.target_rank) + immediate_switch = self.immediate_switch_tensor.item() + return immediate_switch + + def stop(self): + self.model_runner = None + if self.chat_adapter is not None: + self.chat_adapter.stop() + self.chat_adapter = None + logger.warning("OmniVAReader stopped") + + def __del__(self): + self.stop() + + +if __name__ == "__main__": + WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) + RANK = int(os.environ.get("RANK", 0)) + if WORLD_SIZE > 1: + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank()) + logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}") + + reader = OmniVAReader( + RANK, + WORLD_SIZE, + "https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=publish&stream=test_stream_ll&eip=10.120.114.82:8000", + segment_duration=17 / 16, + sample_rate=16000, + audio_channels=1, + prev_duration=1 / 16, + ) + reader.start() + fail_count = 0 + max_fail_count = 100000000 + + try: + while True: + audio_data = reader.get_audio_segment(timeout=1) + if audio_data is not None: + logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]") + fail_count = 0 + else: + fail_count += 1 + if fail_count > max_fail_count: + logger.warning("Failed to get audio chunk, stop reader") + reader.stop() + break + time.sleep(0.95) + finally: + reader.stop() diff --git a/lightx2v/deploy/common/va_recorder.py b/lightx2v/deploy/common/va_recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..de551d338098ea40d649c00b1441450baa27894a --- /dev/null +++ b/lightx2v/deploy/common/va_recorder.py @@ -0,0 +1,657 @@ +import os +import queue +import socket +import subprocess +import threading +import time +import traceback + +import numpy as np +import torch +import torchaudio as ta +from loguru import logger + + +def pseudo_random(a, b): + x = str(time.time()).split(".")[1] + y = int(float("0." + x) * 1000000) + return a + (y % (b - a + 1)) + + +class VARecorder: + def __init__( + self, + livestream_url: str, + fps: float = 16.0, + sample_rate: int = 16000, + slice_frame: int = 1, + prev_frame: int = 1, + ): + self.livestream_url = livestream_url + self.fps = fps + self.sample_rate = sample_rate + self.audio_port = pseudo_random(32000, 40000) + self.video_port = self.audio_port + 1 + self.ffmpeg_log_level = os.getenv("FFMPEG_LOG_LEVEL", "error") + logger.info(f"VARecorder audio port: {self.audio_port}, video port: {self.video_port}, ffmpeg_log_level: {self.ffmpeg_log_level}") + + self.width = None + self.height = None + self.stoppable_t = None + self.realtime = False + if self.livestream_url.startswith("rtmp://") or self.livestream_url.startswith("http"): + self.realtime = True + + # ffmpeg process for mix video and audio data and push to livestream + self.ffmpeg_process = None + + # TCP connection objects + self.audio_socket = None + self.video_socket = None + self.audio_conn = None + self.video_conn = None + self.audio_thread = None + self.video_thread = None + + # queue for send data to ffmpeg process + self.audio_queue = queue.Queue() + self.video_queue = queue.Queue() + + # buffer for stream data + self.audio_samples_per_frame = round(self.sample_rate / self.fps) + self.stream_buffer = [] + self.stream_buffer_lock = threading.Lock() + self.stop_schedule = False + self.schedule_thread = None + self.slice_frame = slice_frame + self.prev_frame = prev_frame + assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame" + + def init_sockets(self): + # TCP socket for send and recv video and audio data + self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.video_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.video_socket.bind(("127.0.0.1", self.video_port)) + self.video_socket.listen(1) + + self.audio_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.audio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.audio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.audio_socket.bind(("127.0.0.1", self.audio_port)) + self.audio_socket.listen(1) + + def audio_worker(self): + try: + logger.info("Waiting for ffmpeg to connect to audio socket...") + self.audio_conn, _ = self.audio_socket.accept() + logger.info(f"Audio connection established from {self.audio_conn.getpeername()}") + fail_time, max_fail_time = 0, 10 + while True: + try: + if self.audio_queue is None: + break + data = self.audio_queue.get() + if data is None: + logger.info("Audio thread received stop signal") + break + # Convert audio data to 16-bit integer format + audios = torch.clamp(torch.round(data * 32767), -32768, 32767).to(torch.int16) + try: + self.audio_conn.send(audios[None].cpu().numpy().tobytes()) + except (BrokenPipeError, OSError, ConnectionResetError) as e: + logger.info(f"Audio connection closed, stopping worker: {type(e).__name__}") + return + fail_time = 0 + except (BrokenPipeError, OSError, ConnectionResetError): + logger.info("Audio connection closed during queue processing") + break + except Exception: + logger.error(f"Send audio data error: {traceback.format_exc()}") + fail_time += 1 + if fail_time > max_fail_time: + logger.error(f"Audio push worker thread failed {fail_time} times, stopping...") + break + except Exception: + logger.error(f"Audio push worker thread error: {traceback.format_exc()}") + finally: + logger.info("Audio push worker thread stopped") + + def video_worker(self): + try: + logger.info("Waiting for ffmpeg to connect to video socket...") + self.video_conn, _ = self.video_socket.accept() + logger.info(f"Video connection established from {self.video_conn.getpeername()}") + fail_time, max_fail_time = 0, 10 + packet_secs = 1.0 / self.fps + while True: + try: + if self.video_queue is None: + break + data = self.video_queue.get() + if data is None: + logger.info("Video thread received stop signal") + break + + # Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg + for i in range(data.shape[0]): + t0 = time.time() + frame = (data[i] * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + try: + self.video_conn.send(frame.tobytes()) + except (BrokenPipeError, OSError, ConnectionResetError) as e: + logger.info(f"Video connection closed, stopping worker: {type(e).__name__}") + return + if self.realtime and i < data.shape[0] - 1: + time.sleep(max(0, packet_secs - (time.time() - t0))) + + fail_time = 0 + except (BrokenPipeError, OSError, ConnectionResetError): + logger.info("Video connection closed during queue processing") + break + except Exception: + logger.error(f"Send video data error: {traceback.format_exc()}") + fail_time += 1 + if fail_time > max_fail_time: + logger.error(f"Video push worker thread failed {fail_time} times, stopping...") + break + except Exception: + logger.error(f"Video push worker thread error: {traceback.format_exc()}") + finally: + logger.info("Video push worker thread stopped") + + def start_ffmpeg_process_local(self): + """Start ffmpeg process that connects to our TCP sockets""" + ffmpeg_cmd = [ + "ffmpeg", + "-fflags", + "nobuffer", + "-analyzeduration", + "0", + "-probesize", + "32", + "-flush_packets", + "1", + "-f", + "s16le", + "-ar", + str(self.sample_rate), + "-ac", + "1", + "-i", + f"tcp://127.0.0.1:{self.audio_port}", + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", + "-color_range", + "pc", + "-colorspace", + "rgb", + "-color_primaries", + "bt709", + "-color_trc", + "iec61966-2-1", + "-r", + str(self.fps), + "-s", + f"{self.width}x{self.height}", + "-i", + f"tcp://127.0.0.1:{self.video_port}", + "-ar", + "44100", + "-b:v", + "4M", + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-tune", + "zerolatency", + "-g", + f"{self.fps}", + "-pix_fmt", + "yuv420p", + "-f", + "mp4", + self.livestream_url, + "-y", + "-loglevel", + self.ffmpeg_log_level, + ] + try: + self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd) + logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}") + logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}") + except Exception as e: + logger.error(f"Failed to start FFmpeg: {e}") + + def start_ffmpeg_process_rtmp(self): + """Start ffmpeg process that connects to our TCP sockets""" + ffmpeg_cmd = [ + "ffmpeg", + "-re", + "-f", + "s16le", + "-ar", + str(self.sample_rate), + "-ac", + "1", + "-i", + f"tcp://127.0.0.1:{self.audio_port}", + "-f", + "rawvideo", + "-re", + "-pix_fmt", + "rgb24", + "-r", + str(self.fps), + "-s", + f"{self.width}x{self.height}", + "-i", + f"tcp://127.0.0.1:{self.video_port}", + "-ar", + "44100", + "-b:v", + "2M", + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-tune", + "zerolatency", + "-g", + f"{self.fps}", + "-pix_fmt", + "yuv420p", + "-f", + "flv", + self.livestream_url, + "-y", + "-loglevel", + self.ffmpeg_log_level, + ] + try: + self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd) + logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}") + logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}") + except Exception as e: + logger.error(f"Failed to start FFmpeg: {e}") + + def start_ffmpeg_process_whip(self): + """Start ffmpeg process that connects to our TCP sockets""" + ffmpeg_cmd = [ + "ffmpeg", + "-re", + "-fflags", + "nobuffer", + "-analyzeduration", + "0", + "-probesize", + "32", + "-flush_packets", + "1", + "-f", + "s16le", + "-ar", + str(self.sample_rate), + "-ac", + "1", + "-ch_layout", + "mono", + "-i", + f"tcp://127.0.0.1:{self.audio_port}", + "-f", + "rawvideo", + "-re", + "-pix_fmt", + "rgb24", + "-r", + str(self.fps), + "-s", + f"{self.width}x{self.height}", + "-i", + f"tcp://127.0.0.1:{self.video_port}", + "-ar", + "48000", + "-c:a", + "libopus", + "-ac", + "2", + "-b:v", + "2M", + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-tune", + "zerolatency", + "-g", + f"{self.fps}", + "-pix_fmt", + "yuv420p", + "-threads", + "1", + "-bf", + "0", + "-f", + "whip", + self.livestream_url, + "-y", + "-loglevel", + self.ffmpeg_log_level, + ] + try: + self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd) + logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}") + logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}") + except Exception as e: + logger.error(f"Failed to start FFmpeg: {e}") + + def start(self, width: int, height: int): + self.set_video_size(width, height) + duration = 1.0 + frames = int(self.fps * duration) + samples = int(self.sample_rate * (frames / self.fps)) + self.pub_livestream(torch.zeros((frames, height, width, 3), dtype=torch.float16), torch.zeros(samples, dtype=torch.float16)) + time.sleep(duration) + + def set_video_size(self, width: int, height: int): + if self.width is not None and self.height is not None: + assert self.width == width and self.height == height, "Video size already set" + return + self.width = width + self.height = height + self.init_sockets() + if self.livestream_url.startswith("rtmp://"): + self.start_ffmpeg_process_rtmp() + elif self.livestream_url.startswith("http"): + self.start_ffmpeg_process_whip() + else: + self.start_ffmpeg_process_local() + self.audio_thread = threading.Thread(target=self.audio_worker) + self.video_thread = threading.Thread(target=self.video_worker) + self.audio_thread.start() + self.video_thread.start() + if self.realtime: + self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer) + self.schedule_thread.start() + + # Publish ComfyUI Image tensor and audio tensor to livestream + def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor): + N, height, width, C = images.shape + M = audios.reshape(-1).shape[0] + assert C == 3, "Input must be [N, H, W, C] with C=3" + + logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]") + audio_frames = round(M * self.fps / self.sample_rate) + if audio_frames != N: + logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}") + + self.set_video_size(width, height) + self.audio_queue.put(audios) + self.video_queue.put(images) + logger.info(f"Published {N} frames and {M} audio samples") + + self.stoppable_t = time.time() + M / self.sample_rate + 3 + + def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor): + N, height, width, C = images.shape + M = audios.reshape(-1).shape[0] + assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame" + assert C == 3, "Input must be [N, H, W, C] with C=3" + + audio_frames = round(M * self.fps / self.sample_rate) + if audio_frames != N: + logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}") + self.set_video_size(width, height) + + # logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}") + rets = [] + for i in range(0, N, self.slice_frame): + end_frame = i + self.slice_frame + img = images[i:end_frame] + aud = audios[i * self.audio_samples_per_frame : end_frame * self.audio_samples_per_frame] + gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame] + rets.append((img, aud, gen)) + + with self.stream_buffer_lock: + origin_size = len(self.stream_buffer) + self.stream_buffer.extend(rets) + logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments") + + def get_buffer_stream_size(self): + return len(self.stream_buffer) + + def truncate_stream_buffer(self, size: int): + with self.stream_buffer_lock: + self.stream_buffer = self.stream_buffer[:size] + logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments") + if len(self.stream_buffer) > 0: + return self.stream_buffer[-1][2] # return the last video tensor + else: + return None + + def schedule_stream_buffer(self): + schedule_interval = self.slice_frame / self.fps + logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds") + t = None + while True: + try: + if self.stop_schedule: + break + img, aud, gen = None, None, None + with self.stream_buffer_lock: + if len(self.stream_buffer) > 0: + img, aud, gen = self.stream_buffer.pop(0) + + if t is not None: + wait_secs = schedule_interval - (time.time() - t) + if wait_secs > 0: + time.sleep(wait_secs) + t = time.time() + + if img is not None and aud is not None: + self.audio_queue.put(aud) + self.video_queue.put(img) + # logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish") + del gen + self.stoppable_t = time.time() + aud.shape[0] / self.sample_rate + 3 + else: + logger.warning(f"No stream buffer to schedule") + except Exception: + logger.error(f"Schedule stream buffer error: {traceback.format_exc()}") + break + logger.info("Schedule stream buffer thread stopped") + + def stop(self, wait=True): + if wait and self.stoppable_t: + t = self.stoppable_t - time.time() + if t > 0: + logger.warning(f"Waiting for {t} seconds to stop ...") + time.sleep(t) + self.stoppable_t = None + + if self.schedule_thread: + self.stop_schedule = True + self.schedule_thread.join(timeout=5) + if self.schedule_thread and self.schedule_thread.is_alive(): + logger.error(f"Schedule thread did not stop after 5s") + + # Send stop signals to queues + if self.audio_queue: + self.audio_queue.put(None) + if self.video_queue: + self.video_queue.put(None) + + # Wait for threads to finish processing queued data (increased timeout) + queue_timeout = 30 # Increased from 5s to 30s to allow sufficient time for large video frames + if self.audio_thread and self.audio_thread.is_alive(): + self.audio_thread.join(timeout=queue_timeout) + if self.audio_thread.is_alive(): + logger.error(f"Audio push thread did not stop after {queue_timeout}s") + if self.video_thread and self.video_thread.is_alive(): + self.video_thread.join(timeout=queue_timeout) + if self.video_thread.is_alive(): + logger.error(f"Video push thread did not stop after {queue_timeout}s") + + # Shutdown connections to signal EOF to FFmpeg + # shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed + if self.audio_conn: + try: + self.audio_conn.getpeername() + self.audio_conn.shutdown(socket.SHUT_WR) + logger.info("Audio connection shutdown initiated") + except OSError: + # Connection already closed, skip shutdown + pass + + if self.video_conn: + try: + self.video_conn.getpeername() + self.video_conn.shutdown(socket.SHUT_WR) + logger.info("Video connection shutdown initiated") + except OSError: + # Connection already closed, skip shutdown + pass + + if self.ffmpeg_process: + is_local_file = not self.livestream_url.startswith(("rtmp://", "http")) + # Local MP4 files need time to write moov atom and finalize the container + timeout_seconds = 30 if is_local_file else 10 + logger.info(f"Waiting for FFmpeg to finalize file (timeout={timeout_seconds}s, local_file={is_local_file})") + logger.info(f"FFmpeg output: {self.livestream_url}") + + try: + returncode = self.ffmpeg_process.wait(timeout=timeout_seconds) + if returncode == 0: + logger.info(f"FFmpeg process exited successfully (exit code: {returncode})") + else: + logger.warning(f"FFmpeg process exited with non-zero code: {returncode}") + except subprocess.TimeoutExpired: + logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...") + try: + self.ffmpeg_process.terminate() # SIGTERM + returncode = self.ffmpeg_process.wait(timeout=5) + logger.warning(f"FFmpeg process terminated with SIGTERM (exit code: {returncode})") + except subprocess.TimeoutExpired: + logger.error("FFmpeg process still running after SIGTERM, killing with SIGKILL...") + self.ffmpeg_process.kill() + self.ffmpeg_process.wait() # Wait for kill to complete + logger.error("FFmpeg process killed with SIGKILL") + finally: + self.ffmpeg_process = None + + if self.audio_conn: + try: + self.audio_conn.close() + except Exception as e: + logger.debug(f"Error closing audio connection: {e}") + finally: + self.audio_conn = None + + if self.video_conn: + try: + self.video_conn.close() + except Exception as e: + logger.debug(f"Error closing video connection: {e}") + finally: + self.video_conn = None + + if self.audio_socket: + try: + self.audio_socket.close() + except Exception as e: + logger.debug(f"Error closing audio socket: {e}") + finally: + self.audio_socket = None + + if self.video_socket: + try: + self.video_socket.close() + except Exception as e: + logger.debug(f"Error closing video socket: {e}") + finally: + self.video_socket = None + + if self.audio_queue: + while self.audio_queue.qsize() > 0: + try: + self.audio_queue.get_nowait() + except: # noqa + break + if self.video_queue: + while self.video_queue.qsize() > 0: + try: + self.video_queue.get_nowait() + except: # noqa + break + self.audio_queue = None + self.video_queue = None + logger.info("VARecorder stopped and resources cleaned up") + + def __del__(self): + self.stop(wait=False) + + +def create_simple_video(frames=10, height=480, width=640): + video_data = [] + for i in range(frames): + frame = np.zeros((height, width, 3), dtype=np.float32) + stripe_height = height // 8 + colors = [ + [1.0, 0.0, 0.0], # 红色 + [0.0, 1.0, 0.0], # 绿色 + [0.0, 0.0, 1.0], # 蓝色 + [1.0, 1.0, 0.0], # 黄色 + [1.0, 0.0, 1.0], # 洋红 + [0.0, 1.0, 1.0], # 青色 + [1.0, 1.0, 1.0], # 白色 + [0.5, 0.5, 0.5], # 灰色 + ] + for j, color in enumerate(colors): + start_y = j * stripe_height + end_y = min((j + 1) * stripe_height, height) + frame[start_y:end_y, :] = color + offset = int((i / frames) * width) + frame = np.roll(frame, offset, axis=1) + frame = torch.tensor(frame, dtype=torch.float32) + video_data.append(frame) + return torch.stack(video_data, dim=0) + + +if __name__ == "__main__": + sample_rate = 16000 + fps = 16 + width = 640 + height = 480 + + recorder = VARecorder( + # livestream_url="rtmp://localhost/live/test", + # livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000", + livestream_url="/path/to/output_video.mp4", + fps=fps, + sample_rate=sample_rate, + ) + + audio_path = "/path/to/test_b_2min.wav" + audio_array, ori_sr = ta.load(audio_path) + audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000) + audio_array = audio_array.reshape(-1) + secs = audio_array.shape[0] // sample_rate + interval = 1 + + for i in range(0, secs, interval): + logger.info(f"{i} / {secs} s") + start = i * sample_rate + end = (i + interval) * sample_rate + cur_audio_array = audio_array[start:end] + logger.info(f"audio: {cur_audio_array.shape} {cur_audio_array.dtype} {cur_audio_array.min()} {cur_audio_array.max()}") + + num_frames = int(interval * fps) + images = create_simple_video(num_frames, height, width) + logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}") + + recorder.pub_livestream(images, cur_audio_array) + time.sleep(interval) + recorder.stop() diff --git a/lightx2v/deploy/common/va_recorder_x264.py b/lightx2v/deploy/common/va_recorder_x264.py new file mode 100644 index 0000000000000000000000000000000000000000..93a82a7e6619af0d18cfe44c66dd4a5d44b51500 --- /dev/null +++ b/lightx2v/deploy/common/va_recorder_x264.py @@ -0,0 +1,321 @@ +import ctypes +import queue +import threading +import time +import traceback + +import numpy as np +import torch +import torchaudio as ta +from loguru import logger +from scipy.signal import resample + + +class X264VARecorder: + def __init__( + self, + whip_shared_path: str, + livestream_url: str, + fps: float = 16.0, + sample_rate: int = 16000, + slice_frame: int = 1, + prev_frame: int = 1, + ): + assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream" + self.livestream_url = livestream_url + self.fps = fps + self.sample_rate = sample_rate + + self.width = None + self.height = None + self.stoppable_t = None + + # only enable whip shared api for whip http livestream + self.whip_shared_path = whip_shared_path + self.whip_shared_lib = None + self.whip_shared_handle = None + + assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream" + self.realtime = True + + # queue for send data to whip shared api + self.queue = queue.Queue() + self.worker_thread = None + + # buffer for stream data + self.target_sample_rate = 48000 + self.target_samples_per_frame = round(self.target_sample_rate / self.fps) + self.target_chunks_per_frame = self.target_samples_per_frame * 2 + self.stream_buffer = [] + self.stream_buffer_lock = threading.Lock() + self.stop_schedule = False + self.schedule_thread = None + self.slice_frame = slice_frame + self.prev_frame = prev_frame + assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame" + + def worker(self): + try: + fail_time, max_fail_time = 0, 10 + packet_secs = 1.0 / self.fps + while True: + try: + if self.queue is None: + break + data = self.queue.get() + if data is None: + logger.info("Worker thread received stop signal") + break + audios, images = data + + for i in range(images.shape[0]): + t0 = time.time() + cur_audio = audios[i * self.target_chunks_per_frame : (i + 1) * self.target_chunks_per_frame].flatten() + audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)) + self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, self.target_samples_per_frame) + + cur_video = images[i].flatten() + video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)) + self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height) + + if self.realtime and i < images.shape[0] - 1: + time.sleep(max(0, packet_secs - (time.time() - t0))) + + fail_time = 0 + except: # noqa + logger.error(f"Send audio data error: {traceback.format_exc()}") + fail_time += 1 + if fail_time > max_fail_time: + logger.error(f"Audio push worker thread failed {fail_time} times, stopping...") + break + except: # noqa + logger.error(f"Audio push worker thread error: {traceback.format_exc()}") + finally: + logger.info("Audio push worker thread stopped") + + def start_libx264_whip_shared_api(self, width: int, height: int): + self.whip_shared_lib = ctypes.CDLL(self.whip_shared_path) + + # define function argtypes and restype + self.whip_shared_lib.initWhipStream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] + self.whip_shared_lib.initWhipStream.restype = ctypes.c_void_p + + self.whip_shared_lib.pushWhipRawAudioFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int16), ctypes.c_int] + self.whip_shared_lib.pushWhipRawVideoFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int, ctypes.c_int] + + self.whip_shared_lib.destroyWhipStream.argtypes = [ctypes.c_void_p] + + whip_url = ctypes.c_char_p(self.livestream_url.encode("utf-8")) + self.whip_shared_handle = ctypes.c_void_p(self.whip_shared_lib.initWhipStream(whip_url, 1, 1, 0, width, height)) + logger.info(f"WHIP shared API initialized with handle: {self.whip_shared_handle}") + + def convert_data(self, audios, images): + # Convert audio data to 16-bit integer format + audio_datas = torch.clamp(torch.round(audios * 32767), -32768, 32767).to(torch.int16).cpu().numpy().reshape(-1) + # Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg + image_datas = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + + logger.info(f"image_datas: {image_datas.shape} {image_datas.dtype} {image_datas.min()} {image_datas.max()}") + reample_audios = resample(audio_datas, int(len(audio_datas) * 48000 / self.sample_rate)) + stereo_audios = np.stack([reample_audios, reample_audios], axis=-1).astype(np.int16).reshape(-1) + return stereo_audios, image_datas + + def start(self, width: int, height: int): + self.set_video_size(width, height) + + def set_video_size(self, width: int, height: int): + if self.width is not None and self.height is not None: + assert self.width == width and self.height == height, "Video size already set" + return + self.width = width + self.height = height + self.start_libx264_whip_shared_api(width, height) + self.worker_thread = threading.Thread(target=self.worker) + self.worker_thread.start() + if self.realtime: + self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer) + self.schedule_thread.start() + + def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor): + N, height, width, C = images.shape + M = audios.reshape(-1).shape[0] + assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame" + assert C == 3, "Input must be [N, H, W, C] with C=3" + + audio_frames = round(M * self.fps / self.sample_rate) + if audio_frames != N: + logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}") + self.set_video_size(width, height) + audio_datas, image_datas = self.convert_data(audios, images) + + # logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}") + rets = [] + for i in range(0, N, self.slice_frame): + end_frame = i + self.slice_frame + img = image_datas[i:end_frame] + aud = audio_datas[i * self.target_chunks_per_frame : end_frame * self.target_chunks_per_frame] + gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame] + rets.append((img, aud, gen)) + + with self.stream_buffer_lock: + origin_size = len(self.stream_buffer) + self.stream_buffer.extend(rets) + logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments") + + def get_buffer_stream_size(self): + return len(self.stream_buffer) + + def truncate_stream_buffer(self, size: int): + with self.stream_buffer_lock: + self.stream_buffer = self.stream_buffer[:size] + logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments") + if len(self.stream_buffer) > 0: + return self.stream_buffer[-1][2] # return the last video tensor + else: + return None + + def schedule_stream_buffer(self): + schedule_interval = self.slice_frame / self.fps + logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds") + t = None + while True: + try: + if self.stop_schedule: + break + img, aud, gen = None, None, None + with self.stream_buffer_lock: + if len(self.stream_buffer) > 0: + img, aud, gen = self.stream_buffer.pop(0) + + if t is not None: + wait_secs = schedule_interval - (time.time() - t) + if wait_secs > 0: + time.sleep(wait_secs) + t = time.time() + + if img is not None and aud is not None: + self.queue.put((aud, img)) + # logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish") + del gen + self.stoppable_t = time.time() + img.shape[0] / self.fps + 3 + else: + logger.warning(f"No stream buffer to schedule") + except Exception: + logger.error(f"Schedule stream buffer error: {traceback.format_exc()}") + break + logger.info("Schedule stream buffer thread stopped") + + def stop(self, wait=True): + if wait and self.stoppable_t: + t = self.stoppable_t - time.time() + if t > 0: + logger.warning(f"Waiting for {t} seconds to stop ...") + time.sleep(t) + self.stoppable_t = None + + if self.schedule_thread: + self.stop_schedule = True + self.schedule_thread.join(timeout=5) + if self.schedule_thread and self.schedule_thread.is_alive(): + logger.error(f"Schedule thread did not stop after 5s") + + # Send stop signals to queues + if self.queue: + self.queue.put(None) + + # Wait for threads to finish + if self.worker_thread and self.worker_thread.is_alive(): + self.worker_thread.join(timeout=5) + if self.worker_thread.is_alive(): + logger.warning("Worker thread did not stop gracefully") + + # Destroy WHIP shared API + if self.whip_shared_lib and self.whip_shared_handle: + self.whip_shared_lib.destroyWhipStream(self.whip_shared_handle) + self.whip_shared_handle = None + self.whip_shared_lib = None + logger.warning("WHIP shared API destroyed") + + def __del__(self): + self.stop() + + +def create_simple_video(frames=10, height=480, width=640): + video_data = [] + for i in range(frames): + frame = np.zeros((height, width, 3), dtype=np.float32) + stripe_height = height // 8 + colors = [ + [1.0, 0.0, 0.0], # 红色 + [0.0, 1.0, 0.0], # 绿色 + [0.0, 0.0, 1.0], # 蓝色 + [1.0, 1.0, 0.0], # 黄色 + [1.0, 0.0, 1.0], # 洋红 + [0.0, 1.0, 1.0], # 青色 + [1.0, 1.0, 1.0], # 白色 + [0.5, 0.5, 0.5], # 灰色 + ] + for j, color in enumerate(colors): + start_y = j * stripe_height + end_y = min((j + 1) * stripe_height, height) + frame[start_y:end_y, :] = color + offset = int((i / frames) * width) + frame = np.roll(frame, offset, axis=1) + frame = torch.tensor(frame, dtype=torch.float32) + video_data.append(frame) + return torch.stack(video_data, dim=0) + + +if __name__ == "__main__": + sample_rate = 16000 + fps = 16 + width = 452 + height = 352 + + recorder = X264VARecorder( + whip_shared_path="/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/0.1.1/go_whxp.so", + livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll2&eip=10.120.114.82:8000", + fps=fps, + sample_rate=sample_rate, + ) + recorder.start(width, height) + + # time.sleep(5) + audio_path = "/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav" + audio_array, ori_sr = ta.load(audio_path) + audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000) + audio_array = audio_array.numpy().reshape(-1) + secs = audio_array.shape[0] // sample_rate + interval = 1 + space = 10 + + i = 0 + while i < space: + t0 = time.time() + logger.info(f"space {i} / {space} s") + cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32) + num_frames = int(interval * fps) + images = create_simple_video(num_frames, height, width) + recorder.buffer_stream(images, torch.tensor(cur_audio_array, dtype=torch.float32), images) + i += interval + time.sleep(interval - (time.time() - t0)) + + started = True + + i = 0 + while i < secs: + t0 = time.time() + start = int(i * sample_rate) + end = int((i + interval) * sample_rate) + cur_audio_array = torch.tensor(audio_array[start:end], dtype=torch.float32) + num_frames = int(interval * fps) + images = create_simple_video(num_frames, height, width) + logger.info(f"{i} / {secs} s") + if started: + logger.warning(f"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!") + started = False + recorder.buffer_stream(images, cur_audio_array, images) + i += interval + time.sleep(interval - (time.time() - t0)) + + recorder.stop() diff --git a/lightx2v/deploy/common/video_recorder.py b/lightx2v/deploy/common/video_recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..e58a18ede880d5faafc33af4bb00be6355f7ed81 --- /dev/null +++ b/lightx2v/deploy/common/video_recorder.py @@ -0,0 +1,422 @@ +import os +import queue +import socket +import subprocess +import threading +import time +import traceback + +import numpy as np +import torch +from loguru import logger + + +def pseudo_random(a, b): + x = str(time.time()).split(".")[1] + y = int(float("0." + x) * 1000000) + return a + (y % (b - a + 1)) + + +class VideoRecorder: + def __init__( + self, + livestream_url: str, + fps: float = 16.0, + ): + self.livestream_url = livestream_url + self.fps = fps + self.video_port = pseudo_random(32000, 40000) + self.ffmpeg_log_level = os.getenv("FFMPEG_LOG_LEVEL", "error") + logger.info(f"VideoRecorder video port: {self.video_port}, ffmpeg_log_level: {self.ffmpeg_log_level}") + + self.width = None + self.height = None + self.stoppable_t = None + self.realtime = True + + # ffmpeg process for video data and push to livestream + self.ffmpeg_process = None + + # TCP connection objects + self.video_socket = None + self.video_conn = None + self.video_thread = None + + # queue for send data to ffmpeg process + self.video_queue = queue.Queue() + + def init_sockets(self): + # TCP socket for send and recv video data + self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.video_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.video_socket.bind(("127.0.0.1", self.video_port)) + self.video_socket.listen(1) + + def video_worker(self): + try: + logger.info("Waiting for ffmpeg to connect to video socket...") + self.video_conn, _ = self.video_socket.accept() + logger.info(f"Video connection established from {self.video_conn.getpeername()}") + fail_time, max_fail_time = 0, 10 + packet_secs = 1.0 / self.fps + while True: + try: + if self.video_queue is None: + break + data = self.video_queue.get() + if data is None: + logger.info("Video thread received stop signal") + break + + # Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg + for i in range(data.shape[0]): + t0 = time.time() + frame = (data[i] * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + try: + self.video_conn.send(frame.tobytes()) + except (BrokenPipeError, OSError, ConnectionResetError) as e: + logger.info(f"Video connection closed, stopping worker: {type(e).__name__}") + return + if self.realtime: + time.sleep(max(0, packet_secs - (time.time() - t0))) + + fail_time = 0 + except (BrokenPipeError, OSError, ConnectionResetError): + logger.info("Video connection closed during queue processing") + break + except Exception: + logger.error(f"Send video data error: {traceback.format_exc()}") + fail_time += 1 + if fail_time > max_fail_time: + logger.error(f"Video push worker thread failed {fail_time} times, stopping...") + break + except Exception: + logger.error(f"Video push worker thread error: {traceback.format_exc()}") + finally: + logger.info("Video push worker thread stopped") + + def start_ffmpeg_process_local(self): + """Start ffmpeg process that connects to our TCP sockets""" + ffmpeg_cmd = [ + "ffmpeg", + "-fflags", + "nobuffer", + "-analyzeduration", + "0", + "-probesize", + "32", + "-flush_packets", + "1", + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", + "-color_range", + "pc", + "-colorspace", + "rgb", + "-color_primaries", + "bt709", + "-color_trc", + "iec61966-2-1", + "-r", + str(self.fps), + "-s", + f"{self.width}x{self.height}", + "-i", + f"tcp://127.0.0.1:{self.video_port}", + "-b:v", + "4M", + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-tune", + "zerolatency", + "-g", + f"{self.fps}", + "-pix_fmt", + "yuv420p", + "-f", + "mp4", + self.livestream_url, + "-y", + "-loglevel", + self.ffmpeg_log_level, + ] + try: + self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd) + logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}") + logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}") + except Exception as e: + logger.error(f"Failed to start FFmpeg: {e}") + + def start_ffmpeg_process_rtmp(self): + """Start ffmpeg process that connects to our TCP sockets""" + ffmpeg_cmd = [ + "ffmpeg", + "-f", + "rawvideo", + "-re", + "-pix_fmt", + "rgb24", + "-r", + str(self.fps), + "-s", + f"{self.width}x{self.height}", + "-i", + f"tcp://127.0.0.1:{self.video_port}", + "-b:v", + "2M", + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-tune", + "zerolatency", + "-g", + f"{self.fps}", + "-pix_fmt", + "yuv420p", + "-f", + "flv", + self.livestream_url, + "-y", + "-loglevel", + self.ffmpeg_log_level, + ] + try: + self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd) + logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}") + logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}") + except Exception as e: + logger.error(f"Failed to start FFmpeg: {e}") + + def start_ffmpeg_process_whip(self): + """Start ffmpeg process that connects to our TCP sockets""" + ffmpeg_cmd = [ + "ffmpeg", + "-re", + "-fflags", + "nobuffer", + "-analyzeduration", + "0", + "-probesize", + "32", + "-flush_packets", + "1", + "-f", + "rawvideo", + "-re", + "-pix_fmt", + "rgb24", + "-r", + str(self.fps), + "-s", + f"{self.width}x{self.height}", + "-i", + f"tcp://127.0.0.1:{self.video_port}", + "-b:v", + "2M", + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-tune", + "zerolatency", + "-g", + f"{self.fps}", + "-pix_fmt", + "yuv420p", + "-threads", + "1", + "-bf", + "0", + "-f", + "whip", + self.livestream_url, + "-y", + "-loglevel", + self.ffmpeg_log_level, + ] + try: + self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd) + logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}") + logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}") + except Exception as e: + logger.error(f"Failed to start FFmpeg: {e}") + + def start(self, width: int, height: int): + self.set_video_size(width, height) + duration = 1.0 + self.pub_video(torch.zeros((int(self.fps * duration), height, width, 3), dtype=torch.float16)) + time.sleep(duration) + + def set_video_size(self, width: int, height: int): + if self.width is not None and self.height is not None: + assert self.width == width and self.height == height, "Video size already set" + return + self.width = width + self.height = height + self.init_sockets() + if self.livestream_url.startswith("rtmp://"): + self.start_ffmpeg_process_rtmp() + elif self.livestream_url.startswith("http"): + self.start_ffmpeg_process_whip() + else: + self.start_ffmpeg_process_local() + self.realtime = False + self.video_thread = threading.Thread(target=self.video_worker) + self.video_thread.start() + + # Publish ComfyUI Image tensor to livestream + def pub_video(self, images: torch.Tensor): + N, height, width, C = images.shape + assert C == 3, "Input must be [N, H, W, C] with C=3" + + logger.info(f"Publishing video [{N}x{width}x{height}]") + + self.set_video_size(width, height) + self.video_queue.put(images) + logger.info(f"Published {N} frames") + + self.stoppable_t = time.time() + N / self.fps + 3 + + def stop(self, wait=True): + if wait and self.stoppable_t: + t = self.stoppable_t - time.time() + if t > 0: + logger.warning(f"Waiting for {t} seconds to stop ...") + time.sleep(t) + self.stoppable_t = None + + # Send stop signals to queues + if self.video_queue: + self.video_queue.put(None) + + # Wait for threads to finish processing queued data (increased timeout) + queue_timeout = 30 # Increased from 5s to 30s to allow sufficient time for large video frames + if self.video_thread and self.video_thread.is_alive(): + self.video_thread.join(timeout=queue_timeout) + if self.video_thread.is_alive(): + logger.error(f"Video push thread did not stop after {queue_timeout}s") + + # Shutdown connections to signal EOF to FFmpeg + # shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed + if self.video_conn: + try: + self.video_conn.getpeername() + self.video_conn.shutdown(socket.SHUT_WR) + logger.info("Video connection shutdown initiated") + except OSError: + # Connection already closed, skip shutdown + pass + + if self.ffmpeg_process: + is_local_file = not self.livestream_url.startswith(("rtmp://", "http")) + # Local MP4 files need time to write moov atom and finalize the container + timeout_seconds = 30 if is_local_file else 10 + logger.info(f"Waiting for FFmpeg to finalize file (timeout={timeout_seconds}s, local_file={is_local_file})") + logger.info(f"FFmpeg output: {self.livestream_url}") + + try: + returncode = self.ffmpeg_process.wait(timeout=timeout_seconds) + if returncode == 0: + logger.info(f"FFmpeg process exited successfully (exit code: {returncode})") + else: + logger.warning(f"FFmpeg process exited with non-zero code: {returncode}") + except subprocess.TimeoutExpired: + logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...") + try: + self.ffmpeg_process.terminate() # SIGTERM + returncode = self.ffmpeg_process.wait(timeout=5) + logger.warning(f"FFmpeg process terminated with SIGTERM (exit code: {returncode})") + except subprocess.TimeoutExpired: + logger.error("FFmpeg process still running after SIGTERM, killing with SIGKILL...") + self.ffmpeg_process.kill() + self.ffmpeg_process.wait() # Wait for kill to complete + logger.error("FFmpeg process killed with SIGKILL") + finally: + self.ffmpeg_process = None + + if self.video_conn: + try: + self.video_conn.close() + except Exception as e: + logger.debug(f"Error closing video connection: {e}") + finally: + self.video_conn = None + + if self.video_socket: + try: + self.video_socket.close() + except Exception as e: + logger.debug(f"Error closing video socket: {e}") + finally: + self.video_socket = None + + if self.video_queue: + while self.video_queue.qsize() > 0: + try: + self.video_queue.get_nowait() + except: # noqa + break + self.video_queue = None + logger.info("VideoRecorder stopped and resources cleaned up") + + def __del__(self): + self.stop(wait=False) + + +def create_simple_video(frames=10, height=480, width=640): + video_data = [] + for i in range(frames): + frame = np.zeros((height, width, 3), dtype=np.float32) + stripe_height = height // 8 + colors = [ + [1.0, 0.0, 0.0], # 红色 + [0.0, 1.0, 0.0], # 绿色 + [0.0, 0.0, 1.0], # 蓝色 + [1.0, 1.0, 0.0], # 黄色 + [1.0, 0.0, 1.0], # 洋红 + [0.0, 1.0, 1.0], # 青色 + [1.0, 1.0, 1.0], # 白色 + [0.5, 0.5, 0.5], # 灰色 + ] + for j, color in enumerate(colors): + start_y = j * stripe_height + end_y = min((j + 1) * stripe_height, height) + frame[start_y:end_y, :] = color + offset = int((i / frames) * width) + frame = np.roll(frame, offset, axis=1) + frame = torch.tensor(frame, dtype=torch.float32) + video_data.append(frame) + return torch.stack(video_data, dim=0) + + +if __name__ == "__main__": + fps = 16 + width = 640 + height = 480 + + recorder = VideoRecorder( + # livestream_url="rtmp://localhost/live/test", + # livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000", + livestream_url="/path/to/output_video.mp4", + fps=fps, + ) + + secs = 10 # 10秒视频 + interval = 1 + + for i in range(0, secs, interval): + logger.info(f"{i} / {secs} s") + + num_frames = int(interval * fps) + images = create_simple_video(num_frames, height, width) + logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}") + + recorder.pub_video(images) + time.sleep(interval) + recorder.stop() diff --git a/lightx2v/deploy/common/volcengine_tts.py b/lightx2v/deploy/common/volcengine_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e60d3670215973bb3e1506baff1b40e26d0975 --- /dev/null +++ b/lightx2v/deploy/common/volcengine_tts.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- + +import asyncio +import base64 +import json +import os +import sys + +import aiohttp +from loguru import logger + + +class VolcEngineTTSClient: + """ + VolcEngine TTS客户端 + + 参数范围说明: + - speech_rate: -50~100 (100代表2倍速, -50代表0.5倍速, 0为正常语速) + - loudness_rate: -50~100 (100代表2倍音量, -50代表0.5倍音量, 0为正常音量) + - emotion_scale: 1-5 + """ + + def __init__(self, voices_list_file=None): + self.url = "https://openspeech.bytedance.com/api/v3/tts/unidirectional" + self.appid = os.getenv("VOLCENGINE_TTS_APPID") + self.access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN") + self.proxy = os.getenv("HTTPS_PROXY", None) + if self.proxy: + logger.info(f"volcengine tts use proxy: {self.proxy}") + if voices_list_file is not None: + with open(voices_list_file, "r", encoding="utf-8") as f: + self.voices_list = json.load(f) + else: + self.voices_list = None + + def get_voice_list(self): + return self.voices_list + + async def tts_http_stream(self, headers, params, audio_save_path): + """执行TTS流式请求""" + try: + logger.info(f"volcengine tts params: {params}") + audio_data = bytearray() + total_audio_size = 0 + + async with aiohttp.ClientSession() as session: + async with session.post(self.url, json=params, headers=headers, proxy=self.proxy) as response: + response.raise_for_status() + async for chunk in response.content: + if not chunk: + continue + try: + data = json.loads(chunk.decode("utf-8").strip()) + if data.get("code", 0) == 0 and "data" in data and data["data"]: + chunk_audio = base64.b64decode(data["data"]) + audio_size = len(chunk_audio) + total_audio_size += audio_size + audio_data.extend(chunk_audio) + continue + if data.get("code", 0) == 0 and "sentence" in data and data["sentence"]: + continue + if data.get("code", 0) == 20000000: + break + if data.get("code", 0) > 0: + logger.warning(f"volcengine tts error response: {data}") + break + except Exception as e: + logger.warning(f"Failed to parse volcengine tts chunk: {e}") + + # save audio file + if audio_data: + with open(audio_save_path, "wb") as f: + f.write(audio_data) + logger.info(f"audio saved to {audio_save_path}, audio size: {len(audio_data) / 1024:.2f} KB") + # set correct permissions + os.chmod(audio_save_path, 0o644) + return True + else: + logger.warning("No tts audio data received") + return False + + except Exception as e: + logger.warning(f"VolcEngineTTSClient tts request failed: {e}") + return False + + async def tts_request( + self, + text, + voice_type="zh_female_vv_uranus_bigtts", + context_texts="", + emotion="", + emotion_scale=4, + speech_rate=0, + loudness_rate=0, + pitch=0, + output="tts_output.mp3", + resource_id="seed-tts-2.0", + app_key="aGjiRDfUWi", + uid="123123", + format="mp3", + sample_rate=24000, + enable_timestamp=True, + ): + """ + 执行TTS请求 + + Args: + text: 要转换的文本 + voice_type: 声音类型 + emotion: 情感类型 + emotion_scale: 情感强度 (1-5) + speech_rate: 语速调节 (-50~100, 100代表2倍速, -50代表0.5倍速, 0为正常语速) + loudness_rate: 音量调节 (-50~100, 100代表2倍音量, -50代表0.5倍音量, 0为正常音量) + pitch: 音调调节 (-12~12, 12代表高音调, -12代表低音调, 0为正常音调) + output: 输出文件路径 + resource_id: 资源ID + app_key: 应用密钥 + uid: 用户ID + format: 音频格式 + sample_rate: 采样率 + enable_timestamp: 是否启用时间戳 + """ + # 验证参数范围 + if not (-50 <= speech_rate <= 100): + logger.warning(f"speech_rate {speech_rate} 超出有效范围 [-50, 100],将使用默认值 0") + speech_rate = 0 + + if not (-50 <= loudness_rate <= 100): + logger.warning(f"loudness_rate {loudness_rate} 超出有效范围 [-50, 100],将使用默认值 0") + loudness_rate = 0 + + if not (1 <= emotion_scale <= 5): + logger.warning(f"emotion_scale {emotion_scale} 超出有效范围 [1, 5],将使用默认值 3") + emotion_scale = 3 + + if not (-12 <= pitch <= 12): + logger.warning(f"pitch {pitch} 超出有效范围 [-12, 12],将使用默认值 0") + pitch = 0 + + headers = { + "X-Api-App-Id": self.appid, + "X-Api-Access-Key": self.access_token, + "X-Api-Resource-Id": resource_id, + "X-Api-App-Key": app_key, + "Content-Type": "application/json", + "Connection": "keep-alive", + } + additions = json.dumps( + {"explicit_language": "zh", "disable_markdown_filter": True, "enable_timestamp": True, "context_texts": [context_texts] if context_texts else None, "post_process": {"pitch": pitch}} + ) + payload = { + "user": {"uid": uid}, + "req_params": { + "text": text, + "speaker": voice_type, + "audio_params": { + "format": format, + "sample_rate": sample_rate, + "enable_timestamp": enable_timestamp, + "emotion": emotion, + "emotion_scale": emotion_scale, + "speech_rate": speech_rate, + "loudness_rate": loudness_rate, + }, + "additions": additions, + }, + } + success = await self.tts_http_stream(headers=headers, params=payload, audio_save_path=output) + if success: + logger.info(f"VolcEngineTTSClient tts request for '{text}': success") + else: + logger.warning(f"VolcEngineTTSClient tts request for '{text}': failed") + return success + + +async def test(args): + """ + TTS测试函数 + + Args: + args: list, e.g. [text, voice_type, emotion, emotion_scale, speech_rate, loudness_rate, output, resource_id, app_key, uid, format, sample_rate, enable_timestamp] + Provide as many as needed, from left to right. + + Parameter ranges: + - speech_rate: -50~100 (100代表2倍速, -50代表0.5倍速, 0为正常语速) + - loudness_rate: -50~100 (100代表2倍音量, -50代表0.5倍音量, 0为正常音量) + - emotion_scale: 1-5 + - pitch: -12~12 (12代表高音调, -12代表低音调, 0为正常音调) + """ + client = VolcEngineTTSClient() + # 设置默认参数 + params = { + "text": "", + "voice_type": "zh_female_vv_uranus_bigtts", + "context_texts": "", + "emotion": "", + "emotion_scale": 4, + "speech_rate": 0, + "loudness_rate": 0, + "pitch": 12, + "output": "tts_output.mp3", + "resource_id": "seed-tts-2.0", + "app_key": "aGjiRDfUWi", + "uid": "123123", + "format": "mp3", + "sample_rate": 24000, + "enable_timestamp": True, + } + keys = list(params.keys()) + # 覆盖默认参数 + for i, arg in enumerate(args): + # 类型转换 + if keys[i] == "sample_rate": + params[keys[i]] = int(arg) + elif keys[i] == "enable_timestamp": + # 支持多种布尔输入 + params[keys[i]] = str(arg).lower() in ("1", "true", "yes", "on") + else: + params[keys[i]] = arg + + await client.tts_request( + params["text"], + params["voice_type"], + params["context_texts"], + params["emotion"], + params["emotion_scale"], + params["speech_rate"], + params["loudness_rate"], + params["pitch"], + params["output"], + params["resource_id"], + params["app_key"], + params["uid"], + params["format"], + params["sample_rate"], + params["enable_timestamp"], + ) + + +if __name__ == "__main__": + asyncio.run(test(sys.argv[1:])) diff --git a/lightx2v/deploy/data_manager/__init__.py b/lightx2v/deploy/data_manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14ee9d3388f24d9e763c02600972371273ab9933 --- /dev/null +++ b/lightx2v/deploy/data_manager/__init__.py @@ -0,0 +1,248 @@ +import io +import json +import os + +import torch +from PIL import Image + +from lightx2v.deploy.common.utils import class_try_catch_async + + +class BaseDataManager: + def __init__(self): + self.template_images_dir = None + self.template_audios_dir = None + self.template_videos_dir = None + self.template_tasks_dir = None + self.podcast_temp_session_dir = None + self.podcast_output_dir = None + + async def init(self): + pass + + async def close(self): + pass + + def fmt_path(self, base, filename, abs_path=None): + if abs_path: + return abs_path + else: + return os.path.join(base, filename) + + def to_device(self, data, device): + if isinstance(data, dict): + return {key: self.to_device(value, device) for key, value in data.items()} + elif isinstance(data, list): + return [self.to_device(item, device) for item in data] + elif isinstance(data, torch.Tensor): + return data.to(device) + else: + return data + + async def save_bytes(self, bytes_data, filename, abs_path=None): + raise NotImplementedError + + async def load_bytes(self, filename, abs_path=None): + raise NotImplementedError + + async def delete_bytes(self, filename, abs_path=None): + raise NotImplementedError + + async def presign_url(self, filename, abs_path=None): + return None + + async def recurrent_save(self, data, prefix): + if isinstance(data, dict): + return {k: await self.recurrent_save(v, f"{prefix}-{k}") for k, v in data.items()} + elif isinstance(data, list): + return [await self.recurrent_save(v, f"{prefix}-{idx}") for idx, v in enumerate(data)] + elif isinstance(data, torch.Tensor): + save_path = prefix + ".pt" + await self.save_tensor(data, save_path) + return save_path + elif isinstance(data, Image.Image): + save_path = prefix + ".png" + await self.save_image(data, save_path) + return save_path + else: + return data + + async def recurrent_load(self, data, device, prefix): + if isinstance(data, dict): + return {k: await self.recurrent_load(v, device, f"{prefix}-{k}") for k, v in data.items()} + elif isinstance(data, list): + return [await self.recurrent_load(v, device, f"{prefix}-{idx}") for idx, v in enumerate(data)] + elif isinstance(data, str) and data == prefix + ".pt": + return await self.load_tensor(data, device) + elif isinstance(data, str) and data == prefix + ".png": + return await self.load_image(data) + else: + return data + + async def recurrent_delete(self, data, prefix): + if isinstance(data, dict): + return {k: await self.recurrent_delete(v, f"{prefix}-{k}") for k, v in data.items()} + elif isinstance(data, list): + return [await self.recurrent_delete(v, f"{prefix}-{idx}") for idx, v in enumerate(data)] + elif isinstance(data, str) and data == prefix + ".pt": + await self.delete_bytes(data) + elif isinstance(data, str) and data == prefix + ".png": + await self.delete_bytes(data) + + @class_try_catch_async + async def save_object(self, data, filename): + data = await self.recurrent_save(data, filename) + bytes_data = json.dumps(data, ensure_ascii=False).encode("utf-8") + await self.save_bytes(bytes_data, filename) + + @class_try_catch_async + async def load_object(self, filename, device): + bytes_data = await self.load_bytes(filename) + data = json.loads(bytes_data.decode("utf-8")) + data = await self.recurrent_load(data, device, filename) + return data + + @class_try_catch_async + async def delete_object(self, filename): + bytes_data = await self.load_bytes(filename) + data = json.loads(bytes_data.decode("utf-8")) + await self.recurrent_delete(data, filename) + await self.delete_bytes(filename) + + @class_try_catch_async + async def save_tensor(self, data: torch.Tensor, filename): + buffer = io.BytesIO() + torch.save(data.to("cpu"), buffer) + await self.save_bytes(buffer.getvalue(), filename) + + @class_try_catch_async + async def load_tensor(self, filename, device): + bytes_data = await self.load_bytes(filename) + buffer = io.BytesIO(bytes_data) + t = torch.load(io.BytesIO(bytes_data)) + t = t.to(device) + return t + + @class_try_catch_async + async def save_image(self, data: Image.Image, filename): + buffer = io.BytesIO() + data.save(buffer, format="PNG") + await self.save_bytes(buffer.getvalue(), filename) + + @class_try_catch_async + async def load_image(self, filename): + bytes_data = await self.load_bytes(filename) + buffer = io.BytesIO(bytes_data) + img = Image.open(buffer).convert("RGB") + return img + + def get_delete_func(self, type): + maps = { + "TENSOR": self.delete_bytes, + "IMAGE": self.delete_bytes, + "OBJECT": self.delete_object, + "VIDEO": self.delete_bytes, + } + return maps[type] + + def get_template_dir(self, template_type): + if template_type == "audios": + return self.template_audios_dir + elif template_type == "images": + return self.template_images_dir + elif template_type == "videos": + return self.template_videos_dir + elif template_type == "tasks": + return self.template_tasks_dir + else: + raise ValueError(f"Invalid template type: {template_type}") + + @class_try_catch_async + async def list_template_files(self, template_type): + template_dir = self.get_template_dir(template_type) + if template_dir is None: + return [] + return await self.list_files(base_dir=template_dir) + + @class_try_catch_async + async def load_template_file(self, template_type, filename): + template_dir = self.get_template_dir(template_type) + if template_dir is None: + return None + return await self.load_bytes(None, abs_path=os.path.join(template_dir, filename)) + + @class_try_catch_async + async def template_file_exists(self, template_type, filename): + template_dir = self.get_template_dir(template_type) + if template_dir is None: + return None + return await self.file_exists(None, abs_path=os.path.join(template_dir, filename)) + + @class_try_catch_async + async def delete_template_file(self, template_type, filename): + template_dir = self.get_template_dir(template_type) + if template_dir is None: + return None + return await self.delete_bytes(None, abs_path=os.path.join(template_dir, filename)) + + @class_try_catch_async + async def save_template_file(self, template_type, filename, bytes_data): + template_dir = self.get_template_dir(template_type) + if template_dir is None: + return None + abs_path = os.path.join(template_dir, filename) + return await self.save_bytes(bytes_data, None, abs_path=abs_path) + + @class_try_catch_async + async def presign_template_url(self, template_type, filename): + template_dir = self.get_template_dir(template_type) + if template_dir is None: + return None + return await self.presign_url(None, abs_path=os.path.join(template_dir, filename)) + + @class_try_catch_async + async def list_podcast_temp_session_files(self, session_id): + session_dir = os.path.join(self.podcast_temp_session_dir, session_id) + return await self.list_files(base_dir=session_dir) + + @class_try_catch_async + async def save_podcast_temp_session_file(self, session_id, filename, bytes_data): + fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename) + await self.save_bytes(bytes_data, None, abs_path=fpath) + + @class_try_catch_async + async def load_podcast_temp_session_file(self, session_id, filename): + fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename) + return await self.load_bytes(None, abs_path=fpath) + + @class_try_catch_async + async def delete_podcast_temp_session_file(self, session_id, filename): + fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename) + return await self.delete_bytes(None, abs_path=fpath) + + @class_try_catch_async + async def save_podcast_output_file(self, filename, bytes_data): + fpath = os.path.join(self.podcast_output_dir, filename) + await self.save_bytes(bytes_data, None, abs_path=fpath) + + @class_try_catch_async + async def load_podcast_output_file(self, filename): + fpath = os.path.join(self.podcast_output_dir, filename) + return await self.load_bytes(None, abs_path=fpath) + + @class_try_catch_async + async def delete_podcast_output_file(self, filename): + fpath = os.path.join(self.podcast_output_dir, filename) + return await self.delete_bytes(None, abs_path=fpath) + + @class_try_catch_async + async def presign_podcast_output_url(self, filename): + fpath = os.path.join(self.podcast_output_dir, filename) + return await self.presign_url(None, abs_path=fpath) + + +# Import data manager implementations +from .local_data_manager import LocalDataManager # noqa +from .s3_data_manager import S3DataManager # noqa + +__all__ = ["BaseDataManager", "LocalDataManager", "S3DataManager"] diff --git a/lightx2v/deploy/data_manager/local_data_manager.py b/lightx2v/deploy/data_manager/local_data_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..6e11e6cbadef5b7eccd6b9e65de657798b7d2924 --- /dev/null +++ b/lightx2v/deploy/data_manager/local_data_manager.py @@ -0,0 +1,120 @@ +import asyncio +import os +import shutil + +from loguru import logger + +from lightx2v.deploy.common.utils import class_try_catch_async +from lightx2v.deploy.data_manager import BaseDataManager + + +class LocalDataManager(BaseDataManager): + def __init__(self, local_dir, template_dir): + super().__init__() + self.local_dir = local_dir + self.name = "local" + if not os.path.exists(self.local_dir): + os.makedirs(self.local_dir) + if template_dir: + self.template_images_dir = os.path.join(template_dir, "images") + self.template_audios_dir = os.path.join(template_dir, "audios") + self.template_videos_dir = os.path.join(template_dir, "videos") + self.template_tasks_dir = os.path.join(template_dir, "tasks") + assert os.path.exists(self.template_images_dir), f"{self.template_images_dir} not exists!" + assert os.path.exists(self.template_audios_dir), f"{self.template_audios_dir} not exists!" + assert os.path.exists(self.template_videos_dir), f"{self.template_videos_dir} not exists!" + assert os.path.exists(self.template_tasks_dir), f"{self.template_tasks_dir} not exists!" + + # podcast temp session dir and output dir + self.podcast_temp_session_dir = os.path.join(self.local_dir, "podcast_temp_session") + self.podcast_output_dir = os.path.join(self.local_dir, "podcast_output") + os.makedirs(self.podcast_temp_session_dir, exist_ok=True) + os.makedirs(self.podcast_output_dir, exist_ok=True) + + @class_try_catch_async + async def save_bytes(self, bytes_data, filename, abs_path=None): + out_path = self.fmt_path(self.local_dir, filename, abs_path) + with open(out_path, "wb") as fout: + fout.write(bytes_data) + return True + + @class_try_catch_async + async def load_bytes(self, filename, abs_path=None): + inp_path = self.fmt_path(self.local_dir, filename, abs_path) + with open(inp_path, "rb") as fin: + return fin.read() + + @class_try_catch_async + async def delete_bytes(self, filename, abs_path=None): + inp_path = self.fmt_path(self.local_dir, filename, abs_path) + os.remove(inp_path) + logger.info(f"deleted local file {filename}") + return True + + @class_try_catch_async + async def file_exists(self, filename, abs_path=None): + filename = self.fmt_path(self.local_dir, filename, abs_path) + return os.path.exists(filename) + + @class_try_catch_async + async def list_files(self, base_dir=None): + prefix = base_dir if base_dir else self.local_dir + return os.listdir(prefix) + + @class_try_catch_async + async def create_podcast_temp_session_dir(self, session_id): + dir_path = os.path.join(self.podcast_temp_session_dir, session_id) + os.makedirs(dir_path, exist_ok=True) + return dir_path + + @class_try_catch_async + async def clear_podcast_temp_session_dir(self, session_id): + session_dir = os.path.join(self.podcast_temp_session_dir, session_id) + if os.path.isdir(session_dir): + shutil.rmtree(session_dir) + logger.info(f"cleared podcast temp session dir {session_dir}") + return True + + +async def test(): + import torch + from PIL import Image + + m = LocalDataManager("/data/nvme1/liuliang1/lightx2v/local_data", None) + await m.init() + + img = Image.open("/data/nvme1/liuliang1/lightx2v/assets/img_lightx2v.png") + tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0") + + await m.save_image(img, "test_img.png") + print(await m.load_image("test_img.png")) + + await m.save_tensor(tensor, "test_tensor.pt") + print(await m.load_tensor("test_tensor.pt", "cuda:0")) + + await m.save_object( + { + "images": [img, img], + "tensor": tensor, + "list": [ + [2, 0, 5, 5], + { + "1": "hello world", + "2": "world", + "3": img, + "t": tensor, + }, + "0609", + ], + }, + "test_object.json", + ) + print(await m.load_object("test_object.json", "cuda:0")) + + await m.get_delete_func("OBJECT")("test_object.json") + await m.get_delete_func("TENSOR")("test_tensor.pt") + await m.get_delete_func("IMAGE")("test_img.png") + + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/lightx2v/deploy/data_manager/s3_data_manager.py b/lightx2v/deploy/data_manager/s3_data_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0d928fc197bb5d5bb871e7e4199d8ebd31bd8fc9 --- /dev/null +++ b/lightx2v/deploy/data_manager/s3_data_manager.py @@ -0,0 +1,254 @@ +import asyncio +import hashlib +import json +import os + +import aioboto3 +import tos +from botocore.client import Config +from loguru import logger + +from lightx2v.deploy.common.utils import class_try_catch_async +from lightx2v.deploy.data_manager import BaseDataManager + + +class S3DataManager(BaseDataManager): + def __init__(self, config_string, template_dir, max_retries=3): + super().__init__() + self.name = "s3" + self.config = json.loads(config_string) + self.max_retries = max_retries + self.bucket_name = self.config["bucket_name"] + self.aws_access_key_id = self.config["aws_access_key_id"] + self.aws_secret_access_key = self.config["aws_secret_access_key"] + self.endpoint_url = self.config["endpoint_url"] + self.base_path = self.config["base_path"] + self.connect_timeout = self.config.get("connect_timeout", 60) + self.read_timeout = self.config.get("read_timeout", 60) + self.write_timeout = self.config.get("write_timeout", 10) + self.addressing_style = self.config.get("addressing_style", None) + self.region = self.config.get("region", None) + self.cdn_url = self.config.get("cdn_url", "") + self.session = None + self.s3_client = None + self.presign_client = None + if template_dir: + self.template_images_dir = os.path.join(template_dir, "images") + self.template_audios_dir = os.path.join(template_dir, "audios") + self.template_videos_dir = os.path.join(template_dir, "videos") + self.template_tasks_dir = os.path.join(template_dir, "tasks") + + # podcast temp session dir and output dir + self.podcast_temp_session_dir = os.path.join(self.base_path, "podcast_temp_session") + self.podcast_output_dir = os.path.join(self.base_path, "podcast_output") + + async def init_presign_client(self): + # init tos client for volces.com + if "volces.com" in self.endpoint_url: + self.presign_client = tos.TosClientV2( + self.aws_access_key_id, + self.aws_secret_access_key, + self.endpoint_url.replace("tos-s3-", "tos-"), + self.region, + ) + + async def init(self): + for i in range(self.max_retries): + try: + logger.info(f"S3DataManager init with config: {self.config} (attempt {i + 1}/{self.max_retries}) ...") + s3_config = {"payload_signing_enabled": True} + if self.addressing_style: + s3_config["addressing_style"] = self.addressing_style + self.session = aioboto3.Session() + self.s3_client = await self.session.client( + "s3", + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + endpoint_url=self.endpoint_url, + config=Config( + signature_version="s3v4", + s3=s3_config, + connect_timeout=self.connect_timeout, + read_timeout=self.read_timeout, + parameter_validation=False, + max_pool_connections=50, + ), + ).__aenter__() + + try: + await self.s3_client.head_bucket(Bucket=self.bucket_name) + logger.info(f"check bucket {self.bucket_name} success") + except Exception as e: + logger.info(f"check bucket {self.bucket_name} error: {e}, try to create it...") + await self.s3_client.create_bucket(Bucket=self.bucket_name) + + await self.init_presign_client() + logger.info(f"Successfully init S3 bucket: {self.bucket_name} with timeouts - connect: {self.connect_timeout}s, read: {self.read_timeout}s, write: {self.write_timeout}s") + return + except Exception as e: + logger.warning(f"Failed to connect to S3: {e}") + await asyncio.sleep(1) + + async def close(self): + if self.s3_client: + await self.s3_client.__aexit__(None, None, None) + if self.session: + self.session = None + + @class_try_catch_async + async def save_bytes(self, bytes_data, filename, abs_path=None): + filename = self.fmt_path(self.base_path, filename, abs_path) + content_sha256 = hashlib.sha256(bytes_data).hexdigest() + await self.s3_client.put_object( + Bucket=self.bucket_name, + Key=filename, + Body=bytes_data, + ChecksumSHA256=content_sha256, + ContentType="application/octet-stream", + ) + return True + + @class_try_catch_async + async def load_bytes(self, filename, abs_path=None): + filename = self.fmt_path(self.base_path, filename, abs_path) + response = await self.s3_client.get_object(Bucket=self.bucket_name, Key=filename) + return await response["Body"].read() + + @class_try_catch_async + async def delete_bytes(self, filename, abs_path=None): + filename = self.fmt_path(self.base_path, filename, abs_path) + await self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename) + logger.info(f"deleted s3 file {filename}") + return True + + @class_try_catch_async + async def file_exists(self, filename, abs_path=None): + filename = self.fmt_path(self.base_path, filename, abs_path) + try: + await self.s3_client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except Exception: + return False + + @class_try_catch_async + async def list_files(self, base_dir=None): + if base_dir: + prefix = self.fmt_path(self.base_path, None, abs_path=base_dir) + else: + prefix = self.base_path + prefix = prefix + "/" if not prefix.endswith("/") else prefix + + # Handle pagination for S3 list_objects_v2 + files = [] + continuation_token = None + page = 1 + + while True: + list_kwargs = {"Bucket": self.bucket_name, "Prefix": prefix, "MaxKeys": 1000} + if continuation_token: + list_kwargs["ContinuationToken"] = continuation_token + response = await self.s3_client.list_objects_v2(**list_kwargs) + + if "Contents" in response: + page_files = [] + for obj in response["Contents"]: + # Remove the prefix from the key to get just the filename + key = obj["Key"] + if key.startswith(prefix): + filename = key[len(prefix) :] + if filename: # Skip empty filenames (the directory itself) + page_files.append(filename) + files.extend(page_files) + else: + logger.warning(f"[S3DataManager.list_files] Page {page}: No files found in this page.") + + # Check if there are more pages + if response.get("IsTruncated", False): + continuation_token = response.get("NextContinuationToken") + page += 1 + else: + break + return files + + @class_try_catch_async + async def presign_url(self, filename, abs_path=None): + filename = self.fmt_path(self.base_path, filename, abs_path) + if self.cdn_url: + return f"{self.cdn_url}/{filename}" + + if self.presign_client: + expires = self.config.get("presign_expires", 24 * 60 * 60) + out = await asyncio.to_thread(self.presign_client.pre_signed_url, tos.HttpMethodType.Http_Method_Get, self.bucket_name, filename, expires) + return out.signed_url + else: + return None + + @class_try_catch_async + async def create_podcast_temp_session_dir(self, session_id): + pass + + @class_try_catch_async + async def clear_podcast_temp_session_dir(self, session_id): + session_dir = os.path.join(self.podcast_temp_session_dir, session_id) + fs = await self.list_files(base_dir=session_dir) + logger.info(f"clear podcast temp session dir {session_dir} with files: {fs}") + for f in fs: + await self.delete_bytes(f, abs_path=os.path.join(session_dir, f)) + + +async def test(): + import torch + from PIL import Image + + s3_config = { + "aws_access_key_id": "xxx", + "aws_secret_access_key": "xx", + "endpoint_url": "xxx", + "bucket_name": "xxx", + "base_path": "xxx", + "connect_timeout": 10, + "read_timeout": 10, + "write_timeout": 10, + } + + m = S3DataManager(json.dumps(s3_config), None) + await m.init() + + img = Image.open("../../../assets/img_lightx2v.png") + tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0") + + await m.save_image(img, "test_img.png") + print(await m.load_image("test_img.png")) + + await m.save_tensor(tensor, "test_tensor.pt") + print(await m.load_tensor("test_tensor.pt", "cuda:0")) + + await m.save_object( + { + "images": [img, img], + "tensor": tensor, + "list": [ + [2, 0, 5, 5], + { + "1": "hello world", + "2": "world", + "3": img, + "t": tensor, + }, + "0609", + ], + }, + "test_object.json", + ) + print(await m.load_object("test_object.json", "cuda:0")) + + print("all files:", await m.list_files()) + await m.get_delete_func("OBJECT")("test_object.json") + await m.get_delete_func("TENSOR")("test_tensor.pt") + await m.get_delete_func("IMAGE")("test_img.png") + print("after delete all files", await m.list_files()) + await m.close() + + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/lightx2v/deploy/queue_manager/__init__.py b/lightx2v/deploy/queue_manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dee49b849bbf38c7584a797eb718f7bc0cdf6be1 --- /dev/null +++ b/lightx2v/deploy/queue_manager/__init__.py @@ -0,0 +1,25 @@ +class BaseQueueManager: + def __init__(self): + pass + + async def init(self): + pass + + async def close(self): + pass + + async def put_subtask(self, subtask): + raise NotImplementedError + + async def get_subtasks(self, queue, max_batch, timeout): + raise NotImplementedError + + async def pending_num(self, queue): + raise NotImplementedError + + +# Import queue manager implementations +from .local_queue_manager import LocalQueueManager # noqa +from .rabbitmq_queue_manager import RabbitMQQueueManager # noqa + +__all__ = ["BaseQueueManager", "LocalQueueManager", "RabbitMQQueueManager"] diff --git a/lightx2v/deploy/queue_manager/local_queue_manager.py b/lightx2v/deploy/queue_manager/local_queue_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8752953ecf2bd78dc6acc38e1027a15ee04b32 --- /dev/null +++ b/lightx2v/deploy/queue_manager/local_queue_manager.py @@ -0,0 +1,113 @@ +import asyncio +import json +import os +import time +import traceback + +from loguru import logger + +from lightx2v.deploy.common.utils import class_try_catch_async +from lightx2v.deploy.queue_manager import BaseQueueManager + + +class LocalQueueManager(BaseQueueManager): + def __init__(self, local_dir): + self.local_dir = local_dir + if not os.path.exists(self.local_dir): + os.makedirs(self.local_dir) + + async def get_conn(self): + pass + + async def del_conn(self): + pass + + async def declare_queue(self, queue): + pass + + @class_try_catch_async + async def put_subtask(self, subtask): + out_name = self.get_filename(subtask["queue"]) + keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"] + msg = json.dumps({k: subtask[k] for k in keys}) + "\n" + logger.info(f"Local published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {subtask['queue']}") + with open(out_name, "a") as fout: + fout.write(msg) + return True + + def read_first_line(self, queue): + out_name = self.get_filename(queue) + if not os.path.exists(out_name): + return None + lines = [] + with open(out_name) as fin: + lines = fin.readlines() + if len(lines) <= 0: + return None + subtask = json.loads(lines[0]) + msgs = "".join(lines[1:]) + fout = open(out_name, "w") + fout.write(msgs) + fout.close() + return subtask + + @class_try_catch_async + async def get_subtasks(self, queue, max_batch, timeout): + try: + t0 = time.time() + subtasks = [] + while True: + subtask = self.read_first_line(queue) + if subtask: + subtasks.append(subtask) + if len(subtasks) >= max_batch: + return subtasks + else: + continue + else: + if len(subtasks) > 0: + return subtasks + if time.time() - t0 > timeout: + return None + await asyncio.sleep(1) + except asyncio.CancelledError: + logger.warning(f"local queue get_subtasks for {queue} cancelled") + return None + except: # noqa + logger.warning(f"local queue get_subtasks for {queue} failed: {traceback.format_exc()}") + return None + + def get_filename(self, queue): + return os.path.join(self.local_dir, f"{queue}.jsonl") + + @class_try_catch_async + async def pending_num(self, queue): + out_name = self.get_filename(queue) + if not os.path.exists(out_name): + return 0 + lines = [] + with open(out_name) as fin: + lines = fin.readlines() + return len(lines) + + +async def test(): + q = LocalQueueManager("/data/nvme1/liuliang1/lightx2v/local_queue") + await q.init() + subtask = { + "task_id": "test-subtask-id", + "queue": "test_queue", + "worker_name": "test_worker", + "inputs": {}, + "outputs": {}, + "params": {}, + } + await q.put_subtask(subtask) + await asyncio.sleep(5) + for i in range(2): + subtask = await q.get_subtasks("test_queue", 3, 5) + print("get subtask:", subtask) + + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/lightx2v/deploy/queue_manager/rabbitmq_queue_manager.py b/lightx2v/deploy/queue_manager/rabbitmq_queue_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fe8c3f180b4cebe9bfb58381055c0d8bacd42372 --- /dev/null +++ b/lightx2v/deploy/queue_manager/rabbitmq_queue_manager.py @@ -0,0 +1,124 @@ +import asyncio +import json +import traceback + +import aio_pika +from loguru import logger + +from lightx2v.deploy.common.utils import class_try_catch_async +from lightx2v.deploy.queue_manager import BaseQueueManager + + +class RabbitMQQueueManager(BaseQueueManager): + def __init__(self, conn_url, max_retries=3): + self.conn_url = conn_url + self.max_retries = max_retries + self.conn = None + self.chan = None + self.queues = set() + + async def init(self): + await self.get_conn() + + async def close(self): + await self.del_conn() + + async def get_conn(self): + if self.chan and self.conn: + return + for i in range(self.max_retries): + try: + logger.info(f"Connect to RabbitMQ (attempt {i + 1}/{self.max_retries}..)") + self.conn = await aio_pika.connect_robust(self.conn_url) + self.chan = await self.conn.channel() + self.queues = set() + await self.chan.set_qos(prefetch_count=10) + logger.info("Successfully connected to RabbitMQ") + return + except Exception as e: + logger.warning(f"Failed to connect to RabbitMQ: {e}") + if i < self.max_retries - 1: + await asyncio.sleep(1) + else: + raise + + async def declare_queue(self, queue): + if queue not in self.queues: + await self.get_conn() + await self.chan.declare_queue(queue, durable=True) + self.queues.add(queue) + return await self.chan.get_queue(queue) + + @class_try_catch_async + async def put_subtask(self, subtask): + queue = subtask["queue"] + await self.declare_queue(queue) + keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"] + msg = json.dumps({k: subtask[k] for k in keys}).encode("utf-8") + message = aio_pika.Message(body=msg, delivery_mode=aio_pika.DeliveryMode.PERSISTENT, content_type="application/json") + await self.chan.default_exchange.publish(message, routing_key=queue) + logger.info(f"Rabbitmq published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {queue}") + return True + + async def get_subtasks(self, queue, max_batch, timeout): + try: + q = await self.declare_queue(queue) + subtasks = [] + async with q.iterator() as qiter: + async for message in qiter: + await message.ack() + subtask = json.loads(message.body.decode("utf-8")) + subtasks.append(subtask) + if len(subtasks) >= max_batch: + return subtasks + while True: + message = await q.get(no_ack=False, fail=False) + if message: + await message.ack() + subtask = json.loads(message.body.decode("utf-8")) + subtasks.append(subtask) + if len(subtasks) >= max_batch: + return subtasks + else: + return subtasks + except asyncio.CancelledError: + logger.warning(f"rabbitmq get_subtasks for {queue} cancelled") + return None + except: # noqa + logger.warning(f"rabbitmq get_subtasks for {queue} failed: {traceback.format_exc()}") + return None + + @class_try_catch_async + async def pending_num(self, queue): + q = await self.declare_queue(queue) + return q.declaration_result.message_count + + async def del_conn(self): + if self.chan: + await self.chan.close() + if self.conn: + await self.conn.close() + + +async def test(): + conn_url = "amqp://username:password@127.0.0.1:5672" + q = RabbitMQQueueManager(conn_url) + await q.init() + subtask = { + "task_id": "test-subtask-id", + "queue": "test_queue", + "worker_name": "test_worker", + "inputs": {}, + "outputs": {}, + "params": {}, + } + await q.put_subtask(subtask) + await asyncio.sleep(5) + for i in range(2): + subtask = await q.get_subtasks("test_queue", 3, 5) + print("get subtask:", subtask) + await q.close() + + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/lightx2v/deploy/server/__init__.py b/lightx2v/deploy/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/deploy/server/__main__.py b/lightx2v/deploy/server/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..348fe074b58ff0454e8b2bc6e18e074ff9a82d87 --- /dev/null +++ b/lightx2v/deploy/server/__main__.py @@ -0,0 +1,1490 @@ +import argparse +import asyncio +import base64 +import copy +import json +import mimetypes +import os +import re +import tempfile +import traceback +import uuid +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import Depends, FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, Response +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.staticfiles import StaticFiles +from loguru import logger +from pydantic import BaseModel + +from lightx2v.deploy.common.audio_separator import AudioSeparator +from lightx2v.deploy.common.face_detector import FaceDetector +from lightx2v.deploy.common.pipeline import Pipeline +from lightx2v.deploy.common.podcasts import VolcEnginePodcastClient +from lightx2v.deploy.common.utils import check_params, data_name, fetch_resource, format_image_data, load_inputs +from lightx2v.deploy.common.volcengine_tts import VolcEngineTTSClient +from lightx2v.deploy.data_manager import LocalDataManager, S3DataManager +from lightx2v.deploy.queue_manager import LocalQueueManager, RabbitMQQueueManager +from lightx2v.deploy.server.auth import AuthManager +from lightx2v.deploy.server.metrics import MetricMonitor +from lightx2v.deploy.server.monitor import ServerMonitor, WorkerStatus +from lightx2v.deploy.server.redis_monitor import RedisServerMonitor +from lightx2v.deploy.task_manager import FinishedStatus, LocalTaskManager, PostgresSQLTaskManager, TaskStatus +from lightx2v.utils.service_utils import ProcessManager + +# ========================= +# Pydantic Models +# ========================= + + +class TTSRequest(BaseModel): + text: str + voice_type: str + context_texts: str = "" + emotion: str = "" + emotion_scale: int = 3 + speech_rate: int = 0 + pitch: int = 0 + loudness_rate: int = 0 + resource_id: str = "seed-tts-1.0" + + +class RefreshTokenRequest(BaseModel): + refresh_token: str + + +# ========================= +# FastAPI Related Code +# ========================= + +model_pipelines = None +task_manager = None +data_manager = None +queue_manager = None +server_monitor = None +auth_manager = None +metrics_monitor = MetricMonitor() +volcengine_tts_client = None +volcengine_podcast_client = None +face_detector = None +audio_separator = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await task_manager.init() + await task_manager.mark_server_restart() + await data_manager.init() + await queue_manager.init() + await server_monitor.init() + asyncio.create_task(server_monitor.loop()) + yield + await server_monitor.close() + await queue_manager.close() + await data_manager.close() + await task_manager.close() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + logger.error(f"HTTP Exception: {exc.status_code} - {exc.detail} for {request.url}") + return JSONResponse(status_code=exc.status_code, content={"message": exc.detail}) + + +static_dir = os.path.join(os.path.dirname(__file__), "static") +app.mount("/static", StaticFiles(directory=static_dir), name="static") + +# 添加assets目录的静态文件服务 +assets_dir = os.path.join(os.path.dirname(__file__), "static", "assets") +app.mount("/assets", StaticFiles(directory=assets_dir), name="assets") +security = HTTPBearer() + + +async def verify_user_access(credentials: HTTPAuthorizationCredentials = Depends(security)): + token = credentials.credentials + payload = auth_manager.verify_jwt_token(token) + user_id = payload.get("user_id", None) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user") + user = await task_manager.query_user(user_id) + # logger.info(f"Verfiy user access: {payload}") + if user is None or user["user_id"] != user_id: + raise HTTPException(status_code=401, detail="Invalid user") + return user + + +async def verify_user_access_from_query(request: Request): + """从查询参数中验证用户访问权限""" + # 首先尝试从 Authorization 头部获取 token + auth_header = request.headers.get("Authorization") + token = None + + if auth_header and auth_header.startswith("Bearer "): + token = auth_header[7:] # 移除 "Bearer " 前缀 + else: + # 如果没有 Authorization 头部,尝试从查询参数获取 + token = request.query_params.get("token") + + payload = auth_manager.verify_jwt_token(token) + user_id = payload.get("user_id", None) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user") + user = await task_manager.query_user(user_id) + if user is None or user["user_id"] != user_id: + raise HTTPException(status_code=401, detail="Invalid user") + return user + + +async def verify_worker_access(credentials: HTTPAuthorizationCredentials = Depends(security)): + token = credentials.credentials + if not auth_manager.verify_worker_token(token): + raise HTTPException(status_code=403, detail="Invalid worker token") + return True + + +def error_response(e, code): + return JSONResponse({"message": f"error: {e}!"}, status_code=code) + + +def format_user_response(user): + return { + "user_id": user.get("user_id"), + "id": user.get("id"), + "source": user.get("source"), + "username": user.get("username") or "", + "email": user.get("email") or "", + "homepage": user.get("homepage") or "", + "avatar_url": user.get("avatar_url") or "", + } + + +def guess_file_type(name, default_type): + content_type, _ = mimetypes.guess_type(name) + if content_type is None: + content_type = default_type + return content_type + + +@app.get("/", response_class=HTMLResponse) +async def root(): + with open(os.path.join(static_dir, "index.html"), "r", encoding="utf-8") as f: + return HTMLResponse(content=f.read()) + + +@app.get("/sitemap.xml", response_class=HTMLResponse) +async def sitemap(): + with open(os.path.join(os.path.dirname(__file__), "frontend", "dist", "sitemap.xml"), "r", encoding="utf-8") as f: + return HTMLResponse(content=f.read()) + + +@app.get("/auth/login/github") +async def github_auth(request: Request): + client_id = auth_manager.github_client_id + redirect_uri = f"{request.base_url}" + auth_url = f"https://github.com/login/oauth/authorize?client_id={client_id}&redirect_uri={redirect_uri}" + return {"auth_url": auth_url} + + +@app.get("/auth/callback/github") +async def github_callback(request: Request): + try: + code = request.query_params.get("code") + if not code: + return error_response("Missing authorization code", 400) + user_info = await auth_manager.auth_github(code) + user_id = await task_manager.create_user(user_info) + user_info["user_id"] = user_id + user_response = format_user_response(user_info) + access_token, refresh_token = auth_manager.create_tokens(user_response) + logger.info(f"GitHub callback: user_info: {user_response}, access token issued") + return {"access_token": access_token, "refresh_token": refresh_token, "user_info": user_response} + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/auth/login/google") +async def google_auth(request: Request): + client_id = auth_manager.google_client_id + redirect_uri = auth_manager.google_redirect_uri + auth_url = f"https://accounts.google.com/o/oauth2/v2/auth?client_id={client_id}&redirect_uri={redirect_uri}&response_type=code&scope=openid%20email%20profile&access_type=offline" + logger.info(f"Google auth: auth_url: {auth_url}") + return {"auth_url": auth_url} + + +@app.get("/auth/callback/google") +async def google_callback(request: Request): + try: + code = request.query_params.get("code") + if not code: + return error_response("Missing authorization code", 400) + user_info = await auth_manager.auth_google(code) + user_id = await task_manager.create_user(user_info) + user_info["user_id"] = user_id + user_response = format_user_response(user_info) + access_token, refresh_token = auth_manager.create_tokens(user_response) + logger.info(f"Google callback: user_info: {user_response}, access token issued") + return {"access_token": access_token, "refresh_token": refresh_token, "user_info": user_response} + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/auth/login/sms") +async def sms_auth(request: Request): + try: + phone_number = request.query_params.get("phone_number") + if not phone_number: + return error_response("Missing phone number", 400) + ok = await auth_manager.send_sms(phone_number) + if not ok: + return error_response("SMS send failed", 400) + return {"msg": "SMS send successfully"} + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/auth/callback/sms") +async def sms_callback(request: Request): + try: + phone_number = request.query_params.get("phone_number") + verify_code = request.query_params.get("verify_code") + if not phone_number or not verify_code: + return error_response("Missing phone number or verify code", 400) + user_info = await auth_manager.check_sms(phone_number, verify_code) + if not user_info: + return error_response("SMS verify failed", 400) + + user_id = await task_manager.create_user(user_info) + user_info["user_id"] = user_id + user_response = format_user_response(user_info) + access_token, refresh_token = auth_manager.create_tokens(user_response) + logger.info(f"SMS callback: user_info: {user_response}, access token issued") + return {"access_token": access_token, "refresh_token": refresh_token, "user_info": user_response} + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.post("/auth/refresh") +async def refresh_access_token(request: RefreshTokenRequest): + try: + payload = auth_manager.verify_refresh_token(request.refresh_token) + user_id = payload.get("user_id") + if not user_id: + raise HTTPException(status_code=401, detail="Invalid refresh token") + user = await task_manager.query_user(user_id) + if user is None or user.get("user_id") != user_id: + raise HTTPException(status_code=401, detail="Invalid user") + user_info = format_user_response(user) + access_token, refresh_token = auth_manager.create_tokens(user_info) + return {"access_token": access_token, "refresh_token": refresh_token, "user_info": user_info} + except HTTPException as exc: + raise exc + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +async def prepare_subtasks(task_id): + # schedule next subtasks and pend, put to message queue + subtasks = await task_manager.next_subtasks(task_id) + for sub in subtasks: + logger.info(f"Prepare ready subtask: ({task_id}, {sub['worker_name']})") + r = await queue_manager.put_subtask(sub) + assert r, "put subtask to queue error" + await server_monitor.pending_subtasks_add(sub["queue"], sub["task_id"]) + + +def format_task(task): + task["status"] = task["status"].name + task["model_cls"] = model_pipelines.outer_model_name(task["model_cls"]) + + +@app.get("/api/v1/model/list") +async def api_v1_model_list(user=Depends(verify_user_access)): + try: + return {"models": model_pipelines.get_model_lists()} + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.post("/api/v1/task/submit") +async def api_v1_task_submit(request: Request, user=Depends(verify_user_access)): + task_id = None + try: + msg = await server_monitor.check_user_busy(user["user_id"], active_new_task=True) + if msg is not True: + return error_response(msg, 400) + params = await request.json() + keys = [params.pop("task"), params.pop("model_cls"), params.pop("stage")] + keys[1] = model_pipelines.inner_model_name(keys[1]) + assert len(params["prompt"]) > 0, "valid prompt is required" + + # get worker infos, model input names + workers = model_pipelines.get_workers(keys) + inputs = model_pipelines.get_inputs(keys) + outputs = model_pipelines.get_outputs(keys) + types = model_pipelines.get_types(keys) + check_params(params, inputs, outputs, types) + + # check if task can be published to queues + queues = [v["queue"] for v in workers.values()] + wait_time = await server_monitor.check_queue_busy(keys, queues) + if wait_time is None: + return error_response(f"Queue busy, please try again later", 500) + + # process multimodal inputs data + inputs_data = await load_inputs(params, inputs, types) + + # init task (we need task_id before preprocessing to save processed files) + task_id = await task_manager.create_task(keys, workers, params, inputs, outputs, user["user_id"]) + logger.info(f"Submit task: {task_id} {params}") + + # save multimodal inputs data + for inp, data in inputs_data.items(): + await data_manager.save_bytes(data, data_name(inp, task_id)) + + await prepare_subtasks(task_id) + return {"task_id": task_id, "workers": workers, "params": params, "wait_time": wait_time} + + except Exception as e: + traceback.print_exc() + if task_id: + await task_manager.finish_subtasks(task_id, TaskStatus.FAILED, fail_msg=f"submit failed: {e}") + return error_response(str(e), 500) + + +@app.get("/api/v1/task/query") +async def api_v1_task_query(request: Request, user=Depends(verify_user_access)): + try: + if "task_ids" in request.query_params: + task_ids = request.query_params["task_ids"].split(",") + tasks = [] + for task_id in task_ids: + task_id = task_id.strip() + if task_id: + task, subtasks = await task_manager.query_task(task_id, user["user_id"], only_task=False) + if task is not None: + task["subtasks"] = await server_monitor.format_subtask(subtasks) + format_task(task) + tasks.append(task) + return {"tasks": tasks} + + # 单个任务查询 + task_id = request.query_params["task_id"] + task, subtasks = await task_manager.query_task(task_id, user["user_id"], only_task=False) + if task is None: + return error_response(f"Task {task_id} not found", 404) + task["subtasks"] = await server_monitor.format_subtask(subtasks) + format_task(task) + return task + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/task/list") +async def api_v1_task_list(request: Request, user=Depends(verify_user_access)): + try: + user_id = user["user_id"] + page = int(request.query_params.get("page", 1)) + page_size = int(request.query_params.get("page_size", 10)) + assert page > 0 and page_size > 0, "page and page_size must be greater than 0" + status_filter = request.query_params.get("status", None) + + query_params = {"user_id": user_id} + if status_filter and status_filter != "ALL": + query_params["status"] = TaskStatus[status_filter.upper()] + + total_tasks = await task_manager.list_tasks(count=True, **query_params) + total_pages = (total_tasks + page_size - 1) // page_size + page_info = {"page": page, "page_size": page_size, "total": total_tasks, "total_pages": total_pages} + if page > total_pages: + return {"tasks": [], "pagination": page_info} + + query_params["offset"] = (page - 1) * page_size + query_params["limit"] = page_size + + tasks = await task_manager.list_tasks(**query_params) + for task in tasks: + format_task(task) + + return {"tasks": tasks, "pagination": page_info} + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/task/result_url") +async def api_v1_task_result_url(request: Request, user=Depends(verify_user_access)): + try: + name = request.query_params["name"] + task_id = request.query_params["task_id"] + task = await task_manager.query_task(task_id, user_id=user["user_id"]) + assert task is not None, f"Task {task_id} not found" + assert task["status"] == TaskStatus.SUCCEED, f"Task {task_id} not succeed" + assert name in task["outputs"], f"Output {name} not found in task {task_id}" + assert name not in task["params"], f"Output {name} is a stream" + + url = await data_manager.presign_url(task["outputs"][name]) + if url is None: + url = f"./assets/task/result?task_id={task_id}&name={name}" + return {"url": url} + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/task/input_url") +async def api_v1_task_input_url(request: Request, user=Depends(verify_user_access)): + try: + name = request.query_params["name"] + task_id = request.query_params["task_id"] + filename = request.query_params.get("filename", None) + + task = await task_manager.query_task(task_id, user_id=user["user_id"]) + assert task is not None, f"Task {task_id} not found" + assert name in task["inputs"], f"Input {name} not found in task {task_id}" + if name in task["params"]: + return error_response(f"Input {name} is a stream", 400) + + # eg, multi person audio directory input + if filename is not None: + extra_inputs = task["params"]["extra_inputs"][name] + name = f"{name}/{filename}" + assert name in task["inputs"], f"Extra input {name} not found in task {task_id}" + assert name in extra_inputs, f"Filename {filename} not found in extra inputs" + + url = await data_manager.presign_url(task["inputs"][name]) + if url is None: + url = f"./assets/task/input?task_id={task_id}&name={name}" + if filename is not None: + url += f"&filename={filename}" + return {"url": url} + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/assets/task/result") +async def assets_task_result(request: Request, user=Depends(verify_user_access_from_query)): + try: + name = request.query_params["name"] + task_id = request.query_params["task_id"] + task = await task_manager.query_task(task_id, user_id=user["user_id"]) + assert task is not None, f"Task {task_id} not found" + assert task["status"] == TaskStatus.SUCCEED, f"Task {task_id} not succeed" + assert name in task["outputs"], f"Output {name} not found in task {task_id}" + assert name not in task["params"], f"Output {name} is a stream" + data = await data_manager.load_bytes(task["outputs"][name]) + + # set correct Content-Type + content_type = guess_file_type(name, "application/octet-stream") + headers = {"Content-Disposition": f'attachment; filename="{name}"'} + headers["Cache-Control"] = "public, max-age=3600" + return Response(content=data, media_type=content_type, headers=headers) + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/assets/task/input") +async def assets_task_input(request: Request, user=Depends(verify_user_access_from_query)): + try: + name = request.query_params["name"] + task_id = request.query_params["task_id"] + filename = request.query_params.get("filename", None) + + task = await task_manager.query_task(task_id, user_id=user["user_id"]) + assert task is not None, f"Task {task_id} not found" + assert name in task["inputs"], f"Input {name} not found in task {task_id}" + if name in task["params"]: + return error_response(f"Input {name} is a stream", 400) + + # eg, multi person audio directory input + if filename is not None: + extra_inputs = task["params"]["extra_inputs"][name] + name = f"{name}/{filename}" + assert name in task["inputs"], f"Extra input {name} not found in task {task_id}" + assert name in extra_inputs, f"Filename {filename} not found in extra inputs" + data = await data_manager.load_bytes(task["inputs"][name]) + + # set correct Content-Type + content_type = guess_file_type(name, "application/octet-stream") + headers = {"Content-Disposition": f'attachment; filename="{name}"'} + headers["Cache-Control"] = "public, max-age=3600" + return Response(content=data, media_type=content_type, headers=headers) + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/task/cancel") +async def api_v1_task_cancel(request: Request, user=Depends(verify_user_access)): + try: + task_id = request.query_params["task_id"] + ret = await task_manager.cancel_task(task_id, user_id=user["user_id"]) + logger.warning(f"Task {task_id} cancelled: {ret}") + if ret is True: + return {"msg": "Task cancelled successfully"} + else: + return error_response({"error": f"Task {task_id} cancel failed: {ret}"}, 400) + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/task/resume") +async def api_v1_task_resume(request: Request, user=Depends(verify_user_access)): + try: + task_id = request.query_params["task_id"] + + task = await task_manager.query_task(task_id, user_id=user["user_id"]) + keys = [task["task_type"], task["model_cls"], task["stage"]] + if not model_pipelines.check_item_by_keys(keys): + return error_response(f"Model {keys} is not supported now, please submit a new task", 400) + + ret = await task_manager.resume_task(task_id, user_id=user["user_id"], all_subtask=False) + if ret is True: + await prepare_subtasks(task_id) + return {"msg": "ok"} + else: + return error_response(f"Task {task_id} resume failed: {ret}", 400) + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.delete("/api/v1/task/delete") +async def api_v1_task_delete(request: Request, user=Depends(verify_user_access)): + try: + task_id = request.query_params["task_id"] + + task = await task_manager.query_task(task_id, user["user_id"], only_task=True) + if not task: + return error_response("Task not found", 404) + + if task["status"] not in FinishedStatus: + return error_response("Only finished tasks can be deleted", 400) + + # delete task record + success = await task_manager.delete_task(task_id, user["user_id"]) + if success: + logger.info(f"Task {task_id} deleted by user {user['user_id']}") + return JSONResponse({"message": "Task deleted successfully"}) + else: + return error_response("Failed to delete task", 400) + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.post("/api/v1/worker/fetch") +async def api_v1_worker_fetch(request: Request, valid=Depends(verify_worker_access)): + try: + params = await request.json() + logger.info(f"Worker fetching: {params}") + keys = params.pop("worker_keys") + identity = params.pop("worker_identity") + max_batch = params.get("max_batch", 1) + timeout = params.get("timeout", 5) + + # check client disconnected + async def check_client(request, fetch_task, identity, queue): + while True: + msg = await request.receive() + if msg["type"] == "http.disconnect": + logger.warning(f"Worker {identity} {queue} disconnected, req: {request.client}, {msg}") + fetch_task.cancel() + await server_monitor.worker_update(queue, identity, WorkerStatus.DISCONNECT) + return + await asyncio.sleep(1) + + # get worker info + worker = model_pipelines.get_worker(keys) + await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.FETCHING) + + fetch_task = asyncio.create_task(queue_manager.get_subtasks(worker["queue"], max_batch, timeout)) + check_task = asyncio.create_task(check_client(request, fetch_task, identity, worker["queue"])) + try: + subtasks = await asyncio.wait_for(fetch_task, timeout=timeout) + except asyncio.TimeoutError: + subtasks = [] + fetch_task.cancel() + check_task.cancel() + + subtasks = [] if subtasks is None else subtasks + for sub in subtasks: + await server_monitor.pending_subtasks_sub(sub["queue"], sub["task_id"]) + valid_subtasks = await task_manager.run_subtasks(subtasks, identity) + valids = [sub["task_id"] for sub in valid_subtasks] + + if len(valid_subtasks) > 0: + await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.FETCHED) + logger.info(f"Worker {identity} {keys} {request.client} fetched {valids}") + else: + await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.DISCONNECT) + return {"subtasks": valid_subtasks} + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.post("/api/v1/worker/report") +async def api_v1_worker_report(request: Request, valid=Depends(verify_worker_access)): + try: + params = await request.json() + logger.info(f"{params}") + task_id = params.pop("task_id") + worker_name = params.pop("worker_name") + status = TaskStatus[params.pop("status")] + identity = params.pop("worker_identity") + queue = params.pop("queue") + fail_msg = params.pop("fail_msg", None) + await server_monitor.worker_update(queue, identity, WorkerStatus.REPORT) + + ret = await task_manager.finish_subtasks(task_id, status, worker_identity=identity, worker_name=worker_name, fail_msg=fail_msg, should_running=True) + + # not all subtasks finished, prepare new ready subtasks + if ret not in [TaskStatus.SUCCEED, TaskStatus.FAILED]: + await prepare_subtasks(task_id) + + # all subtasks succeed, delete temp data + elif ret == TaskStatus.SUCCEED: + logger.info(f"Task {task_id} succeed") + task = await task_manager.query_task(task_id) + keys = [task["task_type"], task["model_cls"], task["stage"]] + temps = model_pipelines.get_temps(keys) + for temp in temps: + type = model_pipelines.get_type(temp) + name = data_name(temp, task_id) + await data_manager.get_delete_func(type)(name) + + elif ret == TaskStatus.FAILED: + logger.warning(f"Task {task_id} failed") + + return {"msg": "ok"} + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.post("/api/v1/worker/ping/subtask") +async def api_v1_worker_ping_subtask(request: Request, valid=Depends(verify_worker_access)): + try: + params = await request.json() + logger.info(f"{params}") + task_id = params.pop("task_id") + worker_name = params.pop("worker_name") + identity = params.pop("worker_identity") + queue = params.pop("queue") + + task = await task_manager.query_task(task_id) + if task is None or task["status"] != TaskStatus.RUNNING: + return {"msg": "delete"} + + assert await task_manager.ping_subtask(task_id, worker_name, identity) + await server_monitor.worker_update(queue, identity, WorkerStatus.PING) + return {"msg": "ok"} + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/metrics") +async def api_v1_monitor_metrics(): + try: + return Response(content=metrics_monitor.get_metrics(), media_type="text/plain") + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/template/asset_url/{template_type}/{filename}") +async def api_v1_template_asset_url(template_type: str, filename: str): + """get template asset URL - no authentication required""" + try: + url = await data_manager.presign_template_url(template_type, filename) + if url is None: + url = f"./assets/template/{template_type}/{filename}" + headers = {"Cache-Control": "public, max-age=3600"} + return Response(content=json.dumps({"url": url}), media_type="application/json", headers=headers) + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +# Template API endpoints +@app.get("/assets/template/{template_type}/{filename}") +async def assets_template(template_type: str, filename: str): + """get template file - no authentication required""" + try: + if not await data_manager.template_file_exists(template_type, filename): + return error_response(f"template file {template_type} {filename} not found", 404) + data = await data_manager.load_template_file(template_type, filename) + + # set media type according to file type + if template_type == "images": + if filename.lower().endswith(".png"): + media_type = "image/png" + elif filename.lower().endswith((".jpg", ".jpeg")): + media_type = "image/jpeg" + else: + media_type = "application/octet-stream" + elif template_type == "audios": + if filename.lower().endswith(".mp3"): + media_type = "audio/mpeg" + elif filename.lower().endswith(".wav"): + media_type = "audio/wav" + else: + media_type = "application/octet-stream" + elif template_type == "videos": + if filename.lower().endswith(".mp4"): + media_type = "video/mp4" + elif filename.lower().endswith(".webm"): + media_type = "video/webm" + elif filename.lower().endswith(".avi"): + media_type = "video/x-msvideo" + else: + media_type = "video/mp4" # default to mp4 + else: + media_type = "application/octet-stream" + + headers = {"Cache-Control": "public, max-age=3600"} + return Response(content=data, media_type=media_type, headers=headers) + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/template/list") +async def api_v1_template_list(request: Request): + """get template file list (support pagination) - no authentication required""" + try: + # check page params + page = int(request.query_params.get("page", 1)) + page_size = int(request.query_params.get("page_size", 12)) + if page < 1 or page_size < 1: + return error_response("page and page_size must be greater than 0", 400) + # limit page size + page_size = min(page_size, 100) + + all_images = await data_manager.list_template_files("images") + all_audios = await data_manager.list_template_files("audios") + all_videos = await data_manager.list_template_files("videos") + all_images = [] if all_images is None else all_images + all_audios = [] if all_audios is None else all_audios + all_videos = [] if all_videos is None else all_videos + + # 创建图片文件名(不含扩展名)到图片信息的映射 + all_images_sorted = sorted(all_images) + image_map = {} # 文件名(不含扩展名) -> {"filename": 完整文件名, "url": URL} + for img_name in all_images_sorted: + img_name_without_ext = img_name.rsplit(".", 1)[0] if "." in img_name else img_name + url = await data_manager.presign_template_url("images", img_name) + if url is None: + url = f"./assets/template/images/{img_name}" + image_map[img_name_without_ext] = {"filename": img_name, "url": url} + + # 创建音频文件名(不含扩展名)到音频信息的映射 + all_audios_sorted = sorted(all_audios) + audio_map = {} # 文件名(不含扩展名) -> {"filename": 完整文件名, "url": URL} + for audio_name in all_audios_sorted: + audio_name_without_ext = audio_name.rsplit(".", 1)[0] if "." in audio_name else audio_name + url = await data_manager.presign_template_url("audios", audio_name) + if url is None: + url = f"./assets/template/audios/{audio_name}" + audio_map[audio_name_without_ext] = {"filename": audio_name, "url": url} + + # 合并音频和图片模板,基于文件名前缀匹配 + # 获取所有唯一的基础文件名(不含扩展名) + all_base_names = set(list(image_map.keys()) + list(audio_map.keys())) + all_base_names_sorted = sorted(all_base_names) + + # 构建合并后的模板列表 + merged_templates = [] + for base_name in all_base_names_sorted: + template_item = { + "id": base_name, # 使用基础文件名作为ID + "image": image_map.get(base_name), + "audio": audio_map.get(base_name), + } + merged_templates.append(template_item) + + # 分页处理 + total = len(merged_templates) + total_pages = (total + page_size - 1) // page_size if total > 0 else 1 + + paginated_templates = [] + if page <= total_pages: + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + paginated_templates = merged_templates[start_idx:end_idx] + + # 为了保持向后兼容,仍然返回images和audios字段(但可能为空) + # 同时添加新的merged字段 + return { + "templates": { + "images": [], # 保持向后兼容,但设为空 + "audios": [], # 保持向后兼容,但设为空 + "videos": [], # 保持向后兼容 + "merged": paginated_templates, # 新的合并列表 + }, + "pagination": {"page": page, "page_size": page_size, "total": total, "total_pages": total_pages}, + } + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/template/tasks") +async def api_v1_template_tasks(request: Request): + """get template task list (support pagination) - no authentication required""" + try: + # check page params + page = int(request.query_params.get("page", 1)) + page_size = int(request.query_params.get("page_size", 12)) + category = request.query_params.get("category", None) + search = request.query_params.get("search", None) + if page < 1 or page_size < 1: + return error_response("page and page_size must be greater than 0", 400) + # limit page size + page_size = min(page_size, 100) + + all_templates = [] + all_categories = set() + template_files = await data_manager.list_template_files("tasks") + template_files = [] if template_files is None else template_files + + for template_file in template_files: + try: + bytes_data = await data_manager.load_template_file("tasks", template_file) + template_data = json.loads(bytes_data) + template_data["task"]["model_cls"] = model_pipelines.outer_model_name(template_data["task"]["model_cls"]) + all_categories.update(template_data["task"]["tags"]) + if category and category not in template_data["task"]["tags"]: + continue + if search and search not in template_data["task"]["params"]["prompt"] + template_data["task"]["params"]["negative_prompt"] + template_data["task"]["model_cls"] + template_data["task"][ + "stage" + ] + template_data["task"]["task_type"] + ",".join(template_data["task"]["tags"]): + continue + all_templates.append(template_data["task"]) + except Exception as e: + logger.warning(f"Failed to load template file {template_file}: {e}") + + # page info + total_templates = len(all_templates) + total_pages = (total_templates + page_size - 1) // page_size + paginated_templates = [] + + if page <= total_pages: + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + paginated_templates = all_templates[start_idx:end_idx] + + return {"templates": paginated_templates, "pagination": {"page": page, "page_size": page_size, "total": total_templates, "total_pages": total_pages}, "categories": list(all_categories)} + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/template/{template_id}") +async def api_v1_template_get(template_id: str, user=None): + try: + template_files = await data_manager.list_template_files("tasks") + template_files = [] if template_files is None else template_files + + for template_file in template_files: + try: + bytes_data = await data_manager.load_template_file("tasks", template_file) + template_data = json.loads(bytes_data) + if template_data["task"]["task_id"] == template_id: + return template_data["task"] + except Exception as e: + logger.warning(f"Failed to load template file {template_file}: {e}") + continue + return error_response("Template not found", 404) + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.post("/api/v1/share/create") +async def api_v1_share_create(request: Request, user=Depends(verify_user_access)): + try: + params = await request.json() + task_id = params["task_id"] + valid_days = params.get("valid_days", 7) + auth_type = params.get("auth_type", "public") + auth_value = params.get("auth_value", "") + share_type = params.get("share_type", "task") + assert auth_type == "public", "Only public share is supported now" + + if share_type == "template": + template = await api_v1_template_get(task_id, user) + assert isinstance(template, dict) and template["task_id"] == task_id, f"Template {task_id} not found" + else: + task = await task_manager.query_task(task_id, user["user_id"], only_task=True) + assert task, f"Task {task_id} not found" + + if auth_type == "user_id": + assert auth_value != "", "Target user is required for auth_type = user_id" + target_user = await task_manager.query_user(auth_value) + assert target_user and target_user["user_id"] == auth_value, f"Target user {auth_value} not found" + + share_id = await task_manager.create_share(task_id, user["user_id"], share_type, valid_days, auth_type, auth_value) + return {"share_id": share_id, "share_url": f"/share/{share_id}"} + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/share/{share_id}") +async def api_v1_share_get(share_id: str): + try: + share_data = await task_manager.query_share(share_id) + assert share_data, f"Share {share_id} not found or expired or deleted" + task_id = share_data["task_id"] + share_type = share_data["share_type"] + assert share_data["auth_type"] == "public", "Only public share is supported now" + + if share_type == "template": + task = await api_v1_template_get(task_id, None) + assert isinstance(task, dict) and task["task_id"] == task_id, f"Template {task_id} not found" + else: + task = await task_manager.query_task(task_id, only_task=True) + assert task, f"Task {task_id} not found" + + user_info = await task_manager.query_user(share_data["user_id"]) + username = user_info.get("username", "用户") if user_info else "用户" + + share_info = { + "task_id": task_id, + "share_type": share_type, + "user_id": share_data["user_id"], + "username": username, + "task_type": task["task_type"], + "model_cls": task["model_cls"], + "stage": task["stage"], + "prompt": task["params"].get("prompt", ""), + "negative_prompt": task["params"].get("negative_prompt", ""), + "inputs": task["inputs"], + "outputs": task["outputs"], + "create_t": task["create_t"], + "valid_days": share_data["valid_days"], + "valid_t": share_data["valid_t"], + "auth_type": share_data["auth_type"], + "auth_value": share_data["auth_value"], + "output_video_url": None, + "input_urls": {}, + } + + for input_name, input_filename in task["inputs"].items(): + if share_type == "template": + template_type = "images" if "image" in input_name else "audios" + input_url = await data_manager.presign_template_url(template_type, input_filename) + else: + input_url = await data_manager.presign_url(input_filename) + share_info["input_urls"][input_name] = input_url + + for output_name, output_filename in task["outputs"].items(): + if share_type == "template": + assert "video" in output_name, "Only video output is supported for template share" + output_url = await data_manager.presign_template_url("videos", output_filename) + else: + output_url = await data_manager.presign_url(output_filename) + share_info["output_video_url"] = output_url + + return share_info + + except Exception as e: + traceback.print_exc() + return error_response(str(e), 500) + + +@app.get("/api/v1/voices/list") +async def api_v1_voices_list(request: Request): + try: + version = request.query_params.get("version", "all") + if volcengine_tts_client is None: + return error_response("Volcengine TTS client not loaded", 500) + voices = volcengine_tts_client.get_voice_list() + if voices is None: + return error_response("No voice list found", 404) + if version != "all": + voices = copy.deepcopy(voices) + voices["voices"] = [v for v in voices["voices"] if v["version"] == version] + return voices + except Exception as e: + traceback.print_exc() + return error_response("Failed to get voice list", 500) + + +@app.post("/api/v1/tts/generate") +async def api_v1_tts_generate(request: TTSRequest): + """Generate TTS audio from text""" + try: + # Validate parameters + if not request.text.strip(): + return JSONResponse({"error": "Text cannot be empty"}, status_code=400) + + if not request.voice_type: + return JSONResponse({"error": "Voice type is required"}, status_code=400) + + # Generate unique output filename + output_filename = f"tts_output_{uuid.uuid4().hex}.mp3" + output_path = os.path.join(tempfile.gettempdir(), output_filename) + + # Generate TTS + success = await volcengine_tts_client.tts_request( + text=request.text, + voice_type=request.voice_type, + context_texts=request.context_texts, + emotion=request.emotion, + emotion_scale=request.emotion_scale, + speech_rate=request.speech_rate, + loudness_rate=request.loudness_rate, + pitch=request.pitch, + output=output_path, + resource_id=request.resource_id, + ) + + if success and os.path.exists(output_path): + # Return the audio file + return FileResponse(output_path, media_type="audio/mpeg", filename=output_filename) + else: + return JSONResponse({"error": "TTS generation failed"}, status_code=500) + + except Exception as e: + logger.error(f"TTS generation error: {e}") + return JSONResponse({"error": f"TTS generation failed: {str(e)}"}, status_code=500) + + +@app.websocket("/api/v1/podcast/generate") +async def api_v1_podcast_generate_ws(websocket: WebSocket): + await websocket.accept() + + def ws_get_user_id(): + token = websocket.query_params.get("token") + if not token: + token = websocket.headers.get("authorization") or websocket.headers.get("Authorization") + if token and token.startswith("Bearer "): + token = token[7:] + payload = auth_manager.verify_jwt_token(token) + user_id = payload["user_id"] + return user_id + + async def safe_send_json(payload): + try: + await websocket.send_json(payload) + except (WebSocketDisconnect, RuntimeError) as e: + logger.warning(f"WebSocket send skipped: {e}") + + try: + user_id = ws_get_user_id() + data = await websocket.receive_text() + request_data = json.loads(data) + + # stop request + if request_data.get("type") == "stop": + logger.info("Received stop signal from client") + await safe_send_json({"type": "stopped"}) + return + + # user input prompt + input_text = request_data.get("input", "") + is_url = input_text.startswith(("http://", "https://")) + if not input_text: + await safe_send_json({"error": "输入不能为空"}) + return + + session_id = "session_" + str(uuid.uuid4()) + params = { + "session_id": session_id, + "data_manager": data_manager, + "text": "" if is_url else input_text, + "input_url": input_text if is_url else "", + "action": 0, + "use_head_music": False, + "use_tail_music": False, + "skip_round_audio_save": False, + } + logger.info(f"WebSocket generating podcast with params: {params}") + + # 使用回调函数实时推送音频 + async def on_round_complete(round_info): + await safe_send_json({"type": "audio_update", "data": round_info}) + + params["on_round_complete"] = on_round_complete + + # 创建一个任务来处理停止信号 + async def listen_for_stop(podcast_task): + while True: + try: + if podcast_task.done(): + return + data = await asyncio.wait_for(websocket.receive_text(), timeout=0.1) + request = json.loads(data) + if request.get("type") == "stop": + logger.warning("Stop signal received during podcast generation") + podcast_task.cancel() + return + except asyncio.TimeoutError: + continue + except Exception as e: + logger.warning(f"Stop listener ended: {e}") + return + + podcast_task = asyncio.create_task(volcengine_podcast_client.podcast_request(**params)) + stop_listener_task = asyncio.create_task(listen_for_stop(podcast_task)) + podcast_info = None + + try: + podcast_info = await podcast_task + except asyncio.CancelledError: + logger.warning("Podcast generation cancelled by user") + await safe_send_json({"type": "stopped"}) + return + finally: + stop_listener_task.cancel() + if podcast_info is None: + await safe_send_json({"error": "播客生成失败,请稍后重试"}) + return + + audio_path = podcast_info["audio_name"] + rounds = podcast_info["subtitles"] + await task_manager.create_podcast(session_id, user_id, input_text, audio_path, rounds) + audio_url = await data_manager.presign_podcast_output_url(audio_path) + if not audio_url: + audio_url = f"/api/v1/podcast/audio?session_id={session_id}&filename={audio_path}" + logger.info(f"completed podcast generation (session: {session_id})") + + await safe_send_json( + { + "type": "complete", + "data": { + "audio_url": audio_url, + "subtitles": podcast_info["subtitles"], + "session_id": session_id, + "user_id": user_id, + }, + } + ) + + except WebSocketDisconnect: + logger.info("WebSocket disconnected") + + except Exception: + logger.error(f"Error in websocket: {traceback.format_exc()}") + await safe_send_json({"error": "websocket internal error, please try again later!"}) + + +@app.get("/api/v1/podcast/audio") +async def api_v1_podcast_audio(request: Request, user=Depends(verify_user_access_from_query)): + try: + user_id = user["user_id"] + session_id = request.query_params.get("session_id") + filename = request.query_params.get("filename") + if not session_id or not filename: + return JSONResponse({"error": "session_id and filename are required"}, status_code=400) + + ext = os.path.splitext(filename)[1].lower() + assert ext == ".mp3", f"Unsupported file extension: {ext}" + + # 解析 Range 头,格式:bytes=start-end 或 bytes=start- + range_header = request.headers.get("Range") + start_byte, end_byte = None, None + if range_header: + match = re.match(r"bytes=(\d+)-(\d*)", range_header) + if match: + start_byte = int(match.group(1)) + end_byte = int(match.group(2)) + 1 if match.group(2) else None + + podcast_data = await task_manager.query_podcast(session_id, user_id) + if podcast_data: + # generate is finished and save info to database + func = data_manager.load_podcast_output_file + filename = podcast_data["audio_path"] + func_args = (filename,) + else: + func = data_manager.load_podcast_temp_session_file + func_args = (session_id, filename) + + logger.debug(f"Serving audio file from {func.__name__} with args: {func_args}, start_byte: {start_byte}, end_byte: {end_byte}") + file_bytes = await func(*func_args) + file_size = len(file_bytes) + file_bytes = file_bytes[start_byte:end_byte] + + content_length = len(file_bytes) + media_type = "audio/mpeg" + status_code = 200 + headers = {"Content-Length": str(content_length), "Accept-Ranges": "bytes", "Content-Type": media_type, "Content-Disposition": f'attachment; filename="{filename}"'} + + if start_byte is not None and start_byte > 0: + status_code = 206 # Partial Content + headers["Content-Range"] = f"bytes {start_byte}-{start_byte + content_length - 1}/{file_size}" + return Response(content=file_bytes, media_type=media_type, status_code=status_code, headers=headers) + + except Exception as e: + logger.error(f"Error serving audio: {e}") + traceback.print_exc() + return JSONResponse({"error": str(e)}, status_code=500) + + +@app.get("/api/v1/podcast/history") +async def api_v1_podcast_history(request: Request, user=Depends(verify_user_access)): + try: + user_id = user["user_id"] + page = int(request.query_params.get("page", 1)) + page_size = int(request.query_params.get("page_size", 10)) + assert page > 0 and page_size > 0, "page and page_size must be greater than 0" + status = request.query_params.get("status", None) # has_audio, no_audio + + query_params = {"user_id": user_id} + if status == "has_audio": + query_params["has_audio"] = True + elif status == "no_audio": + query_params["has_audio"] = False + + total_tasks = await task_manager.list_podcasts(count=True, **query_params) + total_pages = (total_tasks + page_size - 1) // page_size + page_info = {"page": page, "page_size": page_size, "total": total_tasks, "total_pages": total_pages} + if page > total_pages: + return {"sessions": [], "pagination": page_info} + + query_params["offset"] = (page - 1) * page_size + query_params["limit"] = page_size + sessions = await task_manager.list_podcasts(**query_params) + return {"sessions": sessions, "pagination": page_info} + + except Exception as e: + logger.error(f"Error getting podcast history: {e}") + traceback.print_exc() + return {"sessions": []} + + +@app.get("/api/v1/podcast/session/{session_id}/audio_url") +async def api_v1_podcast_session_audio_url(session_id: str, user=Depends(verify_user_access)): + try: + user_id = user["user_id"] + podcast_data = await task_manager.query_podcast(session_id, user_id) + if not podcast_data: + return JSONResponse({"error": "Podcast session not found"}, status_code=404) + + audio_path = podcast_data["audio_path"] + audio_url = await data_manager.presign_podcast_output_url(audio_path) + if not audio_url: + audio_url = f"/api/v1/podcast/audio?session_id={session_id}&filename={audio_path}" + return {"audio_url": audio_url} + + except Exception as e: + logger.error(f"Error getting podcast session audio URL: {e}") + traceback.print_exc() + return JSONResponse({"error": str(e)}, status_code=500) + + +class FaceDetectRequest(BaseModel): + image: str # Base64 encoded image + + +class AudioSeparateRequest(BaseModel): + audio: str # Base64 encoded audio + num_speakers: int = None # Optional: number of speakers to separate + + +@app.post("/api/v1/face/detect") +async def api_v1_face_detect(request: FaceDetectRequest, user=Depends(verify_user_access)): + """Detect faces in image (only detection, no cropping - cropping is done on frontend) + Supports both base64 encoded images and URLs (blob URLs, http URLs, etc.) + """ + try: + if not face_detector: + return error_response("Face detector not initialized", 500) + + # 验证输入 + if not request.image or not request.image.strip(): + logger.error("Face detection request: image is empty") + return error_response("Image input is empty", 400) + + image_bytes = None + try: + # Check if input is a URL (blob:, http:, https:, or data: URL) + if request.image.startswith(("http://", "https://")): + timeout = int(os.getenv("REQUEST_TIMEOUT", "10")) + image_bytes = await fetch_resource(request.image, timeout=timeout) + logger.debug(f"Fetched image from URL for face detection: {request.image[:100]}... (size: {len(image_bytes)} bytes)") + else: + encoded = request.image + # Data URL format: "data:image/png;base64,..." + if encoded.startswith("data:image"): + _, encoded = encoded.split(",", 1) + image_bytes = base64.b64decode(encoded) + logger.debug(f"Decoded base64 image: {request.image[:100]}... (size: {len(image_bytes)} bytes)") + + # Validate image format before passing to face detector + image_bytes = await asyncio.to_thread(format_image_data, image_bytes) + + except Exception as e: + logger.error(f"Failed to decode base64 image: {e}, image length: {len(request.image) if request.image else 0}") + return error_response(f"Invalid image format: {str(e)}", 400) + + # Detect faces only (no cropping) + result = await asyncio.to_thread(face_detector.detect_faces, image_bytes, return_image=False) + faces_data = [] + for i, face in enumerate(result["faces"]): + faces_data.append( + { + "index": i, + "bbox": face["bbox"], # [x1, y1, x2, y2] - absolute pixel coordinates in original image + "confidence": face["confidence"], + "class_id": face["class_id"], + "class_name": face["class_name"], + # Note: face_image is not included - frontend will crop it based on bbox + } + ) + return {"faces": faces_data, "total": len(faces_data)} + + except Exception as e: + logger.error(f"Face detection error: {traceback.format_exc()}") + return error_response(f"Face detection failed: {str(e)}", 500) + + +@app.post("/api/v1/audio/separate") +async def api_v1_audio_separate(request: AudioSeparateRequest, user=Depends(verify_user_access)): + """Separate different speakers in audio""" + try: + if not audio_separator: + return error_response("Audio separator not initialized", 500) + audio_bytes = None + try: + encoded = request.audio + if encoded.startswith("data:"): + # Remove data URL prefix (e.g., "data:audio/mpeg;base64," or "data:application/octet-stream;base64,") + _, encoded = encoded.split(",", 1) + audio_bytes = await asyncio.to_thread(base64.b64decode, encoded, validate=True) + logger.debug(f"Successfully decoded base64 audio, size: {len(audio_bytes)} bytes") + + except Exception as e: + logger.error(f"Failed to decode base64 audio {request.audio[:100]}..., error: {str(e)}") + return error_response(f"Invalid base64 audio data", 400) + + # Separate speakers + result = await asyncio.to_thread(audio_separator.separate_speakers, audio_bytes, num_speakers=request.num_speakers) + + # Convert audio tensors to base64 strings (without saving to file) + speakers_data = [] + for speaker in result["speakers"]: + # Convert audio tensor directly to base64 + audio_base64 = await asyncio.to_thread(audio_separator.speaker_audio_to_base64, speaker["audio"], speaker["sample_rate"], format="wav") + speakers_data.append( + { + "speaker_id": speaker["speaker_id"], + "audio": audio_base64, # Base64 encoded audio + "segments": speaker["segments"], + "sample_rate": speaker["sample_rate"], + } + ) + return {"speakers": speakers_data, "total": len(speakers_data), "method": result.get("method", "pyannote")} + + except Exception as e: + logger.error(f"Audio separation error: {traceback.format_exc()}") + return error_response(f"Audio separation failed: {str(e)}", 500) + + +# 所有未知路由 fallback 到 index.html (必须在所有API路由之后) +@app.get("/{full_path:path}", response_class=HTMLResponse) +async def vue_fallback(full_path: str): + index_path = os.path.join(static_dir, "index.html") + if os.path.exists(index_path): + return FileResponse(index_path) + return HTMLResponse("

Frontend not found

", status_code=404) + + +# ========================= +# Main Entry +# ========================= + +if __name__ == "__main__": + ProcessManager.register_signal_handler() + parser = argparse.ArgumentParser() + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + base_dir = os.path.abspath(os.path.join(cur_dir, "../../..")) + dft_pipeline_json = os.path.join(base_dir, "configs/model_pipeline.json") + dft_task_url = os.path.join(base_dir, "local_task") + dft_data_url = os.path.join(base_dir, "local_data") + dft_queue_url = os.path.join(base_dir, "local_queue") + dft_volcengine_tts_list_json = os.path.join(base_dir, "configs/volcengine_voices_list.json") + + parser.add_argument("--pipeline_json", type=str, default=dft_pipeline_json) + parser.add_argument("--task_url", type=str, default=dft_task_url) + parser.add_argument("--data_url", type=str, default=dft_data_url) + parser.add_argument("--queue_url", type=str, default=dft_queue_url) + parser.add_argument("--redis_url", type=str, default="") + parser.add_argument("--template_dir", type=str, default="") + parser.add_argument("--volcengine_tts_list_json", type=str, default=dft_volcengine_tts_list_json) + parser.add_argument("--ip", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=8080) + parser.add_argument("--face_detector_model_path", type=str, default=None) + parser.add_argument("--audio_separator_model_path", type=str, default="") + args = parser.parse_args() + logger.info(f"args: {args}") + + model_pipelines = Pipeline(args.pipeline_json) + volcengine_tts_client = VolcEngineTTSClient(args.volcengine_tts_list_json) + volcengine_podcast_client = VolcEnginePodcastClient() + face_detector = FaceDetector(model_path=args.face_detector_model_path) + audio_separator = AudioSeparator(model_path=args.audio_separator_model_path) + auth_manager = AuthManager() + if args.task_url.startswith("/"): + task_manager = LocalTaskManager(args.task_url, metrics_monitor) + elif args.task_url.startswith("postgresql://"): + task_manager = PostgresSQLTaskManager(args.task_url, metrics_monitor) + else: + raise NotImplementedError + if args.data_url.startswith("/"): + data_manager = LocalDataManager(args.data_url, args.template_dir) + elif args.data_url.startswith("{"): + data_manager = S3DataManager(args.data_url, args.template_dir) + else: + raise NotImplementedError + if args.queue_url.startswith("/"): + queue_manager = LocalQueueManager(args.queue_url) + elif args.queue_url.startswith("amqp://"): + queue_manager = RabbitMQQueueManager(args.queue_url) + else: + raise NotImplementedError + if args.redis_url: + server_monitor = RedisServerMonitor(model_pipelines, task_manager, queue_manager, args.redis_url) + else: + server_monitor = ServerMonitor(model_pipelines, task_manager, queue_manager) + + uvicorn.run(app, host=args.ip, port=args.port, reload=False, workers=1) diff --git a/lightx2v/deploy/server/auth.py b/lightx2v/deploy/server/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..11801257463992bc102bed480314943db147e3eb --- /dev/null +++ b/lightx2v/deploy/server/auth.py @@ -0,0 +1,205 @@ +import os +import time +import uuid + +import aiohttp +import jwt +from fastapi import HTTPException +from loguru import logger + +from lightx2v.deploy.common.aliyun import AlibabaCloudClient + + +class AuthManager: + def __init__(self): + # Worker access token + self.worker_secret_key = os.getenv("WORKER_SECRET_KEY", "worker-secret-key-change-in-production") + + # GitHub OAuth + self.github_client_id = os.getenv("GITHUB_CLIENT_ID", "") + self.github_client_secret = os.getenv("GITHUB_CLIENT_SECRET", "") + + # Google OAuth + self.google_client_id = os.getenv("GOOGLE_CLIENT_ID", "") + self.google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", "") + self.google_redirect_uri = os.getenv("GOOGLE_REDIRECT_URI", "") + + self.jwt_algorithm = os.getenv("JWT_ALGORITHM", "HS256") + self.jwt_secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production") + self.jwt_expiration_hours = int(os.getenv("JWT_EXPIRATION_HOURS", "168")) + self.refresh_token_expiration_days = int(os.getenv("REFRESH_TOKEN_EXPIRATION_DAYS", "30")) + self.refresh_jwt_secret_key = os.getenv("REFRESH_JWT_SECRET_KEY", self.jwt_secret_key) + + # Aliyun SMS + self.aliyun_client = AlibabaCloudClient() + + logger.info(f"AuthManager: GITHUB_CLIENT_ID: {self.github_client_id}") + logger.info(f"AuthManager: GITHUB_CLIENT_SECRET: {self.github_client_secret}") + logger.info(f"AuthManager: GOOGLE_CLIENT_ID: {self.google_client_id}") + logger.info(f"AuthManager: GOOGLE_CLIENT_SECRET: {self.google_client_secret}") + logger.info(f"AuthManager: GOOGLE_REDIRECT_URI: {self.google_redirect_uri}") + logger.info(f"AuthManager: JWT_SECRET_KEY: {self.jwt_secret_key}") + logger.info(f"AuthManager: WORKER_SECRET_KEY: {self.worker_secret_key}") + + def _create_token(self, data, expires_in_seconds, token_type, secret_key): + now = int(time.time()) + payload = { + "user_id": data["user_id"], + "username": data["username"], + "email": data["email"], + "homepage": data["homepage"], + "token_type": token_type, + "iat": now, + "exp": now + expires_in_seconds, + "jti": str(uuid.uuid4()), + } + return jwt.encode(payload, secret_key, algorithm=self.jwt_algorithm) + + def create_access_token(self, data): + return self._create_token(data, self.jwt_expiration_hours * 3600, "access", self.jwt_secret_key) + + def create_refresh_token(self, data): + return self._create_token(data, self.refresh_token_expiration_days * 24 * 3600, "refresh", self.refresh_jwt_secret_key) + + def create_tokens(self, data): + return self.create_access_token(data), self.create_refresh_token(data) + + def create_jwt_token(self, data): + # Backwards compatibility for callers that still expect this name + return self.create_access_token(data) + + async def auth_github(self, code): + try: + logger.info(f"GitHub OAuth code: {code}") + token_url = "https://github.com/login/oauth/access_token" + token_data = {"client_id": self.github_client_id, "client_secret": self.github_client_secret, "code": code} + headers = {"Accept": "application/json"} + + proxy = os.getenv("auth_https_proxy", None) + if proxy: + logger.info(f"auth_github use proxy: {proxy}") + async with aiohttp.ClientSession() as session: + async with session.post(token_url, data=token_data, headers=headers, proxy=proxy) as response: + response.raise_for_status() + token_info = await response.json() + + if "error" in token_info: + raise HTTPException(status_code=400, detail=f"GitHub OAuth error: {token_info['error']}") + + access_token = token_info.get("access_token") + if not access_token: + raise HTTPException(status_code=400, detail="Failed to get access token") + + user_url = "https://api.github.com/user" + user_headers = {"Authorization": f"token {access_token}", "Accept": "application/vnd.github.v3+json"} + async with aiohttp.ClientSession() as session: + async with session.get(user_url, headers=user_headers, proxy=proxy) as response: + response.raise_for_status() + user_info = await response.json() + + return { + "source": "github", + "id": str(user_info["id"]), + "username": user_info["login"], + "email": user_info.get("email", ""), + "homepage": user_info.get("html_url", ""), + "avatar_url": user_info.get("avatar_url", ""), + } + + except aiohttp.ClientError as e: + logger.error(f"GitHub API request failed: {e}") + raise HTTPException(status_code=500, detail="Failed to authenticate with GitHub") + + except Exception as e: + logger.error(f"Authentication error: {e}") + raise HTTPException(status_code=500, detail="Authentication failed") + + async def auth_google(self, code): + try: + logger.info(f"Google OAuth code: {code}") + token_url = "https://oauth2.googleapis.com/token" + token_data = { + "client_id": self.google_client_id, + "client_secret": self.google_client_secret, + "code": code, + "redirect_uri": self.google_redirect_uri, + "grant_type": "authorization_code", + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + proxy = os.getenv("auth_https_proxy", None) + if proxy: + logger.info(f"auth_google use proxy: {proxy}") + async with aiohttp.ClientSession() as session: + async with session.post(token_url, data=token_data, headers=headers, proxy=proxy) as response: + response.raise_for_status() + token_info = await response.json() + + if "error" in token_info: + raise HTTPException(status_code=400, detail=f"Google OAuth error: {token_info['error']}") + + access_token = token_info.get("access_token") + if not access_token: + raise HTTPException(status_code=400, detail="Failed to get access token") + + # get user info + user_url = "https://www.googleapis.com/oauth2/v2/userinfo" + user_headers = {"Authorization": f"Bearer {access_token}"} + async with aiohttp.ClientSession() as session: + async with session.get(user_url, headers=user_headers, proxy=proxy) as response: + response.raise_for_status() + user_info = await response.json() + return { + "source": "google", + "id": str(user_info["id"]), + "username": user_info.get("name", user_info.get("email", "")), + "email": user_info.get("email", ""), + "homepage": user_info.get("link", ""), + "avatar_url": user_info.get("picture", ""), + } + + except aiohttp.ClientError as e: + logger.error(f"Google API request failed: {e}") + raise HTTPException(status_code=500, detail="Failed to authenticate with Google") + + except Exception as e: + logger.error(f"Google authentication error: {e}") + raise HTTPException(status_code=500, detail="Google authentication failed") + + async def send_sms(self, phone_number): + return await self.aliyun_client.send_sms(phone_number) + + async def check_sms(self, phone_number, verify_code): + ok = await self.aliyun_client.check_sms(phone_number, verify_code) + if not ok: + return None + return { + "source": "phone", + "id": phone_number, + "username": phone_number, + "email": "", + "homepage": "", + "avatar_url": "", + } + + def _verify_token(self, token, expected_type, secret_key): + try: + payload = jwt.decode(token, secret_key, algorithms=[self.jwt_algorithm]) + token_type = payload.get("token_type") + if token_type and token_type != expected_type: + raise HTTPException(status_code=401, detail="Token type mismatch") + return payload + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token has expired") + except Exception as e: + logger.error(f"verify_jwt_token error: {e}") + raise HTTPException(status_code=401, detail="Could not validate credentials") + + def verify_jwt_token(self, token): + return self._verify_token(token, "access", self.jwt_secret_key) + + def verify_refresh_token(self, token): + return self._verify_token(token, "refresh", self.refresh_jwt_secret_key) + + def verify_worker_token(self, token): + return token == self.worker_secret_key diff --git a/lightx2v/deploy/server/frontend/.gitignore b/lightx2v/deploy/server/frontend/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a547bf36d8d11a4f89c59c144f24795749086dd1 --- /dev/null +++ b/lightx2v/deploy/server/frontend/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/lightx2v/deploy/server/frontend/README.md b/lightx2v/deploy/server/frontend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1511959c22b74118c06679739cf9abe2ceeea48c --- /dev/null +++ b/lightx2v/deploy/server/frontend/README.md @@ -0,0 +1,5 @@ +# Vue 3 + Vite + +This template should help get you started developing with Vue 3 in Vite. The template uses Vue 3 ` + + + +
+ +
+
+

LightX2V

+

免费、轻量、快速的AI数字人视频生成平台,由 Light AI 工具链提供端到端加速支持。

+

了解更多关于工具链与最新动态,请访问 Light AI 官网LightX2V GitHub

+
+
+

功能亮点

+
    +
  • 电影级数字人视频生成
  • +
  • 20倍生成提速
  • +
  • 超低成本生成
  • +
  • 精准口型对齐
  • +
  • 分钟级视频时长
  • +
  • 多场景应用
  • +
  • 最新tts语音合成技术,支持多种语言,支持100+种音色,支持语音指令控制合成语音细节
  • +
+
+
+

快速开始

+
    +
  1. 上传图片及音频,输入视频生成提示词,点击开始生成
  2. +
  3. 生成并下载视频
  4. +
  5. 应用模版,一键生成同款数字人视频
  6. +
+
+
+
+ + + diff --git a/lightx2v/deploy/server/frontend/package-lock.json b/lightx2v/deploy/server/frontend/package-lock.json new file mode 100644 index 0000000000000000000000000000000000000000..fa8a049b09b361b7696b4a5186781956d8dd1106 --- /dev/null +++ b/lightx2v/deploy/server/frontend/package-lock.json @@ -0,0 +1,2143 @@ +{ + "name": "frontend", + "version": "0.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "frontend", + "version": "0.0.0", + "dependencies": { + "@flaticon/flaticon-uicons": "^3.3.1", + "@headlessui/vue": "^1.7.23", + "@heroicons/vue": "^2.2.0", + "@tailwindcss/vite": "^4.1.13", + "pinia": "^3.0.3", + "tailwindcss": "^4.1.13", + "vue": "^3.5.21", + "vue-i18n": "^11.1.12", + "vue-router": "^4.5.1" + }, + "devDependencies": { + "@vitejs/plugin-vue": "^6.0.1", + "vite": "^7.1.7" + } + }, + "node_modules/@babel/helper-string-parser": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-validator-identifier": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.27.1.tgz", + "integrity": "sha512-D2hP9eA+Sqx1kBZgzxZh0y1trbuU+JoDkiEwqhQ36nodYqJwyEIhPSdMNd7lOm/4io72luTPWH20Yda0xOuUow==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/parser": { + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.4.tgz", + "integrity": "sha512-yZbBqeM6TkpP9du/I2pUZnJsRMGGvOuIrhjzC1AwHwW+6he4mni6Bp/m8ijn0iOuZuPI2BfkCoSRunpyjnrQKg==", + "license": "MIT", + "dependencies": { + "@babel/types": "^7.28.4" + }, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/types": { + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.4.tgz", + "integrity": "sha512-bkFqkLhh3pMBUQQkpVgWDWq/lqzc2678eUyDlTBhRqhCHFguYYGM0Efga7tYk4TogG/3x0EEl66/OQ+WGbWB/Q==", + "license": "MIT", + "dependencies": { + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.27.1" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@esbuild/aix-ppc64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.11.tgz", + "integrity": "sha512-Xt1dOL13m8u0WE8iplx9Ibbm+hFAO0GsU2P34UNoDGvZYkY8ifSiy6Zuc1lYxfG7svWE2fzqCUmFp5HCn51gJg==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.11.tgz", + "integrity": "sha512-uoa7dU+Dt3HYsethkJ1k6Z9YdcHjTrSb5NUy66ZfZaSV8hEYGD5ZHbEMXnqLFlbBflLsl89Zke7CAdDJ4JI+Gg==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.11.tgz", + "integrity": "sha512-9slpyFBc4FPPz48+f6jyiXOx/Y4v34TUeDDXJpZqAWQn/08lKGeD8aDp9TMn9jDz2CiEuHwfhRmGBvpnd/PWIQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.11.tgz", + "integrity": "sha512-Sgiab4xBjPU1QoPEIqS3Xx+R2lezu0LKIEcYe6pftr56PqPygbB7+szVnzoShbx64MUupqoE0KyRlN7gezbl8g==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.11.tgz", + "integrity": "sha512-VekY0PBCukppoQrycFxUqkCojnTQhdec0vevUL/EDOCnXd9LKWqD/bHwMPzigIJXPhC59Vd1WFIL57SKs2mg4w==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.11.tgz", + "integrity": "sha512-+hfp3yfBalNEpTGp9loYgbknjR695HkqtY3d3/JjSRUyPg/xd6q+mQqIb5qdywnDxRZykIHs3axEqU6l1+oWEQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.11.tgz", + "integrity": "sha512-CmKjrnayyTJF2eVuO//uSjl/K3KsMIeYeyN7FyDBjsR3lnSJHaXlVoAK8DZa7lXWChbuOk7NjAc7ygAwrnPBhA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.11.tgz", + "integrity": "sha512-Dyq+5oscTJvMaYPvW3x3FLpi2+gSZTCE/1ffdwuM6G1ARang/mb3jvjxs0mw6n3Lsw84ocfo9CrNMqc5lTfGOw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.11.tgz", + "integrity": "sha512-TBMv6B4kCfrGJ8cUPo7vd6NECZH/8hPpBHHlYI3qzoYFvWu2AdTvZNuU/7hsbKWqu/COU7NIK12dHAAqBLLXgw==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.11.tgz", + "integrity": "sha512-Qr8AzcplUhGvdyUF08A1kHU3Vr2O88xxP0Tm8GcdVOUm25XYcMPp2YqSVHbLuXzYQMf9Bh/iKx7YPqECs6ffLA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.11.tgz", + "integrity": "sha512-TmnJg8BMGPehs5JKrCLqyWTVAvielc615jbkOirATQvWWB1NMXY77oLMzsUjRLa0+ngecEmDGqt5jiDC6bfvOw==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.11.tgz", + "integrity": "sha512-DIGXL2+gvDaXlaq8xruNXUJdT5tF+SBbJQKbWy/0J7OhU8gOHOzKmGIlfTTl6nHaCOoipxQbuJi7O++ldrxgMw==", + "cpu": [ + "loong64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.11.tgz", + "integrity": "sha512-Osx1nALUJu4pU43o9OyjSCXokFkFbyzjXb6VhGIJZQ5JZi8ylCQ9/LFagolPsHtgw6himDSyb5ETSfmp4rpiKQ==", + "cpu": [ + "mips64el" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.11.tgz", + "integrity": "sha512-nbLFgsQQEsBa8XSgSTSlrnBSrpoWh7ioFDUmwo158gIm5NNP+17IYmNWzaIzWmgCxq56vfr34xGkOcZ7jX6CPw==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.11.tgz", + "integrity": "sha512-HfyAmqZi9uBAbgKYP1yGuI7tSREXwIb438q0nqvlpxAOs3XnZ8RsisRfmVsgV486NdjD7Mw2UrFSw51lzUk1ww==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.11.tgz", + "integrity": "sha512-HjLqVgSSYnVXRisyfmzsH6mXqyvj0SA7pG5g+9W7ESgwA70AXYNpfKBqh1KbTxmQVaYxpzA/SvlB9oclGPbApw==", + "cpu": [ + "s390x" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.11.tgz", + "integrity": "sha512-HSFAT4+WYjIhrHxKBwGmOOSpphjYkcswF449j6EjsjbinTZbp8PJtjsVK1XFJStdzXdy/jaddAep2FGY+wyFAQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-arm64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.11.tgz", + "integrity": "sha512-hr9Oxj1Fa4r04dNpWr3P8QKVVsjQhqrMSUzZzf+LZcYjZNqhA3IAfPQdEh1FLVUJSiu6sgAwp3OmwBfbFgG2Xg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.11.tgz", + "integrity": "sha512-u7tKA+qbzBydyj0vgpu+5h5AeudxOAGncb8N6C9Kh1N4n7wU1Xw1JDApsRjpShRpXRQlJLb9wY28ELpwdPcZ7A==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-arm64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.11.tgz", + "integrity": "sha512-Qq6YHhayieor3DxFOoYM1q0q1uMFYb7cSpLD2qzDSvK1NAvqFi8Xgivv0cFC6J+hWVw2teCYltyy9/m/14ryHg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.11.tgz", + "integrity": "sha512-CN+7c++kkbrckTOz5hrehxWN7uIhFFlmS/hqziSFVWpAzpWrQoAG4chH+nN3Be+Kzv/uuo7zhX716x3Sn2Jduw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openharmony-arm64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/openharmony-arm64/-/openharmony-arm64-0.25.11.tgz", + "integrity": "sha512-rOREuNIQgaiR+9QuNkbkxubbp8MSO9rONmwP5nKncnWJ9v5jQ4JxFnLu4zDSRPf3x4u+2VN4pM4RdyIzDty/wQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.11.tgz", + "integrity": "sha512-nq2xdYaWxyg9DcIyXkZhcYulC6pQ2FuCgem3LI92IwMgIZ69KHeY8T4Y88pcwoLIjbed8n36CyKoYRDygNSGhA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.11.tgz", + "integrity": "sha512-3XxECOWJq1qMZ3MN8srCJ/QfoLpL+VaxD/WfNRm1O3B4+AZ/BnLVgFbUV3eiRYDMXetciH16dwPbbHqwe1uU0Q==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.11.tgz", + "integrity": "sha512-3ukss6gb9XZ8TlRyJlgLn17ecsK4NSQTmdIXRASVsiS2sQ6zPPZklNJT5GR5tE/MUarymmy8kCEf5xPCNCqVOA==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.11.tgz", + "integrity": "sha512-D7Hpz6A2L4hzsRpPaCYkQnGOotdUpDzSGRIv9I+1ITdHROSFUWW95ZPZWQmGka1Fg7W3zFJowyn9WGwMJ0+KPA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@flaticon/flaticon-uicons": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/@flaticon/flaticon-uicons/-/flaticon-uicons-3.3.1.tgz", + "integrity": "sha512-WN2zuECCdjuGBQrjzN0kpeSygzC5fgF8Q7pDR+FUuGtYWczSdIhIwoD+/fKBEfwqKfNIMZ1WouidevGQ4OJORg==", + "license": "SEE LICENSE IN LICENSE", + "optionalDependencies": { + "esbuild-linux-64": "^0.14.5" + } + }, + "node_modules/@headlessui/vue": { + "version": "1.7.23", + "resolved": "https://registry.npmjs.org/@headlessui/vue/-/vue-1.7.23.tgz", + "integrity": "sha512-JzdCNqurrtuu0YW6QaDtR2PIYCKPUWq28csDyMvN4zmGccmE7lz40Is6hc3LA4HFeCI7sekZ/PQMTNmn9I/4Wg==", + "license": "MIT", + "dependencies": { + "@tanstack/vue-virtual": "^3.0.0-beta.60" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "vue": "^3.2.0" + } + }, + "node_modules/@heroicons/vue": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@heroicons/vue/-/vue-2.2.0.tgz", + "integrity": "sha512-G3dbSxoeEKqbi/DFalhRxJU4mTXJn7GwZ7ae8NuEQzd1bqdd0jAbdaBZlHPcvPD2xI1iGzNVB4k20Un2AguYPw==", + "license": "MIT", + "peerDependencies": { + "vue": ">= 3" + } + }, + "node_modules/@intlify/core-base": { + "version": "11.1.12", + "resolved": "https://registry.npmjs.org/@intlify/core-base/-/core-base-11.1.12.tgz", + "integrity": "sha512-whh0trqRsSqVLNEUCwU59pyJZYpU8AmSWl8M3Jz2Mv5ESPP6kFh4juas2NpZ1iCvy7GlNRffUD1xr84gceimjg==", + "license": "MIT", + "dependencies": { + "@intlify/message-compiler": "11.1.12", + "@intlify/shared": "11.1.12" + }, + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://github.com/sponsors/kazupon" + } + }, + "node_modules/@intlify/message-compiler": { + "version": "11.1.12", + "resolved": "https://registry.npmjs.org/@intlify/message-compiler/-/message-compiler-11.1.12.tgz", + "integrity": "sha512-Fv9iQSJoJaXl4ZGkOCN1LDM3trzze0AS2zRz2EHLiwenwL6t0Ki9KySYlyr27yVOj5aVz0e55JePO+kELIvfdQ==", + "license": "MIT", + "dependencies": { + "@intlify/shared": "11.1.12", + "source-map-js": "^1.0.2" + }, + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://github.com/sponsors/kazupon" + } + }, + "node_modules/@intlify/shared": { + "version": "11.1.12", + "resolved": "https://registry.npmjs.org/@intlify/shared/-/shared-11.1.12.tgz", + "integrity": "sha512-Om86EjuQtA69hdNj3GQec9ZC0L0vPSAnXzB3gP/gyJ7+mA7t06d9aOAiqMZ+xEOsumGP4eEBlfl8zF2LOTzf2A==", + "license": "MIT", + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://github.com/sponsors/kazupon" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.13", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", + "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0", + "@jridgewell/trace-mapping": "^0.3.24" + } + }, + "node_modules/@jridgewell/remapping": { + "version": "2.3.5", + "resolved": "https://registry.npmjs.org/@jridgewell/remapping/-/remapping-2.3.5.tgz", + "integrity": "sha512-LI9u/+laYG4Ds1TDKSJW2YPrIlcVYOwi2fUC6xB43lueCjgxV4lffOCZCtYFiH6TNOX+tQKXx97T4IKHbhyHEQ==", + "license": "MIT", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.24" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.5", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", + "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", + "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@rolldown/pluginutils": { + "version": "1.0.0-beta.29", + "resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-beta.29.tgz", + "integrity": "sha512-NIJgOsMjbxAXvoGq/X0gD7VPMQ8j9g0BiDaNjVNVjvl+iKXxL3Jre0v31RmBYeLEmkbj2s02v8vFTbUXi5XS2Q==", + "dev": true, + "license": "MIT" + }, + "node_modules/@rollup/rollup-android-arm-eabi": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.52.5.tgz", + "integrity": "sha512-8c1vW4ocv3UOMp9K+gToY5zL2XiiVw3k7f1ksf4yO1FlDFQ1C2u72iACFnSOceJFsWskc2WZNqeRhFRPzv+wtQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-android-arm64": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.52.5.tgz", + "integrity": "sha512-mQGfsIEFcu21mvqkEKKu2dYmtuSZOBMmAl5CFlPGLY94Vlcm+zWApK7F/eocsNzp8tKmbeBP8yXyAbx0XHsFNA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-darwin-arm64": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.52.5.tgz", + "integrity": "sha512-takF3CR71mCAGA+v794QUZ0b6ZSrgJkArC+gUiG6LB6TQty9T0Mqh3m2ImRBOxS2IeYBo4lKWIieSvnEk2OQWA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-darwin-x64": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.52.5.tgz", + "integrity": "sha512-W901Pla8Ya95WpxDn//VF9K9u2JbocwV/v75TE0YIHNTbhqUTv9w4VuQ9MaWlNOkkEfFwkdNhXgcLqPSmHy0fA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-freebsd-arm64": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.52.5.tgz", + "integrity": "sha512-QofO7i7JycsYOWxe0GFqhLmF6l1TqBswJMvICnRUjqCx8b47MTo46W8AoeQwiokAx3zVryVnxtBMcGcnX12LvA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-freebsd-x64": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.52.5.tgz", + "integrity": "sha512-jr21b/99ew8ujZubPo9skbrItHEIE50WdV86cdSoRkKtmWa+DDr6fu2c/xyRT0F/WazZpam6kk7IHBerSL7LDQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-linux-arm-gnueabihf": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.52.5.tgz", + "integrity": "sha512-PsNAbcyv9CcecAUagQefwX8fQn9LQ4nZkpDboBOttmyffnInRy8R8dSg6hxxl2Re5QhHBf6FYIDhIj5v982ATQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm-musleabihf": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.52.5.tgz", + "integrity": "sha512-Fw4tysRutyQc/wwkmcyoqFtJhh0u31K+Q6jYjeicsGJJ7bbEq8LwPWV/w0cnzOqR2m694/Af6hpFayLJZkG2VQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-gnu": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.52.5.tgz", + "integrity": "sha512-a+3wVnAYdQClOTlyapKmyI6BLPAFYs0JM8HRpgYZQO02rMR09ZcV9LbQB+NL6sljzG38869YqThrRnfPMCDtZg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-musl": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.52.5.tgz", + "integrity": "sha512-AvttBOMwO9Pcuuf7m9PkC1PUIKsfaAJ4AYhy944qeTJgQOqJYJ9oVl2nYgY7Rk0mkbsuOpCAYSs6wLYB2Xiw0Q==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-loong64-gnu": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.52.5.tgz", + "integrity": "sha512-DkDk8pmXQV2wVrF6oq5tONK6UHLz/XcEVow4JTTerdeV1uqPeHxwcg7aFsfnSm9L+OO8WJsWotKM2JJPMWrQtA==", + "cpu": [ + "loong64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-gnu": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.52.5.tgz", + "integrity": "sha512-W/b9ZN/U9+hPQVvlGwjzi+Wy4xdoH2I8EjaCkMvzpI7wJUs8sWJ03Rq96jRnHkSrcHTpQe8h5Tg3ZzUPGauvAw==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-gnu": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.52.5.tgz", + "integrity": "sha512-sjQLr9BW7R/ZiXnQiWPkErNfLMkkWIoCz7YMn27HldKsADEKa5WYdobaa1hmN6slu9oWQbB6/jFpJ+P2IkVrmw==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-musl": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.52.5.tgz", + "integrity": "sha512-hq3jU/kGyjXWTvAh2awn8oHroCbrPm8JqM7RUpKjalIRWWXE01CQOf/tUNWNHjmbMHg/hmNCwc/Pz3k1T/j/Lg==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-s390x-gnu": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.52.5.tgz", + "integrity": "sha512-gn8kHOrku8D4NGHMK1Y7NA7INQTRdVOntt1OCYypZPRt6skGbddska44K8iocdpxHTMMNui5oH4elPH4QOLrFQ==", + "cpu": [ + "s390x" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.52.5.tgz", + "integrity": "sha512-hXGLYpdhiNElzN770+H2nlx+jRog8TyynpTVzdlc6bndktjKWyZyiCsuDAlpd+j+W+WNqfcyAWz9HxxIGfZm1Q==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-musl": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.52.5.tgz", + "integrity": "sha512-arCGIcuNKjBoKAXD+y7XomR9gY6Mw7HnFBv5Rw7wQRvwYLR7gBAgV7Mb2QTyjXfTveBNFAtPt46/36vV9STLNg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-openharmony-arm64": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.52.5.tgz", + "integrity": "sha512-QoFqB6+/9Rly/RiPjaomPLmR/13cgkIGfA40LHly9zcH1S0bN2HVFYk3a1eAyHQyjs3ZJYlXvIGtcCs5tko9Cw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ] + }, + "node_modules/@rollup/rollup-win32-arm64-msvc": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.52.5.tgz", + "integrity": "sha512-w0cDWVR6MlTstla1cIfOGyl8+qb93FlAVutcor14Gf5Md5ap5ySfQ7R9S/NjNaMLSFdUnKGEasmVnu3lCMqB7w==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-ia32-msvc": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.52.5.tgz", + "integrity": "sha512-Aufdpzp7DpOTULJCuvzqcItSGDH73pF3ko/f+ckJhxQyHtp67rHw3HMNxoIdDMUITJESNE6a8uh4Lo4SLouOUg==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-gnu": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.52.5.tgz", + "integrity": "sha512-UGBUGPFp1vkj6p8wCRraqNhqwX/4kNQPS57BCFc8wYh0g94iVIW33wJtQAx3G7vrjjNtRaxiMUylM0ktp/TRSQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-msvc": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.52.5.tgz", + "integrity": "sha512-TAcgQh2sSkykPRWLrdyy2AiceMckNf5loITqXxFI5VuQjS5tSuw3WlwdN8qv8vzjLAUTvYaH/mVjSFpbkFbpTg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@tailwindcss/node": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/node/-/node-4.1.15.tgz", + "integrity": "sha512-HF4+7QxATZWY3Jr8OlZrBSXmwT3Watj0OogeDvdUY/ByXJHQ+LBtqA2brDb3sBxYslIFx6UP94BJ4X6a4L9Bmw==", + "license": "MIT", + "dependencies": { + "@jridgewell/remapping": "^2.3.4", + "enhanced-resolve": "^5.18.3", + "jiti": "^2.6.0", + "lightningcss": "1.30.2", + "magic-string": "^0.30.19", + "source-map-js": "^1.2.1", + "tailwindcss": "4.1.15" + } + }, + "node_modules/@tailwindcss/oxide": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide/-/oxide-4.1.15.tgz", + "integrity": "sha512-krhX+UOOgnsUuks2SR7hFafXmLQrKxB4YyRTERuCE59JlYL+FawgaAlSkOYmDRJdf1Q+IFNDMl9iRnBW7QBDfQ==", + "license": "MIT", + "engines": { + "node": ">= 10" + }, + "optionalDependencies": { + "@tailwindcss/oxide-android-arm64": "4.1.15", + "@tailwindcss/oxide-darwin-arm64": "4.1.15", + "@tailwindcss/oxide-darwin-x64": "4.1.15", + "@tailwindcss/oxide-freebsd-x64": "4.1.15", + "@tailwindcss/oxide-linux-arm-gnueabihf": "4.1.15", + "@tailwindcss/oxide-linux-arm64-gnu": "4.1.15", + "@tailwindcss/oxide-linux-arm64-musl": "4.1.15", + "@tailwindcss/oxide-linux-x64-gnu": "4.1.15", + "@tailwindcss/oxide-linux-x64-musl": "4.1.15", + "@tailwindcss/oxide-wasm32-wasi": "4.1.15", + "@tailwindcss/oxide-win32-arm64-msvc": "4.1.15", + "@tailwindcss/oxide-win32-x64-msvc": "4.1.15" + } + }, + "node_modules/@tailwindcss/oxide-android-arm64": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-android-arm64/-/oxide-android-arm64-4.1.15.tgz", + "integrity": "sha512-TkUkUgAw8At4cBjCeVCRMc/guVLKOU1D+sBPrHt5uVcGhlbVKxrCaCW9OKUIBv1oWkjh4GbunD/u/Mf0ql6kEA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-darwin-arm64": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-arm64/-/oxide-darwin-arm64-4.1.15.tgz", + "integrity": "sha512-xt5XEJpn2piMSfvd1UFN6jrWXyaKCwikP4Pidcf+yfHTSzSpYhG3dcMktjNkQO3JiLCp+0bG0HoWGvz97K162w==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-darwin-x64": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-x64/-/oxide-darwin-x64-4.1.15.tgz", + "integrity": "sha512-TnWaxP6Bx2CojZEXAV2M01Yl13nYPpp0EtGpUrY+LMciKfIXiLL2r/SiSRpagE5Fp2gX+rflp/Os1VJDAyqymg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-freebsd-x64": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-freebsd-x64/-/oxide-freebsd-x64-4.1.15.tgz", + "integrity": "sha512-quISQDWqiB6Cqhjc3iWptXVZHNVENsWoI77L1qgGEHNIdLDLFnw3/AfY7DidAiiCIkGX/MjIdB3bbBZR/G2aJg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm-gnueabihf": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm-gnueabihf/-/oxide-linux-arm-gnueabihf-4.1.15.tgz", + "integrity": "sha512-ObG76+vPlab65xzVUQbExmDU9FIeYLQ5k2LrQdR2Ud6hboR+ZobXpDoKEYXf/uOezOfIYmy2Ta3w0ejkTg9yxg==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm64-gnu": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-gnu/-/oxide-linux-arm64-gnu-4.1.15.tgz", + "integrity": "sha512-4WbBacRmk43pkb8/xts3wnOZMDKsPFyEH/oisCm2q3aLZND25ufvJKcDUpAu0cS+CBOL05dYa8D4U5OWECuH/Q==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-arm64-musl": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-musl/-/oxide-linux-arm64-musl-4.1.15.tgz", + "integrity": "sha512-AbvmEiteEj1nf42nE8skdHv73NoR+EwXVSgPY6l39X12Ex8pzOwwfi3Kc8GAmjsnsaDEbk+aj9NyL3UeyHcTLg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-x64-gnu": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-gnu/-/oxide-linux-x64-gnu-4.1.15.tgz", + "integrity": "sha512-+rzMVlvVgrXtFiS+ES78yWgKqpThgV19ISKD58Ck+YO5pO5KjyxLt7AWKsWMbY0R9yBDC82w6QVGz837AKQcHg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-linux-x64-musl": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-musl/-/oxide-linux-x64-musl-4.1.15.tgz", + "integrity": "sha512-fPdEy7a8eQN9qOIK3Em9D3TO1z41JScJn8yxl/76mp4sAXFDfV4YXxsiptJcOwy6bGR+70ZSwFIZhTXzQeqwQg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-wasm32-wasi": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-wasm32-wasi/-/oxide-wasm32-wasi-4.1.15.tgz", + "integrity": "sha512-sJ4yd6iXXdlgIMfIBXuVGp/NvmviEoMVWMOAGxtxhzLPp9LOj5k0pMEMZdjeMCl4C6Up+RM8T3Zgk+BMQ0bGcQ==", + "bundleDependencies": [ + "@napi-rs/wasm-runtime", + "@emnapi/core", + "@emnapi/runtime", + "@tybys/wasm-util", + "@emnapi/wasi-threads", + "tslib" + ], + "cpu": [ + "wasm32" + ], + "license": "MIT", + "optional": true, + "dependencies": { + "@emnapi/core": "^1.5.0", + "@emnapi/runtime": "^1.5.0", + "@emnapi/wasi-threads": "^1.1.0", + "@napi-rs/wasm-runtime": "^1.0.7", + "@tybys/wasm-util": "^0.10.1", + "tslib": "^2.4.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@tailwindcss/oxide-win32-arm64-msvc": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.15.tgz", + "integrity": "sha512-sJGE5faXnNQ1iXeqmRin7Ds/ru2fgCiaQZQQz3ZGIDtvbkeV85rAZ0QJFMDg0FrqsffZG96H1U9AQlNBRLsHVg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/oxide-win32-x64-msvc": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-x64-msvc/-/oxide-win32-x64-msvc-4.1.15.tgz", + "integrity": "sha512-NLeHE7jUV6HcFKS504bpOohyi01zPXi2PXmjFfkzTph8xRxDdxkRsXm/xDO5uV5K3brrE1cCwbUYmFUSHR3u1w==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@tailwindcss/vite": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/@tailwindcss/vite/-/vite-4.1.15.tgz", + "integrity": "sha512-B6s60MZRTUil+xKoZoGe6i0Iar5VuW+pmcGlda2FX+guDuQ1G1sjiIy1W0frneVpeL/ZjZ4KEgWZHNrIm++2qA==", + "license": "MIT", + "dependencies": { + "@tailwindcss/node": "4.1.15", + "@tailwindcss/oxide": "4.1.15", + "tailwindcss": "4.1.15" + }, + "peerDependencies": { + "vite": "^5.2.0 || ^6 || ^7" + } + }, + "node_modules/@tanstack/virtual-core": { + "version": "3.13.12", + "resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.13.12.tgz", + "integrity": "sha512-1YBOJfRHV4sXUmWsFSf5rQor4Ss82G8dQWLRbnk3GA4jeP8hQt1hxXh0tmflpC0dz3VgEv/1+qwPyLeWkQuPFA==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + } + }, + "node_modules/@tanstack/vue-virtual": { + "version": "3.13.12", + "resolved": "https://registry.npmjs.org/@tanstack/vue-virtual/-/vue-virtual-3.13.12.tgz", + "integrity": "sha512-vhF7kEU9EXWXh+HdAwKJ2m3xaOnTTmgcdXcF2pim8g4GvI7eRrk2YRuV5nUlZnd/NbCIX4/Ja2OZu5EjJL06Ww==", + "license": "MIT", + "dependencies": { + "@tanstack/virtual-core": "3.13.12" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "vue": "^2.7.0 || ^3.0.0" + } + }, + "node_modules/@types/estree": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", + "license": "MIT" + }, + "node_modules/@vitejs/plugin-vue": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/@vitejs/plugin-vue/-/plugin-vue-6.0.1.tgz", + "integrity": "sha512-+MaE752hU0wfPFJEUAIxqw18+20euHHdxVtMvbFcOEpjEyfqXH/5DCoTHiVJ0J29EhTJdoTkjEv5YBKU9dnoTw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@rolldown/pluginutils": "1.0.0-beta.29" + }, + "engines": { + "node": "^20.19.0 || >=22.12.0" + }, + "peerDependencies": { + "vite": "^5.0.0 || ^6.0.0 || ^7.0.0", + "vue": "^3.2.25" + } + }, + "node_modules/@vue/compiler-core": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/compiler-core/-/compiler-core-3.5.22.tgz", + "integrity": "sha512-jQ0pFPmZwTEiRNSb+i9Ow/I/cHv2tXYqsnHKKyCQ08irI2kdF5qmYedmF8si8mA7zepUFmJ2hqzS8CQmNOWOkQ==", + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.28.4", + "@vue/shared": "3.5.22", + "entities": "^4.5.0", + "estree-walker": "^2.0.2", + "source-map-js": "^1.2.1" + } + }, + "node_modules/@vue/compiler-dom": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/compiler-dom/-/compiler-dom-3.5.22.tgz", + "integrity": "sha512-W8RknzUM1BLkypvdz10OVsGxnMAuSIZs9Wdx1vzA3mL5fNMN15rhrSCLiTm6blWeACwUwizzPVqGJgOGBEN/hA==", + "license": "MIT", + "dependencies": { + "@vue/compiler-core": "3.5.22", + "@vue/shared": "3.5.22" + } + }, + "node_modules/@vue/compiler-sfc": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/compiler-sfc/-/compiler-sfc-3.5.22.tgz", + "integrity": "sha512-tbTR1zKGce4Lj+JLzFXDq36K4vcSZbJ1RBu8FxcDv1IGRz//Dh2EBqksyGVypz3kXpshIfWKGOCcqpSbyGWRJQ==", + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.28.4", + "@vue/compiler-core": "3.5.22", + "@vue/compiler-dom": "3.5.22", + "@vue/compiler-ssr": "3.5.22", + "@vue/shared": "3.5.22", + "estree-walker": "^2.0.2", + "magic-string": "^0.30.19", + "postcss": "^8.5.6", + "source-map-js": "^1.2.1" + } + }, + "node_modules/@vue/compiler-ssr": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/compiler-ssr/-/compiler-ssr-3.5.22.tgz", + "integrity": "sha512-GdgyLvg4R+7T8Nk2Mlighx7XGxq/fJf9jaVofc3IL0EPesTE86cP/8DD1lT3h1JeZr2ySBvyqKQJgbS54IX1Ww==", + "license": "MIT", + "dependencies": { + "@vue/compiler-dom": "3.5.22", + "@vue/shared": "3.5.22" + } + }, + "node_modules/@vue/devtools-api": { + "version": "7.7.7", + "resolved": "https://registry.npmjs.org/@vue/devtools-api/-/devtools-api-7.7.7.tgz", + "integrity": "sha512-lwOnNBH2e7x1fIIbVT7yF5D+YWhqELm55/4ZKf45R9T8r9dE2AIOy8HKjfqzGsoTHFbWbr337O4E0A0QADnjBg==", + "license": "MIT", + "dependencies": { + "@vue/devtools-kit": "^7.7.7" + } + }, + "node_modules/@vue/devtools-kit": { + "version": "7.7.7", + "resolved": "https://registry.npmjs.org/@vue/devtools-kit/-/devtools-kit-7.7.7.tgz", + "integrity": "sha512-wgoZtxcTta65cnZ1Q6MbAfePVFxfM+gq0saaeytoph7nEa7yMXoi6sCPy4ufO111B9msnw0VOWjPEFCXuAKRHA==", + "license": "MIT", + "dependencies": { + "@vue/devtools-shared": "^7.7.7", + "birpc": "^2.3.0", + "hookable": "^5.5.3", + "mitt": "^3.0.1", + "perfect-debounce": "^1.0.0", + "speakingurl": "^14.0.1", + "superjson": "^2.2.2" + } + }, + "node_modules/@vue/devtools-shared": { + "version": "7.7.7", + "resolved": "https://registry.npmjs.org/@vue/devtools-shared/-/devtools-shared-7.7.7.tgz", + "integrity": "sha512-+udSj47aRl5aKb0memBvcUG9koarqnxNM5yjuREvqwK6T3ap4mn3Zqqc17QrBFTqSMjr3HK1cvStEZpMDpfdyw==", + "license": "MIT", + "dependencies": { + "rfdc": "^1.4.1" + } + }, + "node_modules/@vue/reactivity": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/reactivity/-/reactivity-3.5.22.tgz", + "integrity": "sha512-f2Wux4v/Z2pqc9+4SmgZC1p73Z53fyD90NFWXiX9AKVnVBEvLFOWCEgJD3GdGnlxPZt01PSlfmLqbLYzY/Fw4A==", + "license": "MIT", + "dependencies": { + "@vue/shared": "3.5.22" + } + }, + "node_modules/@vue/runtime-core": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/runtime-core/-/runtime-core-3.5.22.tgz", + "integrity": "sha512-EHo4W/eiYeAzRTN5PCextDUZ0dMs9I8mQ2Fy+OkzvRPUYQEyK9yAjbasrMCXbLNhF7P0OUyivLjIy0yc6VrLJQ==", + "license": "MIT", + "dependencies": { + "@vue/reactivity": "3.5.22", + "@vue/shared": "3.5.22" + } + }, + "node_modules/@vue/runtime-dom": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/runtime-dom/-/runtime-dom-3.5.22.tgz", + "integrity": "sha512-Av60jsryAkI023PlN7LsqrfPvwfxOd2yAwtReCjeuugTJTkgrksYJJstg1e12qle0NarkfhfFu1ox2D+cQotww==", + "license": "MIT", + "dependencies": { + "@vue/reactivity": "3.5.22", + "@vue/runtime-core": "3.5.22", + "@vue/shared": "3.5.22", + "csstype": "^3.1.3" + } + }, + "node_modules/@vue/server-renderer": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/server-renderer/-/server-renderer-3.5.22.tgz", + "integrity": "sha512-gXjo+ao0oHYTSswF+a3KRHZ1WszxIqO7u6XwNHqcqb9JfyIL/pbWrrh/xLv7jeDqla9u+LK7yfZKHih1e1RKAQ==", + "license": "MIT", + "dependencies": { + "@vue/compiler-ssr": "3.5.22", + "@vue/shared": "3.5.22" + }, + "peerDependencies": { + "vue": "3.5.22" + } + }, + "node_modules/@vue/shared": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/@vue/shared/-/shared-3.5.22.tgz", + "integrity": "sha512-F4yc6palwq3TT0u+FYf0Ns4Tfl9GRFURDN2gWG7L1ecIaS/4fCIuFOjMTnCyjsu/OK6vaDKLCrGAa+KvvH+h4w==", + "license": "MIT" + }, + "node_modules/birpc": { + "version": "2.6.1", + "resolved": "https://registry.npmjs.org/birpc/-/birpc-2.6.1.tgz", + "integrity": "sha512-LPnFhlDpdSH6FJhJyn4M0kFO7vtQ5iPw24FnG0y21q09xC7e8+1LeR31S1MAIrDAHp4m7aas4bEkTDTvMAtebQ==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/antfu" + } + }, + "node_modules/copy-anything": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/copy-anything/-/copy-anything-3.0.5.tgz", + "integrity": "sha512-yCEafptTtb4bk7GLEQoM8KVJpxAfdBJYaXyzQEgQQQgYrZiDp8SJmGKlYza6CYjEDNstAdNdKA3UuoULlEbS6w==", + "license": "MIT", + "dependencies": { + "is-what": "^4.1.8" + }, + "engines": { + "node": ">=12.13" + }, + "funding": { + "url": "https://github.com/sponsors/mesqueeb" + } + }, + "node_modules/csstype": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", + "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==", + "license": "MIT" + }, + "node_modules/detect-libc": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.1.2.tgz", + "integrity": "sha512-Btj2BOOO83o3WyH59e8MgXsxEQVcarkUOpEYrubB0urwnN10yQ364rsiByU11nZlqWYZm05i/of7io4mzihBtQ==", + "license": "Apache-2.0", + "engines": { + "node": ">=8" + } + }, + "node_modules/enhanced-resolve": { + "version": "5.18.3", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.3.tgz", + "integrity": "sha512-d4lC8xfavMeBjzGr2vECC3fsGXziXZQyJxD868h2M/mBI3PwAuODxAkLkq5HYuvrPYcUtiLzsTo8U3PgX3Ocww==", + "license": "MIT", + "dependencies": { + "graceful-fs": "^4.2.4", + "tapable": "^2.2.0" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/entities": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/entities/-/entities-4.5.0.tgz", + "integrity": "sha512-V0hjH4dGPh9Ao5p0MoRY6BVqtwCjhz6vI5LT8AJ55H+4g9/4vbHx1I54fS0XuclLhDHArPQCiMjDxjaL8fPxhw==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.12" + }, + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, + "node_modules/esbuild": { + "version": "0.25.11", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.11.tgz", + "integrity": "sha512-KohQwyzrKTQmhXDW1PjCv3Tyspn9n5GcY2RTDqeORIdIJY8yKIF7sTSopFmn/wpMPW4rdPXI0UE5LJLuq3bx0Q==", + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.25.11", + "@esbuild/android-arm": "0.25.11", + "@esbuild/android-arm64": "0.25.11", + "@esbuild/android-x64": "0.25.11", + "@esbuild/darwin-arm64": "0.25.11", + "@esbuild/darwin-x64": "0.25.11", + "@esbuild/freebsd-arm64": "0.25.11", + "@esbuild/freebsd-x64": "0.25.11", + "@esbuild/linux-arm": "0.25.11", + "@esbuild/linux-arm64": "0.25.11", + "@esbuild/linux-ia32": "0.25.11", + "@esbuild/linux-loong64": "0.25.11", + "@esbuild/linux-mips64el": "0.25.11", + "@esbuild/linux-ppc64": "0.25.11", + "@esbuild/linux-riscv64": "0.25.11", + "@esbuild/linux-s390x": "0.25.11", + "@esbuild/linux-x64": "0.25.11", + "@esbuild/netbsd-arm64": "0.25.11", + "@esbuild/netbsd-x64": "0.25.11", + "@esbuild/openbsd-arm64": "0.25.11", + "@esbuild/openbsd-x64": "0.25.11", + "@esbuild/openharmony-arm64": "0.25.11", + "@esbuild/sunos-x64": "0.25.11", + "@esbuild/win32-arm64": "0.25.11", + "@esbuild/win32-ia32": "0.25.11", + "@esbuild/win32-x64": "0.25.11" + } + }, + "node_modules/esbuild-linux-64": { + "version": "0.14.54", + "resolved": "https://registry.npmjs.org/esbuild-linux-64/-/esbuild-linux-64-0.14.54.tgz", + "integrity": "sha512-EgjAgH5HwTbtNsTqQOXWApBaPVdDn7XcK+/PtJwZLT1UmpLoznPd8c5CxqsH2dQK3j05YsB3L17T8vE7cp4cCg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/estree-walker": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", + "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", + "license": "MIT" + }, + "node_modules/fdir": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.5.0.tgz", + "integrity": "sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==", + "license": "MIT", + "engines": { + "node": ">=12.0.0" + }, + "peerDependencies": { + "picomatch": "^3 || ^4" + }, + "peerDependenciesMeta": { + "picomatch": { + "optional": true + } + } + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "license": "ISC" + }, + "node_modules/hookable": { + "version": "5.5.3", + "resolved": "https://registry.npmjs.org/hookable/-/hookable-5.5.3.tgz", + "integrity": "sha512-Yc+BQe8SvoXH1643Qez1zqLRmbA5rCL+sSmk6TVos0LWVfNIB7PGncdlId77WzLGSIB5KaWgTaNTs2lNVEI6VQ==", + "license": "MIT" + }, + "node_modules/is-what": { + "version": "4.1.16", + "resolved": "https://registry.npmjs.org/is-what/-/is-what-4.1.16.tgz", + "integrity": "sha512-ZhMwEosbFJkA0YhFnNDgTM4ZxDRsS6HqTo7qsZM08fehyRYIYa0yHu5R6mgo1n/8MgaPBXiPimPD77baVFYg+A==", + "license": "MIT", + "engines": { + "node": ">=12.13" + }, + "funding": { + "url": "https://github.com/sponsors/mesqueeb" + } + }, + "node_modules/jiti": { + "version": "2.6.1", + "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.6.1.tgz", + "integrity": "sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==", + "license": "MIT", + "bin": { + "jiti": "lib/jiti-cli.mjs" + } + }, + "node_modules/lightningcss": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss/-/lightningcss-1.30.2.tgz", + "integrity": "sha512-utfs7Pr5uJyyvDETitgsaqSyjCb2qNRAtuqUeWIAKztsOYdcACf2KtARYXg2pSvhkt+9NfoaNY7fxjl6nuMjIQ==", + "license": "MPL-2.0", + "dependencies": { + "detect-libc": "^2.0.3" + }, + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + }, + "optionalDependencies": { + "lightningcss-android-arm64": "1.30.2", + "lightningcss-darwin-arm64": "1.30.2", + "lightningcss-darwin-x64": "1.30.2", + "lightningcss-freebsd-x64": "1.30.2", + "lightningcss-linux-arm-gnueabihf": "1.30.2", + "lightningcss-linux-arm64-gnu": "1.30.2", + "lightningcss-linux-arm64-musl": "1.30.2", + "lightningcss-linux-x64-gnu": "1.30.2", + "lightningcss-linux-x64-musl": "1.30.2", + "lightningcss-win32-arm64-msvc": "1.30.2", + "lightningcss-win32-x64-msvc": "1.30.2" + } + }, + "node_modules/lightningcss-android-arm64": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-android-arm64/-/lightningcss-android-arm64-1.30.2.tgz", + "integrity": "sha512-BH9sEdOCahSgmkVhBLeU7Hc9DWeZ1Eb6wNS6Da8igvUwAe0sqROHddIlvU06q3WyXVEOYDZ6ykBZQnjTbmo4+A==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-darwin-arm64": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-arm64/-/lightningcss-darwin-arm64-1.30.2.tgz", + "integrity": "sha512-ylTcDJBN3Hp21TdhRT5zBOIi73P6/W0qwvlFEk22fkdXchtNTOU4Qc37SkzV+EKYxLouZ6M4LG9NfZ1qkhhBWA==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-darwin-x64": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-x64/-/lightningcss-darwin-x64-1.30.2.tgz", + "integrity": "sha512-oBZgKchomuDYxr7ilwLcyms6BCyLn0z8J0+ZZmfpjwg9fRVZIR5/GMXd7r9RH94iDhld3UmSjBM6nXWM2TfZTQ==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-freebsd-x64": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-freebsd-x64/-/lightningcss-freebsd-x64-1.30.2.tgz", + "integrity": "sha512-c2bH6xTrf4BDpK8MoGG4Bd6zAMZDAXS569UxCAGcA7IKbHNMlhGQ89eRmvpIUGfKWNVdbhSbkQaWhEoMGmGslA==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm-gnueabihf": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm-gnueabihf/-/lightningcss-linux-arm-gnueabihf-1.30.2.tgz", + "integrity": "sha512-eVdpxh4wYcm0PofJIZVuYuLiqBIakQ9uFZmipf6LF/HRj5Bgm0eb3qL/mr1smyXIS1twwOxNWndd8z0E374hiA==", + "cpu": [ + "arm" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm64-gnu": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-gnu/-/lightningcss-linux-arm64-gnu-1.30.2.tgz", + "integrity": "sha512-UK65WJAbwIJbiBFXpxrbTNArtfuznvxAJw4Q2ZGlU8kPeDIWEX1dg3rn2veBVUylA2Ezg89ktszWbaQnxD/e3A==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-arm64-musl": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-musl/-/lightningcss-linux-arm64-musl-1.30.2.tgz", + "integrity": "sha512-5Vh9dGeblpTxWHpOx8iauV02popZDsCYMPIgiuw97OJ5uaDsL86cnqSFs5LZkG3ghHoX5isLgWzMs+eD1YzrnA==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-x64-gnu": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-gnu/-/lightningcss-linux-x64-gnu-1.30.2.tgz", + "integrity": "sha512-Cfd46gdmj1vQ+lR6VRTTadNHu6ALuw2pKR9lYq4FnhvgBc4zWY1EtZcAc6EffShbb1MFrIPfLDXD6Xprbnni4w==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-linux-x64-musl": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-musl/-/lightningcss-linux-x64-musl-1.30.2.tgz", + "integrity": "sha512-XJaLUUFXb6/QG2lGIW6aIk6jKdtjtcffUT0NKvIqhSBY3hh9Ch+1LCeH80dR9q9LBjG3ewbDjnumefsLsP6aiA==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-win32-arm64-msvc": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-win32-arm64-msvc/-/lightningcss-win32-arm64-msvc-1.30.2.tgz", + "integrity": "sha512-FZn+vaj7zLv//D/192WFFVA0RgHawIcHqLX9xuWiQt7P0PtdFEVaxgF9rjM/IRYHQXNnk61/H/gb2Ei+kUQ4xQ==", + "cpu": [ + "arm64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/lightningcss-win32-x64-msvc": { + "version": "1.30.2", + "resolved": "https://registry.npmjs.org/lightningcss-win32-x64-msvc/-/lightningcss-win32-x64-msvc-1.30.2.tgz", + "integrity": "sha512-5g1yc73p+iAkid5phb4oVFMB45417DkRevRbt/El/gKXJk4jid+vPFF/AXbxn05Aky8PapwzZrdJShv5C0avjw==", + "cpu": [ + "x64" + ], + "license": "MPL-2.0", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/parcel" + } + }, + "node_modules/magic-string": { + "version": "0.30.19", + "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.19.tgz", + "integrity": "sha512-2N21sPY9Ws53PZvsEpVtNuSW+ScYbQdp4b9qUaL+9QkHUrGFKo56Lg9Emg5s9V/qrtNBmiR01sYhUOwu3H+VOw==", + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.5" + } + }, + "node_modules/mitt": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mitt/-/mitt-3.0.1.tgz", + "integrity": "sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==", + "license": "MIT" + }, + "node_modules/nanoid": { + "version": "3.3.11", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.11.tgz", + "integrity": "sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/perfect-debounce": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/perfect-debounce/-/perfect-debounce-1.0.0.tgz", + "integrity": "sha512-xCy9V055GLEqoFaHoC1SoLIaLmWctgCUaBaWxDZ7/Zx4CTyX7cJQLJOok/orfjZAh9kEYpjJa4d0KcJmCbctZA==", + "license": "MIT" + }, + "node_modules/picocolors": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", + "license": "ISC" + }, + "node_modules/picomatch": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/pinia": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/pinia/-/pinia-3.0.3.tgz", + "integrity": "sha512-ttXO/InUULUXkMHpTdp9Fj4hLpD/2AoJdmAbAeW2yu1iy1k+pkFekQXw5VpC0/5p51IOR/jDaDRfRWRnMMsGOA==", + "license": "MIT", + "dependencies": { + "@vue/devtools-api": "^7.7.2" + }, + "funding": { + "url": "https://github.com/sponsors/posva" + }, + "peerDependencies": { + "typescript": ">=4.4.4", + "vue": "^2.7.0 || ^3.5.11" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/postcss": { + "version": "8.5.6", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz", + "integrity": "sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "nanoid": "^3.3.11", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/rfdc": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/rfdc/-/rfdc-1.4.1.tgz", + "integrity": "sha512-q1b3N5QkRUWUl7iyylaaj3kOpIT0N2i9MqIEQXP73GVsN9cw3fdx8X63cEmWhJGi2PPCF23Ijp7ktmd39rawIA==", + "license": "MIT" + }, + "node_modules/rollup": { + "version": "4.52.5", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.52.5.tgz", + "integrity": "sha512-3GuObel8h7Kqdjt0gxkEzaifHTqLVW56Y/bjN7PSQtkKr0w3V/QYSdt6QWYtd7A1xUtYQigtdUfgj1RvWVtorw==", + "license": "MIT", + "dependencies": { + "@types/estree": "1.0.8" + }, + "bin": { + "rollup": "dist/bin/rollup" + }, + "engines": { + "node": ">=18.0.0", + "npm": ">=8.0.0" + }, + "optionalDependencies": { + "@rollup/rollup-android-arm-eabi": "4.52.5", + "@rollup/rollup-android-arm64": "4.52.5", + "@rollup/rollup-darwin-arm64": "4.52.5", + "@rollup/rollup-darwin-x64": "4.52.5", + "@rollup/rollup-freebsd-arm64": "4.52.5", + "@rollup/rollup-freebsd-x64": "4.52.5", + "@rollup/rollup-linux-arm-gnueabihf": "4.52.5", + "@rollup/rollup-linux-arm-musleabihf": "4.52.5", + "@rollup/rollup-linux-arm64-gnu": "4.52.5", + "@rollup/rollup-linux-arm64-musl": "4.52.5", + "@rollup/rollup-linux-loong64-gnu": "4.52.5", + "@rollup/rollup-linux-ppc64-gnu": "4.52.5", + "@rollup/rollup-linux-riscv64-gnu": "4.52.5", + "@rollup/rollup-linux-riscv64-musl": "4.52.5", + "@rollup/rollup-linux-s390x-gnu": "4.52.5", + "@rollup/rollup-linux-x64-gnu": "4.52.5", + "@rollup/rollup-linux-x64-musl": "4.52.5", + "@rollup/rollup-openharmony-arm64": "4.52.5", + "@rollup/rollup-win32-arm64-msvc": "4.52.5", + "@rollup/rollup-win32-ia32-msvc": "4.52.5", + "@rollup/rollup-win32-x64-gnu": "4.52.5", + "@rollup/rollup-win32-x64-msvc": "4.52.5", + "fsevents": "~2.3.2" + } + }, + "node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/speakingurl": { + "version": "14.0.1", + "resolved": "https://registry.npmjs.org/speakingurl/-/speakingurl-14.0.1.tgz", + "integrity": "sha512-1POYv7uv2gXoyGFpBCmpDVSNV74IfsWlDW216UPjbWufNf+bSU6GdbDsxdcxtfwb4xlI3yxzOTKClUosxARYrQ==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/superjson": { + "version": "2.2.2", + "resolved": "https://registry.npmjs.org/superjson/-/superjson-2.2.2.tgz", + "integrity": "sha512-5JRxVqC8I8NuOUjzBbvVJAKNM8qoVuH0O77h4WInc/qC2q5IreqKxYwgkga3PfA22OayK2ikceb/B26dztPl+Q==", + "license": "MIT", + "dependencies": { + "copy-anything": "^3.0.2" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/tailwindcss": { + "version": "4.1.15", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.15.tgz", + "integrity": "sha512-k2WLnWkYFkdpRv+Oby3EBXIyQC8/s1HOFMBUViwtAh6Z5uAozeUSMQlIsn/c6Q2iJzqG6aJT3wdPaRNj70iYxQ==", + "license": "MIT" + }, + "node_modules/tapable": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", + "license": "MIT", + "engines": { + "node": ">=6" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/tinyglobby": { + "version": "0.2.15", + "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", + "integrity": "sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ==", + "license": "MIT", + "dependencies": { + "fdir": "^6.5.0", + "picomatch": "^4.0.3" + }, + "engines": { + "node": ">=12.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/SuperchupuDev" + } + }, + "node_modules/vite": { + "version": "7.1.11", + "resolved": "https://registry.npmjs.org/vite/-/vite-7.1.11.tgz", + "integrity": "sha512-uzcxnSDVjAopEUjljkWh8EIrg6tlzrjFUfMcR1EVsRDGwf/ccef0qQPRyOrROwhrTDaApueq+ja+KLPlzR/zdg==", + "license": "MIT", + "dependencies": { + "esbuild": "^0.25.0", + "fdir": "^6.5.0", + "picomatch": "^4.0.3", + "postcss": "^8.5.6", + "rollup": "^4.43.0", + "tinyglobby": "^0.2.15" + }, + "bin": { + "vite": "bin/vite.js" + }, + "engines": { + "node": "^20.19.0 || >=22.12.0" + }, + "funding": { + "url": "https://github.com/vitejs/vite?sponsor=1" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + }, + "peerDependencies": { + "@types/node": "^20.19.0 || >=22.12.0", + "jiti": ">=1.21.0", + "less": "^4.0.0", + "lightningcss": "^1.21.0", + "sass": "^1.70.0", + "sass-embedded": "^1.70.0", + "stylus": ">=0.54.8", + "sugarss": "^5.0.0", + "terser": "^5.16.0", + "tsx": "^4.8.1", + "yaml": "^2.4.2" + }, + "peerDependenciesMeta": { + "@types/node": { + "optional": true + }, + "jiti": { + "optional": true + }, + "less": { + "optional": true + }, + "lightningcss": { + "optional": true + }, + "sass": { + "optional": true + }, + "sass-embedded": { + "optional": true + }, + "stylus": { + "optional": true + }, + "sugarss": { + "optional": true + }, + "terser": { + "optional": true + }, + "tsx": { + "optional": true + }, + "yaml": { + "optional": true + } + } + }, + "node_modules/vue": { + "version": "3.5.22", + "resolved": "https://registry.npmjs.org/vue/-/vue-3.5.22.tgz", + "integrity": "sha512-toaZjQ3a/G/mYaLSbV+QsQhIdMo9x5rrqIpYRObsJ6T/J+RyCSFwN2LHNVH9v8uIcljDNa3QzPVdv3Y6b9hAJQ==", + "license": "MIT", + "dependencies": { + "@vue/compiler-dom": "3.5.22", + "@vue/compiler-sfc": "3.5.22", + "@vue/runtime-dom": "3.5.22", + "@vue/server-renderer": "3.5.22", + "@vue/shared": "3.5.22" + }, + "peerDependencies": { + "typescript": "*" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/vue-i18n": { + "version": "11.1.12", + "resolved": "https://registry.npmjs.org/vue-i18n/-/vue-i18n-11.1.12.tgz", + "integrity": "sha512-BnstPj3KLHLrsqbVU2UOrPmr0+Mv11bsUZG0PyCOzsawCivk8W00GMXHeVUWIDOgNaScCuZah47CZFE+Wnl8mw==", + "license": "MIT", + "dependencies": { + "@intlify/core-base": "11.1.12", + "@intlify/shared": "11.1.12", + "@vue/devtools-api": "^6.5.0" + }, + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://github.com/sponsors/kazupon" + }, + "peerDependencies": { + "vue": "^3.0.0" + } + }, + "node_modules/vue-i18n/node_modules/@vue/devtools-api": { + "version": "6.6.4", + "resolved": "https://registry.npmjs.org/@vue/devtools-api/-/devtools-api-6.6.4.tgz", + "integrity": "sha512-sGhTPMuXqZ1rVOk32RylztWkfXTRhuS7vgAKv0zjqk8gbsHkJ7xfFf+jbySxt7tWObEJwyKaHMikV/WGDiQm8g==", + "license": "MIT" + }, + "node_modules/vue-router": { + "version": "4.6.3", + "resolved": "https://registry.npmjs.org/vue-router/-/vue-router-4.6.3.tgz", + "integrity": "sha512-ARBedLm9YlbvQomnmq91Os7ck6efydTSpRP3nuOKCvgJOHNrhRoJDSKtee8kcL1Vf7nz6U+PMBL+hTvR3bTVQg==", + "license": "MIT", + "dependencies": { + "@vue/devtools-api": "^6.6.4" + }, + "funding": { + "url": "https://github.com/sponsors/posva" + }, + "peerDependencies": { + "vue": "^3.5.0" + } + }, + "node_modules/vue-router/node_modules/@vue/devtools-api": { + "version": "6.6.4", + "resolved": "https://registry.npmjs.org/@vue/devtools-api/-/devtools-api-6.6.4.tgz", + "integrity": "sha512-sGhTPMuXqZ1rVOk32RylztWkfXTRhuS7vgAKv0zjqk8gbsHkJ7xfFf+jbySxt7tWObEJwyKaHMikV/WGDiQm8g==", + "license": "MIT" + } + } +} diff --git a/lightx2v/deploy/server/frontend/package.json b/lightx2v/deploy/server/frontend/package.json new file mode 100644 index 0000000000000000000000000000000000000000..a73da99376df363063ab3b7156dbaeb6bdb232ad --- /dev/null +++ b/lightx2v/deploy/server/frontend/package.json @@ -0,0 +1,26 @@ +{ + "name": "frontend", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview" + }, + "dependencies": { + "@flaticon/flaticon-uicons": "^3.3.1", + "@headlessui/vue": "^1.7.23", + "@heroicons/vue": "^2.2.0", + "@tailwindcss/vite": "^4.1.13", + "pinia": "^3.0.3", + "tailwindcss": "^4.1.13", + "vue": "^3.5.21", + "vue-i18n": "^11.1.12", + "vue-router": "^4.5.1" + }, + "devDependencies": { + "@vitejs/plugin-vue": "^6.0.1", + "vite": "^7.1.7" + } +} diff --git a/lightx2v/deploy/server/frontend/public/cover.png b/lightx2v/deploy/server/frontend/public/cover.png new file mode 100644 index 0000000000000000000000000000000000000000..75d5f2dc8ba46e8c4785e4658d00f60a76d15f7e Binary files /dev/null and b/lightx2v/deploy/server/frontend/public/cover.png differ diff --git a/lightx2v/deploy/server/frontend/public/female.svg b/lightx2v/deploy/server/frontend/public/female.svg new file mode 100644 index 0000000000000000000000000000000000000000..5acbf50e191c113f05f57b7bb135f7b5aec6986e --- /dev/null +++ b/lightx2v/deploy/server/frontend/public/female.svg @@ -0,0 +1,3 @@ + + + diff --git a/lightx2v/deploy/server/frontend/public/logo.svg b/lightx2v/deploy/server/frontend/public/logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..fc1c012f2c82c45035779371988a8d115b4e81f5 --- /dev/null +++ b/lightx2v/deploy/server/frontend/public/logo.svg @@ -0,0 +1 @@ + diff --git a/lightx2v/deploy/server/frontend/public/logo_black.png b/lightx2v/deploy/server/frontend/public/logo_black.png new file mode 100644 index 0000000000000000000000000000000000000000..71f1e37eb0b65a8da9280f76abf907fb38a8c467 Binary files /dev/null and b/lightx2v/deploy/server/frontend/public/logo_black.png differ diff --git a/lightx2v/deploy/server/frontend/public/logo_black.svg b/lightx2v/deploy/server/frontend/public/logo_black.svg new file mode 100644 index 0000000000000000000000000000000000000000..6b48e31019f68919c65266eda836199122a9855e --- /dev/null +++ b/lightx2v/deploy/server/frontend/public/logo_black.svg @@ -0,0 +1 @@ + diff --git a/lightx2v/deploy/server/frontend/public/male.svg b/lightx2v/deploy/server/frontend/public/male.svg new file mode 100644 index 0000000000000000000000000000000000000000..5a8e879db7f1149381a0fe89acd962b125c029da --- /dev/null +++ b/lightx2v/deploy/server/frontend/public/male.svg @@ -0,0 +1,3 @@ + + + diff --git a/lightx2v/deploy/server/frontend/public/robots.txt b/lightx2v/deploy/server/frontend/public/robots.txt new file mode 100644 index 0000000000000000000000000000000000000000..f99277b958c6451b01be56231a31ece6f033b083 --- /dev/null +++ b/lightx2v/deploy/server/frontend/public/robots.txt @@ -0,0 +1,3 @@ +User-agent: * +Allow: / +Sitemap: https://x2v.light-ai.top/sitemap.xml diff --git a/lightx2v/deploy/server/frontend/public/sitemap.xml b/lightx2v/deploy/server/frontend/public/sitemap.xml new file mode 100644 index 0000000000000000000000000000000000000000..64ff8fd8ab19933c58bfc23d264dde0bcceeeba7 --- /dev/null +++ b/lightx2v/deploy/server/frontend/public/sitemap.xml @@ -0,0 +1,14 @@ + + + + https://x2v.light-ai.top/ + daily + 1.0 + + + https://x2v.light-ai.top/examples + + + https://x2v.light-ai.top/docs + + diff --git a/lightx2v/deploy/server/frontend/public/vite.svg b/lightx2v/deploy/server/frontend/public/vite.svg new file mode 100644 index 0000000000000000000000000000000000000000..ee9fadaf9c4a762ac0ec010ca16ce8fa39a09e56 --- /dev/null +++ b/lightx2v/deploy/server/frontend/public/vite.svg @@ -0,0 +1 @@ + diff --git a/lightx2v/deploy/server/frontend/src/App.vue b/lightx2v/deploy/server/frontend/src/App.vue new file mode 100644 index 0000000000000000000000000000000000000000..fa7a9e74224f069ffe8548b1ba3ec36657f5dad0 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/App.vue @@ -0,0 +1,125 @@ + + + diff --git a/lightx2v/deploy/server/frontend/src/components/Alert.vue b/lightx2v/deploy/server/frontend/src/components/Alert.vue new file mode 100644 index 0000000000000000000000000000000000000000..0d86e1acae7d5afc544722f87db5e6e0de2d0724 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/Alert.vue @@ -0,0 +1,311 @@ + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/AudioPreviewTest.vue b/lightx2v/deploy/server/frontend/src/components/AudioPreviewTest.vue new file mode 100644 index 0000000000000000000000000000000000000000..21ab03b89ab17cb01c0a655e7b5c2c14a7ec0183 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/AudioPreviewTest.vue @@ -0,0 +1,67 @@ + + + diff --git a/lightx2v/deploy/server/frontend/src/components/Confirm.vue b/lightx2v/deploy/server/frontend/src/components/Confirm.vue new file mode 100644 index 0000000000000000000000000000000000000000..1dfff450ac0360a7a79b7462897670bb6c725353 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/Confirm.vue @@ -0,0 +1,70 @@ + + diff --git a/lightx2v/deploy/server/frontend/src/components/DropdownMenu.vue b/lightx2v/deploy/server/frontend/src/components/DropdownMenu.vue new file mode 100644 index 0000000000000000000000000000000000000000..3c332594501a2fcb60f81f98ebf2f49fc814864d --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/DropdownMenu.vue @@ -0,0 +1,96 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/FloatingParticles.vue b/lightx2v/deploy/server/frontend/src/components/FloatingParticles.vue new file mode 100644 index 0000000000000000000000000000000000000000..2afd56b83835181229b71268cadaf80f74b6bed7 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/FloatingParticles.vue @@ -0,0 +1,73 @@ + + + diff --git a/lightx2v/deploy/server/frontend/src/components/Generate.vue b/lightx2v/deploy/server/frontend/src/components/Generate.vue new file mode 100644 index 0000000000000000000000000000000000000000..d2a92f273e30e33ac30b7e7528e63db8b01bd1f9 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/Generate.vue @@ -0,0 +1,3517 @@ + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/Inspirations.vue b/lightx2v/deploy/server/frontend/src/components/Inspirations.vue new file mode 100644 index 0000000000000000000000000000000000000000..9772301931bdbea687e403c052e7d9bde3deca10 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/Inspirations.vue @@ -0,0 +1,289 @@ + + diff --git a/lightx2v/deploy/server/frontend/src/components/LeftBar.vue b/lightx2v/deploy/server/frontend/src/components/LeftBar.vue new file mode 100644 index 0000000000000000000000000000000000000000..ca678b2a335215769692bf42c4b5edeaf0ea8d6a --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/LeftBar.vue @@ -0,0 +1,56 @@ + + diff --git a/lightx2v/deploy/server/frontend/src/components/Loading.vue b/lightx2v/deploy/server/frontend/src/components/Loading.vue new file mode 100644 index 0000000000000000000000000000000000000000..41d0e6eaa59f1f666149cf46b580a06dd1f8ac98 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/Loading.vue @@ -0,0 +1,15 @@ + + diff --git a/lightx2v/deploy/server/frontend/src/components/LoginCard.vue b/lightx2v/deploy/server/frontend/src/components/LoginCard.vue new file mode 100644 index 0000000000000000000000000000000000000000..684d389ad4efbd6b21eb40d128b6c285be084da4 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/LoginCard.vue @@ -0,0 +1,124 @@ + + + diff --git a/lightx2v/deploy/server/frontend/src/components/MediaTemplate.vue b/lightx2v/deploy/server/frontend/src/components/MediaTemplate.vue new file mode 100644 index 0000000000000000000000000000000000000000..1a343e2905192f6def1697cb34af3624c374db67 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/MediaTemplate.vue @@ -0,0 +1,466 @@ + + + diff --git a/lightx2v/deploy/server/frontend/src/components/ModelDropdown.vue b/lightx2v/deploy/server/frontend/src/components/ModelDropdown.vue new file mode 100644 index 0000000000000000000000000000000000000000..bc2d12e3fc9ce8141612b4d570f7a696b63aa51d --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/ModelDropdown.vue @@ -0,0 +1,57 @@ + + + diff --git a/lightx2v/deploy/server/frontend/src/components/Projects.vue b/lightx2v/deploy/server/frontend/src/components/Projects.vue new file mode 100644 index 0000000000000000000000000000000000000000..ba7b0dbf3d1cc7fc741720f8e84aa30f98aa3938 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/Projects.vue @@ -0,0 +1,609 @@ + + diff --git a/lightx2v/deploy/server/frontend/src/components/PromptTemplate.vue b/lightx2v/deploy/server/frontend/src/components/PromptTemplate.vue new file mode 100644 index 0000000000000000000000000000000000000000..7f9b39b576e78d0941d5d130a922f53ef76ffe1a --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/PromptTemplate.vue @@ -0,0 +1,123 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/SiteFooter.vue b/lightx2v/deploy/server/frontend/src/components/SiteFooter.vue new file mode 100644 index 0000000000000000000000000000000000000000..e66df66ba05c69f1b800f1b740fed4c1f98873d5 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/SiteFooter.vue @@ -0,0 +1,38 @@ + + + diff --git a/lightx2v/deploy/server/frontend/src/components/TaskCarousel.vue b/lightx2v/deploy/server/frontend/src/components/TaskCarousel.vue new file mode 100644 index 0000000000000000000000000000000000000000..29163828c790dfc4b9cea39d10d00a8f8adc71c7 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/TaskCarousel.vue @@ -0,0 +1,497 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/TaskDetails.vue b/lightx2v/deploy/server/frontend/src/components/TaskDetails.vue new file mode 100644 index 0000000000000000000000000000000000000000..51550acb84b8c9eb16fb15282875ed9eaefc8ff0 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/TaskDetails.vue @@ -0,0 +1,1154 @@ + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/TemplateDetails.vue b/lightx2v/deploy/server/frontend/src/components/TemplateDetails.vue new file mode 100644 index 0000000000000000000000000000000000000000..d46e768faa26c7798fa6715ab41e0ded294735fc --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/TemplateDetails.vue @@ -0,0 +1,335 @@ + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/TemplateDisplay.vue b/lightx2v/deploy/server/frontend/src/components/TemplateDisplay.vue new file mode 100644 index 0000000000000000000000000000000000000000..b36096f6690f0f27600c7de9cae7fbe10b1c7079 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/TemplateDisplay.vue @@ -0,0 +1,294 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/TopBar.vue b/lightx2v/deploy/server/frontend/src/components/TopBar.vue new file mode 100644 index 0000000000000000000000000000000000000000..ec4ebe1ebd50bb83976f60d1b6808c72918bc409 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/TopBar.vue @@ -0,0 +1,106 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/VoiceSelector.vue b/lightx2v/deploy/server/frontend/src/components/VoiceSelector.vue new file mode 100644 index 0000000000000000000000000000000000000000..dc42a70a83a02d7b00f3e104b6c7d11ff60fad2d --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/VoiceSelector.vue @@ -0,0 +1,233 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/components/VoiceTtsHistoryPanel.vue b/lightx2v/deploy/server/frontend/src/components/VoiceTtsHistoryPanel.vue new file mode 100644 index 0000000000000000000000000000000000000000..f38a6a5c6ddbe76850c809d88125756eeed01cbe --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/VoiceTtsHistoryPanel.vue @@ -0,0 +1,308 @@ + + + diff --git a/lightx2v/deploy/server/frontend/src/components/Voice_tts.vue b/lightx2v/deploy/server/frontend/src/components/Voice_tts.vue new file mode 100644 index 0000000000000000000000000000000000000000..131f085849c8df54651174a068df7dbcc9780f48 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/components/Voice_tts.vue @@ -0,0 +1,2511 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/locales/en.json b/lightx2v/deploy/server/frontend/src/locales/en.json new file mode 100644 index 0000000000000000000000000000000000000000..77974b7b8ef6341b45ccf0a571f363c7610366eb --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/locales/en.json @@ -0,0 +1,651 @@ +{ + "faceEdit": "Face Edit", + "adjustFaceBox": "Adjust Character", + "rolesAndAudio": "Multiple roles mode (drag and drop to adjust the role and audio pairing)", + "roleModeInfo": { + "title": "Role Mode Guide", + "singleMode": { + "title": "Single Role Mode", + "points": [ + "All characters use the same audio track for lip-sync", + "Ideal for single-person dialogue or unified voice-over scenarios", + "Simple operation, no manual pairing required" + ] + }, + "multiMode": { + "title": "Multi-Role Mode", + "points": [ + "Image and audio tracks are automatically separated", + "Manually pair each character with its corresponding audio track", + "Each character will only speak the content of its assigned track" + ] + } + }, + "useImage": "Use Image", + "detectingCharacters": "Detecting characters...", + "waitingForMultipleRolesAudio": "Waiting for multiple roles audio...", + "deleteAudio": "Delete Audio", + "deleteImage": "Delete Image", + "addNewRole": "Add New Role", + "adjustFaceBoxDescription": "Adjust the face box to ensure the face is fully displayed.", + "adjustFaceBoxConfirm": "Adjust Face Box", + "adjustFaceBoxConfirmDescription": "Adjust the face box to ensure the face is fully displayed.", + "adjustFaceBoxConfirmButton": "Confirm Adjust", + "adjustFaceBoxCancelButton": "Cancel", + "adjustFaceBoxSaveButton": "Save", + "detectedCharacters": "Detected Characters", + "starOnGitHub": "Star on GitHub", + "pageNotFound": "Page Not Found", + "pageNotFoundDescription": "The page you are looking for does not exist.", + "goBackHome": "Home", + "goBack": "Go back", + "viewErrorDetails": "Error Details", + "viewAll": "All", + "stop": "Stop", + "preview": "Preview", + "share": "Share", + "noAvailableTemplates": "No available templates", + "pleaseSelectTaskType": "Please select task type first", + "textToSpeech": "Text to Speech", + "historyTask": "History Task", + "total": "Total", + "tasks": "tasks", + "records": "records", + "goToHome": "Go to Home", + "imageTemplates": "Image Templates", + "audioTemplates": "Audio Templates", + "noImageTemplates": "No image templates", + "noAudioTemplates": "No audio templates", + "templateDetail": "Template detail", + "viewTemplateDetail": "View Template Detail", + "viewTaskDetails": "View Task Details", + "templateInfo": "Template Info", + "model": "Model", + "type": "Type", + "inputMaterials": "Input Materials", + "inputImage": "Input Image", + "inputAudio": "Input Audio", + "optional": "(Optional)", + "pageTitle": "LightX2V Service", + "uploadVideo": "Upload Video", + "supportedVideoFormats": "Supported MP4, WebM, QuickTime formats", + "pleaseEnterThePromptForVideoGeneration": "Please enter the prompt for video generation", + "describeTheContentStyleSceneOfTheVideo": "Describe the content, style, and scene of the video...", + "describeTheDigitalHumanImageBackgroundStyleActionRequirements": "Describe the digital human expression, tone, and action...", + "describeTheContentActionRequirementsBasedOnTheImage": "Describe the content and action requirements based on the image...", + "loginSubtitle": "A powerful video generation platform", + "whatDoYouWantToDo": "What do you want to do today?", + "whatMaterialsDoYouNeed": "What materials do you need to create a video?", + "pleaseEnterTheMostDetailedVideoScript": "Please enter the most detailed video script", + "pleaseUploadAnImageAsTheFirstFrameOfTheVideoAndTheMostDetailedVideoScript": "Please upload an image as the first frame of the video and the most detailed video script", + "pleaseUploadARoleImageAnAudioAndTheGeneralVideoRequirements": "Please upload a role image, an audio, and the general video requirements", + "collapseCreationArea": "Collapse creation area", + "startCreatingVideo": "Start creating video···", + "loginWithGitHub": "GitHub", + "loginWithGoogle": "Google", + "loginWithSMS": "SMS", + "loggingIn": "Logging in...", + "logout": "Logout", + "loggedOut": "Logged out", + "loginFailed": "Login failed", + "loginError": "Error occurred during login", + "authFailed": "Authentication failed, please login again", + "loginExpired": "Login expired, please login again", + "orLoginWith": "Or login with", + "login": "Login / Register", + "loginLoading": "Logging in···", + "sendSmsCode": "Send SMS", + "phoneNumber": "Phone Number", + "verifyCode": "Verify Code", + "feature1": "Cinema-grade digital human videos", + "feature2": "20x faster generation", + "feature3": "Ultra-low cost generation", + "feature4": "Precise lip-sync alignment", + "feature5": "Minute-level video duration", + "feature6": "Multi-scenario applications", + "generateVideo": "Generate Video", + "history": "History", + "inspiration": "Inspiration", + "discoverCreativity": "Discover creativity, inspire ideas", + "searchTasks": "Search history tasks...", + "searchInspiration": "Search inspiration...", + "refresh": "Refresh task list", + "noHistoryTasks": "No history tasks", + "startToCreateYourFirstAIVideo": "Start creating your first AI video", + "switchLanguage": "Switch Language", + "selectTaskType": "Select Task Type", + "selectTaskTypeFirst": "Please select task type first", + "noHistoryRecords": "No history records", + "imageHistoryAutoSave": "Image history will be automatically saved when you start using images", + "audioHistoryAutoSave": "Audio history will be automatically saved when you start using audio", + "clearHistory": "Clear history", + "clear": "Clear", + "promptHistoryAutoSave": "Prompt history will be automatically saved when you start creating tasks", + "promptTip": "Describe the video content you want in detail", + "viewFailureReason": "View Failure Reason", + "failureReason": "Failure Reason", + "noPrompt": "No Prompt", + "uploadMaterials": "Upload Materials", + "image": "Image", + "video": "Video", + "historyAudio": "History Audio", + "historyAudioApplied": "History audio applied", + "myProjects": "My Projects", + "initializationFailed": "Initialization Failed, Please Refresh The Page", + "browserNotSupported": "Browser Not Supported", + "videoLoadFailed": "Video Load Failed", + "loadingVideo": "Loading Video···", + "videoGenerating": "Video Generating", + "taskProgress": "Task Progress", + "subtask": "Subtask", + "queuePosition": "Waiting for", + "availableWorker": "Available Worker", + "videoGeneratingFailed": "Video Generating Failed", + "sorryYourVideoGenerationTaskFailed": "Sorry Your Video Generation Task Failed", + "thisTaskHasBeenCancelledYouCanRegenerateOrViewTheMaterialsYouUploadedBefore": "This Task Has Been Cancelled You Can Regenerate Or View The Materials You Uploaded Before", + "taskCompleted": "Task Completed", + "taskFailed": "Task Failed", + "taskCancelled": "Cancelled", + "taskRunning": "Task Running", + "taskPending": "Task Pending", + "taskInfo": "Task Info", + "taskID": "Task ID", + "modelName": "Model Name", + "createTime": "Create Time", + "updateTime": "Update Time", + "aiIsGeneratingYourVideo": "LightX2V is generating your video...", + "taskSubmittedSuccessfully": "Task submitted successfully, accelerating processing...", + "taskQueuePleaseWait": "The task is a little bit, accelerating queueing...", + "success": "Success", + "failed": "Failed", + "pending": "Waiting", + "cancelled": "Cancelled", + "all": "All", + "created": "Created", + "status": "Status", + "reuseTask": "Reuse", + "regenerateTask": "Retry", + "retryTask": "Retry", + "downloadTask": "Download Video", + "downloadVideo": "Download Video", + "deleteTask": "Delete", + "cancelTask": "Cancel", + "download": "Download", + "downloadPreparing": "Preparing download...", + "downloadFetching": "Fetching file...", + "downloadSaving": "Saving file...", + "mobileSaveToAlbumTip": "Long press the video in the new tab to save it to your gallery.", + "mobileSavePreviewTitle": "Preview & Save", + "mobileSaveInstruction": "Tap the full-screen button or long press the video to save it to your photo library.", + "mute": "Mute", + "unmute": "Unmute", + "unsupportedAudioOrVideo": "Please select an audio or video file.", + "unsupportedVideoFormat": "Only MP4/M4V/MPEG video files are supported for audio extraction.", + "downloadInProgressNotice": "A download is already in progress. Please wait.", + "downloadCancelledAlert": "Download cancelled", + "createVideo": "Create Video", + "selectTemplate": "Select Template", + "uploadImage": "Upload Image", + "uploadAudio": "Upload Audio or Video", + "recordAudio": "Record Audio", + "recording": "Recording...", + "takePhoto": "Take Photo", + "retake": "Retake", + "usePhoto": "Use Photo", + "upload": "Upload", + "stopRecording": "Stop Recording", + "recordingStarted": "Recording started", + "recordingStopped": "Recording stopped", + "recordingCompleted": "Recording completed", + "recordingFailed": "Recording failed", + "enterPrompt": "Enter Prompt", + "selectModel": "Select Model", + "startGeneration": "Start Generation", + "templates": "Templates", + "useTemplate": "Use Template", + "applyImage": "Apply Image", + "applyAudio": "Apply Audio", + "featuredTemplates": "Featured Templates", + "discoverFeaturedCreativity": "Discover Featured Creativity", + "refreshRandomTemplates": "Refresh Random Templates", + "discover": "Discover", + "viewMoreTemplates": "View More Templates", + "searchTemplates": "Search Templates", + "browseCategories": "Browse Categories", + "inspirationGallery": "Inspirations", + "viewMore": "View More", + "more": "More", + "applyPrompt": "Apply Prompt", + "imageApplied": "Image applied", + "audioApplied": "Audio applied", + "promptApplied": "Prompt applied", + "copy": "Copy", + "view": "View", + "promptCopied": "Prompt copied to clipboard", + "outputVideo": "Output Video", + "textToVideo": "Text to Video", + "imageToVideo": "Image to Video", + "speechToVideo": "Speech to Video", + "animate": "Character Replacement", + "prompt": "Prompt (Optional)", + "negativePrompt": "Negative Prompt", + "promptTemplates": "Prompt Templates", + "promptHistory": "Prompt History", + "t2vHint1": "Enter text description, AI will generate精彩的视频内容", + "t2vHint2": "Support multiple styles: realistic, animation, art, etc.", + "t2vHint3": "Can describe scenes, actions, emotions, etc.", + "t2vHint4": "Let your creativity become a vivid video", + "i2vHint1": "Upload an image, AI will generate dynamic video", + "i2vHint2": "Support multiple image formats: JPG, PNG, WebP, etc.", + "i2vHint3": "Can generate various dynamic effects.", + "i2vHint4": "Let static image become dynamic, create infinite possibilities", + "s2vHint1": "Upload a role image + an audio", + "s2vHint2": "AI will make the role speak and move according to the audio content.", + "s2vHint3": "Let your role become alive.", + "s2vHint4": "Create your own digital person.", + "uploadImageFile": "Upload Image File", + "uploadAudioFile": "Upload Audio or Video File", + "dragDropHere": "Drag and drop files here or click to upload", + "supportedImageFormats": "Supported jpg, png, webp image formats (< 10MB)", + "supportedAudioFormats": "Supports audio or video formats (< 120s).", + "supportedAudioFormatsShort": "Supports audio or video formats (< 120s).", + "prefillLoadingDefault": "Preparing materials...", + "prefillLoadingTemplate": "Loading template assets...", + "prefillLoadingTask": "Loading task materials...", + "clearCharacterImageTip": "Upload a clear character image", + "maxFileSize": "Max file size", + "taskDetail": "Task Details", + "taskId": "Task ID", + "taskType": "Task Type", + "taskStatus": "Task Status", + "createdAt": "Created At", + "completedAt": "Completed At", + "duration": "Duration", + "confirm": "Confirm", + "cancel": "Cancel", + "save": "Save", + "edit": "Edit", + "delete": "Delete", + "close": "Close", + "copyLink": "Copy Link", + "pleaseCopyManually": "Please manually select and copy the text below", + "back": "Back", + "next": "Next", + "previous": "Previous", + "finish": "Finish", + "submitting": "Submitting...", + "operationSuccess": "Operation successful", + "operationFailed": "Operation failed", + "pleaseWait": "Please wait...", + "loading": "Loading···", + "noData": "No data", + "errorOccurred": "Error occurred", + "networkError": "Network error", + "serverError": "Server error", + "seconds": "seconds", + "deleteTaskConfirm": "Delete Task?", + "deleteTaskConfirmMessage": "This action cannot be undone. It will delete the task record, generated files, and related data.", + "confirmDelete": "Delete", + "regenerateTaskConfirm": "Regenerate Task?", + "regenerateTaskConfirmMessage": "Regenerating will delete the current task and generated content, then create a new task with the same parameters. This action cannot be undone!", + "confirmRegenerate": "Regenerate", + "regeneratingTaskAlert": "Regenerating task...", + "deletingTaskAlert": "Deleting task...", + "taskDeletedSuccessAlert": "Task deleted successfully", + "deleteTaskFailedAlert": "Delete task failed", + "getTaskDetailFailedAlert": "Get task detail failed", + "taskNotExistAlert": "Task not exist", + "loadTaskFilesFailedAlert": "Load task files failed", + "taskMaterialReuseSuccessAlert": "Task material reuse successfully", + "loadTaskDataFailedAlert": "Load task data failed", + "fileUnavailableAlert": "File unavailable", + "downloadFailedAlert": "Download failed. Please try again.", + "taskSubmitSuccessAlert": "Task submit successfully", + "taskSubmitFailedAlert": "Task submit failed", + "submitTaskFailedAlert": "Submit task failed", + "downloadSuccessAlert": "File download successfully", + "getTaskResultFailedAlert": "Get task result failed", + "downloadTaskResultFailedAlert": "Download task result failed", + "viewTaskResultFailedAlert": "View task result failed", + "cancelTaskConfirm": "Cancel task?", + "cancelTaskConfirmMessage": "Cancel task will stop the task execution, and the generated part of the result may be lost, can be regenerated later.", + "confirmCancel": "Cancel", + "taskCancelSuccessAlert": "Task cancel successfully", + "cancelTaskFailedAlert": "Cancel task failed", + "taskRetrySuccessAlert": "Task retry successfully", + "retryTaskFailedAlert": "Retry task failed", + "taskRegenerateSuccessAlert": "Task regenerated successfully", + "regenerateTaskFailedAlert": "Regenerate task failed", + "taskNotFoundAlert": "Task not found", + "templateApplied": "Template applied", + "shareTemplate": "Share", + "copyShareLink": "Share", + "promptHistoryApplied": "Prompt history applied", + "promptHistoryCleared": "Prompt history cleared", + "getPromptHistoryFailed": "Get prompt history failed", + "saveTaskHistoryFailed": "Save task history failed", + "parseTaskHistoryFailed": "Parse task history failed", + "getTaskHistoryFailed": "Get task history failed", + "getImageHistoryFailed": "Get image history failed", + "taskHistorySaved": "Task history saved", + "taskHistoryCleared": "Task history cleared", + "clickToDownload": "Click to download", + "clickApply": "Click to apply", + "justNow": "Just now", + "minutesAgo": "minutes ago", + "hoursAgo": "hours ago", + "daysAgo": "days ago", + "weeksAgo": "weeks ago", + "monthsAgo": "months ago", + "yearsAgo": "years ago", + "oneMinuteAgo": "one minute ago", + "oneHourAgo": "one hour ago", + "oneDayAgo": "one day ago", + "oneWeekAgo": "one week ago", + "oneMonthAgo": "one month ago", + "oneYearAgo": "one year ago", + "shareNotFound": "Share not found", + "backToHome": "Back to Home", + "videoNotAvailable": "Video not available", + "createdWithAI": "Created with AI", + "createSimilar": "Create Similar", + "createSimilarDescription": "Click the button to create your video with the same settings", + "shareDataImported": "Share data imported successfully", + "shareDataImportFailed": "Failed to import share data", + "templatesGeneratedByLightX2V": "The following videos are generated by LightX2V, hover/click to play", + "materials": "Materials", + "template": "Template", + "templateDescription": "The video is generated by LightX2V-digital human model", + "pleaseLoginFirst": "Please login first", + "showDetails": "Show Details", + "hideDetails": "Hide Details", + "shareLinkCopied": "Share link copied", + "randomTemplates": "Random Refresh Templates", + "oneClickReplication": "One-click replication", + "customizableContent": "Customizable content", + "poweredByLightX2V": "Speed-generated video - LightX2V", + "latestAIModel": "Latest AI digital human model, rapid video generation", + "customizableCharacter": "Freely customizable characters and audio", + "userGeneratedVideo": " generated video", + "noImage": "No Images", + "noAudio": "No Audio", + "noVideo": "No Video", + "taskCompletedSuccessfully": "LightX2V has generated video for you successfully", + "onlyUseImage": "Only use image", + "onlyUseAudio": "Only use audio", + "reUseImage": "Reuse image", + "reUseAudio": "Reuse audio", + "templateVideo": "Speech-to-video generation template", + "description": "The video is generated by LightX2V", + "timeCost": "Time cost ", + "voiceSynthesis": "Voice Synthesis", + "applySelectedVoice": "Apply selected voice", + "generatedAudio": "Generated Audio", + "synthesizedAudio": "Synthesized Audio", + "enterTextToConvert": "Enter text to convert", + "ttsPlaceholder": "The weather is nice today, let's go for a walk~", + "voiceInstruction": "Voice Instruction", + "voiceInstructionHint": "(Only for v2.0 voices)", + "voiceInstructionPlaceholder": "Use instruction to control voice details, including emotion, context, dialect, tone, speed, pitch, etc. Example: Please use a warm and friendly voice", + "selectVoice": "Select Voice", + "searchVoice": "Search Voice", + "filter": "Filter", + "filterVoices": "Filter Voices", + "voiceSettings": "Voice Settings", + "speechRate": "Speech Rate", + "volume": "Volume", + "pitch": "Pitch", + "emotionIntensity": "Emotion Intensity", + "emotionType": "Emotion Type", + "neutral": "Neutral", + "scene": "Scene", + "version": "Version", + "language": "Language", + "gender": "Gender", + "reset": "Reset", + "done": "Done", + "ttsGenerationFailed": "TTS generation failed, please retry", + "applyAudioFailed": "Apply audio failed, please retry", + "allScenes": "All Scenes", + "generalScene": "General", + "customerServiceScene": "Customer Service", + "educationScene": "Education", + "funAccent": "Fun Accent", + "rolePlaying": "Role Playing", + "audiobook": "Audiobook", + "multilingual": "Multilingual", + "multiEmotion": "Multi-emotion", + "videoDubbing": "Video Dubbing", + "ttsHistoryTitle": "History", + "ttsHistoryHint": "We automatically keep the last 20 voice texts and instructions you used.", + "ttsHistoryEmpty": "No saved entries yet", + "ttsHistoryEmptyHint": "Generate voice once to create your first history entry.", + "ttsHistoryTextLabel": "Voice Text", + "ttsHistoryInstructionLabel": "Voice Instruction", + "ttsHistoryTextEmpty": "Empty text", + "ttsHistoryInstructionEmpty": "Empty instruction", + "ttsHistoryVoiceLabel": "Voice History", + "ttsHistoryVoiceEmpty": "Not set", + "ttsHistoryApply": "Use This", + "ttsHistoryApplySelected": "Apply", + "ttsHistoryDeleteEntry": "Remove", + "ttsHistoryTabCombined": "All", + "ttsHistoryTabText": "Text History", + "ttsHistoryTabInstruction": "Instruction History", + "ttsHistoryTabVoice": "Voice History", + "ttsHistoryTitleCombined": "All History", + "ttsHistoryTitleText": "Text History", + "ttsHistoryTitleInstruction": "Instruction History", + "ttsHistoryTitleVoice": "Voice History", + "ttsHistoryClear": "Clear History", + "allVersions": "All Versions", + "allLanguages": "All Languages", + "allGenders": "All Genders", + "female": "Female", + "male": "Male", + "taskCountdown": "Task countdown", + "footer": { + "tagline": "AI digital human video generation powered by LightX2V framework", + "links": { + "home": "Light AI Homepage", + "github": "GitHub", + "xiaohongshu": "RedNote" + }, + "alt": { + "github": "GitHub logo", + "xiaohongshu": "RedNote logo" + }, + "copyright": "© {year} Light AI. All rights reserved." + }, + "tts": { + "title": "AI Voice Synthesis", + "subtitle": "Synthesize your voice with AI", + "inputText": "Enter text to synthesize", + "voiceSelection": "Select voice", + "voiceSettings": "Voice settings", + "speechRate": "Speech rate", + "volume": "Volume", + "pitch": "Pitch", + "emotionIntensity": "Emotion intensity", + "emotionType": "Emotion type", + "neutral": "Neutral", + "scene": "Scene", + "version": "Version", + "language": "Language", + "gender": "Gender", + "reset": "Reset", + "done": "Done", + "ttsGenerationFailed": "TTS generation failed, please retry", + "applyAudioFailed": "Apply audio failed, please retry", + "allScenes": "All Scenes", + "generalScene": "General", + "customerServiceScene": "Customer Service", + "educationScene": "Education", + "funAccent": "Fun Accent", + "rolePlaying": "Role Playing", + "audiobook": "Audiobook", + "multilingual": "Multilingual", + "multiEmotion": "Multi-emotion", + "videoDubbing": "Video Dubbing", + "allVersions": "All Versions", + "allLanguages": "All Languages", + "allGenders": "All Genders", + "female": "Female", + "male": "Male", + "multiSegmentMode": "Multi-segment Mode", + "singleSegmentMode": "Single-segment Mode", + "switchToSingleSegmentMode": "Switch to Single-segment Mode", + "switchToMultiSegmentMode": "Switch to Multi-segment Mode", + "mergedAudio": "Merged Audio", + "applyMergedAudio": "Apply Merged Audio", + "addSegment": "Add Segment", + "segment": "Segment", + "segmentNumber": "Segment {index}", + "dragToReorder": "Drag to Reorder", + "copySegment": "Copy Segment", + "deleteSegment": "Delete Segment", + "selectVoice": "Select Voice", + "generate": "Generate", + "text": "Text", + "voiceInstructionOptional": "Voice Instruction (Optional)", + "segmentCopied": "Segment copied and added to end", + "noSegmentsToApply": "No audio segments available to apply", + "mergedAudioLoadFailed": "Merged audio load failed: {error}", + "mergedAudioFailed": "Merged audio failed: {error}", + "unknownError": "Unknown error", + "playbackFailed": "Playback failed: {error}", + "audioDecodeFailed": "Audio decode failed: {error}" + }, + "podcast": { + "title": "AI Dual-Person Podcast Generator", + "subtitle": "Make knowledge audible", + "generating": "Generating podcast...", + "generatingStatusWithCount": "Generating podcast ({count} segments generated)...", + "ready": "Audio ready, click to play", + "readyWithCount": "Audio ready ({count} segments generated), click to play", + "preparingFirstAudio": "Preparing first audio segment...", + "preparingAudio": "Preparing audio...", + "completed": "Generation completed ({count} segments total)", + "stopped": "Generation stopped", + "generationFailed": "Generation failed", + "generatePodcast": "Generate Podcast", + "dualPersonPodcast": "Podcast", + "stopGeneration": "Stop Generation", + "downloadAudio": "Download Audio", + "applyToDigitalHuman": "Convert to Digital Human Video", + "generateMore": "Generate More Podcasts", + "historyTitle": "Generated Records", + "toggleSidebar": "Collapse/Expand", + "noHistory": "No history records", + "completedStatus": "Completed", + "generatingStatus": "Generating...", + "showSubtitles": "Show Subtitles", + "hideSubtitles": "Hide Subtitles", + "inputPlaceholder": "Enter article/file link or specify a topic, e.g.: Principles of AI", + "enterLinkOrTopic": "Please enter a link or topic", + "audioReady": "Audio ready", + "audioLoading": "Audio loading, please wait...", + "playbackFailed": "Playback failed, please retry", + "playbackFailedWithError": "Playback failed: {error}", + "audioLoadFailed": "Audio load failed, please check network connection", + "noAudioAvailable": "No audio available", + "noAudioToDownload": "No audio available to download", + "pleaseGenerateFirst": "Please generate a podcast first", + "applySuccess": "Podcast added to audio materials", + "applyFailed": "Failed to apply to digital human", + "loadAudioFailed": "Failed to load podcast audio", + "sessionDataNotFound": "Session data not found, please refresh history", + "loadSessionFailed": "Failed to load session", + "loadAudioFailedDetail": "Failed to load audio", + "audioDecodeFailed": "Audio decode failed: {error}", + "audioLoadFailedNetwork": "Audio load failed, please check network connection", + "audioLoadFailedFormat": "Audio load failed, please check network connection or audio format", + "audioLoadFailedWithError": "Audio load failed: {error}", + "audioMayBeSilent": "Audio playback may be silent, please refresh page and retry", + "unknownError": "Unknown error", + "exampleInputs": [ + "https://github.com/ModelTC/LightX2V", + "Principles of LLM Large Models", + "What is Deep Learning?", + "How to Balance Work and Life?", + "How to Lose Weight Scientifically" + ] + }, + "faceDetectionFailed": "Face detection failed", + "pleaseUploadImage": "Please upload an image first", + "multiRoleModeRequires": "Multi-role mode requires at least 2 roles, please manually add more roles", + "audioSeparationFailed": "Audio separation failed", + "singleRoleModeInfo": "In single-role mode, all characters will use the same audio track for lip-sync.", + "ttsCompleted": "Text-to-speech completed, automatically added to audio materials", + "imageDragSuccess": "Image drag and drop upload successful", + "pleaseDragImage": "Please drag an image file", + "audioDragSuccess": "Audio/video drag and drop upload successful", + "pleaseDragAudio": "Please drag an audio or video file", + "videoDragSuccess": "Video drag and drop upload successful", + "pleaseDragVideo": "Please drag a video file", + "authFailedPleaseRelogin": "Authentication failed, please log in again", + "getGitHubAuthUrlFailed": "Failed to get GitHub authentication URL", + "getGoogleAuthUrlFailed": "Failed to get Google authentication URL", + "pleaseEnterPhoneNumber": "Please enter phone number", + "pleaseEnterValidPhoneNumber": "Please enter a valid phone number format", + "verificationCodeSent": "Verification code sent, please check your SMS", + "sendVerificationCodeFailed": "Failed to send verification code", + "sendVerificationCodeFailedRetry": "Failed to send verification code, please retry", + "pleaseEnterPhoneAndCode": "Please enter phone number and verification code", + "loginSuccess": "Login successful", + "verificationCodeErrorOrExpired": "Verification code error or expired", + "loginFailedRetry": "Login failed, please retry", + "loginError": "An error occurred during login", + "loggedOut": "Logged out", + "loadModelListFailed": "Failed to load model list", + "loadModelFailed": "Failed to load model", + "imageTemplateSelected": "Image template selected", + "loadImageTemplateFailed": "Failed to load image template", + "audioTemplateSelected": "Audio template selected", + "loadAudioTemplateFailed": "Failed to load audio template", + "audioFileUrlFailed": "Failed to get audio file URL", + "audioPlaybackFailed": "Audio playback failed", + "templateLoadingPleaseWait": "Template is loading, please try again later", + "pleaseSelectTaskType": "Please select task type", + "pleaseSelectModel": "Please select model", + "pleaseEnterPrompt": "Please enter prompt", + "promptTooLong": "Prompt length cannot exceed 1000 characters", + "i2vTaskRequiresImage": "Image-to-video task requires uploading a reference image", + "s2vTaskRequiresImage": "Digital human task requires uploading a character image", + "s2vTaskRequiresAudio": "Digital human task requires uploading an audio file", + "animateTaskRequiresImage": "Character replacement task requires uploading a character image", + "animateTaskRequiresVideo": "Character replacement task requires uploading a reference video", + "prepareMultiPersonAudioFailed": "Failed to prepare multi-person audio", + "taskSubmittedButParseFailed": "Task submitted successfully, but failed to parse response", + "refreshTaskListFailed": "Failed to refresh task list", + "getResultFailed": "Failed to get result", + "initFailedPleaseRefresh": "Initialization failed, please refresh the page and try again", + "historyCleared": "History storage cleared", + "historyImageApplied": "History image applied", + "applyHistoryImageFailed": "Failed to apply history image", + "applyHistoryAudioFailed": "Failed to apply history audio", + "audioHistoryUrlFailed": "Failed to get audio history URL", + "imageHistoryCleared": "Image history cleared", + "audioHistoryCleared": "Audio history cleared", + "storageCleared": "Storage cleared", + "clearStorageFailed": "Failed to clear storage", + "loginExpiredPleaseRelogin": "Login expired, please log in again", + "networkRequestFailed": "Network request failed", + "videoLoadTimeout": "Video load timeout, please retry", + "templateDataIncomplete": "Template data incomplete", + "loadMoreInspirationComingSoon": "Load more inspiration feature coming soon...", + "microphonePermissionDenied": "Microphone permission denied. Please click the 🔒 or 🎤 icon on the left side of the Chrome address bar, select 'Allow' microphone access, then refresh the page and try again", + "microphoneNotFound": "Microphone device not found, please check device connection or use another device", + "recordingNotSupportedOnMobile": "Recording is not supported on mobile browsers, you can record video instead", + "microphoneInUse": "Microphone is being used by another application, please close other programs using the microphone and try again", + "microphoneNotCompatible": "Microphone device does not support the required recording parameters, please use another microphone device", + "securityErrorUseHttps": "Security restriction: Please ensure you are accessing the website using HTTPS protocol", + "shareDataIncomplete": "Share data incomplete", + "audioLoadFailed": "Audio load failed", + "pleaseRelogin": "Please log in again", + "pleaseLoginFirst": "Please log in first", + "cancelTaskFailedRetry": "Failed to cancel task, please retry", + "shareFailedRetry": "Share failed, please retry", + "retryTaskFailedRetry": "Failed to retry task, please retry", + "splitingAudio": "Splitting audio in multi-role mode..." +} diff --git a/lightx2v/deploy/server/frontend/src/locales/zh.json b/lightx2v/deploy/server/frontend/src/locales/zh.json new file mode 100644 index 0000000000000000000000000000000000000000..8aaae13beab6fdc7dcc3912dc73e612081880484 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/locales/zh.json @@ -0,0 +1,696 @@ +{ + "faceEdit": "脸部编辑", + "adjustFaceBox": "调整角色", + "rolesAndAudio": "多角色模式(拖拽调整角色和音频的配对)", + "roleModeInfo": { + "title": "角色模式说明", + "singleMode": { + "title": "单角色模式", + "points": [ + "所有角色使用同一个音轨进行对口型,多人同时对口型", + "适合单人对话或统一配音场景", + "操作简单,无需手动配对" + ] + }, + "multiMode": { + "title": "多角色模式", + "points": [ + "图片和音轨自动分离,多人情况下每个人有单独的音轨", + "需要手动将角色与音轨配对", + "每个角色只会说对应音轨的内容" + ] + } + }, + "useImage": "使用图片", + "detectingCharacters": "角色识别中...", + "waitingForMultipleRolesAudio": "等待上传多角色对话音频...", + "deleteAudio": "删除音频", + "deleteImage": "删除图片", + "addNewRole": "新增角色", + "adjustFaceBoxDescription": "调整人脸边界框,确保人脸完整显示。", + "adjustFaceBoxConfirm": "调整人脸边界框", + "adjustFaceBoxConfirmDescription": "调整人脸边界框,确保人脸完整显示。", + "adjustFaceBoxConfirmButton": "确认调整", + "adjustFaceBoxCancelButton": "取消", + "adjustFaceBoxSaveButton": "保存", + "detectedCharacters": "检测到的角色", + "starOnGitHub": "Star on GitHub", + "pageNotFound": "页面不存在", + "pageNotFoundDescription": "您访问的页面不存在。", + "goBackHome": "主页", + "goBack": "返回", + "viewAll": "全部", + "viewErrorDetails": "错误详情", + "stop": "停止", + "preview": "预览", + "share": "分享", + "noAvailableTemplates": "暂无可用素材", + "pleaseSelectTaskType": "请先选择任务类型", + "textToSpeech": "文本转语音", + "historyTask": "历史任务", + "total": "共", + "tasks": "条", + "records": "条", + "goToHome": "返回主页", + "imageTemplates": "图片素材", + "audioTemplates": "音频素材", + "noImageTemplates": "目前暂无图片素材", + "noAudioTemplates": "目前暂无音频素材", + "templateDetail": "模板详情", + "viewTemplateDetail": "查看模板详情", + "viewTaskDetails": "查看任务详情", + "taskDetail": "任务详情", + "templateInfo": "模板信息", + "useTemplate": "使用模板", + "model": "模型", + "type": "类型", + "inputMaterials": "上传素材", + "inputImage": "输入图片", + "inputAudio": "输入音频", + "applyImage": "只应用图片", + "applyAudio": "只应用音频", + "featuredTemplates": "精选模版", + "discoverFeaturedCreativity": "发现精选创意", + "refreshRandomTemplates": "随机获取精选模版", + "discover": "发现", + "viewMoreTemplates": "查看更多模版", + "searchTemplates": "搜索模版", + "browseCategories": "分类浏览", + "inspirationGallery": "灵感", + "viewMore": "查看更多", + "more": "更多", + "applyPrompt": "只应用提示词", + "imageApplied": "图片已应用", + "audioApplied": "音频已应用", + "promptApplied": "提示词已应用", + "copy": "复制", + "view": "查看", + "promptCopied": "提示词已复制到剪贴板", + "outputVideo": "输出视频", + "audio": "音频", + "optional": "(选填)", + "pageTitle": "LightX2V服务", + "uploadVideo": "上传视频", + "supportedVideoFormats": "支持 MP4、WebM、QuickTime 格式", + "pleaseEnterThePromptForVideoGeneration": "视频生成提示词", + "describeTheContentStyleSceneOfTheVideo": "描述视频内容、风格、场景等...", + "describeTheDigitalHumanImageBackgroundStyleActionRequirements": "描述数字人表情、动作等,例如:角色应根据音频做出夸张的动作", + "describeTheContentActionRequirementsBasedOnTheImage": "描述基于图片的视频内容、动作要求等...", + "loginSubtitle": "一个强大的视频生成平台", + "loginWithGitHub": "使用GitHub登录", + "loginWithGoogle": "使用Google登录", + "loginWithSMS": "使用短信登录", + "loggingIn": "登录中...", + "logout": "退出登录", + "loggedOut": "已退出登录", + "loginFailed": "登录失败", + "loginError": "登录过程中发生错误", + "authFailed": "认证失败,请重新登录", + "loginExpired": "登录已过期,请重新登录", + "orLoginWith": "或使用以下方式登录", + "login": "登录 / 注册", + "loginLoading": "登录中···", + "sendSmsCode": "发送验证码", + "phoneNumber": "手机号", + "verifyCode": "验证码", + "feature1": "电影级数字人视频", + "feature2": "20倍生成提速", + "feature3": "超低成本生成", + "feature4": "精准口型对齐", + "feature5": "分钟级视频时长", + "feature6": "多场景应用", + "generateVideo": "生成视频", + "history": "历史记录", + "inspiration": "灵感", + "myProjects": "我的项目", + "discoverCreativity": "发现创意,激发灵感", + "searchInspiration": "搜索灵感...", + "refresh": "刷新任务列表", + "refreshTasks": "刷新任务列表", + "noHistoryTasks": "暂无历史任务", + "startToCreateYourFirstAIVideo": "开始创建你的第一个AI视频吧", + "switchLanguage": "切换语言", + "selectTaskType": "选择任务类型", + "selectTaskTypeFirst": "请先选择任务类型", + "noHistoryRecords": "暂无历史记录", + "imageHistoryAutoSave": "开始使用图片后,历史记录将自动保存", + "audioHistoryAutoSave": "开始使用音频后,历史记录将自动保存", + "clearHistory": "清空历史记录", + "clear": "清空", + "promptHistoryAutoSave": "开始创建任务后,提示词将自动保存", + "searchTasks": "搜索任务", + "initializationFailed": "初始化失败,请刷新页面重试", + "whatDoYouWantToDo": "今天想做什么样的视频呢?", + "whatMaterialsDoYouNeed": "创作视频需要什么素材呢?", + "pleaseEnterTheMostDetailedVideoScript": "请输入尽可能详细的视频脚本", + "pleaseUploadAnImageAsTheFirstFrameOfTheVideoAndTheMostDetailedVideoScript": "请上传一张图片作为视频的首帧图,以及尽可能详细的视频脚本", + "pleaseUploadARoleImageAnAudioAndTheGeneralVideoRequirements": "仅需要一张角色图片和一段音频", + "collapseCreationArea": "收起创作区域", + "startCreatingVideo": "开始创作视频···", + "aiIsGeneratingYourVideo": "LightX2V 正在光速生成您的视频...", + "taskProgress": "任务进度", + "subtask": "子任务", + "queuePosition": "还需等待", + "taskSubmittedSuccessfully": "任务已提交成功,加速处理中...", + "taskQueuePleaseWait": "任务有点多,加速排队中...", + "availableWorker": "可用Worker", + "videoGeneratingFailed": "视频生成失败", + "sorryYourVideoGenerationTaskFailed": "很抱歉,您的视频生成任务未能完成。", + "thisTaskHasBeenCancelledYouCanRegenerateOrViewTheMaterialsYouUploadedBefore": "此任务已被取消,您可以重新生成或查看之前上传的素材。", + "taskCompleted": "任务已完成", + "taskFailed": "任务失败", + "taskCancelled": "任务已取消", + "taskRunning": "任务运行中", + "taskPending": "任务排队中", + "taskInfo": "任务信息", + "taskID": "任务ID", + "taskType": "任务类型", + "modelName": "模型名称", + "createTime": "创建时间", + "updateTime": "更新时间", + "viewFailureReason": "查看失败原因", + "failureReason": "失败原因", + "noPrompt": "无提示词", + "uploadMaterials": "上传素材", + "loading": "加载中···", + "image": "图片", + "video": "视频", + "shareTemplate": "分享模板", + "copyShareLink": "分享", + "historyAudio": "历史音频", + "status": "状态", + "browserNotSupported": "您的浏览器不支持播放", + "videoLoadFailed": "视频加载失败", + "loadingVideo": "加载视频中···", + "videoGenerating": "视频生成中", + "succeed": "成功", + "success": "成功", + "failed": "失败", + "running": "运行", + "pending": "排队", + "remaining": "剩余", + "cancelled": "取消", + "all": "全部", + "reuseTask": "复用", + "regenerateTask": "重试", + "cancelTask": "取消", + "retryTask": "重试", + "downloadTask": "下载视频", + "downloadVideo": "下载视频", + "downloadPreparing": "正在准备下载…", + "downloadFetching": "正在获取文件…", + "downloadSaving": "正在保存文件...", + "mobileSaveToAlbumTip": "新窗口打开后长按视频即可保存到相册。", + "mobileSavePreviewTitle": "预览并保存", + "mobileSaveInstruction": "可点击全屏或长按视频,将其保存到手机相册。", + "mute": "静音", + "unmute": "取消静音", + "unsupportedAudioOrVideo": "请选择音频或视频文件。", + "unsupportedVideoFormat": "仅支持 MP4/M4V/MPEG 视频文件转换音频。", + "downloadInProgressNotice": "已有下载任务正在进行,请稍候。", + "downloadCancelledAlert": "已取消下载", + "deleteTask": "删除", + "createVideo": "创建视频", + "selectTemplate": "选择模板", + "uploadImage": "上传图片", + "uploadAudio": "上传音频", + "recordAudio": "录音", + "recording": "录音中...", + "takePhoto": "拍照", + "retake": "重拍", + "usePhoto": "使用照片", + "upload": "上传", + "stopRecording": "停止录音", + "recordingStarted": "开始录音", + "recordingStopped": "录音已停止", + "recordingCompleted": "录音完成", + "recordingFailed": "录音失败", + "enterPrompt": "输入提示词", + "selectModel": "选择模型", + "startGeneration": "开始生成", + "templates": "素材", + "textToVideo": "文生视频", + "imageToVideo": "图生视频", + "speechToVideo": "数字人", + "animate": "角色替换", + "prompt": "提示词(选填)", + "negativePrompt": "负面提示词", + "promptTemplates": "提示词模板", + "promptHistory": "提示词历史", + "uploadImageFile": "上传图片", + "uploadAudioFile": "上传音频", + "dragDropHere": "拖拽文件到此处或点击上传", + "supportedImageFormats": "支持10MB以内的图片", + "supportedAudioFormats": "支持120s以内的音频/视频", + "supportedAudioFormatsShort": "支持120s以内的音频/视频", + "prefillLoadingDefault": "正在准备素材...", + "prefillLoadingTemplate": "正在加载模板素材...", + "prefillLoadingTask": "正在加载任务素材...", + "maxFileSize": "最大文件大小", + "taskId": "任务ID", + "taskStatus": "任务状态", + "createdAt": "创建时间", + "completedAt": "完成时间", + "duration": "持续时间", + "confirm": "确认", + "cancel": "取消", + "save": "保存", + "edit": "编辑", + "delete": "删除", + "close": "关闭", + "copyLink": "复制链接", + "pleaseCopyManually": "请手动选择并复制下面的文本", + "back": "返回", + "next": "下一步", + "previous": "上一步", + "finish": "完成", + "submitting": "提交中...", + "t2vHint1": "输入文字描述,AI将为您生成精彩的视频内容", + "t2vHint2": "支持多种风格:写实、动画、艺术等", + "t2vHint3": "可以描述场景、动作、情感等细节", + "t2vHint4": "让您的创意通过文字变成生动的视频", + "i2vHint1": "上传一张图片,AI将为您生成动态视频", + "i2vHint2": "支持多种图片格式:JPG、PNG、WebP等", + "i2vHint3": "可以生成各种风格的动态效果", + "i2vHint4": "让静态图片动起来,创造无限可能", + "s2vHint1": "上传一张角色图片+一段音频", + "s2vHint2": "AI将让角色根据音频内容说话和动作", + "s2vHint3": "让您的角色栩栩如生地动起来", + "s2vHint4": "来创造属于您的专属数字人吧", + "operationSuccess": "操作成功", + "operationFailed": "操作失败", + "pleaseWait": "请稍候...", + "noData": "暂无数据", + "errorOccurred": "发生错误", + "networkError": "网络错误", + "serverError": "服务器错误", + "deleteTaskConfirm": "删除任务?", + "deleteTaskConfirmMessage": "删除后无法恢复,包括任务记录、生成的文件、相关数据。此操作不可撤销!", + "confirmDelete": "删除", + "regenerateTaskConfirm": "重新生成任务?", + "regenerateTaskConfirmMessage": "重新生成将删除当前任务和已生成的内容,然后使用相同参数创建新任务。此操作不可撤销!", + "confirmRegenerate": "重新生成", + "regeneratingTaskAlert": "正在重新生成任务...", + "deletingTaskAlert": "正在删除任务...", + "taskDeletedSuccessAlert": "任务删除成功", + "deleteTaskFailedAlert": "删除任务失败", + "getTaskDetailFailedAlert": "获取任务详情失败", + "taskNotExistAlert": "任务不存在", + "loadTaskFilesFailedAlert": "加载任务文件失败", + "taskMaterialReuseSuccessAlert": "任务素材复用成功", + "loadTaskDataFailedAlert": "加载任务数据失败", + "fileUnavailableAlert": "文件不可用", + "downloadFailedAlert": "下载失败,请重试。", + "taskSubmitSuccessAlert": "任务提交成功", + "taskSubmitFailedAlert": "任务提交失败", + "submitTaskFailedAlert": "任务提交失败", + "downloadSuccessAlert": "文件下载成功", + "getTaskResultFailedAlert": "获取结果失败", + "downloadTaskResultFailedAlert": "下载结果失败", + "viewTaskResultFailedAlert": "查看结果失败", + "cancelTaskConfirm": "取消任务?", + "cancelTaskConfirmMessage": "取消任务后任务将停止执行,已生成的部分结果可能丢失,可以稍后重新生成。", + "confirmCancel": "取消", + "taskCancelSuccessAlert": "任务取消成功", + "cancelTaskFailedAlert": "取消任务失败", + "taskRetrySuccessAlert": "任务重试成功", + "retryTaskFailedAlert": "重试任务失败", + "taskRegenerateSuccessAlert": "任务重新生成成功", + "regenerateTaskFailedAlert": "重新生成任务失败", + "taskNotFoundAlert": "任务未找到", + "seconds": "秒", + "minutes": "分钟", + "hours": "小时", + "days": "天", + "weeks": "周", + "months": "月", + "years": "年", + "position": "位置", + "calculating": "计算中", + "completed": "已完成", + "unknown": "未知", + "queueing": "排队中", + "overallProgress": "总体进度", + "queueStatus": "排队状态", + "templateApplied": "已应用模板", + "promptHistoryApplied": "已应用历史提示词", + "promptHistoryCleared": "提示词历史已清空", + "getPromptHistoryFailed": "获取提示词历史失败", + "saveTaskHistoryFailed": "保存任务历史失败", + "parseTaskHistoryFailed": "解析任务历史失败", + "getTaskHistoryFailed": "获取任务历史失败", + "getImageHistoryFailed": "获取图片历史失败", + "taskHistorySaved": "任务历史已保存", + "taskHistoryCleared": "任务历史已清空", + "clickToDownload": "点击下载", + "clickApply": "点击应用", + "justNow": "刚刚", + "minutesAgo": "分钟前", + "hoursAgo": "小时前", + "daysAgo": "天前", + "weeksAgo": "周前", + "monthsAgo": "月前", + "yearsAgo": "年前", + "oneMinuteAgo": "一分钟前", + "oneHourAgo": "一小时前", + "oneDayAgo": "一天前", + "oneWeekAgo": "一周前", + "oneMonthAgo": "一个月前", + "oneYearAgo": "一年前", + "shareNotFound": "分享不存在", + "backToHome": "返回首页", + "videoNotAvailable": "视频不可用", + "shareLinkCopied": "分享链接已复制", + "createdWithAI": "由AI生成", + "createSimilar": "做同款", + "createSimilarDescription": "点击按钮使用相同的设置创建您的视频", + "templatesGeneratedByLightX2V": "以下视频由LightX2V生成,鼠标悬停/点击播放", + "randomTemplates": "随机刷新模版", + "shareDataImported": "分享数据已导入", + "shareDataImportFailed": "分享数据导入失败", + "materials": "素材", + "template": "模板", + "templateDescription": "该视频由LightX2V-数字人模型生成", + "pleaseLoginFirst": "请先登录", + "showDetails": "查看详情", + "hideDetails": "隐藏详情", + "oneClickReplication": "一键复刻同款", + "customizableContent": "可自定义内容", + "poweredByLightX2V": "速生视频 - LightX2V", + "latestAIModel": "最新AI数字人模型,飞速生成视频", + "customizableCharacter": "可自由更换角色与音频", + "userGeneratedVideo": "生成的视频", + "noImage": "暂无图片", + "noAudio": "暂无音频", + "noVideo": "暂无视频", + "taskCompletedSuccessfully": "视频生成完成!", + "onlyUseImage": "仅使用图片", + "onlyUseAudio": "仅使用音频", + "reUseImage": "复用图片", + "reUseAudio": "复用音频", + "templateVideo": "语音驱动视频生成模板", + "description": "该视频由 LightX2V 生成", + "timeCost": "耗时 ", + "voiceSynthesis": "语音合成", + "applySelectedVoice": "应用当前选择的声音", + "generatedAudio": "生成的音频", + "synthesizedAudio": "合成音频", + "enterTextToConvert": "输入要转换的文本", + "ttsPlaceholder": "今天的天气好好呀,要一起出去走走吗~", + "voiceInstruction": "语音指令", + "voiceInstructionHint": "(仅适用于v2.0音色)", + "voiceInstructionPlaceholder": "使用指令控制合成语音细节,包括但不限于情绪、语境、方言、语气、速度、音调等,例如:请用温暖亲切的声线介绍", + "voiceInstructionOptional": "语音指令(可选)", + "selectVoice": "选择音色", + "searchVoice": "搜索音色", + "filter": "筛选", + "filterVoices": "筛选音色", + "voiceSettings": "语音设置", + "speechRate": "语速", + "volume": "音量", + "pitch": "音调", + "emotionIntensity": "情感强度", + "emotionType": "情感类型", + "neutral": "中性", + "scene": "场景", + "version": "版本", + "language": "语言", + "gender": "性别", + "reset": "重置", + "done": "完成", + "ttsGenerationFailed": "语音生成失败,请重试", + "applyAudioFailed": "应用音频失败,请重试", + "allScenes": "全部场景", + "generalScene": "通用场景", + "customerServiceScene": "客服场景", + "educationScene": "教育场景", + "funAccent": "趣味口音", + "rolePlaying": "角色扮演", + "audiobook": "有声阅读", + "multilingual": "多语种", + "multiEmotion": "多情感", + "videoDubbing": "视频配音", + "singleSegmentMode": "单段模式", + "multiSegmentMode": "多段模式", + "switchToSingleSegmentMode": "切换到单段模式", + "switchToMultiSegmentMode": "切换到多段模式", + "mergedAudio": "合并音频", + "applyMergedAudio": "应用合并音频", + "addSegment": "添加段落", + "segment": "段落", + "segmentNumber": "段落 {index}", + "dragToReorder": "拖拽调整顺序", + "copySegment": "复制段落", + "deleteSegment": "删除段落", + "segmentCopied": "段落已复制并添加到末尾", + "noSegmentsToApply": "没有可应用的音频段", + "mergedAudioLoadFailed": "合并音频加载失败: {error}", + "mergedAudioFailed": "合并音频失败: {error}", + "playbackFailed": "播放失败: {error}", + "audioDecodeFailed": "音频解码失败: {error}", + "ttsHistoryTitle": "历史记录", + "ttsHistoryHint": "系统会自动保留最近 20 条使用过的文本与语音指令。", + "ttsHistoryEmpty": "暂未保存任何记录", + "ttsHistoryEmptyHint": "生成一次语音即可创建首条历史记录。", + "ttsHistoryTextLabel": "语音文本历史", + "ttsHistoryInstructionLabel": "语音指令历史", + "ttsHistoryTextEmpty": "(文本为空)", + "ttsHistoryInstructionEmpty": "(指令为空)", + "ttsHistoryVoiceLabel": "音色历史", + "ttsHistoryVoiceEmpty": "未设置音色", + "ttsHistoryApply": "使用该记录", + "ttsHistoryApplySelected": "应用", + "ttsHistoryDeleteEntry": "删除此记录", + "ttsHistoryTabCombined": "全部", + "ttsHistoryTabText": "输入文本历史", + "ttsHistoryTabInstruction": "语音指令历史", + "ttsHistoryTabVoice": "音色历史", + "ttsHistoryTitleCombined": "全部历史记录", + "ttsHistoryTitleText": "语音文本历史", + "ttsHistoryTitleInstruction": "语音指令历史", + "ttsHistoryTitleVoice": "音色历史", + "ttsHistoryClear": "清空历史", + "allVersions": "全部版本", + "allLanguages": "全部语言", + "allGenders": "全部性别", + "female": "女性", + "male": "男性", + "taskCountdown": "任务倒计时", + "footer": { + "tagline": "由流光AI - LightX2V驱动的 AI 数字人视频生成平台", + "links": { + "home": "流光 AI 官网", + "github": "GitHub", + "xiaohongshu": "小红书" + }, + "alt": { + "github": "GitHub 标志", + "xiaohongshu": "小红书标志" + }, + "copyright": "© {year} 流光 AI 版权所有" + }, + "tts": { + "title": "AI 语音合成", + "subtitle": "让您的文字变成动听的声音", + "inputText": "请输入要合成的文字", + "voice": "选择音色", + "speed": "语速", + "volume": "音量", + "pitch": "音调", + "emotion": "情感", + "generateVoice": "生成语音", + "cancel": "取消", + "generating": "生成中...", + "generated": "已生成", + "error": "生成失败", + "errorMessage": "请检查输入文本或选择音色", + "voiceOptions": "音色选项", + "fast": "快速", + "normal": "正常", + "slow": "慢速", + "angry": "愤怒", + "happy": "开心", + "sad": "悲伤", + "neutral": "中性", + "excited": "兴奋", + "calm": "平静", + "gentle": "温柔", + "serious": "严肃", + "friendly": "友好", + "professional": "专业", + "child": "儿童", + "robot": "机器人", + "male": "男声", + "female": "女声", + "other": "其他", + "search": "搜索音色", + "noResults": "没有找到相关音色", + "placeholder": "请输入要合成的文字", + "placeholderVoice": "请选择音色", + "placeholderSpeed": "请选择语速", + "placeholderVolume": "请选择音量", + "placeholderPitch": "请选择音调", + "placeholderEmotion": "请选择情感", + "placeholderGenerate": "请点击生成按钮", + "multiSegmentMode": "多段模式", + "singleSegmentMode": "单段模式", + "switchToSingleSegmentMode": "切换到单段模式", + "switchToMultiSegmentMode": "切换到多段模式", + "mergedAudio": "合并音频", + "applyMergedAudio": "应用合并音频", + "addSegment": "添加段落", + "segment": "段落", + "segmentNumber": "段落 {index}", + "dragToReorder": "拖拽调整顺序", + "copySegment": "复制段落", + "deleteSegment": "删除段落", + "selectVoice": "选择音色", + "generate": "生成", + "text": "文本", + "voiceInstructionOptional": "语音指令(可选)", + "segmentCopied": "段落已复制并添加到末尾", + "noSegmentsToApply": "没有可应用的音频段", + "mergedAudioLoadFailed": "合并音频加载失败: {error}", + "mergedAudioFailed": "合并音频失败: {error}", + "unknownError": "未知错误", + "playbackFailed": "播放失败: {error}", + "audioDecodeFailed": "音频解码失败: {error}" + }, + "podcast": { + "exampleInputs.1": "https://github.com/ModelTC/LightX2V", + "exampleInputs.2": "LLM大模型的原理", + "exampleInputs.3": "什么是深度学习?", + "exampleInputs.4": "如何平衡工作和生活?", + "exampleInputs.5": "如何科学减肥", + "title": "AI 双人播客生成器", + "subtitle": "让知识\"听\"得见", + "generating": "正在生成播客...", + "generatingStatusWithCount": "正在生成播客 (已生成 {count} 段)...", + "ready": "音频已就绪,点击播放", + "readyWithCount": "音频已就绪 (已生成 {count} 段),点击播放", + "preparingFirstAudio": "正在准备第一段音频...", + "preparingAudio": "正在准备音频...", + "completed": "生成完成 (共 {count} 段)", + "stopped": "已停止生成", + "generationFailed": "生成失败", + "generatePodcast": "生成播客", + "dualPersonPodcast": "双人播客", + "stopGeneration": "停止生成", + "downloadAudio": "下载音频", + "applyToDigitalHuman": "转为数字人视频", + "generateMore": "生成更多播客", + "historyTitle": "已生成播客记录", + "toggleSidebar": "折叠/展开", + "noHistory": "暂无历史记录", + "completedStatus": "✓ 已完成", + "generatingStatus": "生成中...", + "showSubtitles": "显示字幕", + "hideSubtitles": "隐藏字幕", + "inputPlaceholder": "输入文章/文件链接或指定主题,例如:AI的原理", + "enterLinkOrTopic": "请输入链接或主题", + "audioReady": "音频已就绪", + "audioLoading": "音频加载中,请稍候...", + "playbackFailed": "播放失败,请重试", + "playbackFailedWithError": "播放失败: {error}", + "audioLoadFailed": "音频加载失败,请检查网络连接", + "noAudioAvailable": "暂无音频可播放", + "noAudioToDownload": "暂无音频可下载", + "pleaseGenerateFirst": "请先生成播客", + "applySuccess": "播客已添加到音频素材", + "applyFailed": "应用到数字人失败", + "loadAudioFailed": "加载播客音频失败", + "sessionDataNotFound": "会话数据不存在,请刷新历史记录", + "loadSessionFailed": "加载会话失败", + "loadAudioFailedDetail": "加载音频失败", + "audioDecodeFailed": "音频解码失败: {error}", + "audioLoadFailedNetwork": "音频加载失败,请检查网络连接", + "audioLoadFailedFormat": "音频加载失败,请检查网络连接或音频格式", + "audioLoadFailedWithError": "音频加载失败: {error}", + "audioMayBeSilent": "音频播放可能无声,请刷新页面重试", + "unknownError": "未知错误", + "exampleInputs": [ + "https://github.com/ModelTC/LightX2V", + "LLM大模型的原理", + "什么是深度学习?", + "如何平衡工作和生活?", + "如何科学减肥" + ] + }, + "faceDetectionFailed": "角色识别失败", + "pleaseUploadImage": "请先上传图片", + "multiRoleModeRequires": "多角色模式需要至少2个角色,请手动添加更多角色", + "audioSeparationFailed": "音频分离失败", + "singleRoleModeInfo": "当前为单角色模式,多个角色将统一按照同一个音轨进行对口型同步。", + "ttsCompleted": "语音合成完成,已自动添加到音频素材", + "imageDragSuccess": "图片拖拽上传成功", + "pleaseDragImage": "请拖拽图片文件", + "audioDragSuccess": "音频/视频拖拽上传成功", + "pleaseDragAudio": "请拖拽音频或视频文件", + "videoDragSuccess": "视频拖拽上传成功", + "pleaseDragVideo": "请拖拽视频文件", + "authFailedPleaseRelogin": "认证失败,请重新登录", + "getGitHubAuthUrlFailed": "获取GitHub认证URL失败", + "getGoogleAuthUrlFailed": "获取Google认证URL失败", + "pleaseEnterPhoneNumber": "请输入手机号", + "pleaseEnterValidPhoneNumber": "请输入正确的手机号格式", + "verificationCodeSent": "验证码已发送,请查收短信", + "sendVerificationCodeFailed": "发送验证码失败", + "sendVerificationCodeFailedRetry": "发送验证码失败,请重试", + "pleaseEnterPhoneAndCode": "请输入手机号和验证码", + "loginSuccess": "登录成功", + "verificationCodeErrorOrExpired": "验证码错误或已过期", + "loginFailedRetry": "登录失败,请重试", + "loginError": "登录过程中发生错误", + "loggedOut": "已退出登录", + "loadModelListFailed": "加载模型列表失败", + "loadModelFailed": "加载模型失败", + "imageTemplateSelected": "图片素材已选择", + "loadImageTemplateFailed": "加载图片素材失败", + "audioTemplateSelected": "音频素材已选择", + "loadAudioTemplateFailed": "加载音频素材失败", + "audioFileUrlFailed": "音频文件URL获取失败", + "audioPlaybackFailed": "音频播放失败", + "templateLoadingPleaseWait": "模板正在加载中,请稍后再试", + "pleaseSelectTaskType": "请选择任务类型", + "pleaseSelectModel": "请选择模型", + "pleaseEnterPrompt": "请输入提示词", + "promptTooLong": "提示词长度不能超过1000个字符", + "i2vTaskRequiresImage": "图生视频任务需要上传参考图片", + "s2vTaskRequiresImage": "数字人任务需要上传角色图片", + "s2vTaskRequiresAudio": "数字人任务需要上传音频文件", + "animateTaskRequiresImage": "角色替换任务需要上传角色图片", + "animateTaskRequiresVideo": "角色替换任务需要上传参考视频", + "prepareMultiPersonAudioFailed": "准备多人音频失败", + "taskSubmittedButParseFailed": "任务提交成功,但解析响应失败", + "refreshTaskListFailed": "刷新任务列表失败", + "getResultFailed": "获取结果失败", + "initFailedPleaseRefresh": "初始化失败,请刷新页面重试", + "historyCleared": "历史记录空间已清理", + "historyImageApplied": "已应用历史图片", + "applyHistoryImageFailed": "应用历史图片失败", + "historyAudioApplied": "已应用历史音频", + "applyHistoryAudioFailed": "应用历史音频失败", + "audioHistoryUrlFailed": "音频历史URL获取失败", + "imageHistoryCleared": "图片历史已清空", + "audioHistoryCleared": "音频历史已清空", + "storageCleared": "存储空间已清理", + "clearStorageFailed": "清理存储空间失败", + "loginExpiredPleaseRelogin": "登录已过期,请重新登录", + "networkRequestFailed": "网络请求失败", + "videoLoadTimeout": "视频加载超时,请重试", + "templateDataIncomplete": "模板数据不完整", + "loadMoreInspirationComingSoon": "加载更多灵感功能开发中...", + "microphonePermissionDenied": "麦克风权限被拒绝。请点击Chrome地址栏左侧的🔒或🎤图标,选择允许麦克风访问,然后刷新页面重试", + "microphoneNotFound": "未找到麦克风设备,请检查设备连接或使用其他设备", + "recordingNotSupportedOnMobile": "移动端浏览器不支持录音功能,可以拍摄视频来代替录音", + "microphoneInUse": "麦克风被其他应用占用,请关闭其他使用麦克风的程序后重试", + "microphoneNotCompatible": "麦克风设备不支持所需的录音参数,请使用其他麦克风设备", + "securityErrorUseHttps": "安全限制:请确保使用HTTPS协议访问网站", + "shareDataIncomplete": "分享数据不完整", + "pleaseRelogin": "请重新登录", + "pleaseLoginFirst": "请先登录", + "cancelTaskFailedRetry": "取消任务失败,请重试", + "shareFailedRetry": "分享失败,请重试", + "retryTaskFailedRetry": "重试任务失败,请重试", + "splitingAudio": "多角色模式,自动分割音频中..." +} diff --git a/lightx2v/deploy/server/frontend/src/main.js b/lightx2v/deploy/server/frontend/src/main.js new file mode 100644 index 0000000000000000000000000000000000000000..04d6dcd78b2f3319e295c2d42d2e15c20c217f4c --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/main.js @@ -0,0 +1,20 @@ +import { createApp } from 'vue' +import router from './router' + +import './style.css' +import App from './App.vue' +import { createPinia } from 'pinia' + +import i18n, { initLanguage } from './utils/i18n' + +const app = createApp(App) +const pinia = createPinia() + +app.use(i18n) +app.use(pinia) +app.use(router) + +// 初始化语言 +initLanguage().then(() => { + app.mount('#app') + }) diff --git a/lightx2v/deploy/server/frontend/src/router/index.js b/lightx2v/deploy/server/frontend/src/router/index.js new file mode 100644 index 0000000000000000000000000000000000000000..39523c8842bcbe0f36ae3b2360a585e65f664359 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/router/index.js @@ -0,0 +1,144 @@ +// src/router/index.js +import { createRouter, createWebHistory } from 'vue-router' +import Login from '../views/Login.vue' +import Layout from '../views/Layout.vue' +import Generate from '../components/Generate.vue' +import Projects from '../components/Projects.vue' +import Inspirations from '../components/Inspirations.vue' +import Share from '../views/Share.vue' +import PodcastGenerate from '../views/PodcastGenerate.vue' +import { showAlert } from '../utils/other' +import i18n from '../utils/i18n' + +const routes = [ + { + path: '/', + redirect: (to) => { + // 保留查询参数(用于 OAuth 回调) + return { path: '/generate', query: to.query } + } + }, + { + path: '/login', name: 'Login', component: Login, meta: { requiresAuth: false } + }, + { + path: '/share/:shareId', name: 'Share', component: Share, meta: { requiresAuth: false } + }, + { + path: '/podcast_generate', name: 'PodcastGenerate', component: PodcastGenerate, meta: { requiresAuth: true } + }, + { + path: '/podcast_generate/:session_id', name: 'PodcastSession', component: PodcastGenerate, meta: { requiresAuth: true } + }, + { + path: '/home', + component: Layout, + meta: { + requiresAuth: true + }, + children: [ + { + path: '/generate', + name: 'Generate', + component: Generate, + meta: { requiresAuth: true }, + props: route => ({ query: route.query }) + }, + { + path: '/projects', + name: 'Projects', + component: Projects, + meta: { requiresAuth: true }, + props: route => ({ query: route.query }) + }, + { + path: '/inspirations', + name: 'Inspirations', + component: Inspirations, + meta: { requiresAuth: true }, + props: route => ({ query: route.query }) + }, + { + path: '/task/:taskId', + name: 'TaskDetail', + component: Projects, + meta: { requiresAuth: true }, + props: route => ({ taskId: route.params.taskId, query: route.query }) + }, + { + path: '/template/:templateId', + name: 'TemplateDetail', + component: Inspirations, + meta: { requiresAuth: true }, + props: route => ({ templateId: route.params.templateId, query: route.query }) + }, + ] + }, + { + path: '/:pathMatch(.*)*', + name: 'NotFound', + component: () => import('../views/404.vue') + } +] + +const router = createRouter({ + history: createWebHistory(), + routes +}) + +// 路由守卫 - 整合和优化后的逻辑 +router.beforeEach((to, from, next) => { + const token = localStorage.getItem('accessToken') + console.log('token', token) + // 检查 URL 中是否有 code 参数(OAuth 回调) + // 可以从路由查询参数或实际 URL 中获取 + const hasOAuthCode = to.query?.code !== undefined || + (typeof window !== 'undefined' && new URLSearchParams(window.location.search).get('code') !== null) + + // 1. OAuth 回调处理:如果有 code 参数(GitHub/Google 登录回调),直接放行 + // App.vue 的 onMounted 会处理登录回调逻辑 + if (hasOAuthCode) { + console.log('检测到 OAuth 回调,放行让 App.vue 处理') + next() + return + } + + // 2. 不需要登录的页面(登录页、分享页等) + if (to.meta.requiresAuth === false) { + // 如果已登录用户访问登录页,重定向到生成页面 + if (token && to.path === '/login') { + console.log('已登录用户访问登录页,重定向到生成页') + next('/generate') + } else { + next() + } + return + } + + // 3. 需要登录的页面处理 + if (!token) { + // 未登录但访问需要登录的页面,跳转到登录页 + console.log('需要登录但未登录,跳转到登录页') + next('/login') + // 延迟显示提示,确保路由跳转完成 + setTimeout(() => { + showAlert(i18n.global.t('pleaseLoginFirst'), 'warning') + }, 100) + return + } + + // 4. 已登录用户处理 + // 已登录用户访问首页,重定向到生成页(保留查询参数) + if (to.path === '/') { + console.log('已登录用户访问首页,重定向到生成页') + const query = to.query && Object.keys(to.query).length > 0 ? to.query : {} + next({ path: '/generate', query }) + } else { + // 已登录且访问其他页面,正常放行 + next() + } +}) + + + +export default router; diff --git a/lightx2v/deploy/server/frontend/src/style.css b/lightx2v/deploy/server/frontend/src/style.css new file mode 100644 index 0000000000000000000000000000000000000000..1d75e5a53811395f991c0b857ee00a4349c7febd --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/style.css @@ -0,0 +1,2540 @@ +@import "tailwindcss"; +@import "@flaticon/flaticon-uicons/css/all/all"; +/* Tailwind v4 深色模式配置 */ +@variant dark (&:is(.dark *)); + +@theme { + /* 自定义颜色 */ + --color-primary: #9a72ff; + --color-secondary: #1b1240; + --color-accent: #b78bff; + --color-dark: #0b0a20; + --color-dark-light: #0f0e22; + --color-laser-purple: #d2c1ff; + --color-neon-purple: #a88bff; + --color-electric-purple: #8e88ff; + --font-display: "Inter", "sans-serif"; + --gradient-primary: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff); + --gradient-main: linear-gradient(135deg, #0b0a20 0%, #1b1240 50%, #0f0e22 100%); + --box-shadow-neon: 0 0 10px rgba(154, 114, 255, 0.5), 0 0 20px rgba(154, 114, 255, 0.3); + --box-shadow-neon-lg: 0 0 15px rgba(154, 114, 255, 0.7), 0 0 30px rgba(154, 114, 255, 0.5); + --box-shadow-laser: 0 0 20px rgba(154, 114, 255, 0.8), 0 0 40px rgba(154, 114, 255, 0.6), 0 0 60px rgba(154, 114, 255, 0.4), 0 0 80px rgba(154, 114, 255, 0.2); + --box-shadow-laser-intense: 0 0 25px rgba(154, 114, 255, 0.9), 0 0 50px rgba(154, 114, 255, 0.7), 0 0 75px rgba(154, 114, 255, 0.5), 0 0 100px rgba(154, 114, 255, 0.3), 0 0 125px rgba(154, 114, 255, 0.1); + --box-shadow-electric: 0 0 15px rgba(124, 106, 255, 0.8), 0 0 30px rgba(124, 106, 255, 0.6), 0 0 45px rgba(124, 106, 255, 0.4); + }; + +:root { + --gradient-primary: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff); + color: white; + + /* Apple 极简黑白风格 - 品牌色 */ + --brand-primary: #90dce1; + --brand-primary-light: #90dce1; + --brand-primary-rgb: 89, 194, 233; + --brand-primary-light-rgb: 142, 220, 255; + + /* Apple 极简黑白风格 - 基础颜色 */ + --text-primary: #1d1d1f; /* 浅色模式主要文字 */ + --text-primary-dark: #f5f5f7; /* 深色模式主要文字 */ + --text-secondary: #86868b; /* 浅色模式次要文字/图标 */ + --text-secondary-dark: #98989d; /* 深色模式次要文字/图标 */ + + --bg-primary: #ffffff; /* 浅色模式主背景 */ + --bg-primary-dark: #000000; /* 深色模式主背景 */ + --bg-secondary: #f5f5f7; /* 浅色模式次要背景 */ + --bg-secondary-dark: #1c1c1e; /* 深色模式次要背景 */ + + --card-bg: rgba(255, 255, 255, 0.8); /* 浅色模式卡片背景 */ + --card-bg-dark: rgba(44, 44, 46, 0.8); /* 深色模式卡片背景 */ + --card-bg-hover-dark: #3a3a3c; /* 深色模式卡片 hover 背景 */ + + --surface-bg: rgba(255, 255, 255, 0.95); /* 浅色模式表面背景 */ + --surface-bg-dark: rgba(30, 30, 30, 0.95); /* 深色模式表面背景 */ + + --border-light: rgba(0, 0, 0, 0.06); /* 浅色模式浅边框 */ + --border-light-dark: rgba(255, 255, 255, 0.08); /* 深色模式浅边框 */ + --border-medium: rgba(0, 0, 0, 0.08); /* 浅色模式中等边框 */ + --border-medium-dark: rgba(255, 255, 255, 0.12); /* 深色模式中等边框 */ +} + +/* Apple 风格全局主题样式 */ +html { + background: #ffffff; + color: #1d1d1f; + transition: background-color 0.15s ease, color 0.15s ease; + will-change: background-color, color; +} + +html.dark { + background: #000000; + color: #f5f5f7; +} + +/* 主题切换时禁用过渡动画,提高切换速度 */ +html.theme-transitioning, +html.theme-transitioning * { + transition: none !important; +} + +/* Body 主题支持 */ +body { + background: #ffffff; + color: #1d1d1f; + transition: background-color 0.15s ease, color 0.15s ease; + will-change: background-color, color; +} + +html.dark body { + background: #000000; + color: #f5f5f7; +} + +/* 减少登录页面闪烁 */ +.login-container { + transition: opacity 0.3s ease-in-out; +} + +.main-container { + transition: opacity 0.3s ease-in-out; +} + +/* 防止翻译闪烁 */ +.app-loading { + opacity: 0; + transition: opacity 0.2s ease-in-out; +} + +.app-loaded { + opacity: 1; +} +/* 确保html和body能够正确填充 */ +html { + height: 100%; + width: 100%; + overflow: hidden; +} + +/* body 作为 #app 的容器 */ +body { + margin: 0; + padding: 0; + width: 100vw; + height: 100vh; + overflow: hidden; +} + +@layer utilities { + .content-auto { + content-visibility: auto; + } + .text-gradient { + background-clip: text; + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + } + /* 新增渐变图标颜色类 */ + .text-gradient-primary { + background: var(--gradient-primary); + -webkit-background-clip: text; + background-clip: text; + -webkit-text-fill-color: transparent; + } + .scrollbar-thin { + scrollbar-width: thin; + } + .scrollbar-thin::-webkit-scrollbar { + width: 6px; + height: 6px; + } + .scrollbar-thin::-webkit-scrollbar-thumb { + background-color: rgba(210, 193, 255, 0.6); + border-radius: 3px; + } + .scrollbar-thin::-webkit-scrollbar-track { + background-color: rgba(31, 41, 55, 0.3); + border-radius: 3px; + } + .scrollbar-thin::-webkit-scrollbar-thumb:hover { + background-color: rgba(210, 193, 255, 0.8); + } + + /* 自定义滚动条颜色 */ + .scrollbar-thumb-laser-purple\/30::-webkit-scrollbar-thumb { + background-color: rgba(168, 85, 247, 0.3); + } + .scrollbar-track-gray-800\/30::-webkit-scrollbar-track { + background-color: rgba(31, 41, 55, 0.3); + } + + /* 历史任务区域滚动条样式 - 与主内容区域保持一致 */ + .history-tasks-scroll::-webkit-scrollbar { + width: 8px !important; + } + + .history-tasks-scroll::-webkit-scrollbar-track { + background: rgba(27, 18, 64, 0.3) !important; + border-radius: 4px; + } + + .history-tasks-scroll::-webkit-scrollbar-thumb { + background: linear-gradient(135deg, rgba(210, 193, 255, 0.8), rgba(168, 139, 255, 0.8)) !important; + border-radius: 4px; + border: 1px solid rgba(210, 193, 255, 0.3); + } + + .history-tasks-scroll::-webkit-scrollbar-thumb:hover { + background: linear-gradient(135deg, rgba(210, 193, 255, 1), rgba(168, 139, 255, 1)) !important; + } + + /* Apple 极简风格滚动条 */ + .main-scrollbar { + scroll-behavior: smooth; + -webkit-overflow-scrolling: touch; + } + + /* 滚动条轨道 - Apple 风格(几乎透明) */ + .main-scrollbar::-webkit-scrollbar { + width: 6px; + height: 6px; + } + + .main-scrollbar::-webkit-scrollbar-track { + background: transparent; + } + + /* 滚动条滑块 - Apple 极简风格 */ + .main-scrollbar::-webkit-scrollbar-thumb { + background: rgba(0, 0, 0, 0.2); + border-radius: 3px; + border: none; + transition: background 0.2s ease; + } + + .main-scrollbar::-webkit-scrollbar-thumb:hover { + background: rgba(0, 0, 0, 0.35); + } + + /* 深色模式下的滚动条 */ + .dark .main-scrollbar::-webkit-scrollbar-thumb { + background: rgba(255, 255, 255, 0.15); + } + + .dark .main-scrollbar::-webkit-scrollbar-thumb:hover { + background: rgba(255, 255, 255, 0.25); + } + + /* Firefox 滚动条样式 */ + .main-scrollbar { + scrollbar-width: thin; + scrollbar-color: rgba(0, 0, 0, 0.2) transparent; + } + + .dark .main-scrollbar { + scrollbar-color: rgba(255, 255, 255, 0.15) transparent; + } + .animate-pulse-slow { + animation: pulse 3s cubic-bezier(0.4, 0, 0.6, 0.5) infinite; + } + .animate-electric-pulse { + animation: electricPulse 1.5s ease-in-out infinite; + } + @keyframes electricPulse { + 0%, 100% { + box-shadow: 0 0 15px rgba(142, 136, 255, 0.8), 0 0 30px rgba(142, 136, 255, 0.6); + transform: scale(1); + } + 50% { + box-shadow: 0 0 25px rgba(142, 136, 255, 1), 0 0 50px rgba(142, 136, 255, 0.8), 0 0 75px rgba(142, 136, 255, 0.4); + transform: scale(1.02); + } + } + .bg-laser-gradient { + background: linear-gradient(135deg, #d2c1ff 0%, #a88bff 25%, #8e88ff 50%, #d2c1ff 75%, #a88bff 100%); + background-size: 200% 200%; + animation: gradientShift 3s ease-in-out infinite; + } + @keyframes gradientShift { + 0%, 100% { background-position: 0% 50%; } + 50% { background-position: 100% 50%; } + } + .text-laser-glow { + text-shadow: 0 0 10px rgba(154, 114, 255, 0.8), 0 0 20px rgba(154, 114, 255, 0.6), 0 0 30px rgba(154, 114, 255, 0.4); + } + .border-laser { + border-color: #d2c1ff; + box-shadow: 0 0 15px rgba(154, 114, 255, 0.6), inset 0 0 15px rgba(154, 114, 255, 0.1); + } + .btn-primary{ + padding: 15px 25px; + border-radius: 14px; + font-weight: 500; + font-size: 14px; + letter-spacing: 0.2px; + font-family: 'Inter', sans-serif; + background: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff); + border: 0; + text-decoration: none; + box-shadow: 0 10px 30px rgba(140, 110, 255, 0.4); + transition: transform 0.15s ease, box-shadow 0.15s ease; + } + .btn-primary:hover{ + transform: translateY(-1px); + box-shadow: 0 14px 40px rgba(140, 110, 255, 0.55); + } + /* 修复布局问题 */ + .task-type-btn { + padding: 0.75rem 1rem; + font-size: 0.875rem; + font-weight: 500; + transition-property: color, background-color; + transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); + transition-duration: 150ms; + } + + .task-type-btn:hover { + background-color: rgba(154, 114, 255, 0.1); + } + + .model-selection { + display: flex; + flex-wrap: wrap; + gap: 0.5rem; + } + + .upload-section { + display: grid; + grid-template-columns: repeat(1, minmax(0, 1fr)); + gap: 1.5rem; + margin-bottom: 1.5rem; + } + + @media (min-width: 768px) { + .upload-section { + grid-template-columns: repeat(2, minmax(0, 1fr)); + } + } + + .upload-area { + position: relative; + border: 2px dashed rgba(154, 114, 255, 0.4); + border-radius: 0.75rem; + padding: 1.5rem; + text-align: center; + justify-content: center; + align-items: center; + transition-property: all; + transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); + transition-duration: 150ms; + cursor: pointer; + height: 30vh; + background-color: rgba(27, 18, 64, 0.1); + } + + /* 光球按钮样式 */ + .floating-orb-btn { + position: relative; + width: 100px; + height: 100px; + border-radius: 50%; + background: linear-gradient(135deg, #9a72ff, #a855f7, #ec4899); + border: none; + cursor: pointer; + transition: all 0.3s ease; + overflow: hidden; + box-shadow: 0 4px 20px rgba(154, 114, 255, 0.4); + } + + .floating-orb-btn:hover { + transform: scale(1.1); + box-shadow: 0 0 40px rgba(154, 114, 255, 0.8), 0 0 80px rgba(154, 114, 255, 0.6); + } + + .orb-glow { + position: absolute; + inset: -15px; + border-radius: 50%; + background: radial-gradient(circle, rgba(154, 114, 255, 0.3) 0%, transparent 70%); + animation: pulse 2s infinite; + } + + .orb-content { + position: relative; + z-index: 2; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + height: 100%; + color: white; + } + + @keyframes pulse { + 0%, 100% { opacity: 0.5; transform: scale(1); } + 50% { opacity: 0.8; transform: scale(1.05); } + } + + /* 创作区域样式 */ + .creation-area { + opacity: 0; + transform: scale(0.8) translateY(20px); + pointer-events: none; + transition: all 0.4s cubic-bezier(0.34, 1.56, 0.64, 1); + transform-origin: center center; + } + + .creation-area.show { + opacity: 1; + transform: scale(1) translateY(0); + pointer-events: auto; + } + + .creation-area.hide { + opacity: 0; + transform: scale(0.8) translateY(20px); + pointer-events: none; + } + + /* 提示文字淡入动画 */ + .animate-fade-in { + animation: fadeIn 0.5s ease-in-out; + } + + @keyframes fadeIn { + from { + opacity: 0; + transform: translateY(10px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + + /* 提示文字滚动动画 */ + .hint-fade-enter-active, .hint-fade-leave-active { + transition: all 0.5s ease-in-out; + } + .hint-fade-enter-from { + opacity: 0; + transform: translateY(20px); + } + .hint-fade-leave-to { + opacity: 0; + transform: translateY(-20px); + } + + .upload-area:hover { + border-color: rgba(154, 114, 255, 0.7); + box-shadow: 0 0 20px rgba(154, 114, 255, 0.8), 0 0 40px rgba(154, 114, 255, 0.6), 0 0 60px rgba(154, 114, 255, 0.4), 0 0 80px rgba(154, 114, 255, 0.2); + } + + .upload-icon { + margin: 0 auto; + width: 4rem; + height: 4rem; + background-color: rgba(154, 114, 255, 0.2); + border-radius: 9999px; + display: flex; + align-items: center; + justify-content: center; + margin-bottom: 1rem; + transition-property: all; + transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); + transition-duration: 150ms; + } + + .upload-icon { + background-color: rgba(154, 114, 255, 0.3); + } + + /* 图片预览占据整个上传区域 */ + .image-preview { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + overflow: hidden; + z-index: 10; + display: flex; + align-items: center; + justify-content: center; + background-color: rgba(154, 114, 255, 0.1); + cursor: pointer; + } + + .image-preview img { + height: 100%; + width: auto; + max-width: 100%; + display: block; + margin: 0 auto; + object-fit: contain; + transition: all 0.3s ease; + background-color: rgba(154, 114, 255, 0.1); + cursor: pointer; + } + + /* 音频预览占据整个上传区域 */ + .audio-preview { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + overflow: hidden; + z-index: 10; + display: flex; + align-items: center; + justify-content: center; + background-color: rgba(154, 114, 255, 0.1); + cursor: pointer; + } + + .audio-preview audio { + width: 90%; + height: 60px; + max-height: 80%; + border-radius: 0.5rem; + background-color: rgba(27, 18, 64, 0.3); + display: block; + } + + /* 确保音频控件在容器中正确显示 */ + .audio-preview audio::-webkit-media-controls { + background-color: rgba(27, 18, 64, 0.5); + border-radius: 0.5rem; + } + + /* 上传内容样式 */ + .upload-content { + width: 100%; + height: 100%; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + } + + .btn-close { + position: absolute; + top: 0.5rem; + right: 0.5rem; + background-color: #ef4444; + color: white; + border-radius: 9999px; + width: 1.5rem; + height: 1.5rem; + display: flex; + align-items: center; + justify-content: center; + font-size: 0.75rem; + cursor: pointer; + z-index: 20; + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.3); + } + + /* 整体缩放80% - 放大容器后缩放,使得铺满屏幕 */ + #app { + display: flex; + width: 125%; + height: 125%; + transform: scale(0.8); + transform-origin: 0 0; + overflow: visible; + } + + /* Firefox 兼容 */ + @supports (-moz-appearance: none) { + #app { + -moz-transform: scale(0.8); + -moz-transform-origin: 0 0; + } + } + + .bg-linear-dark { + background-color: linear-gradient(135deg, #0b0a20 0%, #1b1240 50%, #0f0e22 100%); + } + + aside { + flex-shrink: 0; + width: 280px; /* 默认展开宽度 */ + min-width: 3rem; /* 最小宽度 */ + max-width: 500px; /* 最大宽度 */ + background-color: transparent; + border-right: 1px solid rgba(154, 114, 255, 0.4); + display: flex; + flex-direction: column; + transition-property: all; + transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); + transition-duration: 300ms; + z-index: 10; + position: relative; + } + + /* 拖拽调整器 */ + .resize-handle { + position: absolute; + top: 0; + right: 0; + width: 6px; + height: 100%; + background: transparent; + cursor: col-resize; + z-index: 50; + transition: background-color 0.2s ease; + } + + .resize-handle:hover { + background: rgba(154, 114, 255, 0.5); + } + + .resize-handle:active { + background: rgba(154, 114, 255, 0.8); + } + + /* 确保拖拽手柄可见 */ + .resize-handle::before { + content: ''; + position: absolute; + top: 50%; + right: 1px; + width: 2px; + height: 20px; + background: rgba(154, 114, 255, 0.3); + transform: translateY(-50%); + border-radius: 1px; + } + + /* 拖拽时的视觉反馈 */ + .resizing { + user-select: none; + pointer-events: none; + } + + .resizing * { + pointer-events: none; + } + + main { + flex: 1; + display: flex; + flex-direction: column; + min-width: 0; + width: calc(100% - 280px); /* 主内容区域占据剩余宽度,适应展开的侧边栏 */ + height: 100%; + } + + /* 内容区域全屏显示 */ + .content-area { + flex: 1; + overflow-y: auto; + /* background-color: #0b0a20; */ + padding: 1rem; + width: 100%; + min-height: 0; /* 确保flex子元素可以收缩 */ + } + + /* 任务创建面板全屏 */ + #task-creator { + max-width: none; + width: 80%; + } + + #inspiration-gallery { + max-width: none; + width: 90%; + padding: 0 1rem; + } + + /* 移动端全屏显示 */ + @media (max-width: 768px) { + #task-creator { + width: 95%; + } + + #inspiration-gallery { + width: 100%; + padding: 0 0.5rem; + } + + /* 任务详情面板移动端全屏 */ + .task-detail-panel { + width: 100%; + padding: 0 0.5rem; + } + + /* 任务运行面板移动端全屏 */ + .task-running-panel { + width: 100%; + padding: 0 0.5rem; + } + + /* 任务失败面板移动端全屏 */ + .task-failed-panel { + width: 100%; + padding: 0 0.5rem; + } + + /* 任务取消面板移动端全屏 */ + .task-cancelled-panel { + width: 100%; + padding: 0 0.5rem; + } + + /* 移动端内容区域调整 */ + .content-area { + padding: 0.5rem !important; + } + + /* 移动端创作区域调整 */ + .creation-area-container { + padding: 0.5rem; + } + } + + /* 任务详情面板全屏 */ + .task-detail-panel { + max-width: none; + width: 90%; + padding: 0 0rem; + } + + /* 上传区域全屏布局 */ + .upload-section { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); + gap: 2rem; + margin-bottom: 2rem; + width: 100%; + } + + /* 任务类型选择全屏 */ + .task-type-selection { + width: 100%; + margin-bottom: 2rem; + } + + .task-type-buttons { + display: flex; + width: 100%; + border-bottom: 1px solid rgba(154, 114, 255, 0.3); + } + + .task-type-btn { + flex: 1; + padding: 1rem 1.5rem; + font-size: 1rem; + font-weight: 500; + transition-property: color, background-color; + transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); + transition-duration: 150ms; + text-align: center; + } + + + /* 模型选择全屏 */ + .model-selection { + display: flex; + flex-wrap: wrap; + gap: 1rem; + width: 100%; + justify-content: flex-start; + } + + /* 提示词输入全屏 */ + .prompt-input-section { + width: 100%; + margin-bottom: 2rem; + } + + .prompt-textarea { + width: 100%; + min-height: 150px; + resize: vertical; + } + + + + /* 移动端响应式设计 */ + @media (max-width: 640px) { + /* 通用移动端样式 */ + .mobile-bottom-nav { + width: 100% !important; + height: auto !important; + padding: 0 !important; + backdrop-filter: blur(20px) !important; + z-index: 50 !important; + } + + .mobile-nav-buttons { + display: flex !important; + flex-direction: row !important; + justify-content: center !important; + align-items: center !important; + gap: 1rem !important; + padding: 1rem !important; + width: 100% !important; + } + + /* 确保LeftBar容器在小屏幕下居中 */ + .relative.w-20.pl-5.flex.flex-col.z-10 { + margin-left: auto !important; + margin-right: auto !important; + padding-left: 0 !important; + } + + .mobile-nav-btn { + width: 3rem !important; + height: 3rem !important; + flex-shrink: 0 !important; + } + + + /* 主布局调整为垂直布局 */ + .flex.flex-row { + flex-direction: column; + } + + /* 左侧功能区在移动端移动到下方 */ + .p-2.flex.flex-col.justify-center.h-full { + margin-top: 0 !important; + padding: 1rem !important; + } + + .p-2.flex.flex-col.justify-center.h-full nav { + display: flex !important; + flex-direction: row !important; + justify-content: space-around !important; + align-items: center !important; + gap: 1rem !important; + } + + /* 确保按钮容器在移动端完全对齐 */ + .mobile-nav-buttons .relative.group { + display: flex !important; + justify-content: center !important; + align-items: center !important; + flex: 1 !important; + margin: 0 !important; + } + + /* 按钮在移动端调整大小 */ + .p-2.flex.flex-col.justify-center.h-full nav button { + width: 3rem !important; + height: 3rem !important; + flex-shrink: 0 !important; + } + + + /* 历史任务区域调整 */ + .flex-1.overflow-y-auto.p-10.content-area.main-scrollbar { + padding: 1rem !important; + } + + /* 搜索和筛选区域在移动端垂直排列 */ + .flex.flex-col.gap-4.mb-6 { + flex-direction: column !important; + gap: 1rem !important; + } + + /* 筛选按钮在移动端换行 */ + .flex.gap-2 { + flex-wrap: wrap !important; + gap: 0.5rem !important; + } + + .flex.gap-2 button { + font-size: 0.75rem !important; + padding: 0.5rem 0.75rem !important; + } + + /* 上传区域在移动端调整 */ + .upload-section { + grid-template-columns: 1fr !important; + gap: 1rem !important; + } + + .upload-area { + padding: 1rem !important; + } + + /* 任务类型选择在移动端调整 */ + .grid.grid-cols-1.gap-4 { + grid-template-columns: 1fr !important; + gap: 0.75rem !important; + } + + /* 模型选择在移动端调整 */ + .grid.grid-cols-2.gap-3 { + grid-template-columns: repeat(2, 1fr) !important; + gap: 0.5rem !important; + } + + /* 参数设置区域在移动端调整 */ + .bg-dark-light.rounded-xl.p-6 { + padding: 1rem !important; + } + + /* 提交按钮在移动端调整 */ + .btn-primary.flex.items-center.justify-center.px-8.py-3 { + width: 100% !important; + padding: 1rem !important; + } + + /* 灵感广场移动端适配 */ + .grid.grid-cols-1.gap-6 { + grid-template-columns: 1fr !important; + gap: 1rem !important; + } + + /* 模态框在移动端调整 */ + .fixed.inset-0.z-50 { + padding: 1rem !important; + } + + .bg-dark.rounded-2xl.shadow-2xl.max-w-4xl.w-full { + max-height: 90vh !important; + margin: 0 !important; + } + } + + /* 超小屏幕适配 (iPhone SE等) */ + @media (max-width: 375px) { + /* 底部导航按钮更紧凑 */ + .p-2.flex.flex-col.justify-center.h-full nav button { + width: 2.5rem !important; + height: 2.5rem !important; + } + + + /* 主内容区域调整 */ + .flex-1.flex.flex-col.min-h-0 { + margin-bottom: 4rem !important; + } + + /* 搜索框和按钮调整 */ + .flex.flex-col.gap-4.mb-6 input { + font-size: 0.875rem !important; + } + + .flex.gap-2 button { + font-size: 0.7rem !important; + padding: 0.4rem 0.6rem !important; + } + + /* 任务卡片在超小屏幕调整 */ + .bg-dark-light.rounded-xl.p-4 { + padding: 0.75rem !important; + } + + /* 上传区域在超小屏幕调整 */ + .upload-area { + padding: 0.75rem !important; + } + + .upload-area p { + font-size: 0.875rem !important; + } + + /* 模态框在超小屏幕调整 */ + .fixed.inset-0.z-50 { + padding: 0.5rem !important; + } + + .bg-dark.rounded-2xl.shadow-2xl.max-w-4xl.w-full { + max-height: 95vh !important; + } + + /* 超小屏幕表单优化 */ + .sms-login-form .form-control { + font-size: 13px !important; + padding: 12px 16px !important; + } + + .sms-login-form .btn-sms-code { + min-width: 100px !important; + font-size: 12px !important; + padding: 6px 12px !important; + } + + .sms-login-form .btn-placeholder { + min-width: 100px !important; + } + } + + /* 分页组件样式 */ + .pagination-container { + border-bottom: 1px solid rgba(154, 114, 255, 0.15); + padding-bottom: 0.5rem; + } + + .pagination-btn-compact { + background: rgba(27, 18, 64, 0.2); + border: 1px solid rgba(154, 114, 255, 0.2); + color: rgba(255, 255, 255, 0.6); + min-width: 20px; + height: 20px; + display: flex; + align-items: center; + justify-content: center; + font-size: 10px; + } + + .pagination-btn-compact:hover:not(.disabled) { + background: rgba(154, 114, 255, 0.15); + border-color: rgba(154, 114, 255, 0.4); + color: rgba(255, 255, 255, 0.8); + transform: translateY(-0.5px); + } + + .pagination-btn-compact.active { + background: linear-gradient(135deg, rgba(154, 114, 255, 0.6), rgba(168, 139, 255, 0.6)); + border-color: rgba(154, 114, 255, 0.6); + color: white; + box-shadow: 0 0 6px rgba(154, 114, 255, 0.4); + } + + .pagination-btn-compact.disabled { + opacity: 0.3; + cursor: not-allowed; + background: rgba(27, 18, 64, 0.05); + border-color: rgba(154, 114, 255, 0.05); + color: rgba(255, 255, 255, 0.2); + } + + .pagination-btn-compact.disabled:hover { + transform: none; + background: rgba(27, 18, 64, 0.05); + border-color: rgba(154, 114, 255, 0.05); + color: rgba(255, 255, 255, 0.2); + } + + /* 页码输入框样式 */ + .page-input { + width: 8px; + height: 5px; + text-align: center; + font-size: 12px; + border-radius: 4px; + background: linear-gradient(135deg, rgba(154, 114, 255, 0.3), rgba(168, 139, 255, 0.3)); + border: 1px solid rgba(154, 114, 255, 0.4); + color: white; + outline: none; + transition: all 0.2s ease; + font-weight: 500; + } + + .page-input:focus { + border-color: rgba(154, 114, 255, 0.8); + background: linear-gradient(135deg, rgba(154, 114, 255, 0.5), rgba(168, 139, 255, 0.5)); + box-shadow: 0 0 6px rgba(154, 114, 255, 0.4); + } + + .page-input::-webkit-outer-spin-button, + .page-input::-webkit-inner-spin-button { + -webkit-appearance: none; + margin: 0; + } + + .page-input[type=number] { + -moz-appearance: textfield; + } + + /* 修复状态指示器 */ + .status-indicator { + width: 0.75rem; + height: 0.75rem; + border-radius: 9999px; + box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); + } + + /* 修复按钮样式 */ + .btn-primary { + padding: 12px 22px; + border-radius: 14px; + font-weight: 700; + letter-spacing: 0.2px; + background: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff); + border: 0; + text-decoration: none; + box-shadow: 0 10px 30px rgba(140, 110, 255, 0.4); + transition: transform 0.15s ease, box-shadow 0.15s ease; + cursor: pointer; + display: inline-block; + } + + .btn-primary:hover { + transform: translateY(-1px); + box-shadow: 0 14px 40px rgba(140, 110, 255, 0.55); + } + + /* 修复模型按钮样式 */ + .model-btn { + padding: 0.5rem 1rem; + border-radius: 0.5rem; + font-size: 0.875rem; + transition-property: all; + transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); + transition-duration: 150ms; + cursor: pointer; + border: 1px solid; + } + + .model-btn.active { + background-color: rgba(154, 114, 255, 0.2); + border-color: rgba(154, 114, 255, 0.4); + box-shadow: 0 0 20px rgba(154, 114, 255, 0.8), 0 0 40px rgba(154, 114, 255, 0.6), 0 0 60px rgba(154, 114, 255, 0.4), 0 0 80px rgba(154, 114, 255, 0.2); + animation: electricPulse 1.5s ease-in-out infinite; + } + + /* 确保内容区域正确滚动 */ + .content-scroll { + flex: 1; + overflow-y: auto; + } + + /* 任务进行中面板样式 */ + .task-running-panel .animate-pulse-slow { + animation: pulse 3s cubic-bezier(0.4, 0, 0.6, 0.5) infinite; + } + + @keyframes shimmer { + 0% { transform: translateX(-100%); } + 100% { transform: translateX(100%); } + } + + .subtask-item { + background: rgba(27, 18, 64, 0.2); + border: 1px solid rgba(154, 114, 255, 0.2); + border-radius: 8px; + padding: 0.75rem; + margin: 0.5rem 0; + } + + .subtask-header { + display: flex; + justify-content: between; + align-items: center; + margin-bottom: 0.5rem; + } + + .subtask-status { + padding: 0.25rem 0.5rem; + border-radius: 4px; + font-size: 0.75rem; + font-weight: 500; + } + + .subtask-status.pending { + background: rgba(251, 191, 36, 0.2); + color: #fbbf24; + } + + .subtask-status.running { + background: rgba(59, 130, 246, 0.2); + color: #3b82f6; + } + + .subtask-info { + font-size: 0.75rem; + color: rgba(255, 255, 255, 0.6); + margin-top: 0.25rem; + } + + /* 任务失败面板样式 */ + .task-failed-panel .bg-red-500\/10 { + background-color: rgba(239, 68, 68, 0.1); + } + + .error-details { + background: rgba(239, 68, 68, 0.1); + border: 1px solid rgba(239, 68, 68, 0.3); + border-radius: 8px; + padding: 1rem; + margin: 1rem 0; + text-align: left; + } + + .error-details pre { + background: rgba(0, 0, 0, 0.3); + border-radius: 4px; + padding: 0.75rem; + margin: 0.5rem 0 0 0; + font-size: 0.75rem; + color: #fca5a5; + overflow-x: auto; + white-space: pre-wrap; + word-break: break-word; + } + + .subtask-error { + background: rgba(239, 68, 68, 0.05); + border-left: 3px solid #ef4444; + padding: 0.75rem; + margin: 0.5rem 0; + border-radius: 0 4px 4px 0; + } + + .subtask-error pre { + background: #1a1a1a; + border: 1px solid #dc2626; + border-radius: 6px; + padding: 12px; + margin: 8px 0; + color: #fca5a5; + font-size: 12px; + line-height: 1.4; + white-space: pre-wrap; + word-break: break-word; + max-height: 200px; + overflow-y: auto; + } + + .error-details { + animation: slideDown 0.3s ease-out; + } + + @keyframes slideDown { + from { + opacity: 0; + transform: translateY(-10px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + + .task-detail-panel video { + width: 100%; + height: 100%; + object-fit: cover; + } + + /* 素材预览样式 */ + + /* 任务状态指示器增强 */ + .status-indicator { + position: relative; + } + + .status-indicator::after { + content: ''; + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + width: 0.25rem; + height: 0.25rem; + background-color: currentColor; + border-radius: 50%; + opacity: 0.8; + } + + /* 响应式任务面板 */ + @media (max-width: 768px) { + .task-detail-panel { + padding: 0 0.5rem; + } + } + + /* 提示消息动画 */ + .animate-slide-down { + animation: slideDown 0.3s ease-out; + } + + @keyframes slideDown { + 0% { + opacity: 0; + transform: translate(-50%, -100%); + } + 100% { + opacity: 1; + transform: translate(-50%, 0); + } + } + + /* 提示消息样式 - 统一浅色透明背景 */ + .alert { + backdrop-filter: blur(15px); + background: rgba(0, 0, 0, 0.8); + border-radius: 0.75rem; + box-shadow: 0 8px 32px rgba(0, 0, 0, 0.3); + color: #fff; + } +} + + .floating-toggle-btn { + position: fixed; + top: 50%; + left: 256px; /* 默认位置,对应 w-64 (256px) */ + transform: translateY(-50%); + width: 20px; + height: 40px; + background: linear-gradient(135deg, #1a1a2e 0%, #2a2a4e 50%, #1e1e3e 100%); + border: 1px solid rgba(139, 92, 246, 0.3); + border-left: none; + border-radius: 0 8px 8px 0; + color: #9ca3af; + display: flex; + align-items: center; + justify-content: center; + cursor: pointer; + transition: background-color 0.3s ease, color 0.3s ease, box-shadow 0.3s ease; + z-index: 20; + box-shadow: 2px 0 8px rgba(0, 0, 0, 0.3); + } + + .floating-toggle-btn:hover { + background: linear-gradient(135deg, #2a2a4e 0%, #3a3a5e 50%, #2e2e4e 100%); + color: #8b5cf6; + box-shadow: 2px 0 12px rgba(139, 92, 246, 0.3); + } + + .floating-toggle-btn.collapsed { + border-radius: 0 8px 8px 0; + border-left: 1px solid rgba(139, 92, 246, 0.3); + border-right: none; + } + + .resizing .floating-toggle-btn { + transition: none !important; + } + + .left-glow-zone.show-glow { + opacity: 1; + } + + .history-section { + max-height: calc(100% - 200px); + border-radius: 0 12px 12px 0; + border: 2px solid rgba(139, 92, 246, 0.4); + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); + margin: 8px 8px 8px 0; + transition: all 0.3s ease; + cursor: pointer; + background: rgba(139, 92, 246, 0.05); + } + + .history-section:hover { + background: rgba(139, 92, 246, 0.05) !important; + border-color: rgba(139, 92, 246, 0.15) !important; + box-shadow: 0 0 20px rgba(154, 114, 255, 0.8), 0 0 40px rgba(154, 114, 255, 0.6), 0 0 60px rgba(154, 114, 255, 0.4), 0 0 80px rgba(154, 114, 255, 0.2); + transform: translateY(-2px); + } + + .history-section:hover { + background: rgba(139, 92, 246, 0.08) !important; + } + + /* 修复任务项样式 */ + .task-item { + border: 1px solid rgba(139, 92, 246, 0.3); + padding: 0.75rem; + border-radius: 0.5rem; + cursor: pointer; + transition-property: all; + transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); + transition-duration: 200ms; + transition: all 0.2s ease; + } + + .task-item:hover { + border: 1px solid rgba(167, 132, 255, 0.2); + background-color: rgba(167, 132, 255, 0.2); + transform: translateX(5px); + } + + /* 任务操作菜单样式 */ + .task-menu-container { + position: relative; + } + + .task-menu-dropdown { + animation: fadeInUp 0.2s ease-out; + } + + @keyframes fadeInUp { + from { + opacity: 0; + transform: translateY(10px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + + .task-menu-item { + transition: all 0.2s ease; + } + + .task-menu-item:hover { + transform: translateX(2px); + } + + /* 短信登录样式 */ + .sms-login-form { + background: transparent; + border-radius: 0; + padding: 0; + border: none; + margin: 0 auto 2rem auto; + max-width: 80%; + animation: slideDown 0.3s ease-out; + } + + /* 输入组样式 */ + .input-group { + display: flex; + gap: 12px; + align-items: stretch; + margin-bottom: 1rem; + } + + .input-group .form-control { + flex: 1; + border-radius: 12px; + border: 2px solid rgba(255, 255, 255, 0.2); + background: rgba(255, 255, 255, 0.1); + color: white; + padding: 14px 18px; + font-size: 14px; + font-weight: 400; + transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); + backdrop-filter: blur(10px); + position: relative; + min-width: 0; + height: 48px; + } + + .input-group .form-control:focus { + outline: none; + border-color: #9a72ff; + background: rgba(255, 255, 255, 0.12); + box-shadow: 0 0 0 4px rgba(154, 114, 255, 0.15), + 0 8px 25px rgba(154, 114, 255, 0.1); + transform: translateY(-1px); + } + + .input-group .form-control:hover:not(:focus) { + border-color: rgba(255, 255, 255, 0.25); + background: rgba(255, 255, 255, 0.1); + } + + .input-group .form-control::placeholder { + color: rgba(255, 255, 255, 0.5); + font-weight: 400; + transition: color 0.3s ease; + } + + .input-group .form-control:focus::placeholder { + color: rgba(255, 255, 255, 0.3); + } + + /* 单独的输入框样式 */ + .form-control { + width: 50%; + border: 2px solid rgba(255, 255, 255, 0.15); + background: rgba(255, 255, 255, 0.08); + color: white; + padding: 16px 20px; + font-size: 12px; + font-weight: 400; + transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); + backdrop-filter: blur(10px); + position: relative; + } + + .form-control:focus { + outline: none; + border-color: #9a72ff; + background: rgba(255, 255, 255, 0.12); + box-shadow: 0 0 0 4px rgba(154, 114, 255, 0.15), + 0 8px 25px rgba(154, 114, 255, 0.1); + transform: translateY(-1px); + } + + .input-group .btn { + border-radius: 16px; + padding: 16px 16px; + font-weight: 500; + font-size: 14px; + white-space: nowrap; + min-width: 100px; + flex-shrink: 0; + border: 2px solid transparent; + transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); + position: relative; + overflow: hidden; + } + + .input-group .btn:hover { + transform: translateY(-1px); + box-shadow: 0 8px 25px rgba(154, 114, 255, 0.2); + } + + .input-group .btn:active { + transform: translateY(0); + } + + /* 分隔线样式 */ + .divider { + position: relative; + text-align: center; + margin: 20px 0; + } + + .divider::before { + content: ''; + position: absolute; + top: 50%; + left: 0; + right: 0; + height: 1px; + background: linear-gradient(90deg, + transparent 0%, + rgba(255, 255, 255, 0.1) 20%, + rgba(255, 255, 255, 0.3) 50%, + rgba(255, 255, 255, 0.1) 80%, + transparent 100%); + } + + .divider-text { + background: rgba(27, 18, 64, 0.9); + padding: 8px 20px; + color: rgba(255, 255, 255, 0.7); + font-size: 13px; + font-weight: 500; + position: relative; + z-index: 1; + border-radius: 20px; + border: 1px solid rgba(255, 255, 255, 0.1); + backdrop-filter: blur(10px); + } + + /* 社交登录按钮样式 */ + .social-login-buttons { + display: flex; + justify-content: center; + gap: 20px; + } + + .btn-icon { + width: 48px; + height: 48px; + border-radius: 50%; + border: 2px solid rgba(255, 255, 255, 0.2); + background: rgba(255, 255, 255, 0.05); + color: white; + display: flex; + align-items: center; + justify-content: center; + font-size: 20px; + transition: all 0.3s ease; + cursor: pointer; + position: relative; + overflow: hidden; + } + + .btn-icon:hover { + transform: translateY(-3px) scale(1.05); + box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3); + } + + .btn-icon:active { + transform: translateY(-1px) scale(1.02); + } + + .btn-icon:hover { + background: rgba(255, 255, 255, 0.15); + border-color: rgba(255, 255, 255, 0.6); + box-shadow: 0 10px 30px rgba(255, 255, 255, 0.1); + } + + .btn-icon:hover i { + color: #ffffff; + text-shadow: 0 0 10px rgba(255, 255, 255, 0.5); + } + + .btn-icon { + background: rgba(255, 255, 255, 0.08); + border-color: rgba(255, 255, 255, 0.3); + } + + /* 图标按钮的波纹效果 */ + .btn-icon::before { + content: ''; + position: absolute; + top: 50%; + left: 50%; + width: 0; + height: 0; + border-radius: 50%; + background: rgba(255, 255, 255, 0.1); + transform: translate(-50%, -50%); + transition: width 0.3s ease, height 0.3s ease; + } + + .btn-icon:hover::before { + width: 100%; + height: 100%; + } + + /* 提交按钮区域样式 */ + .login-submit-section { + display: flex; + justify-content: center; + margin-top: 24px; + } + + .btn-submit { + width: 100%; + padding: 10px 28px; + border-radius: 10px; + font-weight: 600; + font-size: 14px; + letter-spacing: 0.5px; + background: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff); + border: 2px solid transparent; + color: #0c0920; + transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); + position: relative; + overflow: hidden; + box-shadow: 0 8px 25px rgba(154, 114, 255, 0.3); + } + + .btn-submit:hover:not(:disabled) { + transform: translateY(-2px); + box-shadow: 0 12px 35px rgba(154, 114, 255, 0.4); + } + + .btn-submit:active:not(:disabled) { + transform: translateY(0); + } + + .btn-submit:disabled { + opacity: 0.6; + cursor: not-allowed; + transform: none; + box-shadow: 0 4px 15px rgba(154, 114, 255, 0.2); + } + + .btn-submit::before { + content: ''; + position: absolute; + top: 0; + left: -100%; + width: 100%; + height: 100%; + background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent); + transition: left 0.5s ease; + } + + .btn-submit:hover::before { + left: 100%; + } + + /* 响应式设计 */ + @media (max-width: 768px) { + .login-card .card-body { + padding: 1.5rem !important; + } + + .social-login-buttons { + gap: 16px; + } + + .btn-icon { + width: 42px; + height: 42px; + font-size: 18px; + } + + .divider-text { + font-size: 12px; + padding: 0 12px; + } + } + + @media (max-width: 480px) { + .login-card .card-body { + padding: 1rem !important; + } + + .input-group .form-control { + padding: 14px 16px; + font-size: 14px; + } + + .btn-icon { + width: 38px; + height: 38px; + font-size: 16px; + } + + .social-login-buttons { + gap: 12px; + } + } + + @keyframes slideDown { + from { + opacity: 0; + transform: translateY(-10px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + + /* 语言切换时的文本过渡动画 */ + .language-transition { + transition: all 0.3s ease-in-out; + } + + .sms-login-form .form-control { + background: rgba(255, 255, 255, 0.1); + border: 1px solid rgba(255, 255, 255, 0.2); + color: white; + border-radius: 8px; + transition: all 0.3s ease; + height: 50px; + padding: 12px 16px; + font-size: 16px; + } + + .sms-login-form .form-control:focus { + background: rgba(255, 255, 255, 0.15); + border-color: #007bff; + box-shadow: 0 0 0 0.2rem rgba(0, 123, 255, 0.25); + color: white; + } + + .sms-login-form .form-control::placeholder { + color: rgba(255, 255, 255, 1.0); + } + + .sms-login-form .btn-primary:disabled { + opacity: 0.6; + cursor: not-allowed; + } + + .sms-login-form .btn-primary { + color: white; + transition: all 0.3s ease; + height: 50px; + padding: 12px 20px; + font-size: 16px; + font-weight: 500; + } + + /* 发送验证码按钮专用样式 */ + .btn-sms-code { + background: linear-gradient(135deg, #9a72ff 0%, #7c6aff 100%); + border: 2px solid rgba(255, 255, 255, 0.2); + color: white; + font-weight: 500; + font-size: 13px; + padding: 8px 16px; + border-radius: 12px; + transition: all 0.3s ease; + box-shadow: 0 4px 12px rgba(154, 114, 255, 0.3); + white-space: nowrap; + min-width: 110px; + height: 48px; + display: flex; + align-items: center; + justify-content: center; + flex-shrink: 0; + } + + .btn-sms-code:hover:not(:disabled) { + background: linear-gradient(135deg, #8a5fff 0%, #6b4aff 100%); + transform: translateY(-1px); + box-shadow: 0 6px 16px rgba(154, 114, 255, 0.4); + border-color: rgba(255, 255, 255, 0.3); + } + + .btn-sms-code:active:not(:disabled) { + transform: translateY(0); + box-shadow: 0 2px 8px rgba(154, 114, 255, 0.3); + } + + .btn-sms-code:disabled { + background: linear-gradient(135deg, #9ca3af 0%, #6b7280 100%); + color: #d1d5db; + cursor: not-allowed; + transform: none; + box-shadow: 0 1px 4px rgba(156, 163, 175, 0.2); + border-color: rgba(255, 255, 255, 0.1); + } + + /* 按钮占位符样式,用于对齐 */ + .btn-placeholder { + min-width: 110px; + height: 48px; + flex-shrink: 0; + } + + /* 表单整体优化 */ + .sms-login-form .input-group:last-child { + margin-bottom: 0; + } + + /* 输入框聚焦时的统一效果 */ + .sms-login-form .form-control:focus { + border-color: #9a72ff; + background: rgba(255, 255, 255, 0.15); + box-shadow: 0 0 0 3px rgba(154, 114, 255, 0.2), + 0 4px 12px rgba(154, 114, 255, 0.1); + transform: translateY(-1px); + } + + /* 占位符文字优化 */ + .sms-login-form .form-control::placeholder { + color: rgba(255, 255, 255, 0.6); + font-weight: 400; + transition: color 0.3s ease; + } + + .sms-login-form .form-control:focus::placeholder { + color: rgba(255, 255, 255, 0.4); + } + /* 文本截断样式 */ +.line-clamp-1 { +display: -webkit-box; +-webkit-line-clamp: 1; +-webkit-box-orient: vertical; +overflow: hidden; +} + +.line-clamp-2 { +display: -webkit-box; +-webkit-line-clamp: 2; +-webkit-box-orient: vertical; +overflow: hidden; +} + +.line-clamp-3 { +display: -webkit-box; +-webkit-line-clamp: 3; +-webkit-box-orient: vertical; +overflow: hidden; +} + +/* 侧边栏动画样式 */ +.sidebar-expand { +transition: width 0.3s cubic-bezier(0.4, 0, 0.2, 1); +} + +.sidebar-text { +transition: opacity 0.3s cubic-bezier(0.4, 0, 0.2, 1), transform 0.3s cubic-bezier(0.4, 0, 0.2, 1); +transform: translateX(-10px); +} + +.sidebar-text.show { +opacity: 1; +transform: translateX(0); +} + +/* 平滑过渡动画 */ +.smooth-transition { +transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); +} + +/* 登录页面样式 */ +.share-container { + min-height: 100%; + min-width: 100%; + background: linear-gradient(135deg, #f5f5f7 0%, #e8e8ed 50%, #f0f0f5 100%); + position: relative; + overflow: auto; + display: flex; + align-items: center; + justify-content: center; + transition: background 0.3s ease; + } + +html.dark .share-container { + background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%); +} + +/* 登录页面样式 - 支持深浅色主题 */ +.login-container { +min-height: 100%; +min-width: 100%; +background: linear-gradient(135deg, #f5f5f7 0%, #e8e8ed 50%, #f0f0f5 100%); +position: relative; +overflow: hidden; +display: flex; +align-items: center; +justify-content: center; +transition: background 0.3s ease; +} + +html.dark .login-container { +background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%); +} + +.login-container::before { +content: ''; +position: absolute; +top: 0; +left: 0; +right: 0; +bottom: 0; +background: +radial-gradient(circle at 20% 80%, rgba(var(--brand-primary-rgb), 0.08) 0%, transparent 50%), +radial-gradient(circle at 80% 20%, rgba(var(--brand-primary-light-rgb), 0.08) 0%, transparent 50%), +radial-gradient(circle at 40% 40%, rgba(var(--brand-primary-rgb), 0.04) 0%, transparent 50%); +animation: backgroundShift 20s ease-in-out infinite; +} + +/* 主容器样式 - 支持深浅色主题 */ +.main-container { +min-height: 100%; +min-width: 100%; +background: linear-gradient(135deg, #f5f5f7 0%, #e8e8ed 50%, #f0f0f5 100%); +overflow-x: visible; +overflow-y: hidden; +display: flex; +flex-direction: column; +transition: background 0.3s ease; +} + +html.dark .main-container { +background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%); +} + +/* 顶部栏样式 */ +.top-bar { +top: 0; +left: 0; +right: 0; +height: 70px; +background: rgba(11, 10, 32, 0.95); +backdrop-filter: blur(20px); +border-bottom: 1px solid rgba(143, 143, 143, 0.2); +z-index: 1000; +display: flex; +align-items: center; +} + +.top-bar-content { +width: 100%; +height: 100%; +display: flex; +align-items: center; +justify-content: space-between; +padding: 0 24px; +} + +.top-bar-left { +display: flex; +align-items: center; +} + +.top-bar-logo { +height: 50px; +width: auto; +filter: brightness(0) invert(1); +} + +.top-bar-right { +display: flex; +align-items: center; +} + +.user-info { +display: flex; +align-items: center; +gap: 12px; +} + +.user-avatar { +width: 36px; +height: 36px; +border-radius: 50%; +border: 1px solid white; +display: flex; +align-items: center; +justify-content: center; +overflow: hidden; +} + +.avatar-img { +width: 100%; +height: 100%; +object-fit: cover; +} + +.user-avatar i { +color: #9a72ff; +font-size: 16px; +} + +.user-details { +display: flex; +flex-direction: column; +align-items: flex-end; +} + +.username { +font-size: 14px; +font-weight: 500; +color: #ffffff; +line-height: 1.2; +} + +.user-email { +font-size: 12px; +color: #9ca3af; +line-height: 1.2; +} + +.logout-btn { +width: 32px; +height: 32px; +border-radius: 6px; +background: rgba(154, 114, 255, 0.1); +border: 1px solid rgba(154, 114, 255, 0.2); +color: #9ca3af; +display: flex; +align-items: center; +justify-content: center; +transition: all 0.2s ease; +cursor: pointer; +} + +.logout-btn:hover { +background: rgba(154, 114, 255, 0.2); +color: #9a72ff; +border-color: rgba(154, 114, 255, 0.4); +} + +@keyframes backgroundShift { +0%, 100% { opacity: 1; } +50% { opacity: 0.8; } +} + +.login-card { + position: relative; + overflow: hidden; + transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1); + max-width: 400px; + width: 100%; + margin: 0 auto; +} + + +.login-logo { + background: white; + -webkit-background-clip: text; + background-clip: text; + -webkit-text-fill-color: transparent; + font-size: 3rem; + font-weight: 600; + margin-bottom: 0.5rem; + letter-spacing: -0.02em; + text-align: center; + font-family: "Inter", sans-serif; + line-height: 1.2; +} + +@keyframes logoGlow { +0% { +filter: drop-shadow(0 0 10px rgba(154, 114, 255, 0.5)); +} +100% { +filter: drop-shadow(0 0 20px rgba(154, 114, 255, 0.8)); +} +} + +.login-subtitle { + color: rgba(255, 255, 255, 0.6); + font-size: 1rem; + margin-bottom: 2.5rem; + font-weight: 400; + text-align: center; + line-height: 1.5; +} + +.floating-particles { +position: absolute; +top: 0; +left: 0; +width: 100%; +height: 100%; +overflow: hidden; +pointer-events: none; +} + +.particle { +position: absolute; +width: 4px; +height: 4px; +background: rgba(193, 169, 255, 0.6); +border-radius: 50%; +animation: floatParticle 15s linear infinite; +} + +.particle:nth-child(1) { left: 10%; animation-delay: 0s; } +.particle:nth-child(2) { left: 20%; animation-delay: 12s; } +.particle:nth-child(3) { left: 30%; animation-delay: 10s; } +.particle:nth-child(4) { left: 40%; animation-delay: 6s; } +.particle:nth-child(5) { left: 50%; animation-delay: 8s; } +.particle:nth-child(6) { left: 60%; animation-delay: 14s; } +.particle:nth-child(7) { left: 70%; animation-delay: 16s; } +.particle:nth-child(8) { left: 80%; animation-delay: 2s; } +.particle:nth-child(9) { left: 90%; animation-delay: 4s; } + +@keyframes floatParticle { +0% { +transform: translateY(100vh) scale(0); +opacity: 0; +} +10% { +opacity: 1; +} +90% { +opacity: 1; +} +100% { +transform: translateY(-100px) scale(1); +opacity: 0; +} +} + +.login-features { +padding-top: 2rem; +padding-left: 10rem; +border-top: 1px solid rgba(154, 114, 255, 0.2); +} + +.feature-item { +display: flex; +align-items: center; +margin-bottom: 1rem; +color: rgba(255, 255, 255, 0.8); +font-size: 0.9rem; +} + +.feature-icon { +width: 20px; +height: 20px; +background: linear-gradient(135deg, #9a72ff, #b78bff); +border-radius: 50%; +margin-right: 12px; +display: flex; +align-items: center; +justify-content: center; +font-size: 10px; +color: white; +} + +/* 登录页面进入动画 */ +.login-container { +animation: fadeInUp 0.8s ease-out; +} + +.login-card { +animation: slideInUp 0.6s ease-out 0.2s both; +} + +.login-logo { +animation: logoGlow 3s ease-in-out infinite alternate, fadeInScale 0.8s ease-out 0.4s both; +} + +.login-subtitle { +animation: fadeIn 0.8s ease-out 0.6s both; +} + +.login-features { +animation: fadeIn 0.8s ease-out 1s both; +} + +.feature-item { +animation: slideInLeft 0.6s ease-out both; +} + +.feature-item:nth-child(1) { animation-delay: 1.2s; } +.feature-item:nth-child(2) { animation-delay: 1.4s; } +.feature-item:nth-child(3) { animation-delay: 1.6s; } +.feature-item:nth-child(4) { animation-delay: 1.8s; } +.feature-item:nth-child(5) { animation-delay: 2.0s; } +.feature-item:nth-child(6) { animation-delay: 2.2s; } + +@keyframes fadeInUp { +from { +opacity: 0; +transform: translateY(30px); +} +to { +opacity: 1; +transform: translateY(0); +} +} + +@keyframes slideInUp { +from { +opacity: 0; +transform: translateY(50px); +} +to { +opacity: 1; +transform: translateY(0); +} +} + +@keyframes fadeInScale { +from { +opacity: 0; +transform: scale(0.8); +} +to { +opacity: 1; +transform: scale(1); +} +} + +@keyframes fadeIn { +from { +opacity: 0; +} +to { +opacity: 1; +} +} + +@keyframes slideInLeft { +from { +opacity: 0; +transform: translateX(-20px); +} +to { +opacity: 1; +transform: translateX(0); +} +} + +/* 响应式设计 */ +@media (max-width: 768px) { +.login-logo { +font-size: 3rem; +} + +.login-subtitle { +font-size: 1rem; +} + +.login-card { +margin: 20px; +border-radius: 20px; +} +} + +@media (max-width: 480px) { +.login-logo { +font-size: 2rem; +} + +.login-subtitle { +font-size: 0.9rem; +} + +.login-card .card-body { +padding: 2rem !important; +} +} + +/* 简约登录页面样式 */ +.login-header { + text-align: center; + margin-bottom: 2rem; +} + +.login-form { + margin-bottom: 2rem; +} + +.form-group { + margin-bottom: 1.5rem; +} + +.form-input { + width: 100%; + padding: 1rem 1.25rem; + background: rgba(255, 255, 255, 0.04); + border: 1px solid rgba(255, 255, 255, 0.08); + border-radius: 12px; + color: #ffffff; + font-size: 1rem; + font-weight: 400; + transition: all 0.3s ease; + backdrop-filter: blur(10px); +} + +.form-input::placeholder { + color: rgba(255, 255, 255, 0.4); + font-weight: 400; +} + +.form-input:focus { + outline: none; + border-color: rgba(154, 114, 255, 0.4); + background: rgba(255, 255, 255, 0.06); + box-shadow: 0 0 0 3px rgba(154, 114, 255, 0.1); +} + +.verify-code-container { + display: flex; + gap: 0.75rem; + align-items: stretch; +} + +.verify-code-container .form-input { + flex: 1; +} + +.send-code-btn { + padding: 1rem 1.25rem; + background: rgba(154, 114, 255, 0.1); + border: 1px solid rgba(154, 114, 255, 0.2); + border-radius: 12px; + color: #9a72ff; + font-size: 0.9rem; + font-weight: 500; + white-space: nowrap; + transition: all 0.3s ease; + cursor: pointer; +} + +.send-code-btn:hover:not(:disabled) { + background: rgba(154, 114, 255, 0.15); + border-color: rgba(154, 114, 255, 0.3); + transform: translateY(-1px); +} + +.send-code-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + transform: none; +} + +.login-btn { + width: 100%; + padding: 1rem 1.25rem; + background: linear-gradient(135deg, #9a72ff, #7c6aff); + border: none; + border-radius: 12px; + color: #ffffff; + font-size: 1rem; + font-weight: 600; + cursor: pointer; + transition: all 0.3s ease; + position: relative; + overflow: hidden; +} + +.login-btn:hover:not(:disabled) { + transform: translateY(-2px); + box-shadow: 0 12px 24px rgba(154, 114, 255, 0.3); +} + +.login-btn:active:not(:disabled) { + transform: translateY(0); +} + +.login-btn:disabled { + opacity: 0.6; + cursor: not-allowed; + transform: none; + box-shadow: none; +} + +.divider { + position: relative; + text-align: center; + margin: 2rem 0; + color: rgba(255, 255, 255, 0.4); + font-size: 0.9rem; + font-weight: 500; +} + +.divider::before { + content: ''; + position: absolute; + top: 50%; + left: 0; + right: 0; + height: 1px; + background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.1), transparent); +} + +.divider span { + background: rgba(15, 14, 34, 0.95); + padding: 0 1rem; + position: relative; + z-index: 1; +} + +.social-login { + display: flex; + justify-content: center; + gap: 1rem; +} + +.social-btn { + width: 3rem; + height: 3rem; + background: rgba(255, 255, 255, 0.04); + border: 1px solid rgba(255, 255, 255, 0.08); + border-radius: 12px; + color: rgba(255, 255, 255, 0.7); + font-size: 1.25rem; + cursor: pointer; + transition: all 0.3s ease; + display: flex; + align-items: center; + justify-content: center; +} + +.social-btn:hover:not(:disabled) { + background: rgba(255, 255, 255, 0.08); + border-color: rgba(255, 255, 255, 0.15); + color: #ffffff; + transform: translateY(-2px); +} + +.social-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + transform: none; +} + +.card-body { + padding: 2.5rem; +} + +/* 响应式设计 */ +@media (max-width: 768px) { + .login-card { + margin: 1rem; + } + + .card-body { + padding: 2rem; + } + + .login-logo { + font-size: 2.5rem; + } + + .login-subtitle { + font-size: 0.9rem; + margin-bottom: 2rem; + } + + .form-input { + padding: 0.875rem 1rem; + font-size: 0.95rem; + } + + .send-code-btn { + padding: 0.875rem 1rem; + font-size: 0.85rem; + } + + .login-btn { + padding: 0.875rem 1rem; + font-size: 0.95rem; + } + + .social-btn { + width: 2.75rem; + height: 2.75rem; + font-size: 1.1rem; + } +} + +@media (max-width: 480px) { + .login-card { + + margin: 0.5rem; + border-radius: 16px; + } + + .card-body { + padding: 1.5rem; + } + + .login-logo { + font-size: 2rem; + } + + .login-subtitle { + font-size: 0.85rem; + margin-bottom: 1.5rem; + } + + .form-group { + margin-bottom: 1.25rem; + } + + .form-input { + padding: 0.75rem 0.875rem; + font-size: 0.9rem; + border-radius: 10px; + } + + .verify-code-container { + gap: 0.5rem; + } + + .send-code-btn { + padding: 0.75rem 0.875rem; + font-size: 0.8rem; + border-radius: 10px; + } + + .login-btn { + padding: 0.75rem 0.875rem; + font-size: 0.9rem; + border-radius: 10px; + } + + .divider { + margin: 1.5rem 0; + font-size: 0.85rem; + } + + .social-login { + gap: 0.75rem; + } + + .social-btn { + width: 2.5rem; + height: 2.5rem; + font-size: 1rem; + border-radius: 10px; + } +} + +/* Alert动画 */ +@keyframes slide-down { +from { +opacity: 0; +transform: translateX(-50%) translateY(-20px); +} +to { +opacity: 1; +transform: translateX(-50%) translateY(0); +} +} + +.animate-slide-down { +animation: slide-down 0.3s ease-out; +} diff --git a/lightx2v/deploy/server/frontend/src/utils/i18n.js b/lightx2v/deploy/server/frontend/src/utils/i18n.js new file mode 100644 index 0000000000000000000000000000000000000000..d993ed0fc0bcdf8201ae4637160282bd7d9a7127 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/utils/i18n.js @@ -0,0 +1,62 @@ +import { createI18n } from 'vue-i18n' +import { ref } from 'vue' + +const loadedLanguages = new Set() + +// 创建 i18n 实例(初始只设置 locale,不加载全部语言) +const i18n = createI18n({ + legacy: false, + globalInjection: true, + locale: 'zh', + fallbackLocale: 'en', + messages: {} +}) + +// 异步加载语言文件 +async function loadLanguageAsync(lang) { + if (!loadedLanguages.has(lang)) { + const messages = await import(`../locales/${lang}.json`) + i18n.global.setLocaleMessage(lang, messages.default) + loadedLanguages.add(lang) + } + if (i18n.global.locale.value === lang) return lang + i18n.global.locale.value = lang + localStorage.setItem('app-lang', lang) // ✅ 记住用户选择 + document.documentElement.lang = lang === 'zh' ? 'zh-CN' : 'en'; + return lang +} + +// 初始化默认语言 +async function initLanguage() { + const savedLang = localStorage.getItem('app-lang') || 'zh' + return loadLanguageAsync(savedLang) +} +async function switchLang() { + const newLang = i18n.global.locale.value === 'zh' ? 'en' : 'zh' + await loadLanguageAsync(newLang) +} + + // // 语言切换功能 + // const switchLanguage = (langCode) => { + // currentLanguage.value = langCode; + // localStorage.setItem('preferredLanguage', langCode); + + // // 更新页面标题 + // document.title = t('pageTitle'); + + // // 更新HTML lang属性 + // document.documentElement.lang = langCode === 'zh' ? 'zh-CN' : 'en'; + // }; + + // // 简单语言切换功能(中英文切换) + // const toggleLanguage = () => { + // const newLang = currentLanguage.value === 'zh' ? 'en' : 'zh'; + // switchLanguage(newLang); + // }; + + const languageOptions = ref([ + { code: 'zh', name: '中文', flag: '中' }, + { code: 'en', name: 'English', flag: 'EN' } +]); + +export { i18n as default, loadLanguageAsync, initLanguage, switchLang, languageOptions } diff --git a/lightx2v/deploy/server/frontend/src/utils/other.js b/lightx2v/deploy/server/frontend/src/utils/other.js new file mode 100644 index 0000000000000000000000000000000000000000..5a0181d666f5ddab68fb2846f19a9300b3c05186 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/utils/other.js @@ -0,0 +1,7839 @@ +import { ref, computed, watch, nextTick } from 'vue'; +import { useRoute, useRouter } from 'vue-router'; +import i18n from './i18n' +import router from '../router' +export const t = i18n.global.t +export const locale = i18n.global.locale + + // 响应式数据 + const loading = ref(false); + const loginLoading = ref(false); + const initLoading = ref(false); + const downloadLoading = ref(false); + const downloadLoadingMessage = ref(''); + const isLoading = ref(false); // 页面加载loading状态 + const isPageLoading = ref(false); // 分页加载loading状态 + + // 录音相关状态 + const isRecording = ref(false); + const mediaRecorder = ref(null); + const audioChunks = ref([]); + const recordingDuration = ref(0); + const recordingTimer = ref(null); + const alert = ref({ show: false, message: '', type: 'info' }); + + + // 短信登录相关数据 + const phoneNumber = ref(''); + const verifyCode = ref(''); + const smsCountdown = ref(0); + const showSmsForm = ref(false); + const showErrorDetails = ref(false); + const showFailureDetails = ref(false); + + // 任务类型下拉菜单 + const showTaskTypeMenu = ref(false); + const showModelMenu = ref(false); + + // 任务状态轮询相关 + const pollingInterval = ref(null); + const pollingTasks = ref(new Set()); // 正在轮询的任务ID集合 + const confirmDialog = ref({ + show: false, + title: '', + message: '', + confirmText: '确认', // 使用静态文本,避免翻译依赖 + warning: null, + confirm: () => { } + }); + const submitting = ref(false); + const templateLoading = ref(false); // 模板/任务复用加载状态 + const templateLoadingMessage = ref(''); + const taskSearchQuery = ref(''); + const sidebarCollapsed = ref(false); + const showExpandHint = ref(false); + const showGlow = ref(false); + const isDefaultStateHidden = ref(false); + const isCreationAreaExpanded = ref(false); + const hasUploadedContent = ref(false); + const isContracting = ref(false); + const faceDetecting = ref(false); // Face detection loading state + const audioSeparating = ref(false); // Audio separation loading state + + const showTaskDetailModal = ref(false); + const modalTask = ref(null); + + // TTS 模态框状态 + const showVoiceTTSModal = ref(false); + const showPodcastModal = ref(false); + + // TaskCarousel当前任务状态 + const currentTask = ref(null); + + // 视频加载状态跟踪 + const videoLoadedStates = ref(new Map()); // 跟踪每个视频的加载状态 + + // 检查视频是否已加载完成 + const isVideoLoaded = (videoSrc) => { + return videoLoadedStates.value.get(videoSrc) || false; + }; + + // 设置视频加载状态 + const setVideoLoaded = (videoSrc, loaded) => { + videoLoadedStates.value.set(videoSrc, loaded); + }; + + // 灵感广场相关数据 + const inspirationSearchQuery = ref(''); + const selectedInspirationCategory = ref(''); + const inspirationItems = ref([]); + const InspirationCategories = ref([]); + + // 灵感广场分页相关变量 + const inspirationPagination = ref(null); + const inspirationCurrentPage = ref(1); + const inspirationPageSize = ref(20); + const inspirationPageInput = ref(1); + const inspirationPaginationKey = ref(0); + + // 模板详情弹窗相关数据 + const showTemplateDetailModal = ref(false); + const selectedTemplate = ref(null); + + // 图片放大弹窗相关数据 + const showImageZoomModal = ref(false); + const zoomedImageUrl = ref(''); + + // 任务文件缓存系统 + const taskFileCache = ref(new Map()); + const taskFileCacheLoaded = ref(false); + + // 模板文件缓存系统 + const templateFileCache = ref(new Map()); + const templateFileCacheLoaded = ref(false); + + // Podcast 音频 URL 缓存系统(模仿任务文件缓存) + const podcastAudioCache = ref(new Map()); + const podcastAudioCacheLoaded = ref(false); + + // 防重复获取的状态管理 + const templateUrlFetching = ref(new Set()); // 正在获取的URL集合 + const taskUrlFetching = ref(new Map()); // 正在获取的任务URL集合 + + // localStorage缓存相关常量 + const TASK_FILE_CACHE_KEY = 'lightx2v_task_files'; + const TEMPLATE_FILE_CACHE_KEY = 'lightx2v_template_files'; + const PODCAST_AUDIO_CACHE_KEY = 'lightx2v_podcast_audio'; + const TASK_FILE_CACHE_EXPIRY = 24 * 60 * 60 * 1000; // 24小时过期 + const PODCAST_AUDIO_CACHE_EXPIRY = 24 * 60 * 60 * 1000; // 24小时过期 + const MODELS_CACHE_KEY = 'lightx2v_models'; + const MODELS_CACHE_EXPIRY = 60 * 60 * 1000; // 1小时过期 + const TEMPLATES_CACHE_KEY = 'lightx2v_templates'; + const TEMPLATES_CACHE_EXPIRY = 24 * 60 * 60 * 1000; // 24小时过期 + const TASKS_CACHE_KEY = 'lightx2v_tasks'; + const TASKS_CACHE_EXPIRY = 5 * 60 * 1000; // 5分钟过期 + + const imageTemplates = ref([]); + const audioTemplates = ref([]); + const mergedTemplates = ref([]); // 合并后的模板列表 + const showImageTemplates = ref(false); + const showAudioTemplates = ref(false); + const mediaModalTab = ref('history'); + + // Template分页相关变量 + const templatePagination = ref(null); + const templateCurrentPage = ref(1); + const templatePageSize = ref(20); // 图片模板每页12个,音频模板每页10个 + const templatePageInput = ref(1); + const templatePaginationKey = ref(0); + const imageHistory = ref([]); + const audioHistory = ref([]); + const ttsHistory = ref([]); + + // 模板文件缓存,避免重复下载 + const currentUser = ref({}); + const models = ref([]); + const tasks = ref([]); + const isLoggedIn = ref(null); // null表示未初始化,false表示未登录,true表示已登录 + + const selectedTaskId = ref(null); + const selectedTask = ref(null); + const selectedModel = ref(null); + const selectedTaskFiles = ref({ inputs: {}, outputs: {} }); // 存储任务的输入输出文件 + const loadingTaskFiles = ref(false); // 加载任务文件的状态 + const statusFilter = ref('ALL'); + const pagination = ref(null); + const currentTaskPage = ref(1); + const taskPageSize = ref(20); + const taskPageInput = ref(1); + const paginationKey = ref(0); // 用于强制刷新分页组件 + const taskMenuVisible = ref({}); // 管理每个任务的菜单显示状态 + const nameMap = computed(() => ({ + 't2v': t('textToVideo'), + 'i2v': t('imageToVideo'), + 's2v': t('speechToVideo'), + 'animate': t('animate') + })); + + // 任务类型提示信息 + const taskHints = computed(() => ({ + 't2v': [ + t('t2vHint1'), + t('t2vHint2'), + t('t2vHint3'), + t('t2vHint4') + ], + 'i2v': [ + t('i2vHint1'), + t('i2vHint2'), + t('i2vHint3'), + t('i2vHint4') + ], + 's2v': [ + t('s2vHint1'), + t('s2vHint2'), + t('s2vHint3'), + t('s2vHint4') + ], + 'animate': [ + t('animateHint1') || '上传目标角色图片和参考视频', + t('animateHint2') || '将视频中的角色替换为目标角色', + ] + })); + + // 当前任务类型的提示信息 + const currentTaskHints = computed(() => { + return taskHints.value[selectedTaskId.value] || taskHints.value['s2v']; + }); + + // 滚动提示相关 + const currentHintIndex = ref(0); + const hintInterval = ref(null); + + // 开始滚动提示 + const startHintRotation = () => { + if (hintInterval.value) { + clearInterval(hintInterval.value); + } + hintInterval.value = setInterval(() => { + currentHintIndex.value = (currentHintIndex.value + 1) % currentTaskHints.value.length; + }, 3000); // 每3秒切换一次 + }; + + // 停止滚动提示 + const stopHintRotation = () => { + if (hintInterval.value) { + clearInterval(hintInterval.value); + hintInterval.value = null; + } + }; + + // 为三个任务类型分别创建独立的表单 + const t2vForm = ref({ + task: 't2v', + model_cls: '', + stage: '', + prompt: '', + seed: 42 + }); + + const i2vForm = ref({ + task: 'i2v', + model_cls: '', + stage: '', + imageFile: null, + prompt: '', + seed: 42, + detectedFaces: [] // List of detected faces: [{ index, bbox, face_image, roleName, ... }] + }); + + const s2vForm = ref({ + task: 's2v', + model_cls: '', + stage: '', + imageFile: null, + audioFile: null, + prompt: '', + seed: 42, + detectedFaces: [], // List of detected faces: [{ index, bbox, face_image, roleName, ... }] + separatedAudios: [] // List of separated audio tracks: [{ speaker_id, audio (base64), roleName, ... }] + }); + + const animateForm = ref({ + task: 'animate', + model_cls: '', + stage: '', + imageFile: null, + videoFile: null, + prompt: '视频中的人在做动作', + seed: 42, + detectedFaces: [] // List of detected faces: [{ index, bbox, face_image, roleName, ... }] + }); + + // 根据当前选择的任务类型获取对应的表单 + const getCurrentForm = () => { + switch (selectedTaskId.value) { + case 't2v': + return t2vForm.value; + case 'i2v': + return i2vForm.value; + case 's2v': + return s2vForm.value; + case 'animate': + return animateForm.value; + default: + return t2vForm.value; + } + }; + + // 控制默认状态显示/隐藏的方法 + const hideDefaultState = () => { + isDefaultStateHidden.value = true; + }; + + const showDefaultState = () => { + isDefaultStateHidden.value = false; + }; + + // 控制创作区域展开/收缩的方法 + const expandCreationArea = () => { + isCreationAreaExpanded.value = true; + // 添加show类来触发动画 + setTimeout(() => { + const creationArea = document.querySelector('.creation-area'); + if (creationArea) { + creationArea.classList.add('show'); + } + }, 10); + }; + + const contractCreationArea = () => { + isContracting.value = true; + const creationArea = document.querySelector('.creation-area'); + if (creationArea) { + // 添加hide类来触发收起动画 + creationArea.classList.add('hide'); + creationArea.classList.remove('show'); + } + // 等待动画完成后更新状态 + setTimeout(() => { + isCreationAreaExpanded.value = false; + isContracting.value = false; + if (creationArea) { + creationArea.classList.remove('hide'); + } + }, 400); + }; + + // 为每个任务类型创建独立的预览变量 + const i2vImagePreview = ref(null); + const s2vImagePreview = ref(null); + const s2vAudioPreview = ref(null); + const animateImagePreview = ref(null); + const animateVideoPreview = ref(null); + + // 监听上传内容变化 + const updateUploadedContentStatus = () => { + hasUploadedContent.value = !!(getCurrentImagePreview() || getCurrentAudioPreview() || getCurrentVideoPreview() || getCurrentForm().prompt?.trim()); + }; + + // 监听表单变化 + watch([i2vImagePreview, s2vImagePreview, s2vAudioPreview, animateImagePreview, animateVideoPreview, () => getCurrentForm().prompt], () => { + updateUploadedContentStatus(); + }, { deep: true }); + + // 监听任务类型变化,重置提示滚动 + watch(selectedTaskId, () => { + currentHintIndex.value = 0; + stopHintRotation(); + startHintRotation(); + }); + + // 根据当前任务类型获取对应的预览变量 + const getCurrentImagePreview = () => { + switch (selectedTaskId.value) { + case 't2v': + return null; + case 'i2v': + return i2vImagePreview.value; + case 's2v': + return s2vImagePreview.value; + case 'animate': + return animateImagePreview.value; + default: + return null; + } + }; + + const getCurrentAudioPreview = () => { + switch (selectedTaskId.value) { + case 't2v': + return null + case 'i2v': + return null + case 's2v': + return s2vAudioPreview.value; + default: + return null; + } + }; + + const setCurrentImagePreview = (value) => { + switch (selectedTaskId.value) { + case 't2v': + break; + case 'i2v': + i2vImagePreview.value = value; + break; + case 's2v': + s2vImagePreview.value = value; + break; + case 'animate': + animateImagePreview.value = value; + break; + } + // 清除图片预览缓存,确保新图片能正确显示 + urlCache.value.delete('current_image_preview'); + }; + + const setCurrentAudioPreview = (value) => { + switch (selectedTaskId.value) { + case 't2v': + break; + case 'i2v': + break; + case 'animate': + break; + case 's2v': + s2vAudioPreview.value = value; + break; + } + // 清除音频预览缓存,确保新音频能正确显示 + urlCache.value.delete('current_audio_preview'); + }; + + // 获取当前任务类型的视频预览 + const getCurrentVideoPreview = () => { + switch (selectedTaskId.value) { + case 'animate': + return animateVideoPreview.value; + default: + return null; + } + }; + + // 设置当前任务类型的视频预览 + const setCurrentVideoPreview = (value) => { + switch (selectedTaskId.value) { + case 'animate': + animateVideoPreview.value = value; + break; + } + // 清除视频预览缓存,确保新视频能正确显示 + urlCache.value.delete('current_video_preview'); + }; + + // 提示词模板相关 + const showTemplates = ref(false); + const showHistory = ref(false); + const showPromptModal = ref(false); + const promptModalTab = ref('templates'); + + // 计算属性 + const availableTaskTypes = computed(() => { + const types = [...new Set(models.value.map(m => m.task))]; + // 重新排序,确保数字人在最左边 + const orderedTypes = []; + + // 检查是否有s2v模型,如果有则添加s2v类型 + const hasS2vModels = models.value.some(m => + m.task === 's2v' + ); + + // 优先添加数字人(如果存在相关模型) + if (hasS2vModels) { + orderedTypes.push('s2v'); + } + + // 然后添加其他类型 + types.forEach(type => { + if (type !== 's2v') { + orderedTypes.push(type); + } + }); + + return orderedTypes; + }); + + const availableModelClasses = computed(() => { + if (!selectedTaskId.value) return []; + + return [...new Set(models.value + .filter(m => m.task === selectedTaskId.value) + .map(m => m.model_cls))]; + }); + + const filteredTasks = computed(() => { + let filtered = tasks.value; + + // 状态过滤 + if (statusFilter.value !== 'ALL') { + filtered = filtered.filter(task => task.status === statusFilter.value); + } + + // 搜索过滤 + if (taskSearchQuery.value) { + filtered = filtered.filter(task => + task.params.prompt?.toLowerCase().includes(taskSearchQuery.value.toLowerCase()) || + task.task_id.toLowerCase().includes(taskSearchQuery.value.toLowerCase()) || + nameMap.value[task.task_type].toLowerCase().includes(taskSearchQuery.value.toLowerCase()) + ); + } + + // 按时间排序,最新的任务在前面 + filtered = filtered.sort((a, b) => { + const timeA = parseInt(a.create_t) || 0; + const timeB = parseInt(b.create_t) || 0; + return timeB - timeA; // 降序排列,最新的在前 + }); + + return filtered; + }); + + // 监听状态筛选变化,重置分页到第一页 + watch(statusFilter, (newStatus, oldStatus) => { + if (newStatus !== oldStatus) { + currentTaskPage.value = 1; + taskPageInput.value = 1; + refreshTasks(true); // 强制刷新 + } + }); + + // 监听搜索查询变化,重置分页到第一页 + watch(taskSearchQuery, (newQuery, oldQuery) => { + if (newQuery !== oldQuery) { + currentTaskPage.value = 1; + taskPageInput.value = 1; + refreshTasks(true); // 强制刷新 + } + }); + + // 分页信息计算属性,确保响应式更新 + const paginationInfo = computed(() => { + if (!pagination.value) return null; + + return { + total: pagination.value.total || 0, + total_pages: pagination.value.total_pages || 0, + current_page: pagination.value.current_page || currentTaskPage.value, + page_size: pagination.value.page_size || taskPageSize.value + }; + }); + + // Template分页信息计算属性 + const templatePaginationInfo = computed(() => { + if (!templatePagination.value) return null; + + return { + total: templatePagination.value.total || 0, + total_pages: templatePagination.value.total_pages || 0, + current_page: templatePagination.value.current_page || templateCurrentPage.value, + page_size: templatePagination.value.page_size || templatePageSize.value + }; + }); + + // 灵感广场分页信息计算属性 + const inspirationPaginationInfo = computed(() => { + if (!inspirationPagination.value) return null; + + return { + total: inspirationPagination.value.total || 0, + total_pages: inspirationPagination.value.total_pages || 0, + current_page: inspirationPagination.value.current_page || inspirationCurrentPage.value, + page_size: inspirationPagination.value.page_size || inspirationPageSize.value + }; + }); + + + // 通用URL缓存 + const urlCache = ref(new Map()); + + // 通用URL缓存函数 + const getCachedUrl = (key, urlGenerator) => { + if (urlCache.value.has(key)) { + return urlCache.value.get(key); + } + + const url = urlGenerator(); + urlCache.value.set(key, url); + return url; + }; + + // 获取历史图片URL(带缓存) + const getHistoryImageUrl = (history) => { + if (!history || !history.thumbnail) return ''; + return getCachedUrl(`history_image_${history.filename}`, () => history.thumbnail); + }; + + // 获取用户头像URL(带缓存) + const getUserAvatarUrl = (user) => { + if (!user || !user.avatar) return ''; + return getCachedUrl(`user_avatar_${user.username}`, () => user.avatar); + }; + + // 获取当前图片预览URL(带缓存) + const getCurrentImagePreviewUrl = () => { + const preview = getCurrentImagePreview(); + if (!preview) return ''; + return getCachedUrl(`current_image_preview`, () => preview); + }; + + // 获取当前音频预览URL(带缓存) + const getCurrentAudioPreviewUrl = () => { + const preview = getCurrentAudioPreview(); + if (!preview) return ''; + return getCachedUrl(`current_audio_preview`, () => preview); + }; + + const getCurrentVideoPreviewUrl = () => { + const preview = getCurrentVideoPreview(); + if (!preview) return ''; + return getCachedUrl(`current_video_preview`, () => preview); + }; + + // Alert定时器,用于清除之前的定时器 + let alertTimeout = null; + + // 方法 + const showAlert = (message, type = 'info', action = null) => { + // 清除之前的定时器 + if (alertTimeout) { + clearTimeout(alertTimeout); + alertTimeout = null; + } + + // 如果当前有alert正在显示,先关闭它 + if (alert.value && alert.value.show) { + alert.value.show = false; + // 等待transition完成(约400ms)后再显示新的alert + setTimeout(() => { + createNewAlert(message, type, action); + }, 450); + } else { + // 如果没有alert在显示,立即创建新的 + // 如果alert存在但已关闭,先重置它以确保状态干净 + if (alert.value && !alert.value.show) { + alert.value = { show: false, message: '', type: 'info', action: null }; + } + // 立即创建新alert,不需要等待nextTick + createNewAlert(message, type, action); + } + }; + + // 创建新alert的辅助函数 + const createNewAlert = (message, type, action) => { + // 再次清除定时器,防止重复设置 + if (alertTimeout) { + clearTimeout(alertTimeout); + alertTimeout = null; + } + + // 创建全新的对象,使用时间戳确保每次都是新对象 + const newAlert = { + show: true, + message: String(message), + type: String(type), + action: action ? { + label: String(action.label), + onClick: action.onClick + } : null, + // 添加一个时间戳确保每次都是新对象,用于key + _timestamp: Date.now() + }; + + // 直接赋值新对象 + alert.value = newAlert; + + // 设置自动关闭定时器 + alertTimeout = setTimeout(() => { + if (alert.value && alert.value.show && alert.value._timestamp === newAlert._timestamp) { + alert.value.show = false; + } + alertTimeout = null; + }, 5000); + }; + + // 显示确认对话框 + const showConfirmDialog = (options) => { + return new Promise((resolve) => { + confirmDialog.value = { + show: true, + title: options.title || '确认操作', + message: options.message || '确定要执行此操作吗?', + confirmText: options.confirmText || '确认', + warning: options.warning || null, + confirm: () => { + confirmDialog.value.show = false; + resolve(true); + }, + cancel: () => { + confirmDialog.value.show = false; + resolve(false); + } + }; + }); + }; + + const setLoading = (value) => { + loading.value = value; + }; + + const apiCall = async (endpoint, options = {}) => { + const url = `${endpoint}`; + const headers = { + 'Content-Type': 'application/json', + ...options.headers + }; + + if (localStorage.getItem('accessToken')) { + headers['Authorization'] = `Bearer ${localStorage.getItem('accessToken')}`; + } + + const response = await fetch(url, { + ...options, + headers + }); + + if (response.status === 401) { + logout(false); + showAlert(t('authFailedPleaseRelogin'), 'warning', { + label: t('login'), + onClick: login + }); + throw new Error(t('authFailedPleaseRelogin')); + } + if (response.status === 400) { + const error = await response.json(); + showAlert(error.message, 'danger'); + throw new Error(error.message); + } + + // 添加50ms延迟,防止触发服务端频率限制 + await new Promise(resolve => setTimeout(resolve, 50)); + + return response; + }; + + const loginWithGitHub = async () => { + try { + console.log('starting GitHub login') + const response = await fetch('/auth/login/github'); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + const data = await response.json(); + localStorage.setItem('loginSource', 'github'); + window.location.href = data.auth_url; + } catch (error) { + console.log('GitHub login error:', error); + showAlert(t('getGitHubAuthUrlFailed'), 'danger'); + } + }; + + const loginWithGoogle = async () => { + try { + console.log('starting Google login') + const response = await fetch('/auth/login/google'); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + const data = await response.json(); + localStorage.setItem('loginSource', 'google'); + window.location.href = data.auth_url; + } catch (error) { + console.error('Google login error:', error); + showAlert(t('getGoogleAuthUrlFailed'), 'danger'); + } + }; + + // 发送短信验证码 + const sendSmsCode = async () => { + if (!phoneNumber.value) { + showAlert(t('pleaseEnterPhoneNumber'), 'warning'); + return; + } + + // 简单的手机号格式验证 + const phoneRegex = /^1[3-9]\d{9}$/; + if (!phoneRegex.test(phoneNumber.value)) { + showAlert(t('pleaseEnterValidPhoneNumber'), 'warning'); + return; + } + + try { + const response = await fetch(`./auth/login/sms?phone_number=${phoneNumber.value}`); + const data = await response.json(); + + if (response.ok) { + showAlert(t('verificationCodeSent'), 'success'); + // 开始倒计时 + startSmsCountdown(); + } else { + showAlert(data.message || t('sendVerificationCodeFailed'), 'danger'); + } + } catch (error) { + showAlert(t('sendVerificationCodeFailedRetry'), 'danger'); + } + }; + + // 短信验证码登录 + const loginWithSms = async () => { + if (!phoneNumber.value || !verifyCode.value) { + showAlert(t('pleaseEnterPhoneAndCode'), 'warning'); + return; + } + + try { + const response = await fetch(`./auth/callback/sms?phone_number=${phoneNumber.value}&verify_code=${verifyCode.value}`); + const data = await response.json(); + + if (response.ok) { + localStorage.setItem('accessToken', data.access_token); + if (data.refresh_token) { + localStorage.setItem('refreshToken', data.refresh_token); + } + localStorage.setItem('currentUser', JSON.stringify(data.user_info)); + currentUser.value = data.user_info; + + // 登录成功后初始化数据 + await init(); + + router.push('/generate'); + console.log('login with sms success'); + isLoggedIn.value = true; + + showAlert(t('loginSuccess'), 'success'); + } else { + showAlert(data.message || t('verificationCodeErrorOrExpired'), 'danger'); + } + } catch (error) { + showAlert(t('loginFailedRetry'), 'danger'); + } + }; + + // 处理手机号输入框回车键 + const handlePhoneEnter = () => { + if (phoneNumber.value && !smsCountdown.value) { + sendSmsCode(); + } + }; + + // 处理验证码输入框回车键 + const handleVerifyCodeEnter = () => { + if (phoneNumber.value && verifyCode.value) { + loginWithSms(); + } + }; + + // 移动端检测和样式应用 + const applyMobileStyles = () => { + if (window.innerWidth <= 640) { + // 为左侧功能区添加移动端样式 + const leftNav = document.querySelector('.relative.w-20.pl-5.flex.flex-col.z-10'); + if (leftNav) { + leftNav.classList.add('mobile-bottom-nav'); + } + + // 为导航按钮容器添加移动端样式 + const navContainer = document.querySelector('.p-2.flex.flex-col.justify-center.h-full'); + if (navContainer) { + navContainer.classList.add('mobile-nav-buttons'); + } + + // 为所有导航按钮添加移动端样式 + const navButtons = document.querySelectorAll('.relative.w-20.pl-5.flex.flex-col.z-10 button'); + navButtons.forEach(btn => { + btn.classList.add('mobile-nav-btn'); + }); + + // 为主内容区域添加移动端样式 + const contentAreas = document.querySelectorAll('.flex-1.flex.flex-col.min-h-0'); + contentAreas.forEach(area => { + area.classList.add('mobile-content'); + }); + } + }; + + // 短信验证码倒计时 + const startSmsCountdown = () => { + smsCountdown.value = 60; + const timer = setInterval(() => { + smsCountdown.value--; + if (smsCountdown.value <= 0) { + clearInterval(timer); + } + }, 1000); + }; + + // 切换短信登录表单显示 + const toggleSmsLogin = () => { + showSmsForm.value = !showSmsForm.value; + if (!showSmsForm.value) { + // 重置表单数据 + phoneNumber.value = ''; + verifyCode.value = ''; + smsCountdown.value = 0; + } + }; + + const handleLoginCallback = async (code, source) => { + try { + const response = await fetch(`/auth/callback/${source}?code=${code}`); + if (response.ok) { + const data = await response.json(); + console.log(data); + localStorage.setItem('accessToken', data.access_token); + if (data.refresh_token) { + localStorage.setItem('refreshToken', data.refresh_token); + } + localStorage.setItem('currentUser', JSON.stringify(data.user_info)); + currentUser.value = data.user_info; + isLoggedIn.value = true; + + // 在进入新页面前显示loading + isLoading.value = true; + + // 登录成功后初始化数据 + await init(); + + // 检查是否有分享数据需要导入 + const shareData = localStorage.getItem('shareData'); + if (shareData) { + // 解析分享数据获取shareId + try { + const parsedShareData = JSON.parse(shareData); + const shareId = parsedShareData.share_id || parsedShareData.task_id; + if (shareId) { + localStorage.removeItem('shareData'); + // 跳转回分享页面,让createSimilar函数处理数据 + router.push(`/share/${shareId}`); + return; + } + } catch (error) { + console.warn('Failed to parse share data:', error); + } + localStorage.removeItem('shareData'); + } + + // 默认跳转到生成页面 + router.push('/generate'); + console.log('login with callback success'); + + // 清除URL中的code参数 + window.history.replaceState({}, document.title, window.location.pathname); + } else { + const error = await response.json(); + showAlert(`${t('loginFailedRetry')}: ${error.detail}`, 'danger'); + } + } catch (error) { + showAlert(t('loginError'), 'danger'); + console.error(error); + } + }; + + let refreshPromise = null; + + const logout = (showMessage = true) => { + localStorage.removeItem('accessToken'); + localStorage.removeItem('refreshToken'); + localStorage.removeItem('currentUser'); + refreshPromise = null; + + clearAllCache(); + switchToLoginView(); + isLoggedIn.value = false; + + models.value = []; + tasks.value = []; + if (showMessage) { + showAlert(t('loggedOut'), 'info'); + } + }; + + const login = () => { + switchToLoginView(); + isLoggedIn.value = false; + }; + + const loadModels = async (forceRefresh = false) => { + try { + // 如果不是强制更新,先尝试从缓存加载 + if (!forceRefresh) { + const cachedModels = loadFromCache(MODELS_CACHE_KEY, MODELS_CACHE_EXPIRY); + if (cachedModels) { + console.log('成功从缓存加载模型列表'); + models.value = cachedModels; + return; + } + } + + console.log('开始加载模型列表...'); + const response = await apiRequest('/api/v1/model/list'); + if (response && response.ok) { + const data = await response.json(); + console.log('模型列表数据:', data); + const modelsData = data.models || []; + models.value = modelsData; + + // 保存到缓存 + saveToCache(MODELS_CACHE_KEY, modelsData); + console.log('模型列表已缓存'); + } else if (response) { + console.error('模型列表API响应失败:', response); + showAlert(t('loadModelListFailed'), 'danger'); + } + // 如果response为null,说明是认证错误,apiRequest已经处理了 + } catch (error) { + console.error('加载模型失败:', error); + showAlert(`${t('loadModelFailed')}: ${error.message}`, 'danger'); + } + }; + + const refreshTemplateFileUrl = (templatesData) => { + for (const img of templatesData.images) { + console.log('刷新图片素材文件URL:', img.filename, img.url); + setTemplateFileToCache(img.filename, {url: img.url, timestamp: Date.now()}); + } + for (const audio of templatesData.audios) { + console.log('刷新音频素材文件URL:', audio.filename, audio.url); + setTemplateFileToCache(audio.filename, {url: audio.url, timestamp: Date.now()}); + } + for (const video of templatesData.videos) { + console.log('刷新视频素材文件URL:', video.filename, video.url); + setTemplateFileToCache(video.filename, {url: video.url, timestamp: Date.now()}); + } + } + + // 加载模板文件 + const loadImageAudioTemplates = async (forceRefresh = false) => { + try { + // 如果不是强制刷新,先尝试从缓存加载 + const cacheKey = `${TEMPLATES_CACHE_KEY}_IMAGE_AUDIO_MERGED_${templateCurrentPage.value}_${templatePageSize.value}`; + if (!forceRefresh) { + // 构建缓存键,包含分页和过滤条件 + const cachedTemplates = loadFromCache(cacheKey, TEMPLATES_CACHE_EXPIRY); + if (cachedTemplates && cachedTemplates.templates) { + console.log('成功从缓存加载模板列表'); + // 优先使用合并后的模板列表 + if (cachedTemplates.templates.merged) { + mergedTemplates.value = cachedTemplates.templates.merged || []; + // 从合并列表中提取图片和音频 + const images = []; + const audios = []; + mergedTemplates.value.forEach(template => { + if (template.image) { + images.push(template.image); + } + if (template.audio) { + audios.push(template.audio); + } + }); + imageTemplates.value = images; + audioTemplates.value = audios; + } else { + // 向后兼容:如果没有合并列表,使用旧的格式 + imageTemplates.value = cachedTemplates.templates.images || []; + audioTemplates.value = cachedTemplates.templates.audios || []; + } + templatePagination.value = cachedTemplates.pagination || null; + return; + } + } + + console.log('开始加载图片音乐素材库...'); + const response = await publicApiCall(`/api/v1/template/list?page=${templateCurrentPage.value}&page_size=${templatePageSize.value}`); + if (response.ok) { + const data = await response.json(); + console.log('图片音乐素材库数据:', data); + + // 使用合并后的模板列表 + const merged = data.templates?.merged || []; + mergedTemplates.value = merged; + + // 为了保持向后兼容,从合并列表中提取图片和音频 + const images = []; + const audios = []; + merged.forEach(template => { + if (template.image) { + images.push(template.image); + } + if (template.audio) { + audios.push(template.audio); + } + }); + + refreshTemplateFileUrl({ images, audios, videos: data.templates?.videos || [] }); + const templatesData = { + images: images, + audios: audios, + merged: merged + }; + + imageTemplates.value = images; + audioTemplates.value = audios; + templatePagination.value = data.pagination || null; + + // 保存到缓存 + saveToCache(cacheKey, { + templates: templatesData, + pagination: templatePagination.value + }); + console.log('图片音乐素材库已缓存:', templatesData); + + } else { + console.warn('加载素材库失败'); + } + } catch (error) { + console.warn('加载素材库失败:', error); + } + }; + + // 获取素材文件的通用函数(带缓存) + const getTemplateFile = async (template) => { + const cacheKey = template.url; + + // 先检查内存缓存 + if (templateFileCache.value.has(cacheKey)) { + console.log('从内存缓存获取素材文件:', template.filename); + return templateFileCache.value.get(cacheKey); + } + + // 如果缓存中没有,则下载并缓存 + console.log('下载素材文件:', template.filename); + const response = await fetch(template.url, { + cache: 'force-cache' // 强制使用浏览器缓存 + }); + + if (response.ok) { + const blob = await response.blob(); + + // 根据文件扩展名确定正确的MIME类型 + let mimeType = blob.type; + const extension = template.filename.toLowerCase().split('.').pop(); + + if (extension === 'wav') { + mimeType = 'audio/wav'; + } else if (extension === 'mp3') { + mimeType = 'audio/mpeg'; + } else if (extension === 'm4a') { + mimeType = 'audio/mp4'; + } else if (extension === 'ogg') { + mimeType = 'audio/ogg'; + } else if (extension === 'webm') { + mimeType = 'audio/webm'; + } + + console.log('文件扩展名:', extension, 'MIME类型:', mimeType); + + const file = new File([blob], template.filename, { type: mimeType }); + + // 缓存文件对象 + templateFileCache.value.set(cacheKey, file); + console.log('下载素材文件完成:', template.filename); + return file; + } else { + throw new Error('下载素材文件失败'); + } + }; + + // 选择图片素材 + const selectImageTemplate = async (template) => { + try { + const file = await getTemplateFile(template); + + if (selectedTaskId.value === 'i2v') { + i2vForm.value.imageFile = file; + i2vForm.value.detectedFaces = []; // Reset detected faces + } else if (selectedTaskId.value === 's2v') { + s2vForm.value.imageFile = file; + s2vForm.value.detectedFaces = []; // Reset detected faces + } else if (selectedTaskId.value === 'animate') { + animateForm.value.imageFile = file; + animateForm.value.detectedFaces = []; // Reset detected faces + } + + // 获取图片的 http/https URL(用于人脸识别和预览) + let imageUrl = null; + // 优先使用 template.url(如果是 http/https URL) + if (template.url && (template.url.startsWith('http://') || template.url.startsWith('https://'))) { + imageUrl = template.url; + } else if (template.inputs && template.inputs.input_image) { + // 如果有 inputs.input_image,使用 getTemplateFileUrlAsync 获取 URL + imageUrl = await getTemplateFileUrlAsync(template.inputs.input_image, 'images'); + } else if (template.filename) { + // 如果有 filename,尝试使用 getTemplateFileUrlAsync + imageUrl = await getTemplateFileUrlAsync(template.filename, 'images'); + } + + // 创建预览(使用 data URL 作为预览) + const reader = new FileReader(); + reader.onload = async (e) => { + const imageDataUrl = e.target.result; + // 如果有 http/https URL,使用它作为预览;否则使用 data URL + setCurrentImagePreview(imageUrl || imageDataUrl); + updateUploadedContentStatus(); + showImageTemplates.value = false; + showAlert(t('imageTemplateSelected'), 'success'); + // 不再自动检测人脸,等待用户手动打开多角色模式开关 + }; + reader.readAsDataURL(file); + + } catch (error) { + showAlert(`${t('loadImageTemplateFailed')}: ${error.message}`, 'danger'); + } + }; + + // 选择音频素材 + const selectAudioTemplate = async (template) => { + try { + const file = await getTemplateFile(template); + + s2vForm.value.audioFile = file; + + // 创建预览 + const reader = new FileReader(); + reader.onload = (e) => { + setCurrentAudioPreview(e.target.result); + updateUploadedContentStatus(); + }; + reader.readAsDataURL(file); + + showAudioTemplates.value = false; + showAlert(t('audioTemplateSelected'), 'success'); + } catch (error) { + showAlert(`${t('loadAudioTemplateFailed')}: ${error.message}`, 'danger'); + } + }; + + // 预览音频素材 + const previewAudioTemplate = (template) => { + console.log('预览音频模板:', template); + const audioUrl = getTemplateFileUrl(template.filename, 'audios'); + console.log('音频URL:', audioUrl); + if (!audioUrl) { + showAlert(t('audioFileUrlFailed'), 'danger'); + return; + } + + // 停止当前播放的音频 + if (currentPlayingAudio) { + currentPlayingAudio.pause(); + currentPlayingAudio.currentTime = 0; + currentPlayingAudio = null; + } + + const audio = new Audio(audioUrl); + currentPlayingAudio = audio; + + // 监听音频播放结束事件 + audio.addEventListener('ended', () => { + currentPlayingAudio = null; + // 调用停止回调 + if (audioStopCallback) { + audioStopCallback(); + audioStopCallback = null; + } + }); + + audio.addEventListener('error', () => { + console.error('音频播放失败:', audio.error); + showAlert(t('audioPlaybackFailed'), 'danger'); + currentPlayingAudio = null; + // 调用停止回调 + if (audioStopCallback) { + audioStopCallback(); + audioStopCallback = null; + } + }); + + audio.play().catch(error => { + console.error('音频播放失败:', error); + showAlert(t('audioPlaybackFailed'), 'danger'); + currentPlayingAudio = null; + }); + }; + + const handleImageUpload = async (event) => { + const file = event.target.files[0]; + if (file) { + if (selectedTaskId.value === 'i2v') { + i2vForm.value.imageFile = file; + i2vForm.value.detectedFaces = []; // Reset detected faces + } else if (selectedTaskId.value === 's2v') { + s2vForm.value.imageFile = file; + s2vForm.value.detectedFaces = []; // Reset detected faces + } else if (selectedTaskId.value === 'animate') { + animateForm.value.imageFile = file; + animateForm.value.detectedFaces = []; // Reset detected faces + } + const reader = new FileReader(); + reader.onload = async (e) => { + const imageDataUrl = e.target.result; + setCurrentImagePreview(imageDataUrl); + updateUploadedContentStatus(); + + // 不再自动检测人脸,等待用户手动打开多角色模式开关 + }; + reader.readAsDataURL(file); + } else { + // 用户取消了选择,保持原有图片不变 + // 不做任何操作 + } + }; + + // Crop face image from original image based on bbox coordinates + const cropFaceImage = (imageUrl, bbox) => { + return new Promise((resolve, reject) => { + // Validate bbox + if (!bbox || bbox.length !== 4) { + reject(new Error('Invalid bbox coordinates')) + return + } + + const [x1, y1, x2, y2] = bbox + const width = x2 - x1 + const height = y2 - y1 + + if (width <= 0 || height <= 0) { + reject(new Error(`Invalid bbox dimensions: ${width}x${height}`)) + return + } + + const img = new Image() + + // For data URLs, crossOrigin is not needed + if (imageUrl.startsWith('data:')) { + img.onload = () => { + try { + // Create Canvas to crop image + const canvas = document.createElement('canvas') + canvas.width = width + canvas.height = height + const ctx = canvas.getContext('2d') + + // Draw the cropped region to Canvas + ctx.drawImage( + img, + x1, y1, width, height, // Source image crop region + 0, 0, width, height // Canvas drawing position + ) + + // Convert to base64 + const base64 = canvas.toDataURL('image/png') + resolve(base64) + } catch (error) { + reject(error) + } + } + img.onerror = (e) => { + reject(new Error('Failed to load image for cropping')) + } + img.src = imageUrl + } else { + // For other URLs, set crossOrigin + img.crossOrigin = 'anonymous' + img.onload = () => { + try { + // Create Canvas to crop image + const canvas = document.createElement('canvas') + canvas.width = width + canvas.height = height + const ctx = canvas.getContext('2d') + + // Draw the cropped region to Canvas + ctx.drawImage( + img, + x1, y1, width, height, // Source image crop region + 0, 0, width, height // Canvas drawing position + ) + + // Convert to base64 + const base64 = canvas.toDataURL('image/png') + resolve(base64) + } catch (error) { + reject(error) + } + } + img.onerror = (e) => { + reject(new Error('Failed to load image for cropping (CORS or network error)')) + } + img.src = imageUrl + } + }) + } + + // Detect faces in uploaded image + const detectFacesInImage = async (imageDataUrl) => { + try { + // 验证输入 + if (!imageDataUrl || imageDataUrl.trim() === '') { + console.error('detectFacesInImage: imageDataUrl is empty'); + return; + } + + faceDetecting.value = true; + + // Convert blob URL to data URL (backend can't access blob URLs) + // For http/https URLs, send directly to backend + let imageInput = imageDataUrl; + if (imageDataUrl.startsWith('blob:')) { + // Blob URL: convert to data URL since backend can't access blob URLs + try { + const response = await fetch(imageDataUrl); + if (!response.ok) { + throw new Error(`Failed to fetch image: ${response.statusText}`); + } + const blob = await response.blob(); + imageInput = await new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result); + reader.onerror = reject; + reader.readAsDataURL(blob); + }); + } catch (error) { + console.error('Failed to convert blob URL to data URL:', error); + throw error; + } + } + // For data URLs and http/https URLs, send directly to backend + + // 再次验证 imageInput + if (!imageInput || imageInput.trim() === '') { + console.error('detectFacesInImage: imageInput is empty after processing'); + return; + } + + const response = await apiCall('/api/v1/face/detect', { + method: 'POST', + body: JSON.stringify({ + image: imageInput + }) + }); + + if (!response.ok) { + console.error('Face detection failed:', response.status, response.statusText); + return; + } + + const data = await response.json(); + console.log('Face detection response:', data); + + if (data && data.faces) { + // Crop face images for each detected face + // Use the original imageDataUrl for cropping (cropFaceImage can handle both data URLs and regular URLs) + const facesWithImages = await Promise.all( + data.faces.map(async (face, index) => { + try { + // Crop face image from original image based on bbox + const croppedImage = await cropFaceImage(imageDataUrl, face.bbox) + + // Remove data URL prefix, keep only base64 part (consistent with backend format) + // croppedImage is in format: "" + let base64Data = croppedImage + if (croppedImage.includes(',')) { + base64Data = croppedImage.split(',')[1] + } + + if (!base64Data) { + console.error(`Failed to extract base64 from cropped image for face ${index}`) + base64Data = null + } + + return { + ...face, + face_image: base64Data, // Base64 encoded face region image (without data URL prefix) + roleName: `角色${index + 1}`, + isEditing: false // Track editing state for each face + } + } catch (error) { + console.error(`Failed to crop face ${index}:`, error, 'bbox:', face.bbox); + // Return face without face_image if cropping fails + return { + ...face, + face_image: null, + roleName: `角色${index + 1}`, + isEditing: false + } + } + }) + ); + + const currentForm = getCurrentForm(); + if (currentForm) { + currentForm.detectedFaces = facesWithImages; + console.log('Updated detectedFaces:', currentForm.detectedFaces.length, 'faces with images'); + // 音频分离由统一的 watch 监听器处理,不需要在这里手动调用 + } + } + } catch (error) { + console.error('Face detection error:', error); + // Silently fail, don't show error to user + } finally { + faceDetecting.value = false; + } + }; + + // Update role name for a detected face + const updateFaceRoleName = (faceIndex, roleName) => { + const currentForm = getCurrentForm(); + if (currentForm && currentForm.detectedFaces && currentForm.detectedFaces[faceIndex]) { + // 使用展开运算符创建新对象,确保响应式更新 + currentForm.detectedFaces[faceIndex] = { + ...currentForm.detectedFaces[faceIndex], + roleName: roleName + }; + // 触发响应式更新 + currentForm.detectedFaces = [...currentForm.detectedFaces]; + } + }; + + // Toggle editing state for a face + const toggleFaceEditing = (faceIndex) => { + const currentForm = getCurrentForm(); + if (currentForm && currentForm.detectedFaces && currentForm.detectedFaces[faceIndex]) { + // 使用展开运算符创建新对象,确保响应式更新 + currentForm.detectedFaces[faceIndex] = { + ...currentForm.detectedFaces[faceIndex], + isEditing: !currentForm.detectedFaces[faceIndex].isEditing + }; + // 触发响应式更新 + currentForm.detectedFaces = [...currentForm.detectedFaces]; + } + }; + + // Save face role name and exit editing + const saveFaceRoleName = (faceIndex, roleName) => { + const currentForm = getCurrentForm(); + if (currentForm && currentForm.detectedFaces && currentForm.detectedFaces[faceIndex]) { + // 同时更新 roleName 和 isEditing,确保响应式更新 + currentForm.detectedFaces[faceIndex] = { + ...currentForm.detectedFaces[faceIndex], + roleName: roleName || currentForm.detectedFaces[faceIndex].roleName, + isEditing: false + }; + // 触发响应式更新 + currentForm.detectedFaces = [...currentForm.detectedFaces]; + } + // 同步更新所有关联的音频播放器角色名 + // 只有当任务类型是 s2v 且有分离的音频时才需要更新 + if (selectedTaskId.value === 's2v' && s2vForm.value.separatedAudios) { + s2vForm.value.separatedAudios.forEach((audio, index) => { + // 如果音频的 roleIndex 等于当前修改的 faceIndex,则更新其 roleName + if (audio.roleIndex === faceIndex) { + s2vForm.value.separatedAudios[index].roleName = roleName || `角色${faceIndex + 1}`; + } + }); + // 使用展开运算符确保响应式更新 + s2vForm.value.separatedAudios = [...s2vForm.value.separatedAudios]; + } + }; + + const selectTask = (taskType) => { + console.log('[selectTask] 开始切换任务类型:', { + taskType, + currentSelectedTaskId: selectedTaskId.value, + currentSelectedModel: selectedModel.value, + currentFormModel: getCurrentForm().model_cls, + currentFormStage: getCurrentForm().stage + }); + + for (const t of models.value.map(m => m.task)) { + if (getTaskTypeName(t) === taskType) { + taskType = t; + } + } + selectedTaskId.value = taskType; + + console.log('[selectTask] 任务类型已更新:', { + newTaskType: selectedTaskId.value, + availableModels: models.value.filter(m => m.task === taskType) + }); + + // 根据任务类型恢复对应的预览 + if (taskType === 'i2v' && i2vForm.value.imageFile) { + // 恢复图片预览 + const reader = new FileReader(); + reader.onload = (e) => { + setCurrentImagePreview(e.target.result); + }; + reader.readAsDataURL(i2vForm.value.imageFile); + } else if (taskType === 's2v') { + // 恢复数字人任务的图片和音频预览 + if (s2vForm.value.imageFile) { + const reader = new FileReader(); + reader.onload = (e) => { + setCurrentImagePreview(e.target.result); + }; + reader.readAsDataURL(s2vForm.value.imageFile); + } + if (s2vForm.value.audioFile) { + const reader = new FileReader(); + reader.onload = (e) => { + setCurrentAudioPreview(e.target.result); + }; + reader.readAsDataURL(s2vForm.value.audioFile); + } + } else if (taskType === 'animate') { + // 恢复角色替换任务的图片和视频预览 + if (animateForm.value.imageFile) { + const reader = new FileReader(); + reader.onload = (e) => { + setCurrentImagePreview(e.target.result); + }; + reader.readAsDataURL(animateForm.value.imageFile); + } + if (animateForm.value.videoFile) { + const reader = new FileReader(); + reader.onload = (e) => { + setCurrentVideoPreview(e.target.result); + }; + reader.readAsDataURL(animateForm.value.videoFile); + } + // 确保 animate 任务类型有默认 prompt + if (!animateForm.value.prompt || animateForm.value.prompt.trim() === '') { + animateForm.value.prompt = '视频中的人在做动作'; + } + } + + // 自动选择该任务类型下的第一个模型(仅在当前模型无效且 URL 中没有 model 参数时) + const currentForm = getCurrentForm(); + const urlParams = new URLSearchParams(window.location.search); + const hasModelInUrl = urlParams.has('model'); + + // 获取新任务类型下的可用模型 + const availableModels = models.value.filter(m => m.task === taskType); + + if (availableModels.length > 0) { + // 检查当前选择的模型和stage是否属于新任务类型 + const currentModel = currentForm.model_cls || selectedModel.value; + const currentStage = currentForm.stage; + const isCurrentModelValid = currentModel && availableModels.some(m => + m.model_cls === currentModel && m.stage === currentStage + ); + + // 如果当前模型无效且 URL 中没有 model 参数,自动选择第一个模型 + if (!isCurrentModelValid || !hasModelInUrl) { + const firstModel = availableModels[0]; + console.log('[selectTask] 自动选择第一个模型:', { + firstModel: firstModel.model_cls, + firstStage: firstModel.stage, + reason: !isCurrentModelValid ? '当前模型或stage无效' : 'URL中没有model参数', + currentModel, + currentStage + }); + // 直接调用 selectModel 来确保路由也会更新 + selectModel(firstModel.model_cls); + } else { + console.log('[selectTask] 不自动选择模型:', { + isCurrentModelValid, + hasModelInUrl, + currentModel, + currentStage + }); + } + } + }; + + const selectModel = (model) => { + console.log('[selectModel] 开始切换模型:', { + model, + currentSelectedModel: selectedModel.value, + currentTaskType: selectedTaskId.value, + currentFormModel: getCurrentForm().model_cls, + currentFormStage: getCurrentForm().stage + }); + + selectedModel.value = model; + const currentForm = getCurrentForm(); + currentForm.model_cls = model; + // 自动设置 stage 为模型对应的第一个 stage + const availableStages = models.value + .filter(m => m.task === selectedTaskId.value && m.model_cls === model) + .map(m => m.stage); + if (availableStages.length > 0) { + currentForm.stage = availableStages[0]; + console.log('[selectModel] 自动设置 stage:', { + stage: currentForm.stage, + availableStages + }); + } + + console.log('[selectModel] 模型切换完成:', { + selectedModel: selectedModel.value, + formModel: currentForm.model_cls, + formStage: currentForm.stage + }); + }; + + const triggerImageUpload = () => { + document.querySelector('input[type="file"][accept="image/*"]').click(); + }; + + const triggerAudioUpload = () => { + const audioInput = document.querySelector('input[type="file"][data-role="audio-input"]'); + if (audioInput) { + audioInput.click(); + } else { + console.warn('音频输入框未找到'); + } + }; + + const removeImage = () => { + setCurrentImagePreview(null); + if (selectedTaskId.value === 'i2v') { + i2vForm.value.imageFile = null; + i2vForm.value.detectedFaces = []; + } else if (selectedTaskId.value === 's2v') { + s2vForm.value.imageFile = null; + s2vForm.value.detectedFaces = []; + } else if (selectedTaskId.value === 'animate') { + animateForm.value.imageFile = null; + animateForm.value.detectedFaces = []; + } + updateUploadedContentStatus(); + // 重置文件输入框,确保可以重新选择相同文件 + const imageInput = document.querySelector('input[type="file"][accept="image/*"]'); + if (imageInput) { + imageInput.value = ''; + } + }; + + const removeAudio = () => { + setCurrentAudioPreview(null); + s2vForm.value.audioFile = null; + s2vForm.value.separatedAudios = []; + updateUploadedContentStatus(); + console.log('音频已移除'); + // 重置音频文件输入框,确保可以重新选择相同文件 + const audioInput = document.querySelector('input[type="file"][data-role="audio-input"]'); + if (audioInput) { + audioInput.value = ''; + } + }; + + // 删除视频(用于 animate 任务类型) + const removeVideo = () => { + setCurrentVideoPreview(null); + animateForm.value.videoFile = null; + updateUploadedContentStatus(); + console.log('视频已移除'); + // 重置视频文件输入框,确保可以重新选择相同文件 + const videoInput = document.querySelector('input[type="file"][data-role="video-input"]'); + if (videoInput) { + videoInput.value = ''; + } + }; + + // Update role assignment for separated audio + const updateSeparatedAudioRole = (speakerIndex, roleIndex) => { + if (s2vForm.value.separatedAudios && s2vForm.value.separatedAudios[speakerIndex]) { + const currentForm = getCurrentForm(); + const detectedFaces = currentForm?.detectedFaces || []; + + if (roleIndex >= 0 && roleIndex < detectedFaces.length) { + s2vForm.value.separatedAudios[speakerIndex].roleName = detectedFaces[roleIndex].roleName || `角色${roleIndex + 1}`; + s2vForm.value.separatedAudios[speakerIndex].roleIndex = roleIndex; + } + } + }; + + // Update audio name for a separated audio + const updateSeparatedAudioName = (audioIndex, audioName) => { + if (s2vForm.value.separatedAudios && s2vForm.value.separatedAudios[audioIndex]) { + s2vForm.value.separatedAudios[audioIndex].audioName = audioName; + } + }; + + // Toggle editing state for a separated audio + const toggleSeparatedAudioEditing = (audioIndex) => { + if (s2vForm.value.separatedAudios && s2vForm.value.separatedAudios[audioIndex]) { + s2vForm.value.separatedAudios[audioIndex].isEditing = !s2vForm.value.separatedAudios[audioIndex].isEditing; + } + }; + + // Save separated audio name and exit editing + const saveSeparatedAudioName = (audioIndex, audioName) => { + updateSeparatedAudioName(audioIndex, audioName); + toggleSeparatedAudioEditing(audioIndex); + }; + + const getAudioMimeType = () => { + if (s2vForm.value.audioFile) { + return s2vForm.value.audioFile.type; + } + return 'audio/mpeg'; // 默认类型 + }; + + // Separate audio tracks for multiple speakers + const separateAudioTracks = async (audioDataUrl, numSpeakers) => { + audioSeparating.value = true; // 开始音频分割,显示加载状态 + try { + // 优先使用 audioFile(如果存在),因为它包含完整的文件信息,避免 data URL 格式问题 + const currentForm = getCurrentForm(); + let audioData = audioDataUrl; + + if (currentForm?.audioFile && currentForm.audioFile instanceof File) { + // 使用 File 对象,读取为 base64,确保格式正确 + try { + const fileDataUrl = await new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result); + reader.onerror = reject; + reader.readAsDataURL(currentForm.audioFile); + }); + audioData = fileDataUrl; + console.log('Using audioFile for separation, format:', currentForm.audioFile.type); + } catch (error) { + console.warn('Failed to read audioFile, falling back to audioDataUrl:', error); + // 如果读取失败,继续使用 audioDataUrl + } + } + + // Clean and validate base64 string before sending + let cleanedAudioData = audioData; + if (audioData.includes(',')) { + // If it's a data URL, extract the base64 part + const parts = audioData.split(','); + if (parts.length > 1) { + cleanedAudioData = parts.slice(1).join(','); // Join in case there are multiple commas + } + } + + // Remove any whitespace and newlines + cleanedAudioData = cleanedAudioData.trim().replace(/\s/g, ''); + + // Check if it's a valid base64 string length (must be multiple of 4) + const missingPadding = cleanedAudioData.length % 4; + if (missingPadding !== 0) { + console.warn(`[separateAudioTracks] Base64 string length (${cleanedAudioData.length}) is not a multiple of 4, adding padding`); + cleanedAudioData += '='.repeat(4 - missingPadding); + } + + // Reconstruct data URL if it was originally a data URL + if (audioData.startsWith('data:')) { + const header = audioData.split(',')[0]; + cleanedAudioData = `${header},${cleanedAudioData}`; + } + + console.log(`[separateAudioTracks] Sending audio for separation, length: ${cleanedAudioData.length}, num_speakers: ${numSpeakers}`); + + const response = await apiCall('/api/v1/audio/separate', { + method: 'POST', + body: JSON.stringify({ + audio: cleanedAudioData, + num_speakers: numSpeakers + }) + }); + + if (!response.ok) { + console.error('Audio separation failed:', response.status, response.statusText); + audioSeparating.value = false; + return; + } + + const data = await response.json(); + console.log('Audio separation response:', data); + + if (data && data.speakers && data.speakers.length > 0) { + const currentForm = getCurrentForm(); + const detectedFaces = currentForm?.detectedFaces || []; + + // Map separated speakers to detected faces + // Initialize with first role if available + const separatedAudios = data.speakers.map((speaker, index) => { + const faceIndex = index < detectedFaces.length ? index : 0; + return { + speaker_id: speaker.speaker_id, + audio: speaker.audio, // Base64 encoded audio + audioDataUrl: `data:audio/wav;base64,${speaker.audio}`, // Data URL for preview + audioName: `音色${index + 1}`, // 音频名称,默认显示为"音色1"、"音色2"等 + roleName: detectedFaces[faceIndex]?.roleName || `角色${faceIndex + 1}`, // 关联的角色名称 + roleIndex: faceIndex, + isEditing: false, // 编辑状态 + sample_rate: speaker.sample_rate, + segments: speaker.segments + }; + }); + + // Update separatedAudios and trigger reactivity + s2vForm.value.separatedAudios = [...separatedAudios]; // Use spread to ensure reactivity + console.log('Updated separatedAudios:', s2vForm.value.separatedAudios.length, 'speakers', s2vForm.value.separatedAudios); + } else { + console.warn('No speakers found in separation response:', data); + s2vForm.value.separatedAudios = []; + } + audioSeparating.value = false; // 音频分割完成,隐藏加载状态 + } catch (error) { + console.error('Audio separation error:', error); + audioSeparating.value = false; // 发生错误时也要隐藏加载状态 + throw error; + } + }; + + const handleAudioUpload = async (event) => { + const file = event.target.files[0]; + + if (file && (file.type?.startsWith('audio/') || file.type?.startsWith('video/'))) { + const allowedVideoTypes = ['video/mp4', 'video/x-m4v', 'video/mpeg']; + if (file.type.startsWith('video/') && !allowedVideoTypes.includes(file.type)) { + showAlert(t('unsupportedVideoFormat'), 'warning'); + setCurrentAudioPreview(null); + s2vForm.value.separatedAudios = []; + updateUploadedContentStatus(); + return; + } + s2vForm.value.audioFile = file; + + // Read file as data URL for preview + const reader = new FileReader(); + reader.onload = async (e) => { + const audioDataUrl = e.target.result; + setCurrentAudioPreview(audioDataUrl); + updateUploadedContentStatus(); + // 音频分离由统一的 watch 监听器处理,不需要在这里手动调用 + console.log('[handleAudioUpload] 音频上传完成,音频分离将由统一的监听器自动处理'); + }; + reader.readAsDataURL(file); + } else { + setCurrentAudioPreview(null); + s2vForm.value.separatedAudios = []; + updateUploadedContentStatus(); + if (file) { + showAlert(t('unsupportedAudioOrVideo'), 'warning'); + } + } + }; + + // 处理视频上传(用于 animate 任务类型) + const handleVideoUpload = async (event) => { + const file = event.target.files[0]; + + if (file && file.type?.startsWith('video/')) { + const allowedVideoTypes = ['video/mp4', 'video/x-m4v', 'video/mpeg', 'video/webm', 'video/quicktime']; + if (!allowedVideoTypes.includes(file.type)) { + showAlert(t('unsupportedVideoFormat') || '不支持的视频格式', 'warning'); + setCurrentVideoPreview(null); + animateForm.value.videoFile = null; + updateUploadedContentStatus(); + return; + } + animateForm.value.videoFile = file; + + // Read file as data URL for preview + const reader = new FileReader(); + reader.onload = async (e) => { + const videoDataUrl = e.target.result; + setCurrentVideoPreview(videoDataUrl); + updateUploadedContentStatus(); + }; + reader.readAsDataURL(file); + } else { + setCurrentVideoPreview(null); + animateForm.value.videoFile = null; + updateUploadedContentStatus(); + if (file) { + showAlert(t('unsupportedVideoFormat') || '不支持的视频格式', 'warning'); + } + } + }; + + // 开始录音 + const startRecording = async () => { + try { + console.log('开始录音...'); + + // 检查浏览器支持 + if (!navigator.mediaDevices) { + throw new Error('该浏览器不支持录音功能'); + } + + if (!navigator.mediaDevices.getUserMedia) { + throw new Error('浏览器不支持录音功能,请确保使用HTTPS协议访问'); + } + + if (!window.MediaRecorder) { + throw new Error('浏览器不支持MediaRecorder,请更新到最新版本浏览器'); + } + + // 检查HTTPS协议 + console.log('当前协议:', location.protocol, '主机名:', location.hostname); + if (location.protocol !== 'https:' && location.hostname !== 'localhost' && !location.hostname.includes('127.0.0.1')) { + throw new Error(`录音功能需要HTTPS协议,当前使用${location.protocol}协议。请使用HTTPS访问网站或通过localhost:端口号访问`); + } + + console.log('浏览器支持检查通过,请求麦克风权限...'); + + // 记录浏览器支持状态用于调试 + const browserSupport = { + mediaDevices: !!navigator.mediaDevices, + getUserMedia: !!navigator.mediaDevices?.getUserMedia, + MediaRecorder: !!window.MediaRecorder, + protocol: location.protocol, + hostname: location.hostname, + userAgent: navigator.userAgent + }; + console.log('浏览器支持状态:', browserSupport); + + // 请求麦克风权限 + console.log('正在请求麦克风权限...'); + const stream = await navigator.mediaDevices.getUserMedia({ + audio: { + echoCancellation: true, + noiseSuppression: true, + sampleRate: 44100 + } + }); + console.log('麦克风权限获取成功,音频流:', stream); + + // 创建MediaRecorder + mediaRecorder.value = new MediaRecorder(stream, { + mimeType: 'audio/webm;codecs=opus' + }); + + audioChunks.value = []; + + // 监听数据可用事件 + mediaRecorder.value.ondataavailable = (event) => { + if (event.data.size > 0) { + audioChunks.value.push(event.data); + } + }; + + // 监听录音停止事件 + mediaRecorder.value.onstop = () => { + const audioBlob = new Blob(audioChunks.value, { type: 'audio/webm' }); + const audioFile = new File([audioBlob], 'recording.webm', { type: 'audio/webm' }); + + // 设置到表单 + s2vForm.value.audioFile = audioFile; + + // 创建预览URL + const audioUrl = URL.createObjectURL(audioBlob); + setCurrentAudioPreview(audioUrl); + updateUploadedContentStatus(); + + // 停止所有音频轨道 + stream.getTracks().forEach(track => track.stop()); + + showAlert(t('recordingCompleted'), 'success'); + }; + + // 开始录音 + mediaRecorder.value.start(1000); // 每秒收集一次数据 + isRecording.value = true; + recordingDuration.value = 0; + + // 开始计时 + recordingTimer.value = setInterval(() => { + recordingDuration.value++; + }, 1000); + + showAlert(t('recordingStarted'), 'info'); + + } catch (error) { + console.error('录音失败:', error); + let errorMessage = t('recordingFailed'); + + if (error.name === 'NotAllowedError') { + errorMessage = t('microphonePermissionDenied'); + } else if (error.name === 'NotFoundError') { + errorMessage = t('microphoneNotFound'); + } else if (error.name === 'NotSupportedError') { + errorMessage = t('recordingNotSupportedOnMobile'); + } else if (error.name === 'NotReadableError') { + errorMessage = t('microphoneInUse'); + } else if (error.name === 'OverconstrainedError') { + errorMessage = t('microphoneNotCompatible'); + } else if (error.name === 'SecurityError') { + errorMessage = t('securityErrorUseHttps'); + } else if (error.message) { + errorMessage = error.message; + } + + // 添加调试信息 + const debugInfo = { + userAgent: navigator.userAgent, + protocol: location.protocol, + hostname: location.hostname, + mediaDevices: !!navigator.mediaDevices, + getUserMedia: !!navigator.mediaDevices?.getUserMedia, + MediaRecorder: !!window.MediaRecorder, + isSecureContext: window.isSecureContext, + chromeVersion: navigator.userAgent.match(/Chrome\/(\d+)/)?.[1] || '未知' + }; + console.log('浏览器调试信息:', debugInfo); + + // 如果是Chrome但仍有问题,提供特定建议 + if (navigator.userAgent.includes('Chrome')) { + console.log('检测到Chrome浏览器,可能的问题:'); + console.log('1. 请确保使用HTTPS协议或localhost访问'); + console.log('2. 检查Chrome地址栏是否有麦克风权限'); + console.log('3. 尝试在Chrome设置中重置网站权限'); + console.log('4. 确保没有其他应用占用麦克风'); + } + + showAlert(errorMessage, 'danger'); + } + }; + + // 停止录音 + const stopRecording = () => { + if (mediaRecorder.value && isRecording.value) { + mediaRecorder.value.stop(); + isRecording.value = false; + + if (recordingTimer.value) { + clearInterval(recordingTimer.value); + recordingTimer.value = null; + } + + showAlert(t('recordingStopped'), 'info'); + } + }; + + // 格式化录音时长 + const formatRecordingDuration = (seconds) => { + const mins = Math.floor(seconds / 60); + const secs = seconds % 60; + return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`; + }; + + const submitTask = async () => { + try { + // 检查是否正在加载模板 + if (templateLoading.value) { + showAlert(t('templateLoadingPleaseWait'), 'warning'); + return; + } + + const currentForm = getCurrentForm(); + + // 表单验证 + if (!selectedTaskId.value) { + showAlert(t('pleaseSelectTaskType'), 'warning'); + return; + } + + if (!currentForm.model_cls) { + showAlert(t('pleaseSelectModel'), 'warning'); + return; + } + + // animate 任务类型不需要 prompt,其他任务类型需要 + if (selectedTaskId.value !== 'animate') { + if (!currentForm.prompt || currentForm.prompt.trim().length === 0) { + if (selectedTaskId.value === 's2v') { + currentForm.prompt = '让角色根据音频内容自然说话'; + } else { + showAlert(t('pleaseEnterPrompt'), 'warning'); + return; + } + } + + if (currentForm.prompt.length > 1000) { + showAlert(t('promptTooLong'), 'warning'); + return; + } + } + + if (selectedTaskId.value === 'i2v' && !currentForm.imageFile) { + showAlert(t('i2vTaskRequiresImage'), 'warning'); + return; + } + + if (selectedTaskId.value === 's2v' && !currentForm.imageFile) { + showAlert(t('s2vTaskRequiresImage'), 'warning'); + return; + } + + if (selectedTaskId.value === 's2v' && !currentForm.audioFile) { + showAlert(t('s2vTaskRequiresAudio'), 'warning'); + return; + } + + if (selectedTaskId.value === 'animate' && !currentForm.imageFile) { + showAlert(t('animateTaskRequiresImage'), 'warning'); + return; + } + + if (selectedTaskId.value === 'animate' && !currentForm.videoFile) { + showAlert(t('animateTaskRequiresVideo'), 'warning'); + return; + } + submitting.value = true; + + // 确定实际提交的任务类型 + let actualTaskType = selectedTaskId.value; + + var formData = { + task: actualTaskType, + model_cls: currentForm.model_cls, + stage: currentForm.stage, + seed: currentForm.seed || Math.floor(Math.random() * 1000000) + }; + + // animate 任务类型使用默认 prompt,其他任务类型需要用户输入 + if (selectedTaskId.value === 'animate') { + // animate 任务类型使用默认 prompt + formData.prompt = currentForm.prompt && currentForm.prompt.trim().length > 0 + ? currentForm.prompt.trim() + : '视频中的人在做动作'; + } else { + formData.prompt = currentForm.prompt ? currentForm.prompt.trim() : ''; + } + + if (currentForm.model_cls.startsWith('wan2.1')) { + formData.negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + } + + if (selectedTaskId.value === 'i2v' && currentForm.imageFile) { + const base64 = await fileToBase64(currentForm.imageFile); + formData.input_image = { + type: 'base64', + data: base64 + }; + } + + if (selectedTaskId.value === 'animate' && currentForm.imageFile) { + const base64 = await fileToBase64(currentForm.imageFile); + formData.input_image = { + type: 'base64', + data: base64 + }; + } + + if (selectedTaskId.value === 'animate' && currentForm.videoFile) { + const base64 = await fileToBase64(currentForm.videoFile); + formData.input_video = { + type: 'base64', + data: base64 + }; + } + + if (selectedTaskId.value === 's2v') { + if (currentForm.imageFile) { + const base64 = await fileToBase64(currentForm.imageFile); + formData.input_image = { + type: 'base64', + data: base64 + }; + } + + // 检测是否为多人模式:有多个分离的音频和多个角色 + const isMultiPersonMode = s2vForm.value.separatedAudios && + s2vForm.value.separatedAudios.length > 1 && + currentForm.detectedFaces && + currentForm.detectedFaces.length > 1; + + if (isMultiPersonMode) { + // 多人模式:生成mask图、保存音频文件、生成config.json + try { + const multiPersonData = await prepareMultiPersonAudio( + currentForm.detectedFaces, + s2vForm.value.separatedAudios, + currentForm.imageFile, + currentForm.audioFile // 传递原始音频文件 + ); + + formData.input_audio = { + type: 'directory', + data: multiPersonData + }; + formData.negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + } catch (error) { + console.error('Failed to prepare multi-person audio:', error); + showAlert(t('prepareMultiPersonAudioFailed') + ': ' + error.message, 'danger'); + submitting.value = false; + return; + } + } else if (currentForm.audioFile) { + // 单人模式:使用原始音频文件 + const base64 = await fileToBase64(currentForm.audioFile); + formData.input_audio = { + type: 'base64', + data: base64 + }; + formData.negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + } + } + + const response = await apiRequest('/api/v1/task/submit', { + method: 'POST', + body: JSON.stringify(formData) + }); + + if (response && response.ok) { + let result; + try { + result = await response.json(); + } catch (error) { + console.error('Failed to parse response JSON:', error); + showAlert(t('taskSubmittedButParseFailed'), 'warning'); + submitting.value = false; + return null; + } + + showAlert(t('taskSubmitSuccessAlert'), 'success'); + + // 开始轮询新提交的任务状态(不等待,异步执行) + try { + startPollingTask(result.task_id); + } catch (error) { + console.error('Failed to start polling task:', error); + // 不阻止流程继续 + } + + // 保存完整的任务历史(包括提示词、图片和音频)- 异步执行,不阻塞 + // 注意:addTaskToHistory 是同步函数,但为了统一处理,使用 Promise.resolve 包装 + Promise.resolve().then(() => { + try { + addTaskToHistory(selectedTaskId.value, currentForm); + } catch (error) { + console.error('Failed to add task to history:', error); + } + }).catch(error => { + console.error('Failed to add task to history:', error); + }); + + // 重置表单(异步执行,不阻塞)- 使用 Promise.race 添加超时保护 + try { + await Promise.race([ + Promise.resolve(resetForm(selectedTaskId.value)), + new Promise((_, reject) => setTimeout(() => reject(new Error('resetForm timeout')), 3000)) + ]); + } catch (error) { + console.error('Failed to reset form:', error); + // 不阻止流程继续,只记录错误 + } + + // 重置当前任务类型的表单(保留模型选择,清空图片、音频和提示词) + try { + selectedTaskId.value = selectedTaskId.value; + selectModel(currentForm.model_cls); + } catch (error) { + console.error('Failed to select model:', error); + // 不阻止流程继续 + } + + // 返回新创建的任务ID + return result.task_id; + } else { + let error; + try { + error = await response.json(); + showAlert(`${t('taskSubmitFailedAlert')}: ${error.message || 'Unknown error'},${error.detail || ''}`, 'danger'); + } catch (parseError) { + console.error('Failed to parse error response:', parseError); + showAlert(`${t('taskSubmitFailedAlert')}: ${response.statusText || 'Unknown error'}`, 'danger'); + } + return null; + } + } catch (error) { + showAlert(`${t('submitTaskFailedAlert')}: ${error.message}`, 'danger'); + return null; + } finally { + submitting.value = false; + } + }; + + const fileToBase64 = (file) => { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.readAsDataURL(file); + reader.onload = () => { + const base64 = reader.result.split(',')[1]; + resolve(base64); + }; + reader.onerror = error => reject(error); + }); + }; + + // 准备多人模式的音频数据:生成mask图、保存音频文件、生成config.json + const prepareMultiPersonAudio = async (detectedFaces, separatedAudios, imageFile, originalAudioFile) => { + // 1. 读取原始图片,获取尺寸 + const imageBase64 = await fileToBase64(imageFile); + const imageDataUrl = `data:image/png;base64,${imageBase64}`; + + // 创建图片对象以获取尺寸 + const img = new Image(); + await new Promise((resolve, reject) => { + img.onload = resolve; + img.onerror = reject; + img.src = imageDataUrl; + }); + const imageWidth = img.naturalWidth; + const imageHeight = img.naturalHeight; + + // 2. 为每个角色生成mask图和音频文件 + const directoryFiles = {}; + const talkObjects = []; + + for (let i = 0; i < detectedFaces.length; i++) { + const face = detectedFaces[i]; + const audioIndex = i < separatedAudios.length ? i : 0; + const audio = separatedAudios[audioIndex]; + + // 生成mask图(box部分为白色,其余部分为黑色) + const maskBase64 = await generateMaskImage( + face.bbox, + imageWidth, + imageHeight + ); + + // 保存mask图 + const maskFilename = `p${i + 1}_mask.png`; + directoryFiles[maskFilename] = maskBase64; + + // 保存音频文件 + // 注意:separatedAudios中的audio已经是base64编码的wav格式 + const audioFilename = `p${i + 1}.wav`; + directoryFiles[audioFilename] = audio.audio; // audio.audio是base64编码的wav数据 + + // 添加到talk_objects + talkObjects.push({ + audio: audioFilename, + mask: maskFilename + }); + } + + // 3. 保存原始未分割的音频文件(用于后续复用) + if (originalAudioFile) { + try { + // 将原始音频文件转换为base64 + const originalAudioBase64 = await fileToBase64(originalAudioFile); + // 根据原始文件名确定扩展名,如果没有扩展名则使用.wav + const originalFilename = originalAudioFile.name || 'original_audio.wav'; + const fileExtension = originalFilename.toLowerCase().split('.').pop(); + const validExtensions = ['wav', 'mp3', 'mp4', 'aac', 'ogg', 'm4a']; + const extension = validExtensions.includes(fileExtension) ? fileExtension : 'wav'; + const originalAudioFilename = `original_audio.${extension}`; + directoryFiles[originalAudioFilename] = originalAudioBase64; + console.log('已保存原始音频文件:', originalAudioFilename); + } catch (error) { + console.warn('保存原始音频文件失败:', error); + // 不阻止任务提交,只记录警告 + } + } + + // 4. 生成config.json + const configJson = { + talk_objects: talkObjects + }; + const configJsonString = JSON.stringify(configJson, null, 4); + const configBase64 = btoa(unescape(encodeURIComponent(configJsonString))); + directoryFiles['config.json'] = configBase64; + + return directoryFiles; + }; + + // 生成mask图:根据bbox坐标生成白色区域,其余为黑色 + const generateMaskImage = async (bbox, imageWidth, imageHeight) => { + // bbox格式: [x1, y1, x2, y2] + const [x1, y1, x2, y2] = bbox; + + // 创建canvas + const canvas = document.createElement('canvas'); + canvas.width = imageWidth; + canvas.height = imageHeight; + const ctx = canvas.getContext('2d'); + + // 填充黑色背景 + ctx.fillStyle = '#000000'; + ctx.fillRect(0, 0, imageWidth, imageHeight); + + // 在bbox区域填充白色 + ctx.fillStyle = '#FFFFFF'; + ctx.fillRect(Math.round(x1), Math.round(y1), Math.round(x2 - x1), Math.round(y2 - y1)); + + // 转换为base64 + return canvas.toDataURL('image/png').split(',')[1]; + }; + + const formatTime = (timestamp) => { + if (!timestamp) return ''; + const date = new Date(timestamp * 1000); + return date.toLocaleString('zh-CN'); + }; + + // 通用缓存管理函数 + const loadFromCache = (cacheKey, expiryKey) => { + try { + const cached = localStorage.getItem(cacheKey); + if (cached) { + const data = JSON.parse(cached); + if (Date.now() - data.timestamp < expiryKey) { + console.log(`成功从缓存加载数据${cacheKey}:`, data.data); + return data.data; + } else { + // 缓存过期,清除 + localStorage.removeItem(cacheKey); + console.log(`缓存过期,清除 ${cacheKey}`); + } + } + } catch (error) { + console.warn(`加载缓存失败 ${cacheKey}:`, error); + localStorage.removeItem(cacheKey); + } + return null; + }; + + const saveToCache = (cacheKey, data) => { + try { + const cacheData = { + data: data, + timestamp: Date.now() + }; + console.log(`成功保存缓存数据 ${cacheKey}:`, cacheData); + localStorage.setItem(cacheKey, JSON.stringify(cacheData)); + } catch (error) { + console.warn(`保存缓存失败 ${cacheKey}:`, error); + } + }; + + // 清除所有应用缓存 + const clearAllCache = () => { + try { + const cacheKeys = [ + TASK_FILE_CACHE_KEY, + TEMPLATE_FILE_CACHE_KEY, + MODELS_CACHE_KEY, + TEMPLATES_CACHE_KEY + ]; + + // 清除所有任务缓存(使用通配符匹配) + for (let i = 0; i < localStorage.length; i++) { + const key = localStorage.key(i); + if (key && key.startsWith(TASKS_CACHE_KEY)) { + localStorage.removeItem(key); + } + } + + // 清除所有模板缓存(使用通配符匹配) + for (let i = 0; i < localStorage.length; i++) { + const key = localStorage.key(i); + if (key && key.startsWith(TEMPLATES_CACHE_KEY)) { + localStorage.removeItem(key); + } + } + // 清除其他缓存 + cacheKeys.forEach(key => { + localStorage.removeItem(key); + }); + + // 清除内存中的任务文件缓存 + taskFileCache.value.clear(); + taskFileCacheLoaded.value = false; + + // 清除内存中的模板文件缓存 + templateFileCache.value.clear(); + templateFileCacheLoaded.value = false; + + console.log('所有缓存已清除'); + } catch (error) { + console.warn('清除缓存失败:', error); + } + }; + + // 模板文件缓存管理函数 + const loadTemplateFilesFromCache = () => { + try { + const cached = localStorage.getItem(TEMPLATE_FILE_CACHE_KEY); + if (cached) { + const data = JSON.parse(cached); + if (data.files) { + for (const [cacheKey, fileData] of Object.entries(data.files)) { + templateFileCache.value.set(cacheKey, fileData); + } + return true; + } else { + console.warn('模板文件缓存数据格式错误'); + return false; + } + } + } catch (error) { + console.warn('加载模板文件缓存失败:', error); + } + return false; + }; + + const saveTemplateFilesToCache = () => { + try { + const files = {}; + for (const [cacheKey, fileData] of templateFileCache.value.entries()) { + files[cacheKey] = fileData; + } + const data = { + files: files, + timestamp: Date.now() + }; + localStorage.setItem(TEMPLATE_FILE_CACHE_KEY, JSON.stringify(data)); + } catch (error) { + console.warn('保存模板文件缓存失败:', error); + } + }; + + const getTemplateFileCacheKey = (templateId, fileKey) => { + return `template_${templateId}_${fileKey}`; + }; + + const getTemplateFileFromCache = (cacheKey) => { + return templateFileCache.value.get(cacheKey) || null; + }; + + const setTemplateFileToCache = (fileKey, fileData) => { + templateFileCache.value.set(fileKey, fileData); + // 异步保存到localStorage + setTimeout(() => { + saveTemplateFilesToCache(); + }, 100); + }; + + const getTemplateFileUrlFromApi = async (fileKey, fileType) => { + const apiUrl = `/api/v1/template/asset_url/${fileType}/${fileKey}`; + const response = await apiRequest(apiUrl); + if (response && response.ok) { + const data = await response.json(); + let assertUrl = data.url; + if (assertUrl.startsWith('./assets/')) { + const token = localStorage.getItem('accessToken'); + if (token) { + assertUrl = `${assertUrl}&token=${encodeURIComponent(token)}`; + } + } + setTemplateFileToCache(fileKey, { + url: assertUrl, + timestamp: Date.now() + }); + return assertUrl; + } + return null; + }; + + // 获取模板文件URL(优先从缓存,缓存没有则生成URL)- 同步版本 + const getTemplateFileUrl = (fileKey, fileType) => { + // 检查参数有效性(静默处理,不打印警告,因为模板可能确实没有某些输入) + if (!fileKey) { + return null; + } + + // 先从缓存获取 + const cachedFile = getTemplateFileFromCache(fileKey); + if (cachedFile) { + /* console.log('从缓存获取模板文件url', { fileKey});*/ + return cachedFile.url; + } + // 如果缓存中没有,返回null,让调用方知道需要异步获取 + console.warn('模板文件URL不在缓存中,需要异步获取:', { fileKey, fileType }); + getTemplateFileUrlAsync(fileKey, fileType).then(url => { + return url; + }); + return null; + }; + + // 创建响应式的模板文件URL(用于首屏渲染) + const createTemplateFileUrlRef = (fileKey, fileType) => { + const urlRef = ref(null); + + // 检查参数有效性(静默处理,不打印警告) + if (!fileKey) { + return urlRef; + } + + // 先从缓存获取 + const cachedFile = getTemplateFileFromCache(fileKey); + if (cachedFile) { + urlRef.value = cachedFile.url; + return urlRef; + } + + // 检查是否正在获取中,避免重复请求 + const fetchKey = `${fileKey}_${fileType}`; + if (templateUrlFetching.value.has(fetchKey)) { + console.log('createTemplateFileUrlRef: 正在获取中,跳过重复请求', { fileKey, fileType }); + return urlRef; + } + + // 标记为正在获取 + templateUrlFetching.value.add(fetchKey); + + // 如果缓存中没有,异步获取 + getTemplateFileUrlFromApi(fileKey, fileType).then(url => { + if (url) { + urlRef.value = url; + // 将获取到的URL存储到缓存中 + setTemplateFileToCache(fileKey, { url, timestamp: Date.now() }); + } + }).catch(error => { + console.warn('获取模板文件URL失败:', error); + }).finally(() => { + // 移除获取状态 + templateUrlFetching.value.delete(fetchKey); + }); + + return urlRef; + }; + + // 创建响应式的任务文件URL(用于首屏渲染) + const createTaskFileUrlRef = (taskId, fileKey) => { + const urlRef = ref(null); + + // 检查参数有效性 + if (!taskId || !fileKey) { + console.warn('createTaskFileUrlRef: 参数为空', { taskId, fileKey }); + return urlRef; + } + + // 先从缓存获取 + const cachedFile = getTaskFileFromCache(taskId, fileKey); + if (cachedFile) { + urlRef.value = cachedFile.url; + return urlRef; + } + + // 如果缓存中没有,异步获取 + getTaskFileUrl(taskId, fileKey).then(url => { + if (url) { + urlRef.value = url; + // 将获取到的URL存储到缓存中 + setTaskFileToCache(taskId, fileKey, { url, timestamp: Date.now() }); + } + }).catch(error => { + console.warn('获取任务文件URL失败:', error); + }); + + return urlRef; + }; + + // 获取模板文件URL(异步版本,用于预加载等场景) + const getTemplateFileUrlAsync = async (fileKey, fileType) => { + // 检查参数有效性(静默处理,不打印警告,因为模板可能确实没有某些输入) + if (!fileKey) { + return null; + } + + // 先从缓存获取 + const cachedFile = getTemplateFileFromCache(fileKey); + if (cachedFile) { + console.log('getTemplateFileUrlAsync: 从缓存获取', { fileKey, url: cachedFile.url }); + return cachedFile.url; + } + + // 检查是否正在获取中,避免重复请求 + const fetchKey = `${fileKey}_${fileType}`; + if (templateUrlFetching.value.has(fetchKey)) { + console.log('getTemplateFileUrlAsync: 正在获取中,等待完成', { fileKey, fileType }); + // 等待其他请求完成 + return new Promise((resolve) => { + const checkInterval = setInterval(() => { + const cachedFile = getTemplateFileFromCache(fileKey); + if (cachedFile) { + clearInterval(checkInterval); + resolve(cachedFile.url); + } else if (!templateUrlFetching.value.has(fetchKey)) { + clearInterval(checkInterval); + resolve(null); + } + }, 100); + }); + } + + // 标记为正在获取 + templateUrlFetching.value.add(fetchKey); + + // 如果缓存中没有,异步获取 + try { + const url = await getTemplateFileUrlFromApi(fileKey, fileType); + if (url) { + // 将获取到的URL存储到缓存中 + setTemplateFileToCache(fileKey, { url, timestamp: Date.now() }); + } + return url; + } catch (error) { + console.warn('getTemplateFileUrlAsync: 获取URL失败', error); + return null; + } finally { + // 移除获取状态 + templateUrlFetching.value.delete(fetchKey); + } + }; + + // 任务文件缓存管理函数 + const loadTaskFilesFromCache = () => { + try { + const cached = localStorage.getItem(TASK_FILE_CACHE_KEY); + if (cached) { + const data = JSON.parse(cached); + // 检查是否过期 + if (Date.now() - data.timestamp < TASK_FILE_CACHE_EXPIRY) { + // 将缓存数据加载到内存缓存中 + for (const [cacheKey, fileData] of Object.entries(data.files)) { + taskFileCache.value.set(cacheKey, fileData); + } + return true; + } else { + // 缓存过期,清除 + localStorage.removeItem(TASK_FILE_CACHE_KEY); + } + } + } catch (error) { + console.warn('加载任务文件缓存失败:', error); + localStorage.removeItem(TASK_FILE_CACHE_KEY); + } + return false; + }; + + const saveTaskFilesToCache = () => { + try { + const files = {}; + for (const [cacheKey, fileData] of taskFileCache.value.entries()) { + files[cacheKey] = fileData; + } + const data = { + files, + timestamp: Date.now() + }; + localStorage.setItem(TASK_FILE_CACHE_KEY, JSON.stringify(data)); + } catch (error) { + console.warn('保存任务文件缓存失败:', error); + } + }; + + // 生成缓存键 + const getTaskFileCacheKey = (taskId, fileKey) => { + return `${taskId}_${fileKey}`; + }; + + // 从缓存获取任务文件 + const getTaskFileFromCache = (taskId, fileKey) => { + const cacheKey = getTaskFileCacheKey(taskId, fileKey); + return taskFileCache.value.get(cacheKey) || null; + }; + + // 设置任务文件到缓存 + const setTaskFileToCache = (taskId, fileKey, fileData) => { + const cacheKey = getTaskFileCacheKey(taskId, fileKey); + taskFileCache.value.set(cacheKey, fileData); + // 异步保存到localStorage + setTimeout(() => { + saveTaskFilesToCache(); + }, 100); + }; + + const getTaskFileUrlFromApi = async (taskId, fileKey, filename = null) => { + let apiUrl = `/api/v1/task/input_url?task_id=${taskId}&name=${fileKey}`; + if (filename) { + apiUrl += `&filename=${encodeURIComponent(filename)}`; + } + if (fileKey.includes('output')) { + apiUrl = `/api/v1/task/result_url?task_id=${taskId}&name=${fileKey}`; + } + const response = await apiRequest(apiUrl); + if (response && response.ok) { + const data = await response.json(); + let assertUrl = data.url; + if (assertUrl.startsWith('./assets/')) { + const token = localStorage.getItem('accessToken'); + if (token) { + assertUrl = `${assertUrl}&token=${encodeURIComponent(token)}`; + } + } + const cacheKey = filename ? `${fileKey}_${filename}` : fileKey; + setTaskFileToCache(taskId, cacheKey, { + url: assertUrl, + timestamp: Date.now() + }); + return assertUrl; + } else if (response && response.status === 400) { + // Handle directory input error (multi-person mode) + try { + const errorData = await response.json(); + if (errorData.error && errorData.error.includes('directory')) { + console.warn(`Input ${fileKey} is a directory (multi-person mode), cannot get single file URL`); + return null; + } + } catch (e) { + // Ignore JSON parse errors + } + } + return null; + }; + + // Podcast 音频 URL 缓存管理函数(模仿任务文件缓存) + const loadPodcastAudioFromCache = () => { + try { + const cached = localStorage.getItem(PODCAST_AUDIO_CACHE_KEY); + if (cached) { + const data = JSON.parse(cached); + // 检查是否过期 + if (Date.now() - data.timestamp < PODCAST_AUDIO_CACHE_EXPIRY) { + // 将缓存数据加载到内存缓存中 + for (const [cacheKey, audioData] of Object.entries(data.audio_urls)) { + podcastAudioCache.value.set(cacheKey, audioData); + } + podcastAudioCacheLoaded.value = true; + return true; + } else { + // 缓存过期,清除 + localStorage.removeItem(PODCAST_AUDIO_CACHE_KEY); + } + } + } catch (error) { + console.warn('加载播客音频缓存失败:', error); + localStorage.removeItem(PODCAST_AUDIO_CACHE_KEY); + } + podcastAudioCacheLoaded.value = true; + return false; + }; + + const savePodcastAudioToCache = () => { + try { + const audio_urls = {}; + for (const [cacheKey, audioData] of podcastAudioCache.value.entries()) { + audio_urls[cacheKey] = audioData; + } + const data = { + audio_urls, + timestamp: Date.now() + }; + localStorage.setItem(PODCAST_AUDIO_CACHE_KEY, JSON.stringify(data)); + } catch (error) { + console.warn('保存播客音频缓存失败:', error); + } + }; + + // 生成播客音频缓存键 + const getPodcastAudioCacheKey = (sessionId) => { + return sessionId; + }; + + // 从缓存获取播客音频 URL + const getPodcastAudioFromCache = (sessionId) => { + const cacheKey = getPodcastAudioCacheKey(sessionId); + return podcastAudioCache.value.get(cacheKey) || null; + }; + + // 设置播客音频 URL 到缓存 + const setPodcastAudioToCache = (sessionId, audioData) => { + const cacheKey = getPodcastAudioCacheKey(sessionId); + podcastAudioCache.value.set(cacheKey, audioData); + // 异步保存到localStorage + setTimeout(() => { + savePodcastAudioToCache(); + }, 100); + }; + + // 从 API 获取播客音频 URL(CDN URL) + const getPodcastAudioUrlFromApi = async (sessionId) => { + try { + const response = await apiCall(`/api/v1/podcast/session/${sessionId}/audio_url`); + if (response && response.ok) { + const data = await response.json(); + const audioUrl = data.audio_url; + setPodcastAudioToCache(sessionId, { + url: audioUrl, + timestamp: Date.now() + }); + return audioUrl; + } + } catch (error) { + console.warn(`Failed to get audio URL for session ${sessionId}:`, error); + } + return null; + }; + + // 获取任务文件URL(优先从缓存,缓存没有则调用后端) + const getTaskFileUrl = async (taskId, fileKey) => { + // 先从缓存获取 + const cachedFile = getTaskFileFromCache(taskId, fileKey); + if (cachedFile) { + return cachedFile.url; + } + return await getTaskFileUrlFromApi(taskId, fileKey); + }; + + // 同步获取任务文件URL(仅从缓存获取,用于模板显示) + const getTaskFileUrlSync = (taskId, fileKey) => { + const cachedFile = getTaskFileFromCache(taskId, fileKey); + if (cachedFile) { + console.log('getTaskFileUrlSync: 从缓存获取', { taskId, fileKey, url: cachedFile.url, type: typeof cachedFile.url }); + return cachedFile.url; + } + console.log('getTaskFileUrlSync: 缓存中没有找到', { taskId, fileKey }); + return null; + }; + + // 预加载任务文件 + const preloadTaskFilesUrl = async (tasks) => { + if (!tasks || tasks.length === 0) return; + + // 先尝试从localStorage加载缓存 + if (taskFileCache.value.size === 0) { + loadTaskFilesFromCache(); + } + + console.log(`开始获取 ${tasks.length} 个任务的文件url`); + + // 分批预加载,避免过多并发请求 + const batchSize = 5; + for (let i = 0; i < tasks.length; i += batchSize) { + const batch = tasks.slice(i, i + batchSize); + + const promises = batch.map(async (task) => { + if (!task.task_id) return; + + // 预加载输入图片 + if (task.inputs && task.inputs.input_image) { + await getTaskFileUrl(task.task_id, 'input_image'); + } + // 预加载输入音频 + if (task.inputs && task.inputs.input_audio) { + await getTaskFileUrl(task.task_id, 'input_audio'); + } + // 预加载输出视频 + if (task.outputs && task.outputs.output_video && task.status === 'SUCCEED') { + await getTaskFileUrl(task.task_id, 'output_video'); + } + }); + + await Promise.all(promises); + + // 批次间添加延迟 + if (i + batchSize < tasks.length) { + await new Promise(resolve => setTimeout(resolve, 200)); + } + } + + console.log('任务文件url预加载完成'); + }; + + // 预加载模板文件 + const preloadTemplateFilesUrl = async (templates) => { + if (!templates || templates.length === 0) return; + + // 先尝试从localStorage加载缓存 + if (templateFileCache.value.size === 0) { + loadTemplateFilesFromCache(); + } + + console.log(`开始获取 ${templates.length} 个模板的文件url`); + + // 分批预加载,避免过多并发请求 + const batchSize = 5; + for (let i = 0; i < templates.length; i += batchSize) { + const batch = templates.slice(i, i + batchSize); + + const promises = batch.map(async (template) => { + if (!template.task_id) return; + + // 预加载视频文件 + if (template.outputs?.output_video) { + await getTemplateFileUrlAsync(template.outputs.output_video, 'videos'); + } + + // 预加载图片文件 + if (template.inputs?.input_image) { + await getTemplateFileUrlAsync(template.inputs.input_image, 'images'); + } + + // 预加载音频文件 + if (template.inputs?.input_audio) { + await getTemplateFileUrlAsync(template.inputs.input_audio, 'audios'); + } + }); + + await Promise.all(promises); + + // 批次间添加延迟 + if (i + batchSize < templates.length) { + await new Promise(resolve => setTimeout(resolve, 200)); + } + } + + console.log('模板文件url预加载完成'); + }; + + const refreshTasks = async (forceRefresh = false) => { + try { + console.log('开始刷新任务列表, forceRefresh:', forceRefresh, 'currentPage:', currentTaskPage.value); + + // 构建缓存键,包含分页和过滤条件 + const cacheKey = `${TASKS_CACHE_KEY}_${currentTaskPage.value}_${taskPageSize.value}_${statusFilter.value}_${taskSearchQuery.value}`; + + // 如果不是强制刷新,先尝试从缓存加载 + if (!forceRefresh) { + const cachedTasks = loadFromCache(cacheKey, TASKS_CACHE_EXPIRY); + if (cachedTasks) { + console.log('从缓存加载任务列表'); + tasks.value = cachedTasks.tasks || []; + pagination.value = cachedTasks.pagination || null; + // 强制触发响应式更新 + await nextTick(); + // 强制刷新分页组件 + paginationKey.value++; + // 使用新的任务文件预加载逻辑 + await preloadTaskFilesUrl(tasks.value); + return; + } + } + + const params = new URLSearchParams({ + page: currentTaskPage.value.toString(), + page_size: taskPageSize.value.toString() + }); + + if (statusFilter.value !== 'ALL') { + params.append('status', statusFilter.value); + } + + console.log('请求任务列表API:', `/api/v1/task/list?${params.toString()}`); + const response = await apiRequest(`/api/v1/task/list?${params.toString()}`); + if (response && response.ok) { + const data = await response.json(); + console.log('任务列表API响应:', data); + + // 强制清空并重新赋值,确保Vue检测到变化 + tasks.value = []; + pagination.value = null; + await nextTick(); + + tasks.value = data.tasks || []; + pagination.value = data.pagination || null; + + // 缓存任务数据 + saveToCache(cacheKey, { + tasks: data.tasks || [], + pagination: data.pagination || null + }); + console.log('缓存任务列表数据成功'); + + // 强制触发响应式更新 + await nextTick(); + + // 强制刷新分页组件 + paginationKey.value++; + + // 使用新的任务文件预加载逻辑 + await preloadTaskFilesUrl(tasks.value); + } else if (response) { + showAlert(t('refreshTaskListFailed'), 'danger'); + } + // 如果response为null,说明是认证错误,apiRequest已经处理了 + } catch (error) { + console.error('刷新任务列表失败:', error); + // showAlert(`刷新任务列表失败: ${error.message}`, 'danger'); + } + }; + + // 分页相关函数 + const goToPage = async (page) => { + isPageLoading.value = true; + if (page < 1 || page > pagination.value?.total_pages || page === currentTaskPage.value) { + isPageLoading.value = false; + return; + } + currentTaskPage.value = page; + taskPageInput.value = page; // 同步更新输入框 + await refreshTasks(); + isPageLoading.value = false; + }; + + const jumpToPage = async () => { + const page = parseInt(taskPageInput.value); + if (page && page >= 1 && page <= pagination.value?.total_pages && page !== currentTaskPage.value) { + await goToPage(page); + } else { + // 如果输入无效,恢复到当前页 + taskPageInput.value = currentTaskPage.value; + } + }; + + // Template分页相关函数 + const goToTemplatePage = async (page) => { + isPageLoading.value=true; + if (page < 1 || page > templatePagination.value?.total_pages || page === templateCurrentPage.value) { + isPageLoading.value = false; + return; + } + templateCurrentPage.value = page; + templatePageInput.value = page; // 同步更新输入框 + await loadImageAudioTemplates(); + isPageLoading.value = false; + }; + + const jumpToTemplatePage = async () => { + const page = parseInt(templatePageInput.value); + if (page && page >= 1 && page <= templatePagination.value?.total_pages && page !== templateCurrentPage.value) { + await goToTemplatePage(page); + } else { + // 如果输入无效,恢复到当前页 + templatePageInput.value = templateCurrentPage.value; + } + }; + + const getVisiblePages = () => { + if (!pagination.value) return []; + + const totalPages = pagination.value.total_pages; + const current = currentTaskPage.value; + const pages = []; + + // 总是显示第一页 + pages.push(1); + + if (totalPages <= 5) { + // 如果总页数少于等于7页,显示所有页码 + for (let i = 2; i <= totalPages - 1; i++) { + pages.push(i); + } + } else { + // 如果总页数大于7页,使用省略号 + if (current <= 3) { + // 当前页在前4页 + for (let i = 2; i <= 3; i++) { + pages.push(i); + } + pages.push('...'); + } else if (current >= totalPages - 2) { + // 当前页在后4页 + pages.push('...'); + for (let i = totalPages - 2; i <= totalPages - 1; i++) { + pages.push(i); + } + } else { + // 当前页在中间 + pages.push('...'); + for (let i = current - 1; i <= current + 1; i++) { + pages.push(i); + } + pages.push('...'); + } + } + + // 总是显示最后一页(如果不是第一页) + if (totalPages > 1) { + pages.push(totalPages); + } + + return pages; + }; + + const getVisibleTemplatePages = () => { + if (!templatePagination.value) return []; + + const totalPages = templatePagination.value.total_pages; + const current = templateCurrentPage.value; + const pages = []; + + // 总是显示第一页 + pages.push(1); + + if (totalPages <= 5) { + // 如果总页数少于等于7页,显示所有页码 + for (let i = 2; i <= totalPages - 1; i++) { + pages.push(i); + } + } else { + // 显示当前页附近的页码 + const start = Math.max(2, current - 1); + const end = Math.min(totalPages - 1, current + 1); + + if (start > 2) { + pages.push('...'); + } + + for (let i = start; i <= end; i++) { + if (i !== 1 && i !== totalPages) { + pages.push(i); + } + } + + if (end < totalPages - 1) { + pages.push('...'); + } + } + + // 总是显示最后一页 + if (totalPages > 1) { + pages.push(totalPages); + } + + return pages; + }; + + // 灵感广场分页相关函数 + const goToInspirationPage = async (page) => { + isPageLoading.value = true; + if (page < 1 || page > inspirationPagination.value?.total_pages || page === inspirationCurrentPage.value) { + isPageLoading.value = false; + return; + } + inspirationCurrentPage.value = page; + inspirationPageInput.value = page; // 同步更新输入框 + await loadInspirationData(); + isPageLoading.value = false; + }; + + const jumpToInspirationPage = async () => { + const page = parseInt(inspirationPageInput.value); + if (page && page >= 1 && page <= inspirationPagination.value?.total_pages && page !== inspirationCurrentPage.value) { + await goToInspirationPage(page); + } else { + // 如果输入无效,恢复到当前页 + inspirationPageInput.value = inspirationCurrentPage.value; + } + }; + + const getVisibleInspirationPages = () => { + if (!inspirationPagination.value) return []; + + const totalPages = inspirationPagination.value.total_pages; + const current = inspirationCurrentPage.value; + const pages = []; + + // 总是显示第一页 + pages.push(1); + + if (totalPages <= 5) { + // 如果总页数少于等于7页,显示所有页码 + for (let i = 2; i <= totalPages - 1; i++) { + pages.push(i); + } + } else { + // 显示当前页附近的页码 + const start = Math.max(2, current - 1); + const end = Math.min(totalPages - 1, current + 1); + + if (start > 2) { + pages.push('...'); + } + + for (let i = start; i <= end; i++) { + if (i !== 1 && i !== totalPages) { + pages.push(i); + } + } + + if (end < totalPages - 1) { + pages.push('...'); + } + } + + // 总是显示最后一页 + if (totalPages > 1) { + pages.push(totalPages); + } + + return pages; + }; + + const getStatusBadgeClass = (status) => { + const statusMap = { + 'SUCCEED': 'bg-success', + 'FAILED': 'bg-danger', + 'RUNNING': 'bg-warning', + 'PENDING': 'bg-secondary', + 'CREATED': 'bg-secondary' + }; + return statusMap[status] || 'bg-secondary'; + }; + + const viewSingleResult = async (taskId, key) => { + try { + downloadLoading.value = true; + const url = await getTaskFileUrl(taskId, key); + if (url) { + const response = await fetch(url); + if (response.ok) { + const blob = await response.blob(); + const videoBlob = new Blob([blob], { type: 'video/mp4' }); + const url = window.URL.createObjectURL(videoBlob); + window.open(url, '_blank'); + } else { + showAlert(t('getResultFailed'), 'danger'); + } + } else { + showAlert(t('getTaskResultFailedAlert'), 'danger'); + } + } catch (error) { + showAlert(`${t('viewTaskResultFailedAlert')}: ${error.message}`, 'danger'); + } finally { + downloadLoading.value = false; + } + }; + + const cancelTask = async (taskId, fromDetailPage = false) => { + try { + // 显示确认对话框 + const confirmed = await showConfirmDialog({ + title: t('cancelTaskConfirm'), + message: t('cancelTaskConfirmMessage'), + confirmText: t('confirmCancel'), + }); + + if (!confirmed) { + return; + } + + const response = await apiRequest(`/api/v1/task/cancel?task_id=${taskId}`); + if (response && response.ok) { + showAlert(t('taskCancelSuccessAlert'), 'success'); + + // 如果当前在任务详情界面,刷新任务后关闭详情弹窗 + if (fromDetailPage) { + refreshTasks(true); // 强制刷新 + const updatedTask = tasks.value.find(t => t.task_id === taskId); + if (updatedTask) { + modalTask.value = updatedTask; + } + await nextTick(); + closeTaskDetailModal(); + } else { + refreshTasks(true); // 强制刷新 + } + + } else if (response) { + const error = await response.json(); + showAlert(`${t('cancelTaskFailedAlert')}: ${error.message}`, 'danger'); + } + // 如果response为null,说明是认证错误,apiRequest已经处理了 + } catch (error) { + showAlert(`${t('cancelTaskFailedAlert')}: ${error.message}`, 'danger'); + } + }; + + const resumeTask = async (taskId, fromDetailPage = false) => { + try { + // 先获取任务信息,检查任务状态 + const taskResponse = await apiRequest(`/api/v1/task/query?task_id=${taskId}`); + if (!taskResponse || !taskResponse.ok) { + showAlert(t('taskNotFoundAlert'), 'danger'); + return; + } + + const task = await taskResponse.json(); + + // 如果任务已完成,则删除并重新生成 + if (task.status === 'SUCCEED') { + // 显示确认对话框 + const confirmed = await showConfirmDialog({ + title: t('regenerateTaskConfirm'), + message: t('regenerateTaskConfirmMessage'), + confirmText: t('confirmRegenerate') + }); + + if (!confirmed) { + return; + } + + // 显示重新生成中的提示 + showAlert(t('regeneratingTaskAlert'), 'info'); + + const deleteResponse = await apiRequest(`/api/v1/task/delete?task_id=${taskId}`, { + method: 'DELETE' + }); + if (!deleteResponse || !deleteResponse.ok) { + showAlert(t('deleteTaskFailedAlert'), 'danger'); + return; + } + try { + // 设置任务类型 + selectedTaskId.value = task.task_type; + console.log('selectedTaskId.value', selectedTaskId.value); + + // 获取当前表单 + const currentForm = getCurrentForm(); + + // 设置模型 + if (task.params && task.params.model_cls) { + currentForm.model_cls = task.params.model_cls; + } + + // 设置prompt + if (task.params && task.params.prompt) { + currentForm.prompt = task.params.prompt; + } + + // localStorage 不再保存文件内容,直接从后端获取任务文件 + try { + // 使用现有的函数获取图片和音频URL + const imageUrl = await getTaskInputImage(task); + const audioUrl = await getTaskInputAudio(task); + + // 加载图片文件 + if (imageUrl) { + try { + const imageResponse = await fetch(imageUrl); + if (imageResponse && imageResponse.ok) { + const blob = await imageResponse.blob(); + const filename = task.inputs[Object.keys(task.inputs).find(key => + key.includes('image') || + task.inputs[key].toString().toLowerCase().match(/\.(jpg|jpeg|png|gif|bmp|webp)$/) + )] || 'image.jpg'; + const file = new File([blob], filename, { type: blob.type }); + currentForm.imageFile = file; + setCurrentImagePreview(URL.createObjectURL(file)); + } + } catch (error) { + console.warn('Failed to load image file:', error); + } + } + + // 加载音频文件 + if (audioUrl) { + try { + const audioResponse = await fetch(audioUrl); + if (audioResponse && audioResponse.ok) { + const blob = await audioResponse.blob(); + const filename = task.inputs[Object.keys(task.inputs).find(key => + key.includes('audio') || + task.inputs[key].toString().toLowerCase().match(/\.(mp3|wav|mp4|aac|ogg|m4a)$/) + )] || 'audio.wav'; + + // 根据文件扩展名确定正确的MIME类型 + let mimeType = blob.type; + if (!mimeType || mimeType === 'application/octet-stream') { + const ext = filename.toLowerCase().split('.').pop(); + const mimeTypes = { + 'mp3': 'audio/mpeg', + 'wav': 'audio/wav', + 'mp4': 'audio/mp4', + 'aac': 'audio/aac', + 'ogg': 'audio/ogg', + 'm4a': 'audio/mp4' + }; + mimeType = mimeTypes[ext] || 'audio/mpeg'; + } + + const file = new File([blob], filename, { type: mimeType }); + currentForm.audioFile = file; + console.log('复用任务 - 从后端加载音频文件:', { + name: file.name, + type: file.type, + size: file.size, + originalBlobType: blob.type + }); + // 使用FileReader生成data URL,与正常上传保持一致 + const reader = new FileReader(); + reader.onload = (e) => { + setCurrentAudioPreview(e.target.result); + console.log('复用任务 - 音频预览已设置:', e.target.result.substring(0, 50) + '...'); + }; + reader.readAsDataURL(file); + } + + } catch (error) { + console.warn('Failed to load audio file:', error); + } + } + } catch (error) { + console.warn('Failed to load task data from backend:', error); + } + + showAlert(t('taskMaterialReuseSuccessAlert'), 'success'); + + } catch (error) { + console.error('Failed to resume task:', error); + showAlert(t('loadTaskDataFailedAlert'), 'danger'); + return; + } + // 如果从详情页调用,关闭详情页 + if (fromDetailPage) { + closeTaskDetailModal(); + } + + submitTask(); + + + return; // 不需要继续执行后续的API调用 + } else { + // 对于未完成的任务,使用原有的恢复逻辑 + const response = await apiRequest(`/api/v1/task/resume?task_id=${taskId}`); + if (response && response.ok) { + showAlert(t('taskRetrySuccessAlert'), 'success'); + + // 如果当前在任务详情界面,先刷新任务列表,然后重新获取任务信息 + if (fromDetailPage) { + refreshTasks(true); // 强制刷新 + const updatedTask = tasks.value.find(t => t.task_id === taskId); + if (updatedTask) { + selectedTask.value = updatedTask; + } + startPollingTask(taskId); + await nextTick(); + } else { + refreshTasks(true); // 强制刷新 + + // 开始轮询新提交的任务状态 + startPollingTask(taskId); + } + } else if (response) { + const error = await response.json(); + showAlert(`${t('retryTaskFailedAlert')}: ${error.message}`, 'danger'); + } + } + } catch (error) { + console.error('resumeTask error:', error); + showAlert(`${t('retryTaskFailedAlert')}: ${error.message}`, 'danger'); + } + }; + + // 切换任务菜单显示状态 + const toggleTaskMenu = (taskId) => { + // 先关闭所有其他菜单 + closeAllTaskMenus(); + // 然后打开当前菜单 + taskMenuVisible.value[taskId] = true; + }; + + // 关闭所有任务菜单 + const closeAllTaskMenus = () => { + taskMenuVisible.value = {}; + }; + + // 点击外部关闭菜单 + const handleClickOutside = (event) => { + if (!event.target.closest('.task-menu-container')) { + closeAllTaskMenus(); + } + if (!event.target.closest('.task-type-dropdown')) { + showTaskTypeMenu.value = false; + } + if (!event.target.closest('.model-dropdown')) { + showModelMenu.value = false; + } + }; + + const deleteTask = async (taskId, fromDetailPage = false) => { + try { + // 显示确认对话框 + const confirmed = await showConfirmDialog({ + title: t('deleteTaskConfirm'), + message: t('deleteTaskConfirmMessage'), + confirmText: t('confirmDelete') + }); + + if (!confirmed) { + return; + } + const response = await apiRequest(`/api/v1/task/delete?task_id=${taskId}`, { + method: 'DELETE' + }); + + if (response && response.ok) { + showAlert(t('taskDeletedSuccessAlert'), 'success'); + const deletedTaskIndex = tasks.value.findIndex(task => task.task_id === taskId); + if (deletedTaskIndex !== -1) { + const wasCurrent = currentTask.value?.task_id === taskId; + tasks.value.splice(deletedTaskIndex, 1); + if (wasCurrent) { + currentTask.value = tasks.value[deletedTaskIndex] || tasks.value[deletedTaskIndex - 1] || null; + } + } + refreshTasks(true); // 强制刷新 + + // 如果是从任务详情页删除,删除成功后关闭详情弹窗 + if (fromDetailPage) { + closeTaskDetailModal(); + if (!selectedTaskId.value) { + if (availableTaskTypes.value.includes('s2v')) { + selectTask('s2v'); + } + } + } + } else if (response) { + const error = await response.json(); + showAlert(`${t('deleteTaskFailedAlert')}: ${error.message}`, 'danger'); + } + // 如果response为null,说明是认证错误,apiRequest已经处理了 + } catch (error) { + showAlert(`${t('deleteTaskFailedAlert')}: ${error.message}`, 'danger'); + } + }; + + const loadTaskFiles = async (task) => { + try { + loadingTaskFiles.value = true; + + const files = { inputs: {}, outputs: {} }; + + // 获取输入文件(所有状态的任务都需要) + if (task.inputs) { + for (const [key, inputPath] of Object.entries(task.inputs)) { + try { + const url = await getTaskFileUrl(taskId, key); + if (url) { + const response = await fetch(url); + if (response && response.ok) { + const blob = await response.blob() + files.inputs[key] = { + name: inputPath, // 使用原始文件名而不是key + path: inputPath, + blob: blob, + url: URL.createObjectURL(blob) + } + } + } + } catch (error) { + console.error(`Failed to load input ${key}:`, error); + files.inputs[key] = { + name: inputPath, // 使用原始文件名而不是key + path: inputPath, + error: true + }; + } + } + } + + // 只对成功完成的任务获取输出文件 + if (task.status === 'SUCCEED' && task.outputs) { + for (const [key, outputPath] of Object.entries(task.outputs)) { + try { + const url = await getTaskFileUrl(taskId, key); + if (url) { + const response = await fetch(url); + if (response && response.ok) { + const blob = await response.blob() + files.outputs[key] = { + name: outputPath, // 使用原始文件名而不是key + path: outputPath, + blob: blob, + url: URL.createObjectURL(blob) + } + }; + } + } catch (error) { + console.error(`Failed to load output ${key}:`, error); + files.outputs[key] = { + name: outputPath, // 使用原始文件名而不是key + path: outputPath, + error: true + }; + } + } + } + + selectedTaskFiles.value = files; + + } catch (error) { + console.error('Failed to load task files: task_id=', taskId, error); + showAlert(t('loadTaskFilesFailedAlert'), 'danger'); + } finally { + loadingTaskFiles.value = false; + } + }; + + const reuseTask = async (task) => { + if (!task) { + showAlert(t('loadTaskDataFailedAlert'), 'danger'); + return; + } + + try { + templateLoading.value = true; + templateLoadingMessage.value = t('prefillLoadingTask'); + // 跳转到任务创建界面 + isCreationAreaExpanded.value = true; + if (showTaskDetailModal.value) { + closeTaskDetailModal(); + } + + // 设置任务类型 + selectedTaskId.value = task.task_type; + console.log('selectedTaskId.value', selectedTaskId.value); + + // 获取当前表单 + const currentForm = getCurrentForm(); + + // 立即切换到创建视图,后续资产异步加载 + switchToCreateView(); + + // 设置模型 + if (task.params && task.params.model_cls) { + currentForm.model_cls = task.params.model_cls; + } + + // 设置prompt + if (task.params && task.params.prompt) { + currentForm.prompt = task.params.prompt; + } + + // localStorage 不再保存文件内容,直接从后端获取任务文件 + try { + // 使用现有的函数获取图片和音频URL + const imageUrl = await getTaskInputImage(task); + const audioUrl = await getTaskInputAudio(task); + + + // 加载音频文件 + if (audioUrl) { + try { + const audioResponse = await fetch(audioUrl); + if (audioResponse && audioResponse.ok) { + // Check if the response is an error (for directory inputs) + const contentType = audioResponse.headers.get('content-type'); + if (contentType && contentType.includes('application/json')) { + const errorData = await audioResponse.json(); + // Not a directory error, proceed with normal loading + currentForm.audioUrl = audioUrl; + setCurrentAudioPreview(audioUrl); + + const blob = await audioResponse.blob(); + const filename = task.inputs[Object.keys(task.inputs).find(key => + key.includes('audio') || + task.inputs[key].toString().toLowerCase().match(/\.(mp3|wav|mp4|aac|ogg|m4a)$/) + )] || 'audio.wav'; + + // 根据文件扩展名确定正确的MIME类型 + let mimeType = blob.type; + if (!mimeType || mimeType === 'application/octet-stream') { + const ext = filename.toLowerCase().split('.').pop(); + const mimeTypes = { + 'mp3': 'audio/mpeg', + 'wav': 'audio/wav', + 'mp4': 'audio/mp4', + 'aac': 'audio/aac', + 'ogg': 'audio/ogg', + 'm4a': 'audio/mp4' + }; + mimeType = mimeTypes[ext] || 'audio/mpeg'; + } + + const file = new File([blob], filename, { type: mimeType }); + currentForm.audioFile = file; + console.log('复用任务 - 从后端加载音频文件:', { + name: file.name, + type: file.type, + size: file.size, + originalBlobType: blob.type + }); + } else { + // Normal audio file response + currentForm.audioUrl = audioUrl; + setCurrentAudioPreview(audioUrl); + + const blob = await audioResponse.blob(); + const filename = task.inputs[Object.keys(task.inputs).find(key => + key.includes('audio') || + task.inputs[key].toString().toLowerCase().match(/\.(mp3|wav|mp4|aac|ogg|m4a)$/) + )] || 'audio.wav'; + + // 根据文件扩展名确定正确的MIME类型 + let mimeType = blob.type; + if (!mimeType || mimeType === 'application/octet-stream') { + const ext = filename.toLowerCase().split('.').pop(); + const mimeTypes = { + 'mp3': 'audio/mpeg', + 'wav': 'audio/wav', + 'mp4': 'audio/mp4', + 'aac': 'audio/aac', + 'ogg': 'audio/ogg', + 'm4a': 'audio/mp4' + }; + mimeType = mimeTypes[ext] || 'audio/mpeg'; + } + + const file = new File([blob], filename, { type: mimeType }); + currentForm.audioFile = file; + console.log('复用任务 - 从后端加载音频文件:', { + name: file.name, + type: file.type, + size: file.size, + originalBlobType: blob.type + }); + } + } + } catch (error) { + console.warn('Failed to load audio file:', error); + } + } + + // 加载图片文件 + if (imageUrl) { + try { + const imageResponse = await fetch(imageUrl); + if (imageResponse && imageResponse.ok) { + const blob = await imageResponse.blob(); + const filename = task.inputs[Object.keys(task.inputs).find(key => + key.includes('image') || + task.inputs[key].toString().toLowerCase().match(/\.(jpg|jpeg|png|gif|bmp|webp)$/) + )] || 'image.jpg'; + const file = new File([blob], filename, { type: blob.type }); + currentForm.imageFile = file; + const imagePreviewUrl = URL.createObjectURL(file); + setCurrentImagePreview(imageUrl); + + // Reset detected faces + if (selectedTaskId.value === 'i2v') { + i2vForm.value.detectedFaces = []; + } else if (selectedTaskId.value === 's2v') { + s2vForm.value.detectedFaces = []; + } + + // 不再自动检测人脸,等待用户手动打开多角色模式开关 + } + } catch (error) { + console.warn('Failed to load image file:', error); + } + } + } catch (error) { + console.warn('Failed to load task data from backend:', error); + } + + showAlert(t('taskMaterialReuseSuccessAlert'), 'success'); + + } catch (error) { + console.error('Failed to reuse task:', error); + showAlert(t('loadTaskDataFailedAlert'), 'danger'); + } finally { + templateLoading.value = false; + templateLoadingMessage.value = ''; + } + }; + + const downloadFile = async (fileInfo) => { + if (!fileInfo || !fileInfo.blob) { + showAlert(t('fileUnavailableAlert'), 'danger'); + return false; + } + + const blob = fileInfo.blob; + const fileName = fileInfo.name || 'download'; + const mimeType = blob.type || fileInfo.mimeType || 'application/octet-stream'; + + try { + const objectUrl = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = objectUrl; + a.download = fileName; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(objectUrl); + showAlert(t('downloadSuccessAlert'), 'success'); + return true; + } catch (error) { + console.error('Download failed:', error); + showAlert(t('downloadFailedAlert'), 'danger'); + return false; + } + }; + + // 处理文件下载 + const handleDownloadFile = async (taskId, fileKey, fileName) => { + if (downloadLoading.value) { + showAlert(t('downloadInProgressNotice'), 'info'); + return; + } + + downloadLoading.value = true; + downloadLoadingMessage.value = t('downloadPreparing'); + + try { + console.log('开始下载文件:', { taskId, fileKey, fileName }); + + // 处理文件名,确保有正确的后缀名 + let finalFileName = fileName; + if (fileName && typeof fileName === 'string') { + const hasExtension = /\.[a-zA-Z0-9]+$/.test(fileName); + if (!hasExtension) { + const extension = getFileExtension(fileKey); + finalFileName = `${fileName}.${extension}`; + console.log('添加后缀名:', finalFileName); + } + } else { + finalFileName = `${fileKey}.${getFileExtension(fileKey)}`; + } + + downloadLoadingMessage.value = t('downloadFetching'); + + let downloadUrl = null; + + const cachedData = getTaskFileFromCache(taskId, fileKey); + if (cachedData?.url) { + downloadUrl = cachedData.url; + } + + if (!downloadUrl) { + downloadUrl = await getTaskFileUrl(taskId, fileKey); + } + + if (!downloadUrl) { + throw new Error('无法获取文件URL'); + } + + const response = await fetch(downloadUrl); + if (!response.ok) { + throw new Error(`文件响应失败: ${response.status}`); + } + + const blob = await response.blob(); + downloadLoadingMessage.value = t('downloadSaving'); + await downloadFile({ + blob, + name: finalFileName, + mimeType: blob.type + }); + } catch (error) { + console.error('下载失败:', error); + showAlert(t('downloadFailedAlert'), 'danger'); + } finally { + downloadLoading.value = false; + downloadLoadingMessage.value = ''; + } + } + + const viewFile = (fileInfo) => { + if (!fileInfo || !fileInfo.url) { + showAlert(t('fileUnavailableAlert'), 'danger'); + return; + } + + // 在新窗口中打开文件 + window.open(fileInfo.url, '_blank'); + }; + + const clearTaskFiles = () => { + // 清理 URL 对象,释放内存 + Object.values(selectedTaskFiles.value.inputs).forEach(file => { + if (file.url) { + URL.revokeObjectURL(file.url); + } + }); + Object.values(selectedTaskFiles.value.outputs).forEach(file => { + if (file.url) { + URL.revokeObjectURL(file.url); + } + }); + selectedTaskFiles.value = { inputs: {}, outputs: {} }; + }; + + const showTaskCreator = () => { + selectedTask.value = null; + // clearTaskFiles(); // 清空文件缓存 + selectedTaskId.value = 's2v'; // 默认选择数字人任务 + + // 停止所有任务状态轮询 + pollingTasks.value.clear(); + if (pollingInterval.value) { + clearInterval(pollingInterval.value); + pollingInterval.value = null; + } + }; + + const toggleSidebar = () => { + sidebarCollapsed.value = !sidebarCollapsed.value; + + if (sidebarCollapsed.value) { + // 收起时,将历史任务栏隐藏到屏幕左侧 + if (sidebar.value) { + sidebar.value.style.transform = 'translateX(-100%)'; + } + } else { + // 展开时,恢复历史任务栏位置 + if (sidebar.value) { + sidebar.value.style.transform = 'translateX(0)'; + } + } + + // 更新悬浮按钮位置 + updateFloatingButtonPosition(sidebarWidth.value); + }; + + const clearPrompt = () => { + getCurrentForm().prompt = ''; + updateUploadedContentStatus(); + }; + + const getTaskItemClass = (status) => { + if (status === 'SUCCEED') return 'bg-laser-purple/15 border border-laser-purple/30'; + if (status === 'RUNNING') return 'bg-laser-purple/15 border border-laser-purple/30'; + if (status === 'FAILED') return 'bg-red-500/15 border border-red-500/30'; + return 'bg-dark-light border border-gray-700'; + }; + + const getStatusIndicatorClass = (status) => { + const base = 'inline-block w-2 aspect-square rounded-full shrink-0 align-middle'; + if (status === 'SUCCEED') + return `${base} bg-gradient-to-r from-emerald-200 to-green-300 shadow-md shadow-emerald-300/30`; + if (status === 'RUNNING') + return `${base} bg-gradient-to-r from-amber-200 to-yellow-300 shadow-md shadow-amber-300/30 animate-pulse`; + if (status === 'FAILED') + return `${base} bg-gradient-to-r from-red-200 to-pink-300 shadow-md shadow-red-300/30`; + return `${base} bg-gradient-to-r from-gray-200 to-gray-300 shadow-md shadow-gray-300/30`; + }; + + const getTaskTypeBtnClass = (taskType) => { + if (selectedTaskId.value === taskType) { + return 'text-gradient-icon border-b-2 border-laser-purple'; + } + return 'text-gray-400 hover:text-gradient-icon'; + }; + + const getModelBtnClass = (model) => { + if (getCurrentForm().model_cls === model) { + return 'bg-laser-purple/20 border border-laser-purple/40 active shadow-laser'; + } + return 'bg-dark-light border border-gray-700 hover:bg-laser-purple/15 hover:border-laser-purple/40 transition-all hover:shadow-laser'; + }; + + const getTaskTypeIcon = (taskType) => { + const iconMap = { + 't2v': 'fas fa-font', // 文字A形图标 + 'i2v': 'fas fa-image', // 图像图标 + 's2v': 'fas fa-user', // 人物图标 + 'animate': 'fi fi-br-running text-lg' // 角色替换图标 + }; + return iconMap[taskType] || 'fas fa-video'; + }; + + const getTaskTypeName = (task) => { + // 如果传入的是字符串,直接返回映射 + if (!task) { + return '未知'; + } + if (typeof task === 'string') { + return nameMap.value[task] || task; + } + + // 如果传入的是任务对象,根据模型类型判断 + if (task && task.model_cls) { + const modelCls = task.model_cls.toLowerCase(); + + return nameMap.value[task.task_type] || task.task_type; + } + + // 默认返回task_type + return task.task_type || '未知'; + }; + + const getPromptPlaceholder = () => { + if (selectedTaskId.value === 't2v') { + return t('pleaseEnterThePromptForVideoGeneration') + ','+ t('describeTheContentStyleSceneOfTheVideo'); + } else if (selectedTaskId.value === 'i2v') { + return t('pleaseEnterThePromptForVideoGeneration') + ','+ t('describeTheContentActionRequirementsBasedOnTheImage'); + } else if (selectedTaskId.value === 's2v') { + return t('optional') + ' '+ t('pleaseEnterThePromptForVideoGeneration') + ','+ t('describeTheDigitalHumanImageBackgroundStyleActionRequirements'); + } else if (selectedTaskId.value === 'animate') { + return t('optional') + ' '+ t('pleaseEnterThePromptForVideoGeneration') + ','+ t('describeTheContentActionRequirementsBasedOnTheImage'); + } + return t('pleaseEnterThePromptForVideoGeneration') + '...'; + }; + + const getStatusTextClass = (status) => { + if (status === 'SUCCEED') return 'text-emerald-400'; + if (status === 'CREATED') return 'text-blue-400'; + if (status === 'PENDING') return 'text-yellow-400'; + if (status === 'RUNNING') return 'text-amber-400'; + if (status === 'FAILED') return 'text-red-400'; + if (status === 'CANCEL') return 'text-gray-400'; + return 'text-gray-400'; + }; + + const getImagePreview = (base64Data) => { + if (!base64Data) return ''; + return `data:image/jpeg;base64,${base64Data}`; + }; + + const getTaskInputUrl = async (taskId, key) => { + // 优先从缓存获取 + const cachedUrl = getTaskFileUrlSync(taskId, key); + if (cachedUrl) { + console.log('getTaskInputUrl: 从缓存获取', { taskId, key, url: cachedUrl }); + return cachedUrl; + } + return await getTaskFileUrlFromApi(taskId, key); + }; + + const getTaskInputImage = async (task) => { + + if (!task || !task.inputs) { + console.log('getTaskInputImage: 任务或输入为空', { task: task?.task_id, inputs: task?.inputs }); + return null; + } + + const imageInputs = Object.keys(task.inputs).filter(key => + key.includes('image') || + task.inputs[key].toString().toLowerCase().match(/\.(jpg|jpeg|png|gif|bmp|webp)$/) + ); + + if (imageInputs.length > 0) { + const firstImageKey = imageInputs[0]; + // 优先从缓存获取 + const cachedUrl = getTaskFileUrlSync(task.task_id, firstImageKey); + if (cachedUrl) { + console.log('getTaskInputImage: 从缓存获取', { taskId: task.task_id, key: firstImageKey, url: cachedUrl }); + return cachedUrl; + } + // 缓存没有则生成URL + const url = await getTaskInputUrl(task.task_id, firstImageKey); + console.log('getTaskInputImage: 生成URL', { taskId: task.task_id, key: firstImageKey, url }); + return url; + } + + console.log('getTaskInputImage: 没有找到图片输入'); + return null; + }; + + const getTaskInputAudio = async (task) => { + if (!task || !task.inputs) return null; + + // Directly use 'input_audio' key + const audioKey = 'input_audio'; + if (!task.inputs[audioKey]) return null; + + // Always bypass cache and check API directly to detect directory type + // This ensures we get the correct URL even if cache has invalid data + let url = await getTaskFileUrlFromApi(task.task_id, audioKey); + + // If it's a directory (multi-person mode) or URL is null, try to get original_audio file + if (!url) { + console.log(`Audio input ${audioKey} is a directory (multi-person mode), trying to get original_audio file`); + // Try to get original_audio file from directory + // Try common extensions + const extensions = ['wav', 'mp3', 'mp4', 'aac', 'ogg', 'm4a']; + for (const ext of extensions) { + const originalAudioFilename = `original_audio.${ext}`; + url = await getTaskFileUrlFromApi(task.task_id, audioKey, originalAudioFilename); + if (url) { + console.log(`Found original audio file: ${originalAudioFilename}`); + break; + } + } + } + + return url; + }; + + const handleThumbnailError = (event) => { + // 当输入图片加载失败时,显示默认图标 + const img = event.target; + const parent = img.parentElement; + parent.innerHTML = '
'; + }; + + const handleImageError = (event) => { + // 当图片加载失败时,隐藏图片,显示文件名 + const img = event.target; + img.style.display = 'none'; + // 文件名已经显示,不需要额外处理 + }; + + const handleImageLoad = (event) => { + // 当图片加载成功时,显示图片和下载按钮,隐藏文件名 + const img = event.target; + img.style.display = 'block'; + // 显示下载按钮 + const downloadBtn = img.parentElement.querySelector('button'); + if (downloadBtn) { + downloadBtn.style.display = 'block'; + } + // 隐藏文件名span + const span = img.parentElement.parentElement.querySelector('span'); + if (span) { + span.style.display = 'none'; + } + }; + + const handleAudioError = (event) => { + // 当音频加载失败时,隐藏音频控件和下载按钮,显示文件名 + const audio = event.target; + audio.style.display = 'none'; + // 隐藏下载按钮 + const downloadBtn = audio.parentElement.querySelector('button'); + if (downloadBtn) { + downloadBtn.style.display = 'none'; + } + // 文件名已经显示,不需要额外处理 + }; + + const handleAudioLoad = (event) => { + // 当音频加载成功时,显示音频控件和下载按钮,隐藏文件名 + const audio = event.target; + audio.style.display = 'block'; + // 显示下载按钮 + const downloadBtn = audio.parentElement.querySelector('button'); + if (downloadBtn) { + downloadBtn.style.display = 'block'; + } + // 隐藏文件名span + const span = audio.parentElement.parentElement.querySelector('span'); + if (span) { + span.style.display = 'none'; + } + }; + + // 监听currentPage变化,同步更新pageInput + watch(currentTaskPage, (newPage) => { + taskPageInput.value = newPage; + }); + + // 监听pagination变化,确保分页组件更新 + watch(pagination, (newPagination) => { + console.log('pagination变化:', newPagination); + if (newPagination && newPagination.total_pages) { + // 确保当前页不超过总页数 + if (currentTaskPage.value > newPagination.total_pages) { + currentTaskPage.value = newPagination.total_pages; + } + } + }, { deep: true }); + + // 监听templateCurrentPage变化,同步更新templatePageInput + watch(templateCurrentPage, (newPage) => { + templatePageInput.value = newPage; + }); + + // 监听templatePagination变化,确保分页组件更新 + watch(templatePagination, (newPagination) => { + console.log('templatePagination变化:', newPagination); + if (newPagination && newPagination.total_pages) { + // 确保当前页不超过总页数 + if (templateCurrentPage.value > newPagination.total_pages) { + templateCurrentPage.value = newPagination.total_pages; + } + } + }, { deep: true }); + + // 监听inspirationCurrentPage变化,同步更新inspirationPageInput + watch(inspirationCurrentPage, (newPage) => { + inspirationPageInput.value = newPage; + }); + + // 监听inspirationPagination变化,确保分页组件更新 + watch(inspirationPagination, (newPagination) => { + console.log('inspirationPagination变化:', newPagination); + if (newPagination && newPagination.total_pages) { + // 确保当前页不超过总页数 + if (inspirationCurrentPage.value > newPagination.total_pages) { + inspirationCurrentPage.value = newPagination.total_pages; + } + } + }, { deep: true }); + + // 统一的初始化函数 + const init = async () => { + try { + // 0. 初始化主题 + initTheme(); + + // 1. 加载模型和任务数据 + await loadModels(); + + // 2. 从路由恢复或设置默认值 + const routeQuery = router.currentRoute.value?.query || {}; + const routeTaskType = routeQuery.taskType; + const routeModel = routeQuery.model; + const routeExpanded = routeQuery.expanded; + + if (routeTaskType && availableTaskTypes.value.includes(routeTaskType)) { + // 路由中有 taskType,恢复它 + selectTask(routeTaskType); + + if (routeModel && availableModelClasses.value.includes(routeModel)) { + // 路由中有 model,恢复它(会自动设置 stage) + selectModel(routeModel); + } else { + // 路由中没有 model 或 model 无效,选择第一个模型 + const firstModel = availableModelClasses.value[0]; + if (firstModel) { + selectModel(firstModel); + } + } + } else { + // 路由中没有 taskType,设置默认值:s2v + const defaultTaskType = availableTaskTypes.value.includes('s2v') ? 's2v' : availableTaskTypes.value[0]; + if (defaultTaskType) { + selectTask(defaultTaskType); + + // 选择该任务下的第一个模型 + const firstModel = availableModelClasses.value[0]; + if (firstModel) { + selectModel(firstModel); + } + } + } + + // 3. 恢复 expanded 状态(如果路由中有) + if (routeExpanded === 'true') { + expandCreationArea(); + } + + // 4. 加载历史记录和素材库(异步,不阻塞首屏) + refreshTasks(true); + loadInspirationData(true); + + // 5. 加载历史记录和素材库文件(异步,不阻塞首屏) + getPromptHistory(); + loadTaskFilesFromCache(); + loadTemplateFilesFromCache(); + + // 异步加载模板数据,不阻塞首屏渲染 + setTimeout(() => { + loadImageAudioTemplates(true); + }, 100); + + + console.log('初始化完成:', { + currentUser: currentUser.value, + availableModels: models.value, + tasks: tasks.value, + inspirationItems: inspirationItems.value, + selectedTaskId: selectedTaskId.value, + selectedModel: selectedModel.value, + currentForm: { + model_cls: getCurrentForm().model_cls, + stage: getCurrentForm().stage + } + }); + + } catch (error) { + console.error('初始化失败:', error); + showAlert(t('initFailedPleaseRefresh'), 'danger'); + } + }; + + // 重置表单函数(保留模型选择,清空图片、音频和提示词) + const resetForm = async (taskType) => { + const currentForm = getCurrentForm(); + const currentModel = currentForm.model_cls; + const currentStage = currentForm.stage; + + // 重置表单但保留模型和阶段 + switch (taskType) { + case 't2v': + t2vForm.value = { + task: 't2v', + model_cls: currentModel, + stage: currentStage, + prompt: '', + seed: Math.floor(Math.random() * 1000000) + }; + break; + case 'i2v': + i2vForm.value = { + task: 'i2v', + model_cls: currentModel, + stage: currentStage, + imageFile: null, + prompt: '', + seed: Math.floor(Math.random() * 1000000) + }; + // 直接清空i2v图片预览 + i2vImagePreview.value = null; + // 清理图片文件输入框 + const imageInput = document.querySelector('input[type="file"][accept="image/*"]'); + if (imageInput) { + imageInput.value = ''; + } + break; + case 's2v': + s2vForm.value = { + task: 's2v', + model_cls: currentModel, + stage: currentStage, + imageFile: null, + audioFile: null, + prompt: 'Make the character speak in a natural way according to the audio.', + seed: Math.floor(Math.random() * 1000000) + }; + break; + case 'animate': + animateForm.value = { + task: 'animate', + model_cls: currentModel, + stage: currentStage, + imageFile: null, + videoFile: null, + prompt: '视频中的人在做动作', + seed: Math.floor(Math.random() * 1000000), + detectedFaces: [] + }; + // 直接清空animate图片和视频预览 + animateImagePreview.value = null; + animateVideoPreview.value = null; + // 清理图片和视频文件输入框 + const animateImageInput = document.querySelector('input[type="file"][accept="image/*"]'); + if (animateImageInput) { + animateImageInput.value = ''; + } + const animateVideoInput = document.querySelector('input[type="file"][data-role="video-input"]'); + if (animateVideoInput) { + animateVideoInput.value = ''; + } + break; + } + + // 强制触发Vue响应式更新 + setCurrentImagePreview(null); + setCurrentAudioPreview(null); + await nextTick(); + }; + + // 开始轮询任务状态 + const startPollingTask = (taskId) => { + if (!pollingTasks.value.has(taskId)) { + pollingTasks.value.add(taskId); + console.log(`开始轮询任务状态: ${taskId}`); + + // 如果还没有轮询定时器,启动一个 + if (!pollingInterval.value) { + pollingInterval.value = setInterval(async () => { + await pollTaskStatuses(); + }, 1000); // 每1秒轮询一次 + } + } + }; + + // 停止轮询任务状态 + const stopPollingTask = (taskId) => { + pollingTasks.value.delete(taskId); + console.log(`停止轮询任务状态: ${taskId}`); + + // 如果没有任务需要轮询了,清除定时器 + if (pollingTasks.value.size === 0 && pollingInterval.value) { + clearInterval(pollingInterval.value); + pollingInterval.value = null; + console.log('停止所有任务状态轮询'); + } + }; + + const refreshTaskFiles = (task) => { + for (const [key, inputPath] of Object.entries(task.inputs)) { + getTaskFileUrlFromApi(task.task_id, key).then(url => { + console.log('refreshTaskFiles: input', task.task_id, key, url); + }); + } + for (const [key, outputPath] of Object.entries(task.outputs)) { + getTaskFileUrlFromApi(task.task_id, key).then(url => { + console.log('refreshTaskFiles: output', task.task_id, key, url); + }); + } + }; + + // 轮询任务状态 + const pollTaskStatuses = async () => { + if (pollingTasks.value.size === 0) return; + + try { + const taskIds = Array.from(pollingTasks.value); + const response = await apiRequest(`/api/v1/task/query?task_ids=${taskIds.join(',')}`); + + if (response && response.ok) { + const tasksData = await response.json(); + const updatedTasks = tasksData.tasks || []; + + // 更新任务列表中的任务状态 + let hasUpdates = false; + updatedTasks.forEach(updatedTask => { + const existingTaskIndex = tasks.value.findIndex(t => t.task_id === updatedTask.task_id); + if (existingTaskIndex !== -1) { + const oldTask = tasks.value[existingTaskIndex]; + tasks.value[existingTaskIndex] = updatedTask; + console.log('updatedTask', updatedTask); + console.log('oldTask', oldTask); + + // 如果状态发生变化,记录日志 + if (oldTask !== updatedTask) { + hasUpdates = true; // 这里基本都会变,因为任务有进度条 + + // 如果当前在查看这个任务的详情,更新selectedTask + if (modalTask.value && modalTask.value.task_id === updatedTask.task_id) { + modalTask.value = updatedTask; + if (updatedTask.status === 'SUCCEED') { + console.log('refresh viewing task: output files'); + loadTaskFiles(updatedTask); + } + } + + // 如果当前TaskCarousel显示的是这个任务,更新currentTask + if (currentTask.value && currentTask.value.task_id === updatedTask.task_id) { + currentTask.value = updatedTask; + console.log('TaskCarousel: 更新currentTask', updatedTask); + } + + // 如果当前在projects页面且变化的是状态,更新tasks + if (router.path === '/projects' && oldTask.status !== updatedTask.status) { + refreshTasks(true); + } + + // 如果任务完成或失败,停止轮询并显示提示 + if (['SUCCEED', 'FAILED', 'CANCEL'].includes(updatedTask.status)) { + stopPollingTask(updatedTask.task_id); + refreshTaskFiles(updatedTask); + refreshTasks(true); + + // 显示任务完成提示 + if (updatedTask.status === 'SUCCEED') { + showAlert(t('taskCompletedSuccessfully'), 'success', { + label: t('view'), + onClick: () => { + openTaskDetailModal(updatedTask); + } + }); + } else if (updatedTask.status === 'FAILED') { + showAlert(t('videoGeneratingFailed'), 'danger', { + label: t('view'), + onClick: () => { + openTaskDetailModal(updatedTask); + } + }); + } else if (updatedTask.status === 'CANCEL') { + showAlert(t('taskCancelled'), 'warning'); + } + } + } + } + }); + + // 如果有更新,触发界面刷新 + if (hasUpdates) { + await nextTick(); + } + } + } catch (error) { + console.error('轮询任务状态失败:', error); + } + }; + + // 任务状态管理 + const getTaskStatusDisplay = (status) => { + const statusMap = { + 'CREATED': t('created'), + 'PENDING': t('pending'), + 'RUNNING': t('running'), + 'SUCCEED': t('succeed'), + 'FAILED': t('failed'), + 'CANCEL': t('cancelled') + }; + return statusMap[status] || status; + }; + + const getTaskStatusColor = (status) => { + const colorMap = { + 'CREATED': 'text-blue-400', + 'PENDING': 'text-yellow-400', + 'RUNNING': 'text-amber-400', + 'SUCCEED': 'text-emerald-400', + 'FAILED': 'text-red-400', + 'CANCEL': 'text-gray-400' + }; + return colorMap[status] || 'text-gray-400'; + }; + + const getTaskStatusIcon = (status) => { + const iconMap = { + 'CREATED': 'fas fa-clock', + 'PENDING': 'fas fa-hourglass-half', + 'RUNNING': 'fas fa-spinner fa-spin', + 'SUCCEED': 'fas fa-check-circle', + 'FAILED': 'fas fa-exclamation-triangle', + 'CANCEL': 'fas fa-ban' + }; + return iconMap[status] || 'fas fa-question-circle'; + }; + + // 任务时间格式化 + const getTaskDuration = (startTime, endTime) => { + if (!startTime || !endTime) return '未知'; + const start = new Date(startTime * 1000); + const end = new Date(endTime * 1000); + const diff = end - start; + const minutes = Math.floor(diff / 60000); + const seconds = Math.floor((diff % 60000) / 1000); + return `${minutes}分${seconds}秒`; + }; + + // 相对时间格式化 + const getRelativeTime = (timestamp) => { + if (!timestamp) return '未知'; + const now = new Date(); + const time = new Date(timestamp * 1000); + const diff = now - time; + + const minutes = Math.floor(diff / 60000); + const hours = Math.floor(diff / 3600000); + const days = Math.floor(diff / 86400000); + const months = Math.floor(diff / 2592000000); // 30天 + const years = Math.floor(diff / 31536000000); + + if (years > 0) { + return years === 1 ? t('oneYearAgo') : `${years}t('yearsAgo')`; + } else if (months > 0) { + return months === 1 ? t('oneMonthAgo') : `${months}${t('monthsAgo')}`; + } else if (days > 0) { + return days === 1 ? t('oneDayAgo') : `${days}${t('daysAgo')}`; + } else if (hours > 0) { + return hours === 1 ? t('oneHourAgo') : `${hours}${t('hoursAgo')}`; + } else if (minutes > 0) { + return minutes === 1 ? t('oneMinuteAgo') : `${minutes}${t('minutesAgo')}`; + } else { + return t('justNow'); + } + }; + + // 任务历史记录管理 + const getTaskHistory = () => { + return tasks.value.filter(task => + ['SUCCEED', 'FAILED', 'CANCEL'].includes(task.status) + ); + }; + + // 子任务进度相关函数 + const getOverallProgress = (subtasks) => { + if (!subtasks || subtasks.length === 0) return 0; + + let completedCount = 0; + subtasks.forEach(subtask => { + if (subtask.status === 'SUCCEED') { + completedCount++; + } + }); + + return Math.round((completedCount / subtasks.length) * 100); + }; + + // 获取进度条标题 + const getProgressTitle = (subtasks) => { + if (!subtasks || subtasks.length === 0) return t('overallProgress'); + + const pendingSubtasks = subtasks.filter(subtask => subtask.status === 'PENDING'); + const runningSubtasks = subtasks.filter(subtask => subtask.status === 'RUNNING'); + + if (pendingSubtasks.length > 0) { + return t('queueStatus'); + } else if (runningSubtasks.length > 0) { + return t('running'); + } else { + return t('overallProgress'); + } + }; + + // 获取进度信息 + const getProgressInfo = (subtasks) => { + if (!subtasks || subtasks.length === 0) return '0%'; + + const pendingSubtasks = subtasks.filter(subtask => subtask.status === 'PENDING'); + const runningSubtasks = subtasks.filter(subtask => subtask.status === 'RUNNING'); + + if (pendingSubtasks.length > 0) { + // 显示排队信息 + const firstPending = pendingSubtasks[0]; + const queuePosition = firstPending.estimated_pending_order; + const estimatedTime = firstPending.estimated_pending_secs; + + let info = t('queueing'); + if (queuePosition !== null && queuePosition !== undefined) { + info += ` (${t('position')}: ${queuePosition})`; + } + if (estimatedTime !== null && estimatedTime !== undefined) { + info += ` - ${formatDuration(estimatedTime)}`; + } + return info; + } else if (runningSubtasks.length > 0) { + // 显示运行信息 + const firstRunning = runningSubtasks[0]; + const workerName = firstRunning.worker_name || t('unknown'); + const estimatedTime = firstRunning.estimated_running_secs; + + let info = `${t('subtask')} ${workerName}`; + if (estimatedTime !== null && estimatedTime !== undefined) { + const elapses = firstRunning.elapses || {}; + const runningTime = elapses['RUNNING-'] || 0; + const remaining = Math.max(0, estimatedTime - runningTime); + info += ` - ${t('remaining')} ${formatDuration(remaining)}`; + } + return info; + } else { + // 显示总体进度 + return getOverallProgress(subtasks) + '%'; + } + }; + + const getSubtaskProgress = (subtask) => { + if (subtask.status === 'SUCCEED') return 100; + if (subtask.status === 'FAILED' || subtask.status === 'CANCEL') return 0; + + // 对于PENDING和RUNNING状态,基于时间估算进度 + if (subtask.status === 'PENDING') { + // 排队中的任务,进度为0 + return 0; + } + + if (subtask.status === 'RUNNING') { + // 运行中的任务,基于已运行时间估算进度 + const elapses = subtask.elapses || {}; + const runningTime = elapses['RUNNING-'] || 0; + const estimatedTotal = subtask.estimated_running_secs || 0; + + if (estimatedTotal > 0) { + const progress = Math.min((runningTime / estimatedTotal) * 100, 95); // 最多95%,避免显示100%但未完成 + return Math.round(progress); + } + + // 如果没有时间估算,基于状态显示一个基础进度 + return 50; // 运行中但无法估算进度时显示50% + } + + return 0; + }; + + + + const getSubtaskStatusText = (status) => { + const statusMap = { + 'PENDING': t('pending'), + 'RUNNING': t('running'), + 'SUCCEED': t('completed'), + 'FAILED': t('failed'), + 'CANCEL': t('cancelled') + }; + return statusMap[status] || status; + }; + + + const formatEstimatedTime = computed(() => { + return (formattedEstimatedTime) => { + if (subtask.status === 'PENDING') { + const pendingSecs = subtask.estimated_pending_secs; + const queuePosition = subtask.estimated_pending_order; + + if (pendingSecs !== null && pendingSecs !== undefined) { + let info = formatDuration(pendingSecs); + if (queuePosition !== null && queuePosition !== undefined) { + info += ` (${t('position')}: ${queuePosition})`; + } + formattedEstimatedTime.value = info; + } + formattedEstimatedTime.value=t('calculating'); + } + + if (subtask.status === 'RUNNING') { + // 使用extra_info.elapses而不是subtask.elapses + const elapses = subtask.extra_info?.elapses || {}; + const runningTime = elapses['RUNNING-'] || 0; + const estimatedTotal = subtask.estimated_running_secs || 0; + + if (estimatedTotal > 0) { + const remaining = Math.max(0, estimatedTotal - runningTime); + estimatedTime.value = remaining; + formattedEstimatedTime.value = `${t('remaining')} ${formatDuration(remaining)}`; + } + + // 如果没有estimated_running_secs,尝试使用elapses计算 + if (Object.keys(elapses).length > 0) { + const totalElapsed = Object.values(elapses).reduce((sum, time) => sum + (time || 0), 0); + if (totalElapsed > 0) { + formattedEstimatedTime.value = `${t('running')} ${formatDuration(totalElapsed)}`; + } + } + + return t('calculating'); + } + + return t('completed'); + }; +}); + + const formatDuration = (seconds) => { + if (seconds < 60) { + return `${Math.round(seconds)}${t('seconds')}`; + } else if (seconds < 3600) { + const minutes = Math.floor(seconds / 60); + const remainingSeconds = Math.round(seconds % 60); + return `${minutes}${t('minutes')}${remainingSeconds}${t('seconds')}`; + } else { + const hours = Math.floor(seconds / 3600); + const minutes = Math.floor((seconds % 3600) / 60); + const remainingSeconds = Math.round(seconds % 60); + return `${hours}${t('hours')}${minutes}${t('minutes')}${remainingSeconds}${t('seconds')}`; + } + }; + + const getActiveTasks = () => { + return tasks.value.filter(task => + ['CREATED', 'PENDING', 'RUNNING'].includes(task.status) + ); + }; + + // 任务搜索和过滤增强 + const searchTasks = (query) => { + if (!query) return tasks.value; + return tasks.value.filter(task => { + const searchText = [ + task.task_id, + task.task_type, + task.model_cls, + task.params?.prompt || '', + getTaskStatusDisplay(task.status) + ].join(' ').toLowerCase(); + return searchText.includes(query.toLowerCase()); + }); + }; + + const filterTasksByStatus = (status) => { + if (status === 'ALL') return tasks.value; + return tasks.value.filter(task => task.status === status); + }; + + const filterTasksByType = (type) => { + if (!type) return tasks.value; + return tasks.value.filter(task => task.task_type === type); + }; + + // 提示消息样式管理 + const getAlertClass = (type) => { + const classMap = { + 'success': 'animate-slide-down', + 'warning': 'animate-slide-down', + 'danger': 'animate-slide-down', + 'info': 'animate-slide-down' + }; + return classMap[type] || 'animate-slide-down'; + }; + + const getAlertBorderClass = (type) => { + const borderMap = { + 'success': 'border-green-500', + 'warning': 'border-yellow-500', + 'danger': 'border-red-500', + 'info': 'border-blue-500' + }; + return borderMap[type] || 'border-gray-500'; + }; + + const getAlertTextClass = (type) => { + // 统一使用白色文字 + return 'text-white'; + }; + + const getAlertIcon = (type) => { + const iconMap = { + 'success': 'fas fa-check text-white', + 'warning': 'fas fa-exclamation text-white', + 'danger': 'fas fa-times text-white', + 'info': 'fas fa-info text-white' + }; + return iconMap[type] || 'fas fa-info text-white'; + }; + + const getAlertIconBgClass = (type) => { + const bgMap = { + 'success': 'bg-green-500/30', + 'warning': 'bg-yellow-500/30', + 'danger': 'bg-red-500/30', + 'info': 'bg-laser-purple/30' + }; + return bgMap[type] || 'bg-laser-purple/30'; + }; + + // 监听器 - 监听任务类型变化 + watch(() => selectedTaskId.value, () => { + const currentForm = getCurrentForm(); + + // 只有当当前表单没有选择模型时,才自动选择第一个可用的模型 + if (!currentForm.model_cls) { + let availableModels; + + availableModels = models.value.filter(m => m.task === selectedTaskId.value); + + if (availableModels.length > 0) { + const firstModel = availableModels[0]; + currentForm.model_cls = firstModel.model_cls; + currentForm.stage = firstModel.stage; + } + } + + // 注意:这里不需要重置预览,因为我们要保持每个任务的独立性 + // 预览会在 selectTask 函数中根据文件状态恢复 + }); + + watch(() => getCurrentForm().model_cls, () => { + const currentForm = getCurrentForm(); + + // 只有当当前表单没有选择阶段时,才自动选择第一个可用的阶段 + if (!currentForm.stage) { + let availableStages; + + availableStages = models.value + .filter(m => m.task === selectedTaskId.value && m.model_cls === currentForm.model_cls) + .map(m => m.stage); + + if (availableStages.length > 0) { + currentForm.stage = availableStages[0]; + } + } + }); + + // 提示词模板管理 + const promptTemplates = { + 's2v': [ + { + id: 's2v_1', + title: '情绪表达', + prompt: '根据音频,人物进行情绪化表达,表情丰富,能体现音频中的情绪,手势根据情绪适当调整。' + }, + { + id: 's2v_2', + title: '故事讲述', + prompt: '根据音频,人物进行故事讲述,表情丰富,能体现音频中的情绪,手势根据故事情节适当调整。' + }, + { + id: 's2v_3', + title: '知识讲解', + prompt: '根据音频,人物进行知识讲解,表情严肃,整体风格专业得体,手势根据知识内容适当调整。' + }, + { + id: 's2v_4', + title: '浮夸表演', + prompt: '根据音频,人物进行浮夸表演,表情夸张,动作浮夸,整体风格夸张搞笑。' + }, + { + id: 's2v_5', + title: '商务演讲', + prompt: '根据音频,人物进行商务演讲,表情严肃,手势得体,整体风格专业商务。' + }, + { + id: 's2v_6', + title: '产品介绍', + prompt: '数字人介绍产品特点,语气亲切热情,表情丰富,动作自然,能体现产品特点。' + } + ], + 't2v': [ + { + id: 't2v_1', + title: '自然风景', + prompt: '一个宁静的山谷,阳光透过云层洒在绿色的草地上,远处有雪山,近处有清澈的溪流,画面温暖自然,充满生机。' + }, + { + id: 't2v_2', + title: '城市夜景', + prompt: '繁华的城市夜景,霓虹灯闪烁,高楼大厦林立,车流如织,天空中有星星点缀,营造出都市的繁华氛围。' + }, + { + id: 't2v_3', + title: '科技未来', + prompt: '未来科技城市,飞行汽车穿梭,全息投影随处可见,建筑具有流线型设计,充满科技感和未来感。' + } + ], + 'i2v': [ + { + id: 'i2v_1', + title: '人物动作', + prompt: '基于参考图片,让角色做出自然的行走动作,保持原有的服装和风格,背景可以适当变化。' + }, + { + id: 'i2v_2', + title: '场景转换', + prompt: '保持参考图片中的人物形象,将背景转换为不同的季节或环境,如从室内到户外,从白天到夜晚。' + } + ] + }; + + const getPromptTemplates = (taskType) => { + return promptTemplates[taskType] || []; + }; + + const selectPromptTemplate = (template) => { + getCurrentForm().prompt = template.prompt; + showPromptModal.value = false; + showAlert(`${t('templateApplied')} ${template.title}`, 'success'); + }; + + // 提示词历史记录管理 - 现在直接从taskHistory中获取 + const promptHistory = ref([]); + + const getPromptHistory = async () => { + try { + // 从taskHistory中获取prompt历史,去重并按时间排序 + const taskHistory = await getLocalTaskHistory(); + const uniquePrompts = []; + const seenPrompts = new Set(); + + // 遍历taskHistory,提取唯一的prompt + for (const task of taskHistory) { + if (task.prompt && task.prompt.trim() && !seenPrompts.has(task.prompt.trim())) { + uniquePrompts.push(task.prompt.trim()); + seenPrompts.add(task.prompt.trim()); + } + } + + const result = uniquePrompts.slice(0, 10); // 只显示最近10条 + promptHistory.value = result; // 更新响应式数据 + return result; + } catch (error) { + console.error(t('getPromptHistoryFailed'), error); + promptHistory.value = []; // 更新响应式数据 + return []; + } + }; + + // addPromptToHistory函数已删除,现在prompt历史直接从taskHistory中获取 + + // 保存完整的任务历史(只保存元数据,不保存文件内容) + const addTaskToHistory = (taskType, formData) => { + console.log('开始保存任务历史:', { taskType, formData }); + + const historyItem = { + id: Date.now(), + timestamp: new Date().toISOString(), + taskType: taskType, + prompt: formData.prompt || '', + // 只保存文件元数据,不保存文件内容 + imageFile: formData.imageFile ? { + name: formData.imageFile.name, + type: formData.imageFile.type, + size: formData.imageFile.size + // 不再保存 data 字段,避免占用大量存储空间 + } : null, + audioFile: formData.audioFile ? { + name: formData.audioFile.name, + type: formData.audioFile.type, + size: formData.audioFile.size + // 不再保存 data 字段,避免占用大量存储空间 + } : null + }; + + console.log('保存任务历史(仅元数据):', historyItem); + saveTaskHistoryItem(historyItem); + }; + + // 保存任务历史项到localStorage + const saveTaskHistoryItem = (historyItem) => { + try { + const existingHistory = JSON.parse(localStorage.getItem('taskHistory') || '[]'); + + // 避免重复添加(基于提示词、任务类型、图片和音频) + const isDuplicate = existingHistory.some(item => { + const samePrompt = item.prompt === historyItem.prompt; + const sameTaskType = item.taskType === historyItem.taskType; + const sameImage = (item.imageFile?.name || '') === (historyItem.imageFile?.name || ''); + const sameAudio = (item.audioFile?.name || '') === (historyItem.audioFile?.name || ''); + + return samePrompt && sameTaskType && sameImage && sameAudio; + }); + + if (!isDuplicate) { + // 按时间戳排序,确保最新的记录在最后 + existingHistory.push(historyItem); + existingHistory.sort((a, b) => new Date(a.timestamp) - new Date(b.timestamp)); + + // 限制历史记录数量为10条(不再保存文件内容,所以可以适当减少) + if (existingHistory.length > 10) { + existingHistory.splice(0, existingHistory.length - 10); + } + + // 保存到localStorage + try { + localStorage.setItem('taskHistory', JSON.stringify(existingHistory)); + console.log('任务历史已保存(仅元数据):', historyItem); + } catch (storageError) { + if (storageError.name === 'QuotaExceededError') { + console.warn('localStorage空间不足,尝试清理旧数据...'); + + // 清理策略1:只保留最新的5条记录 + const cleanedHistory = existingHistory.slice(-5); + + try { + localStorage.setItem('taskHistory', JSON.stringify(cleanedHistory)); + console.log('任务历史已保存(清理后):', historyItem); + } catch (secondError) { + console.error('清理后仍无法保存,尝试清理所有缓存...'); + + // 清理策略2:清理所有任务历史,只保存当前这一条 + try { + localStorage.setItem('taskHistory', JSON.stringify([historyItem])); + console.log('任务历史已保存(完全清理后)'); + showAlert(t('historyCleared'), 'info'); + } catch (thirdError) { + console.error('即使完全清理后仍无法保存:', thirdError); + // 不再显示警告,因为历史记录不是必需的功能 + console.warn('历史记录功能暂时不可用,将从任务列表恢复数据'); + } + } + } else { + throw storageError; + } + } + } else { + console.log('任务历史重复,跳过保存:', historyItem); + } + } catch (error) { + console.error('保存任务历史失败:', error); + // 不再显示警告给用户,因为可以从任务列表恢复数据 + console.warn('历史记录保存失败,将依赖任务列表数据'); + } + }; + + // 获取本地存储的任务历史 + const getLocalTaskHistory = async () => { + try { + // 使用Promise模拟异步操作,避免阻塞UI + return await new Promise((resolve) => { + setTimeout(() => { + try { + const history = JSON.parse(localStorage.getItem('taskHistory') || '[]'); + // 按时间戳排序,最新的记录在前 + const sortedHistory = history.sort((a, b) => new Date(b.timestamp) - new Date(a.timestamp)); + resolve(sortedHistory); + } catch (error) { + console.error(t('parseTaskHistoryFailed'), error); + resolve([]); + } + }, 0); + }); + } catch (error) { + console.error(t('getTaskHistoryFailed'), error); + return []; + } + }; + + const selectPromptHistory = (prompt) => { + getCurrentForm().prompt = prompt; + showPromptModal.value = false; + showAlert(t('promptHistoryApplied'), 'success'); + }; + + const clearPromptHistory = () => { + // 清空taskHistory中的prompt相关数据 + localStorage.removeItem('taskHistory'); + showAlert(t('promptHistoryCleared'), 'info'); + }; + + // 图片历史记录管理 - 从任务列表获取 + const getImageHistory = async () => { + try { + // 确保任务列表已加载 + if (tasks.value.length === 0) { + await refreshTasks(); + } + + const uniqueImages = []; + const seenImages = new Set(); + + // 遍历任务列表,提取唯一的图片 + for (const task of tasks.value) { + if (task.inputs && task.inputs.input_image && !seenImages.has(task.inputs.input_image)) { + // 获取图片URL + const imageUrl = await getTaskFileUrl(task.task_id, 'input_image'); + if (imageUrl) { + uniqueImages.push({ + filename: task.inputs.input_image, + url: imageUrl, + thumbnail: imageUrl, // 使用URL作为缩略图 + taskId: task.task_id, + timestamp: task.create_t, + taskType: task.task_type + }); + seenImages.add(task.inputs.input_image); + } + } + } + + // 按时间戳排序,最新的在前 + uniqueImages.sort((a, b) => new Date(b.timestamp) - new Date(a.timestamp)); + + imageHistory.value = uniqueImages; + console.log('从任务列表获取图片历史:', uniqueImages.length, '条'); + return uniqueImages; + } catch (error) { + console.error('获取图片历史失败:', error); + imageHistory.value = []; + return []; + } + }; + + // 音频历史记录管理 - 从任务列表获取 + const getAudioHistory = async () => { + try { + // 确保任务列表已加载 + if (tasks.value.length === 0) { + await refreshTasks(); + } + + const uniqueAudios = []; + const seenAudios = new Set(); + + // 遍历任务列表,提取唯一的音频 + for (const task of tasks.value) { + if (task.inputs && task.inputs.input_audio && !seenAudios.has(task.inputs.input_audio)) { + // 获取音频URL + let audioUrl = await getTaskFileUrl(task.task_id, 'input_audio'); + + // 如果返回null,可能是目录类型(多人模式),尝试获取original_audio.wav + if (!audioUrl) { + audioUrl = await getTaskFileUrlFromApi(task.task_id, 'input_audio', 'original_audio.wav'); + } + + const imageUrl = task.inputs.input_image ? await getTaskFileUrl(task.task_id, 'input_image') : null; + if (audioUrl) { + uniqueAudios.push({ + filename: task.inputs.input_audio, + url: audioUrl, + taskId: task.task_id, + timestamp: task.create_t, + taskType: task.task_type, + imageUrl + }); + seenAudios.add(task.inputs.input_audio); + } + } + } + + // 按时间戳排序,最新的在前 + uniqueAudios.sort((a, b) => new Date(b.timestamp) - new Date(a.timestamp)); + + audioHistory.value = uniqueAudios; + console.log('从任务列表获取音频历史:', uniqueAudios.length, '条'); + return uniqueAudios; + } catch (error) { + console.error('获取音频历史失败:', error); + audioHistory.value = []; + return []; + } + }; + + // 选择图片历史记录 - 从URL获取 + const selectImageHistory = async (history) => { + try { + // 确保 URL 有效,如果无效则重新获取 + let imageUrl = history.url; + if (!imageUrl || imageUrl.trim() === '') { + // 如果 URL 为空,尝试重新获取 + if (history.taskId) { + imageUrl = await getTaskFileUrl(history.taskId, 'input_image'); + } + if (!imageUrl || imageUrl.trim() === '') { + throw new Error('图片 URL 无效'); + } + } + + // 从URL获取图片文件 + const response = await fetch(imageUrl); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const blob = await response.blob(); + const file = new File([blob], history.filename, { type: blob.type }); + + // 设置图片预览 + setCurrentImagePreview(imageUrl); + updateUploadedContentStatus(); + + // 更新表单 + const currentForm = getCurrentForm(); + currentForm.imageFile = file; + + // Reset detected faces + if (selectedTaskId.value === 'i2v') { + i2vForm.value.detectedFaces = []; + } else if (selectedTaskId.value === 's2v') { + s2vForm.value.detectedFaces = []; + } + + showImageTemplates.value = false; + showAlert(t('historyImageApplied'), 'success'); + + // Auto detect faces after image is loaded + // 不再自动检测人脸,等待用户手动打开多角色模式开关 + try { + // 如果 URL 是 http/https,直接使用;否则转换为 data URL + if (!imageUrl.startsWith('http://') && !imageUrl.startsWith('https://')) { + // 如果不是 http/https URL,转换为 data URL + const reader = new FileReader(); + reader.onload = async (e) => { + // 不再自动检测人脸 + }; + reader.readAsDataURL(file); + } + } catch (error) { + console.error('Face detection failed:', error); + // Don't show error alert, just log it + } + + } catch (error) { + console.error('应用历史图片失败:', error); + showAlert(t('applyHistoryImageFailed') + ': ' + error.message, 'danger'); + } + }; + + // 选择音频历史记录 - 从URL获取 + const selectAudioHistory = async (history) => { + try { + // 从URL获取音频文件 + const response = await fetch(history.url); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const blob = await response.blob(); + const file = new File([blob], history.filename, { type: blob.type }); + + // 设置音频预览 + setCurrentAudioPreview(history.url); + updateUploadedContentStatus(); + + // 更新表单 + const currentForm = getCurrentForm(); + currentForm.audioFile = file; + + showAudioTemplates.value = false; + showAlert(t('historyAudioApplied'), 'success'); + } catch (error) { + console.error('应用历史音频失败:', error); + showAlert(t('applyHistoryAudioFailed'), 'danger'); + } + }; + + // 全局音频播放状态管理 + let currentPlayingAudio = null; + let audioStopCallback = null; + + // 停止音频播放 + const stopAudioPlayback = () => { + if (currentPlayingAudio) { + currentPlayingAudio.pause(); + currentPlayingAudio.currentTime = 0; + currentPlayingAudio = null; + + // 调用停止回调 + if (audioStopCallback) { + audioStopCallback(); + audioStopCallback = null; + } + } + }; + + // 设置音频停止回调 + const setAudioStopCallback = (callback) => { + audioStopCallback = callback; + }; + + // 预览音频历史记录 - 使用URL + const previewAudioHistory = (history) => { + console.log('预览音频历史:', history); + const audioUrl = history.url; + console.log('音频历史URL:', audioUrl); + if (!audioUrl) { + showAlert(t('audioHistoryUrlFailed'), 'danger'); + return; + } + + // 停止当前播放的音频 + if (currentPlayingAudio) { + currentPlayingAudio.pause(); + currentPlayingAudio.currentTime = 0; + currentPlayingAudio = null; + } + + const audio = new Audio(audioUrl); + currentPlayingAudio = audio; + + // 监听音频播放结束事件 + audio.addEventListener('ended', () => { + currentPlayingAudio = null; + // 调用停止回调 + if (audioStopCallback) { + audioStopCallback(); + audioStopCallback = null; + } + }); + + audio.addEventListener('error', () => { + console.error('音频播放失败:', audio.error); + showAlert(t('audioPlaybackFailed'), 'danger'); + currentPlayingAudio = null; + // 调用停止回调 + if (audioStopCallback) { + audioStopCallback(); + audioStopCallback = null; + } + }); + + audio.play().catch(error => { + console.error('音频播放失败:', error); + showAlert(t('audioPlaybackFailed'), 'danger'); + currentPlayingAudio = null; + }); + }; + + // 清空图片历史记录 + const clearImageHistory = () => { + imageHistory.value = []; + showAlert(t('imageHistoryCleared'), 'info'); + }; + + // 清空音频历史记录 + const clearAudioHistory = () => { + audioHistory.value = []; + showAlert(t('audioHistoryCleared'), 'info'); + }; + + // 清理localStorage存储空间 + const clearLocalStorage = () => { + try { + // 清理任务历史 + localStorage.removeItem('taskHistory'); + localStorage.removeItem('refreshToken'); + + // 清理其他可能的缓存数据 + const keysToRemove = []; + for (let i = 0; i < localStorage.length; i++) { + const key = localStorage.key(i); + if (key && (key.includes('template') || key.includes('task') || key.includes('history'))) { + keysToRemove.push(key); + } + } + + keysToRemove.forEach(key => { + localStorage.removeItem(key); + }); + + // 重置相关状态 + imageHistory.value = []; + audioHistory.value = []; + promptHistory.value = []; + + showAlert(t('storageCleared'), 'success'); + console.log('localStorage已清理,释放了存储空间'); + } catch (error) { + console.error('清理localStorage失败:', error); + showAlert(t('clearStorageFailed'), 'danger'); + } + }; + + const getAuthHeaders = () => { + const headers = { + 'Content-Type': 'application/json' + }; + + const token = localStorage.getItem('accessToken'); + if (token) { + headers['Authorization'] = `Bearer ${token}`; + console.log('使用Token进行认证:', token.substring(0, 20) + '...'); + } else { + console.warn('没有找到accessToken'); + } + return headers; + }; + + // 验证token是否有效 + const validateToken = async (token) => { + try { + const response = await fetch('/api/v1/model/list', { + method: 'GET', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json' + } + }); + await new Promise(resolve => setTimeout(resolve, 100)); + return response.ok; + } catch (error) { + console.error('Token validation failed:', error); + return false; + } + }; + + const refreshAccessToken = async () => { + if (refreshPromise) { + return refreshPromise; + } + const refreshToken = localStorage.getItem('refreshToken'); + if (!refreshToken) { + return false; + } + + refreshPromise = (async () => { + try { + const response = await fetch('/auth/refresh', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ refresh_token: refreshToken }) + }); + await new Promise(resolve => setTimeout(resolve, 100)); + if (!response.ok) { + throw new Error(`Refresh failed with status ${response.status}`); + } + + const data = await response.json(); + if (data.access_token) { + localStorage.setItem('accessToken', data.access_token); + } + if (data.refresh_token) { + localStorage.setItem('refreshToken', data.refresh_token); + } + if (data.user_info) { + currentUser.value = data.user_info; + localStorage.setItem('currentUser', JSON.stringify(data.user_info)); + } + return true; + } catch (error) { + console.error('Refresh token failed:', error); + logout(false); + showAlert(t('loginExpiredPleaseRelogin'), 'warning', { + label: t('login'), + onClick: login + }); + return false; + } finally { + refreshPromise = null; + } + })(); + + return refreshPromise; + }; + + // 增强的API请求函数,自动处理认证错误 + const apiRequest = async (url, options = {}, allowRetry = true) => { + const headers = getAuthHeaders(); + + try { + const response = await fetch(url, { + ...options, + headers: { + ...headers, + ...options.headers + } + }); + await new Promise(resolve => setTimeout(resolve, 100)); + // 检查是否是认证错误 + if ((response.status === 401 || response.status === 403) && allowRetry) { + const refreshed = await refreshAccessToken(); + if (refreshed) { + return await apiRequest(url, options, false); + } + return null; + } + + return response; + } catch (error) { + console.error('API request failed:', error); + showAlert(t('networkRequestFailed'), 'danger'); + return null; + } + }; + + // 侧边栏拖拽调整功能 + const sidebar = ref(null); + const sidebarWidth = ref(256); // 默认宽度 256px (w-64) + let isResizing = false; + let startX = 0; + let startWidth = 0; + + // 更新悬浮按钮位置 + const updateFloatingButtonPosition = (width) => { + const floatingBtn = document.querySelector('.floating-toggle-btn'); + if (floatingBtn) { + if (sidebarCollapsed.value) { + // 收起状态时,按钮位于屏幕左侧 + floatingBtn.style.left = '0px'; + floatingBtn.style.right = 'auto'; + } else { + // 展开状态时,按钮位于历史任务栏右侧 + floatingBtn.style.left = width + 'px'; + floatingBtn.style.right = 'auto'; + } + } + }; + + const startResize = (e) => { + e.preventDefault(); + console.log('startResize called'); + + isResizing = true; + startX = e.clientX; + startWidth = sidebar.value.offsetWidth; + console.log('Resize started, width:', startWidth); + + document.body.classList.add('resizing'); + document.addEventListener('mousemove', handleResize); + document.addEventListener('mouseup', stopResize); + }; + + const handleResize = (e) => { + if (!isResizing) return; + + const deltaX = e.clientX - startX; + const newWidth = startWidth + deltaX; + const minWidth = 200; + const maxWidth = 500; + + if (newWidth >= minWidth && newWidth <= maxWidth) { + // 立即更新悬浮按钮位置,不等待其他更新 + const floatingBtn = document.querySelector('.floating-toggle-btn'); + if (floatingBtn && !sidebarCollapsed.value) { + floatingBtn.style.left = newWidth + 'px'; + } + + sidebarWidth.value = newWidth; // 更新响应式变量 + sidebar.value.style.setProperty('width', newWidth + 'px', 'important'); + + // 同时调整主内容区域宽度 + const mainContent = document.querySelector('.main-container main'); + if (mainContent) { + mainContent.style.setProperty('width', `calc(100% - ${newWidth}px)`, 'important'); + } else { + const altMain = document.querySelector('main'); + if (altMain) { + altMain.style.setProperty('width', `calc(100% - ${newWidth}px)`, 'important'); + } + } + } else { + console.log('Width out of range:', newWidth); + } + }; + + const stopResize = () => { + isResizing = false; + document.body.classList.remove('resizing'); + document.removeEventListener('mousemove', handleResize); + document.removeEventListener('mouseup', stopResize); + + // 保存当前宽度到localStorage + if (sidebar.value) { + localStorage.setItem('sidebarWidth', sidebar.value.offsetWidth); + } + }; + + // 应用响应式侧边栏宽度 + const applyResponsiveWidth = () => { + if (!sidebar.value) return; + + const windowWidth = window.innerWidth; + let sidebarWidthPx; + + if (windowWidth <= 768) { + sidebarWidthPx = 200; + } else if (windowWidth <= 1200) { + sidebarWidthPx = 250; + } else { + // 大屏幕时使用保存的宽度或默认宽度 + const savedWidth = localStorage.getItem('sidebarWidth'); + if (savedWidth) { + const width = parseInt(savedWidth); + if (width >= 200 && width <= 500) { + sidebarWidthPx = width; + } else { + sidebarWidthPx = 256; // 默认 w-64 + } + } else { + sidebarWidthPx = 256; // 默认 w-64 + } + } + + sidebarWidth.value = sidebarWidthPx; // 更新响应式变量 + sidebar.value.style.width = sidebarWidthPx + 'px'; + + // 更新悬浮按钮位置 + updateFloatingButtonPosition(sidebarWidthPx); + + const mainContent = document.querySelector('main'); + if (mainContent) { + mainContent.style.width = `calc(100% - ${sidebarWidthPx}px)`; + } + }; + + // 新增:视图切换方法 + const switchToCreateView = () => { + // 生成页面的查询参数 + const generateQuery = {}; + + // 保留任务类型选择 + if (selectedTaskId.value) { + generateQuery.taskType = selectedTaskId.value; + } + + // 保留模型选择 + if (selectedModel.value) { + generateQuery.model = selectedModel.value; + } + + // 保留创作区域展开状态 + if (isCreationAreaExpanded.value) { + generateQuery.expanded = 'true'; + } + + router.push({ path: '/generate', query: generateQuery }); + + // 如果之前有展开过创作区域,保持展开状态 + if (isCreationAreaExpanded.value) { + // 延迟一点时间确保DOM更新完成 + setTimeout(() => { + const creationArea = document.querySelector('.creation-area'); + if (creationArea) { + creationArea.classList.add('show'); + } + }, 50); + } + }; + + const switchToProjectsView = (forceRefresh = false) => { + // 项目页面的查询参数 + const projectsQuery = {}; + + // 保留搜索查询 + if (taskSearchQuery.value) { + projectsQuery.search = taskSearchQuery.value; + } + + // 保留状态筛选 + if (statusFilter.value) { + projectsQuery.status = statusFilter.value; + } + + // 保留当前页码 + if (currentTaskPage.value > 1) { + projectsQuery.page = currentTaskPage.value.toString(); + } + + router.push({ path: '/projects', query: projectsQuery }); + // 刷新任务列表 + refreshTasks(forceRefresh); + }; + + const switchToInspirationView = () => { + // 灵感页面的查询参数 + const inspirationQuery = {}; + + // 保留搜索查询 + if (inspirationSearchQuery.value) { + inspirationQuery.search = inspirationSearchQuery.value; + } + + // 保留分类筛选 + if (selectedInspirationCategory.value) { + inspirationQuery.category = selectedInspirationCategory.value; + } + + // 保留当前页码 + if (inspirationCurrentPage.value > 1) { + inspirationQuery.page = inspirationCurrentPage.value.toString(); + } + + router.push({ path: '/inspirations', query: inspirationQuery }); + // 加载灵感数据 + loadInspirationData(); + }; + + const switchToLoginView = () => { + router.push('/login'); + + }; + + // 日期格式化函数 + const formatDate = (date) => { + if (!date) return ''; + const d = new Date(date); + return d.toLocaleDateString('zh-CN', { + year: 'numeric', + month: '2-digit', + day: '2-digit' + }); + }; + + // 灵感广场相关方法 + const loadInspirationData = async (forceRefresh = false) => { + try { + // 如果不是强制刷新,先尝试从缓存加载 + // 构建缓存键,包含分页和过滤条件 + const cacheKey = `${TEMPLATES_CACHE_KEY}_${inspirationCurrentPage.value}_${inspirationPageSize.value}_${selectedInspirationCategory.value}_${inspirationSearchQuery.value}`; + + if (!forceRefresh) { + const cachedData = loadFromCache(cacheKey, TEMPLATES_CACHE_EXPIRY); + if (cachedData && cachedData.templates) { + console.log(`成功从缓存加载灵感模板数据${cacheKey}:`, cachedData.templates); + inspirationItems.value = cachedData.templates; + InspirationCategories.value = cachedData.all_categories; + // 如果有分页信息也加载 + if (cachedData.pagination) { + inspirationPagination.value = cachedData.pagination; + } + preloadTemplateFilesUrl(inspirationItems.value); + return; + } + } + + // 缓存中没有或强制刷新,从API加载 + const params = new URLSearchParams(); + if (selectedInspirationCategory.value) { + params.append('category', selectedInspirationCategory.value); + } + if (inspirationSearchQuery.value) { + params.append('search', inspirationSearchQuery.value); + } + if (inspirationCurrentPage.value) { + params.append('page', inspirationCurrentPage.value.toString()); + } + if (inspirationPageSize.value) { + params.append('page_size', inspirationPageSize.value.toString()); + } + + const apiUrl = `/api/v1/template/tasks${params.toString() ? '?' + params.toString() : ''}`; + const response = await publicApiCall(apiUrl); + if (response.ok) { + const data = await response.json(); + inspirationItems.value = data.templates || []; + InspirationCategories.value = data.categories || []; + inspirationPagination.value = data.pagination || null; + + // 缓存模板数据 + saveToCache(cacheKey, { + templates: inspirationItems.value, + pagination: inspirationPagination.value, + all_categories: InspirationCategories.value, + category: selectedInspirationCategory.value, + search: inspirationSearchQuery.value, + page: inspirationCurrentPage.value, + page_size: inspirationPageSize.value, + }); + + console.log('缓存灵感模板数据成功:', inspirationItems.value.length, '个模板'); + // 强制触发响应式更新 + await nextTick(); + + // 强制刷新分页组件 + inspirationPaginationKey.value++; + + // 使用新的模板文件预加载逻辑 + preloadTemplateFilesUrl(inspirationItems.value); + } else { + console.warn('加载模板数据失败'); + } + } catch (error) { + console.warn('加载模板数据失败:', error); + } + }; + + + // 选择分类 + const selectInspirationCategory = async (category) => { + isPageLoading.value = true; + // 如果点击的是当前分类,不重复请求 + if (selectedInspirationCategory.value === category) { + isPageLoading.value = false; + return; + } + + // 更新分类 + selectedInspirationCategory.value = category; + + // 重置页码为1 + inspirationCurrentPage.value = 1; + inspirationPageInput.value = 1; + + // 清空当前数据,显示加载状态 + inspirationItems.value = []; + inspirationPagination.value = null; + + // 重新加载数据 + await loadInspirationData(); // 强制刷新,不使用缓存 + isPageLoading.value = false; + }; + + // 搜索防抖定时器 + let searchTimeout = null; + + // 处理搜索 + const handleInspirationSearch = async () => { + isLoading.value = true; + // 清除之前的定时器 + if (searchTimeout) { + clearTimeout(searchTimeout); + } + + // 设置防抖延迟 + searchTimeout = setTimeout(async () => { + // 重置页码为1 + inspirationCurrentPage.value = 1; + inspirationPageInput.value = 1; + + // 清空当前数据,显示加载状态 + inspirationItems.value = []; + inspirationPagination.value = null; + + // 重新加载数据 + await loadInspirationData(); // 强制刷新,不使用缓存 + isPageLoading.value = false; + }, 500); // 500ms 防抖延迟 + }; + + // 全局视频播放管理 + let currentPlayingVideo = null; + let currentLoadingVideo = null; // 跟踪正在等待加载的视频 + + // 更新视频播放按钮图标 + const updateVideoIcon = (video, isPlaying) => { + // 查找视频容器中的播放按钮 + const container = video.closest('.relative'); + if (!container) return; + + // 查找移动端播放按钮 + const playButton = container.querySelector('button[class*="absolute"][class*="bottom-3"]'); + if (playButton) { + const icon = playButton.querySelector('i'); + if (icon) { + icon.className = isPlaying ? 'fas fa-pause text-sm' : 'fas fa-play text-sm'; + } + } + }; + + // 处理视频播放结束 + const onVideoEnded = (event) => { + const video = event.target; + console.log('视频播放完毕:', video.src); + + // 重置视频到开始位置 + video.currentTime = 0; + + // 更新播放按钮图标为播放状态 + updateVideoIcon(video, false); + + // 如果播放完毕的是当前播放的视频,清除引用 + if (currentPlayingVideo === video) { + currentPlayingVideo = null; + console.log('当前播放视频播放完毕'); + } + }; + + // 视频播放控制 + const playVideo = (event) => { + const video = event.target; + + // 检查视频是否已加载完成 + if (video.readyState < 2) { // HAVE_CURRENT_DATA + console.log('视频还没加载完成,忽略鼠标悬停播放'); + return; + } + + // 如果当前有视频在播放,先暂停它 + if (currentPlayingVideo && currentPlayingVideo !== video) { + currentPlayingVideo.pause(); + currentPlayingVideo.currentTime = 0; + // 更新上一个视频的图标 + updateVideoIcon(currentPlayingVideo, false); + console.log('暂停上一个视频'); + } + + // 视频已加载完成,可以播放 + video.currentTime = 0; // 从头开始播放 + video.play().then(() => { + // 播放成功,更新当前播放视频 + currentPlayingVideo = video; + console.log('开始播放新视频'); + }).catch(e => { + console.log('视频播放失败:', e); + currentPlayingVideo = null; + video.pause(); + video.currentTime = 0; + }); + }; + + const pauseVideo = (event) => { + const video = event.target; + + // 检查视频是否已加载完成 + if (video.readyState < 2) { // HAVE_CURRENT_DATA + console.log('视频还没加载完成,忽略鼠标离开暂停'); + return; + } + + video.pause(); + video.currentTime = 0; + + // 更新视频图标 + updateVideoIcon(video, false); + + // 如果暂停的是当前播放的视频,清除引用 + if (currentPlayingVideo === video) { + currentPlayingVideo = null; + console.log('暂停当前播放视频'); + } + }; + + // 移动端视频播放切换 + const toggleVideoPlay = (event) => { + const button = event.target.closest('button'); + if (!button) { + console.error('toggleVideoPlay: 未找到按钮元素'); + return; + } + + const video = button.parentElement.querySelector('video'); + if (!video) { + console.error('toggleVideoPlay: 未找到视频元素'); + return; + } + + const icon = button.querySelector('i'); + + if (video.paused) { + // 如果当前有视频在播放,先暂停它 + if (currentPlayingVideo && currentPlayingVideo !== video) { + currentPlayingVideo.pause(); + currentPlayingVideo.currentTime = 0; + // 更新上一个视频的图标 + updateVideoIcon(currentPlayingVideo, false); + console.log('暂停上一个视频(移动端)'); + } + + // 如果当前有视频在等待加载,取消它的等待状态 + if (currentLoadingVideo && currentLoadingVideo !== video) { + currentLoadingVideo = null; + console.log('取消上一个视频的加载等待(移动端)'); + } + + // 检查视频是否已加载完成 + if (video.readyState >= 2) { // HAVE_CURRENT_DATA + // 视频已加载完成,直接播放 + video.currentTime = 0; + video.play().then(() => { + icon.className = 'fas fa-pause text-sm'; + currentPlayingVideo = video; + console.log('开始播放新视频(移动端)'); + }).catch(e => { + console.log('视频播放失败:', e); + icon.className = 'fas fa-play text-sm'; + currentPlayingVideo = null; + }); + } else { + // 视频未加载完成,显示loading并等待 + console.log('视频还没加载完成,等待加载(移动端), readyState:', video.readyState); + icon.className = 'fas fa-spinner fa-spin text-sm'; + currentLoadingVideo = video; + + // 主动触发视频加载 + video.load(); + + // 设置超时保护(10秒后如果还未加载完成,重置状态) + const loadingTimeout = setTimeout(() => { + if (currentLoadingVideo === video) { + console.warn('视频加载超时(移动端)'); + icon.className = 'fas fa-play text-sm'; + currentLoadingVideo = null; + showAlert(t('videoLoadTimeout'), 'warning'); + } + }, 10000); + + // 等待视频可以播放 + const playHandler = () => { + clearTimeout(loadingTimeout); + + // 检查这个视频是否仍然是当前等待加载的视频 + if (currentLoadingVideo === video) { + currentLoadingVideo = null; + video.currentTime = 0; + video.play().then(() => { + icon.className = 'fas fa-pause text-sm'; + currentPlayingVideo = video; + console.log('开始播放新视频(移动端-延迟加载)'); + }).catch(e => { + console.log('视频播放失败:', e); + icon.className = 'fas fa-play text-sm'; + currentPlayingVideo = null; + }); + } else { + // 这个视频的加载等待已被取消,重置图标 + icon.className = 'fas fa-play text-sm'; + console.log('视频加载完成但等待已被取消(移动端)'); + } + + // 移除事件监听器 + video.removeEventListener('canplay', playHandler); + video.removeEventListener('error', errorHandler); + }; + + const errorHandler = () => { + clearTimeout(loadingTimeout); + console.error('视频加载失败(移动端)'); + icon.className = 'fas fa-play text-sm'; + currentLoadingVideo = null; + + // 移除事件监听器 + video.removeEventListener('canplay', playHandler); + video.removeEventListener('error', errorHandler); + }; + + // 使用 canplay 事件,比 loadeddata 更适合移动端 + video.addEventListener('canplay', playHandler, { once: true }); + video.addEventListener('error', errorHandler, { once: true }); + } + } else { + video.pause(); + video.currentTime = 0; + icon.className = 'fas fa-play text-sm'; + + // 如果暂停的是当前播放的视频,清除引用 + if (currentPlayingVideo === video) { + currentPlayingVideo = null; + console.log('暂停当前播放视频(移动端)'); + } + + // 如果暂停的是当前等待加载的视频,清除引用 + if (currentLoadingVideo === video) { + currentLoadingVideo = null; + console.log('取消当前等待加载的视频(移动端)'); + } + } + }; + + // 暂停所有视频 + const pauseAllVideos = () => { + if (currentPlayingVideo) { + currentPlayingVideo.pause(); + currentPlayingVideo.currentTime = 0; + // 更新视频图标 + updateVideoIcon(currentPlayingVideo, false); + currentPlayingVideo = null; + console.log('暂停所有视频'); + } + + // 清理等待加载的视频状态 + if (currentLoadingVideo) { + // 重置等待加载的视频图标 + const loadingContainer = currentLoadingVideo.closest('.relative'); + if (loadingContainer) { + const loadingButton = loadingContainer.querySelector('button[class*="absolute"][class*="bottom-3"]'); + if (loadingButton) { + const loadingIcon = loadingButton.querySelector('i'); + if (loadingIcon) { + loadingIcon.className = 'fas fa-play text-sm'; + } + } + } + currentLoadingVideo = null; + console.log('取消所有等待加载的视频'); + } + }; + + const onVideoLoaded = (event) => { + const video = event.target; + // 视频加载完成,准备播放 + console.log('视频加载完成:', video.src); + + // 更新视频加载状态(使用视频的实际src) + setVideoLoaded(video.src, true); + + // 触发Vue的响应式更新 + videoLoadedStates.value = new Map(videoLoadedStates.value); + }; + + const onVideoError = (event) => { + const video = event.target; + console.error('视频加载失败:', video.src, event); + const img = event.target; + const parent = img.parentElement; + parent.innerHTML = '
'; + // 回退到图片 + }; + + // 预览模板详情 + const previewTemplateDetail = (item, updateRoute = true) => { + selectedTemplate.value = item; + showTemplateDetailModal.value = true; + + // 只在需要时更新路由到模板详情页面 + if (updateRoute && item?.task_id) { + router.push(`/template/${item.task_id}`); + } + }; + + // 关闭模板详情弹窗 + const closeTemplateDetailModal = () => { + showTemplateDetailModal.value = false; + selectedTemplate.value = null; + // 移除自动路由跳转,让调用方决定路由行为 + }; + + // 显示图片放大 + const showImageZoom = (imageUrl) => { + zoomedImageUrl.value = imageUrl; + showImageZoomModal.value = true; + }; + + // 关闭图片放大弹窗 + const closeImageZoomModal = () => { + showImageZoomModal.value = false; + zoomedImageUrl.value = ''; + }; + + // 通过后端API代理获取文件(避免CORS问题) + const fetchFileThroughProxy = async (fileKey, fileType) => { + try { + // 尝试通过后端API代理获取文件 + const proxyUrl = `/api/v1/template/asset/${fileType}/${fileKey}`; + const response = await apiRequest(proxyUrl); + + if (response && response.ok) { + return await response.blob(); + } + + // 如果代理API不存在,尝试直接获取URL然后fetch + const fileUrl = await getTemplateFileUrlAsync(fileKey, fileType); + if (!fileUrl) { + return null; + } + + // 检查是否是同源URL + const urlObj = new URL(fileUrl, window.location.origin); + const isSameOrigin = urlObj.origin === window.location.origin; + + if (isSameOrigin) { + // 同源,直接fetch + const directResponse = await fetch(fileUrl); + if (directResponse.ok) { + return await directResponse.blob(); + } + } else { + // 跨域,尝试使用no-cors模式(但这样无法读取响应) + // 或者使用img/audio元素加载(不适用于需要File对象的情况) + // 这里我们尝试直接fetch,如果失败会抛出错误 + try { + const directResponse = await fetch(fileUrl, { mode: 'cors' }); + if (directResponse.ok) { + return await directResponse.blob(); + } + } catch (corsError) { + console.warn('CORS错误,尝试使用代理:', corsError); + // 如果后端有代理API,应该使用上面的代理方式 + // 如果没有,这里会返回null,然后调用方会显示错误 + } + } + + return null; + } catch (error) { + console.error('获取文件失败:', error); + return null; + } + }; + + // 应用模板图片 + const applyTemplateImage = async (template) => { + if (!template?.inputs?.input_image) { + showAlert(t('applyImageFailed'), 'danger'); + return; + } + + try { + // 先设置任务类型(如果模板有任务类型) + if (template.task_type && (template.task_type === 'i2v' || template.task_type === 's2v')) { + selectedTaskId.value = template.task_type; + } + + // 检查当前任务类型是否支持图片 + if (selectedTaskId.value !== 'i2v' && selectedTaskId.value !== 's2v') { + showAlert(t('applyImageFailed'), 'danger'); + return; + } + + // 获取图片URL(用于预览) + const imageUrl = await getTemplateFileUrlAsync(template.inputs.input_image, 'images'); + if (!imageUrl) { + console.error('无法获取模板图片URL:', template.inputs.input_image); + showAlert(t('applyImageFailed'), 'danger'); + return; + } + + // 根据任务类型设置图片 + const currentForm = getCurrentForm(); + if (currentForm) { + currentForm.imageUrl = imageUrl; + // Reset detected faces + if (selectedTaskId.value === 'i2v') { + i2vForm.value.detectedFaces = []; + } else if (selectedTaskId.value === 's2v') { + s2vForm.value.detectedFaces = []; + } + } + + // 设置预览 + setCurrentImagePreview(imageUrl); + + // 加载图片文件(与useTemplate相同的逻辑) + try { + // 直接使用获取到的URL fetch(与useTemplate相同) + const imageResponse = await fetch(imageUrl); + if (imageResponse.ok) { + const blob = await imageResponse.blob(); + // 验证返回的是图片而不是HTML + if (blob.type && blob.type.startsWith('text/html')) { + console.error('返回的是HTML而不是图片:', blob.type); + showAlert(t('applyImageFailed'), 'danger'); + return; + } + const filename = template.inputs.input_image || 'template_image.jpg'; + const file = new File([blob], filename, { type: blob.type || 'image/jpeg' }); + if (currentForm) { + currentForm.imageFile = file; + } + console.log('模板图片文件已加载'); + + // 不再自动检测人脸,等待用户手动打开多角色模式开关 + } else { + console.warn('Failed to fetch image from URL:', imageUrl); + showAlert(t('applyImageFailed'), 'danger'); + return; + } + } catch (error) { + console.error('Failed to load template image file:', error); + showAlert(t('applyImageFailed'), 'danger'); + return; + } + updateUploadedContentStatus(); + + // 关闭所有弹窗的辅助函数 + const closeAllModals = () => { + closeTaskDetailModal(); // 使用函数确保状态完全重置 + showVoiceTTSModal.value = false; + closeTemplateDetailModal(); // 使用函数确保状态完全重置 + showImageTemplates.value = false; + showAudioTemplates.value = false; + showPromptModal.value = false; + closeImageZoomModal(); // 使用函数确保状态完全重置 + }; + + // 跳转到创作区域的函数 + const scrollToCreationArea = () => { + // 先关闭所有弹窗 + closeAllModals(); + + // 如果不在生成页面,先切换视图 + if (router.currentRoute.value.path !== '/generate') { + switchToCreateView(); + // 等待路由切换完成后再展开和滚动 + setTimeout(() => { + expandCreationArea(); + setTimeout(() => { + // 滚动到顶部(TopBar 之后的位置,约60px) + const mainScrollable = document.querySelector('.main-scrollbar'); + if (mainScrollable) { + mainScrollable.scrollTo({ + top: 0, + behavior: 'smooth' + }); + } + }, 100); + }, 100); + } else { + // 已经在生成页面,直接展开和滚动 + expandCreationArea(); + setTimeout(() => { + // 滚动到顶部(TopBar 之后的位置,约60px) + const mainScrollable = document.querySelector('.main-scrollbar'); + if (mainScrollable) { + mainScrollable.scrollTo({ + top: 0, + behavior: 'smooth' + }); + } + }, 100); + } + }; + + showAlert(t('imageApplied'), 'success', { + label: t('view'), + onClick: scrollToCreationArea + }); + } catch (error) { + console.error('应用图片失败:', error); + showAlert(t('applyImageFailed'), 'danger'); + } + }; + + // 应用模板音频 + const applyTemplateAudio = async (template) => { + if (!template?.inputs?.input_audio) { + showAlert(t('applyAudioFailed'), 'danger'); + return; + } + + try { + // 先设置任务类型(如果模板有任务类型) + if (template.task_type && template.task_type === 's2v') { + selectedTaskId.value = template.task_type; + } + + // 检查当前任务类型是否支持音频 + if (selectedTaskId.value !== 's2v') { + showAlert(t('applyAudioFailed'), 'danger'); + return; + } + + // 获取音频URL(用于预览) + const audioUrl = await getTemplateFileUrlAsync(template.inputs.input_audio, 'audios'); + if (!audioUrl) { + console.error('无法获取模板音频URL:', template.inputs.input_audio); + showAlert(t('applyAudioFailed'), 'danger'); + return; + } + + // 设置音频文件 + const currentForm = getCurrentForm(); + if (currentForm) { + currentForm.audioUrl = audioUrl; + } + + // 设置预览 + setCurrentAudioPreview(audioUrl); + + // 加载音频文件(与useTemplate相同的逻辑) + try { + // 直接使用获取到的URL fetch(与useTemplate相同) + const audioResponse = await fetch(audioUrl); + if (audioResponse.ok) { + const blob = await audioResponse.blob(); + // 验证返回的是音频而不是HTML + if (blob.type && blob.type.startsWith('text/html')) { + console.error('返回的是HTML而不是音频:', blob.type); + showAlert(t('applyAudioFailed'), 'danger'); + return; + } + const filename = template.inputs.input_audio || 'template_audio.mp3'; + + // 根据文件扩展名确定正确的MIME类型 + let mimeType = blob.type; + if (!mimeType || mimeType === 'application/octet-stream') { + const ext = filename.toLowerCase().split('.').pop(); + const mimeTypes = { + 'mp3': 'audio/mpeg', + 'wav': 'audio/wav', + 'mp4': 'audio/mp4', + 'aac': 'audio/aac', + 'ogg': 'audio/ogg', + 'm4a': 'audio/mp4' + }; + mimeType = mimeTypes[ext] || 'audio/mpeg'; + } + + const file = new File([blob], filename, { type: mimeType }); + if (currentForm) { + currentForm.audioFile = file; + } + console.log('模板音频文件已加载'); + } else { + console.warn('Failed to fetch audio from URL:', audioUrl); + showAlert(t('applyAudioFailed'), 'danger'); + return; + } + } catch (error) { + console.error('Failed to load template audio file:', error); + showAlert(t('applyAudioFailed'), 'danger'); + return; + } + updateUploadedContentStatus(); + + // 关闭所有弹窗的辅助函数 + const closeAllModals = () => { + closeTaskDetailModal(); // 使用函数确保状态完全重置 + showVoiceTTSModal.value = false; + closeTemplateDetailModal(); // 使用函数确保状态完全重置 + showImageTemplates.value = false; + showAudioTemplates.value = false; + showPromptModal.value = false; + closeImageZoomModal(); // 使用函数确保状态完全重置 + }; + + // 跳转到创作区域的函数 + const scrollToCreationArea = () => { + // 先关闭所有弹窗 + closeAllModals(); + + // 如果不在生成页面,先切换视图 + if (router.currentRoute.value.path !== '/generate') { + switchToCreateView(); + // 等待路由切换完成后再展开和滚动 + setTimeout(() => { + expandCreationArea(); + setTimeout(() => { + // 滚动到顶部(TopBar 之后的位置,约60px) + const mainScrollable = document.querySelector('.main-scrollbar'); + if (mainScrollable) { + mainScrollable.scrollTo({ + top: 0, + behavior: 'smooth' + }); + } + }, 100); + }, 100); + } else { + // 已经在生成页面,直接展开和滚动 + expandCreationArea(); + setTimeout(() => { + // 滚动到顶部(TopBar 之后的位置,约60px) + const mainScrollable = document.querySelector('.main-scrollbar'); + if (mainScrollable) { + mainScrollable.scrollTo({ + top: 0, + behavior: 'smooth' + }); + } + }, 100); + } + }; + + showAlert(t('audioApplied'), 'success', { + label: t('view'), + onClick: scrollToCreationArea + }); + } catch (error) { + console.error('应用音频失败:', error); + showAlert(t('applyAudioFailed'), 'danger'); + } + }; + + // 应用模板Prompt + const applyTemplatePrompt = (template) => { + if (template?.params?.prompt) { + const currentForm = getCurrentForm(); + if (currentForm) { + currentForm.prompt = template.params.prompt; + updateUploadedContentStatus(); + showAlert(t('promptApplied'), 'success'); + } + } + }; + + // 复制文本到剪贴板的辅助函数(支持移动端降级) + const copyToClipboard = async (text) => { + // 检查是否支持现代 Clipboard API + if (navigator.clipboard && navigator.clipboard.writeText) { + try { + await navigator.clipboard.writeText(text); + return true; + } catch (error) { + console.warn('Clipboard API 失败,尝试降级方案:', error); + // 降级到传统方法 + } + } + + // 降级方案:使用传统方法(适用于移动端和不支持Clipboard API的浏览器) + try { + const textArea = document.createElement('textarea'); + textArea.value = text; + + // 移动端需要元素可见且可聚焦,所以先设置可见样式 + textArea.style.position = 'fixed'; + textArea.style.left = '0'; + textArea.style.top = '0'; + textArea.style.width = '2em'; + textArea.style.height = '2em'; + textArea.style.padding = '0'; + textArea.style.border = 'none'; + textArea.style.outline = 'none'; + textArea.style.boxShadow = 'none'; + textArea.style.background = 'transparent'; + textArea.style.opacity = '0'; + textArea.style.zIndex = '-1'; + textArea.setAttribute('readonly', ''); + textArea.setAttribute('aria-hidden', 'true'); + textArea.setAttribute('tabindex', '-1'); + + document.body.appendChild(textArea); + + // 聚焦元素(移动端需要) + textArea.focus(); + textArea.select(); + + // 移动端需要 setSelectionRange + if (textArea.setSelectionRange) { + textArea.setSelectionRange(0, text.length); + } + + // 尝试复制 + let successful = false; + try { + successful = document.execCommand('copy'); + } catch (e) { + console.warn('execCommand 执行失败:', e); + } + + // 立即移除元素 + document.body.removeChild(textArea); + + if (successful) { + return true; + } else { + // 如果仍然失败,尝试另一种方法:在视口中心创建可见的输入框 + return await fallbackCopyToClipboard(text); + } + } catch (error) { + console.error('复制失败,尝试备用方案:', error); + // 尝试备用方案 + return await fallbackCopyToClipboard(text); + } + }; + + // 备用复制方案:显示一个可选择的文本区域(Apple风格) + const fallbackCopyToClipboard = async (text) => { + return new Promise((resolve) => { + // 创建遮罩层 + const overlay = document.createElement('div'); + overlay.style.cssText = ` + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.5); + backdrop-filter: blur(8px); + -webkit-backdrop-filter: blur(8px); + z-index: 10000; + display: flex; + align-items: center; + justify-content: center; + padding: 20px; + `; + + // 创建弹窗容器(Apple风格) + const container = document.createElement('div'); + container.style.cssText = ` + background: rgba(255, 255, 255, 0.95); + backdrop-filter: blur(20px) saturate(180%); + -webkit-backdrop-filter: blur(20px) saturate(180%); + border-radius: 20px; + padding: 24px; + max-width: 90%; + width: 100%; + max-width: 500px; + box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3); + `; + + // 深色模式支持 + if (document.documentElement.classList.contains('dark')) { + container.style.background = 'rgba(30, 30, 30, 0.95)'; + } + + const title = document.createElement('div'); + title.textContent = t('copyLink') || '复制链接'; + title.style.cssText = ` + font-size: 18px; + font-weight: 600; + color: #1d1d1f; + margin-bottom: 12px; + text-align: center; + `; + if (document.documentElement.classList.contains('dark')) { + title.style.color = '#f5f5f7'; + } + + const message = document.createElement('div'); + message.textContent = t('pleaseCopyManually') || '请手动选择并复制下面的文本'; + message.style.cssText = ` + color: #86868b; + font-size: 14px; + margin-bottom: 16px; + text-align: center; + `; + if (document.documentElement.classList.contains('dark')) { + message.style.color = '#98989d'; + } + + const input = document.createElement('input'); + input.type = 'text'; + input.value = text; + input.readOnly = true; + input.style.cssText = ` + width: 100%; + padding: 12px 16px; + font-size: 14px; + border: 1px solid rgba(0, 0, 0, 0.1); + border-radius: 12px; + background: rgba(255, 255, 255, 0.8); + color: #1d1d1f; + margin-bottom: 16px; + box-sizing: border-box; + -webkit-appearance: none; + appearance: none; + `; + if (document.documentElement.classList.contains('dark')) { + input.style.border = '1px solid rgba(255, 255, 255, 0.1)'; + input.style.background = 'rgba(44, 44, 46, 0.8)'; + input.style.color = '#f5f5f7'; + } + + const button = document.createElement('button'); + button.textContent = t('close') || '关闭'; + button.style.cssText = ` + width: 100%; + padding: 12px 24px; + background: var(--brand-primary, #007AFF); + color: white; + border: none; + border-radius: 12px; + cursor: pointer; + font-size: 15px; + font-weight: 600; + transition: all 0.2s; + `; + button.onmouseover = () => { + button.style.opacity = '0.9'; + button.style.transform = 'scale(1.02)'; + }; + button.onmouseout = () => { + button.style.opacity = '1'; + button.style.transform = 'scale(1)'; + }; + + container.appendChild(title); + container.appendChild(message); + container.appendChild(input); + container.appendChild(button); + overlay.appendChild(container); + + const close = () => { + document.body.removeChild(overlay); + resolve(false); // 返回false表示需要用户手动复制 + }; + + button.onclick = close; + overlay.onclick = (e) => { + if (e.target === overlay) close(); + }; + + document.body.appendChild(overlay); + + // 选中文本(延迟以确保DOM已渲染) + setTimeout(() => { + input.focus(); + input.select(); + if (input.setSelectionRange) { + input.setSelectionRange(0, text.length); + } + }, 150); + }); + }; + + // 复制Prompt到剪贴板 + const copyPrompt = async (promptText) => { + if (!promptText) return; + + try { + // 使用辅助函数复制,支持移动端 + const success = await copyToClipboard(promptText); + if (success) { + showAlert(t('promptCopied'), 'success'); + } + // 如果返回false,说明已经显示了手动复制的弹窗,不需要额外提示 + } catch (error) { + console.error('复制Prompt失败:', error); + showAlert(t('copyFailed'), 'error'); + } + }; + + // 使用模板 + const useTemplate = async (item) => { + if (!item) { + showAlert(t('templateDataIncomplete'), 'danger'); + return; + } + console.log('使用模板:', item); + + try { + // 开始模板加载 + templateLoading.value = true; + templateLoadingMessage.value = t('prefillLoadingTemplate'); + + // 先设置任务类型 + selectedTaskId.value = item.task_type; + + // 获取当前表单 + const currentForm = getCurrentForm(); + + // 设置表单数据 + currentForm.prompt = item.params?.prompt || ''; + currentForm.negative_prompt = item.params?.negative_prompt || ''; + currentForm.seed = item.params?.seed || 42; + currentForm.model_cls = item.model_cls || ''; + currentForm.stage = item.stage || ''; + + // 立即关闭模板详情并切换到创建视图,后续资源异步加载 + showTemplateDetailModal.value = false; + selectedTemplate.value = null; + isCreationAreaExpanded.value = true; + switchToCreateView(); + + // 创建加载Promise数组 + const loadingPromises = []; + + // 如果有输入图片,先获取正确的URL,然后加载文件 + if (item.inputs && item.inputs.input_image) { + // 异步获取图片URL + const imageLoadPromise = new Promise(async (resolve) => { + try { + // 先获取正确的URL + const imageUrl = await getTemplateFileUrlAsync(item.inputs.input_image, 'images'); + if (!imageUrl) { + console.warn('无法获取模板图片URL:', item.inputs.input_image); + resolve(); + return; + } + + currentForm.imageUrl = imageUrl; + setCurrentImagePreview(imageUrl); // 设置正确的URL作为预览 + console.log('模板输入图片URL:', imageUrl); + + // Reset detected faces + if (selectedTaskId.value === 'i2v') { + i2vForm.value.detectedFaces = []; + } else if (selectedTaskId.value === 's2v') { + s2vForm.value.detectedFaces = []; + } + + // 加载图片文件 + const imageResponse = await fetch(imageUrl); + if (imageResponse.ok) { + const blob = await imageResponse.blob(); + const filename = item.inputs.input_image; + const file = new File([blob], filename, { type: blob.type }); + currentForm.imageFile = file; + console.log('模板图片文件已加载'); + + // 不再自动检测人脸,等待用户手动打开多角色模式开关 + } else { + console.warn('Failed to fetch image from URL:', imageUrl); + } + } catch (error) { + console.warn('Failed to load template image file:', error); + } + resolve(); + }); + loadingPromises.push(imageLoadPromise); + } + + // 如果有输入音频,先获取正确的URL,然后加载文件 + if (item.inputs && item.inputs.input_audio) { + // 异步获取音频URL + const audioLoadPromise = new Promise(async (resolve) => { + try { + // 先获取正确的URL + const audioUrl = await getTemplateFileUrlAsync(item.inputs.input_audio, 'audios'); + if (!audioUrl) { + console.warn('无法获取模板音频URL:', item.inputs.input_audio); + resolve(); + return; + } + + currentForm.audioUrl = audioUrl; + setCurrentAudioPreview(audioUrl); // 设置正确的URL作为预览 + console.log('模板输入音频URL:', audioUrl); + + // 加载音频文件 + const audioResponse = await fetch(audioUrl); + if (audioResponse.ok) { + const blob = await audioResponse.blob(); + const filename = item.inputs.input_audio; + + // 根据文件扩展名确定正确的MIME类型 + let mimeType = blob.type; + if (!mimeType || mimeType === 'application/octet-stream') { + const ext = filename.toLowerCase().split('.').pop(); + const mimeTypes = { + 'mp3': 'audio/mpeg', + 'wav': 'audio/wav', + 'mp4': 'audio/mp4', + 'aac': 'audio/aac', + 'ogg': 'audio/ogg', + 'm4a': 'audio/mp4' + }; + mimeType = mimeTypes[ext] || 'audio/mpeg'; + } + + const file = new File([blob], filename, { type: mimeType }); + currentForm.audioFile = file; + console.log('模板音频文件已加载'); + // 使用FileReader生成data URL,与正常上传保持一致 + const reader = new FileReader(); + reader.onload = (e) => { + setCurrentAudioPreview(e.target.result); + console.log('模板音频预览已设置:', e.target.result.substring(0, 50) + '...'); + }; + reader.readAsDataURL(file); + } else { + console.warn('Failed to fetch audio from URL:', audioUrl); + } + } catch (error) { + console.warn('Failed to load template audio file:', error); + } + resolve(); + }); + loadingPromises.push(audioLoadPromise); + } + + // 等待所有文件加载完成 + if (loadingPromises.length > 0) { + await Promise.all(loadingPromises); + } + + showAlert(`模板加载完成`, 'success'); + } catch (error) { + console.error('应用模板失败:', error); + showAlert(`应用模板失败: ${error.message}`, 'danger'); + } finally { + // 结束模板加载 + templateLoading.value = false; + templateLoadingMessage.value = ''; + } + }; + + // 加载更多灵感 + const loadMoreInspiration = () => { + showAlert(t('loadMoreInspirationComingSoon'), 'info'); + }; + + // 新增:任务详情弹窗方法 + const openTaskDetailModal = (task) => { + console.log('openTaskDetailModal called with task:', task); + modalTask.value = task; + showTaskDetailModal.value = true; + // 只有不在 /generate 页面时才更新路由 + // 在 /generate 页面打开任务详情时,保持在当前页面 + const currentRoute = router.currentRoute.value; + if (task?.task_id && currentRoute.path !== '/generate') { + router.push(`/task/${task.task_id}`); + } + }; + + const closeTaskDetailModal = () => { + showTaskDetailModal.value = false; + modalTask.value = null; + // 只有当前路由是 /task/:id 时才跳转回 Projects + // 如果在其他页面(如 /generate)打开的弹窗,关闭时保持在原页面 + const currentRoute = router.currentRoute.value; + if (currentRoute.path.startsWith('/task/')) { + // 从任务详情路由打开的,返回 Projects 页面 + router.push({ name: 'Projects' }); + } + // 如果不是任务详情路由,不做任何路由跳转,保持在当前页面 + }; + + // 新增:分享功能相关方法 + const generateShareUrl = (taskId) => { + const baseUrl = window.location.origin; + return `${baseUrl}/share/${taskId}`; + }; + + const copyShareLink = async (taskId, shareType = 'task') => { + try { + const token = localStorage.getItem('accessToken'); + if (!token) { + showAlert(t('pleaseLoginFirst'), 'warning'); + return; + } + + // 调用后端接口创建分享链接 + const response = await fetch('/api/v1/share/create', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${token}` + }, + body: JSON.stringify({ + task_id: taskId, + share_type: shareType + }) + }); + + if (!response.ok) { + throw new Error('创建分享链接失败'); + } + + const data = await response.json(); + const shareUrl = `${window.location.origin}${data.share_url}`; + + // 使用辅助函数复制,支持移动端 + const success = await copyToClipboard(shareUrl); + + // 如果成功复制,显示成功提示 + if (success) { + // 显示带操作按钮的alert + showAlert(t('shareLinkCopied'), 'success', { + label: t('view'), + onClick: () => { + window.open(shareUrl, '_blank'); + } + }); + } + // 如果返回false,说明已经显示了手动复制的弹窗,不需要额外提示 + } catch (err) { + console.error('复制失败:', err); + showAlert(t('copyFailed'), 'error'); + } + }; + + const shareToSocial = (taskId, platform) => { + const shareUrl = generateShareUrl(taskId); + const task = modalTask.value; + const title = task?.params?.prompt || t('aiGeneratedVideo'); + const description = t('checkOutThisAIGeneratedVideo'); + + let shareUrlWithParams = ''; + + switch (platform) { + case 'twitter': + shareUrlWithParams = `https://twitter.com/intent/tweet?text=${encodeURIComponent(title)}&url=${encodeURIComponent(shareUrl)}`; + break; + case 'facebook': + shareUrlWithParams = `https://www.facebook.com/sharer/sharer.php?u=${encodeURIComponent(shareUrl)}`; + break; + case 'linkedin': + shareUrlWithParams = `https://www.linkedin.com/sharing/share-offsite/?url=${encodeURIComponent(shareUrl)}`; + break; + case 'whatsapp': + shareUrlWithParams = `https://wa.me/?text=${encodeURIComponent(title + ' ' + shareUrl)}`; + break; + case 'telegram': + shareUrlWithParams = `https://t.me/share/url?url=${encodeURIComponent(shareUrl)}&text=${encodeURIComponent(title)}`; + break; + case 'weibo': + shareUrlWithParams = `https://service.weibo.com/share/share.php?url=${encodeURIComponent(shareUrl)}&title=${encodeURIComponent(title)}`; + break; + default: + return; + } + + window.open(shareUrlWithParams, '_blank', 'width=600,height=400'); + }; + + // 新增:从路由参数打开任务详情 + const openTaskFromRoute = async (taskId) => { + try { + // 如果任务列表为空,先加载任务数据 + if (tasks.value.length === 0) { + await refreshTasks(); + } + + if (showTaskDetailModal.value && modalTask.value?.task_id === taskId) { + console.log('任务详情已打开,不重复打开'); + return; + } + + // 查找任务 + const task = tasks.value.find(t => t.task_id === taskId); + if (task) { + modalTask.value = task; + openTaskDetailModal(task); + } else { + // 如果任务不在当前列表中,尝试从API获取 + showAlert(t('taskNotFound'), 'error'); + router.push({ name: 'Projects' }); + } + } catch (error) { + console.error('打开任务失败:', error); + showAlert(t('openTaskFailed'), 'error'); + router.push({ name: 'Projects' }); + } + }; + + // 新增:模板分享功能相关方法 + const generateTemplateShareUrl = (templateId) => { + const baseUrl = window.location.origin; + return `${baseUrl}/template/${templateId}`; + }; + + const copyTemplateShareLink = async (templateId) => { + try { + const shareUrl = generateTemplateShareUrl(templateId); + // 使用辅助函数复制,支持移动端 + const success = await copyToClipboard(shareUrl); + + // 如果成功复制,显示成功提示 + if (success) { + showAlert(t('templateShareLinkCopied'), 'success', { + label: t('view'), + onClick: () => { + window.open(shareUrl, '_blank'); + } + }); + } + // 如果返回false,说明已经显示了手动复制的弹窗,不需要额外提示 + } catch (err) { + console.error('复制模板分享链接失败:', err); + showAlert(t('copyFailed'), 'error'); + } + }; + + const shareTemplateToSocial = (templateId, platform) => { + const shareUrl = generateTemplateShareUrl(templateId); + const template = selectedTemplate.value; + const title = template?.params?.prompt || t('aiGeneratedTemplate'); + const description = t('checkOutThisAITemplate'); + + let shareUrlWithParams = ''; + + switch (platform) { + case 'twitter': + shareUrlWithParams = `https://twitter.com/intent/tweet?text=${encodeURIComponent(title)}&url=${encodeURIComponent(shareUrl)}`; + break; + case 'facebook': + shareUrlWithParams = `https://www.facebook.com/sharer/sharer.php?u=${encodeURIComponent(shareUrl)}`; + break; + case 'linkedin': + shareUrlWithParams = `https://www.linkedin.com/sharing/share-offsite/?url=${encodeURIComponent(shareUrl)}`; + break; + case 'whatsapp': + shareUrlWithParams = `https://wa.me/?text=${encodeURIComponent(title + ' ' + shareUrl)}`; + break; + case 'telegram': + shareUrlWithParams = `https://t.me/share/url?url=${encodeURIComponent(shareUrl)}&text=${encodeURIComponent(title)}`; + break; + case 'weibo': + shareUrlWithParams = `https://service.weibo.com/share/share.php?url=${encodeURIComponent(shareUrl)}&title=${encodeURIComponent(title)}`; + break; + default: + return; + } + + window.open(shareUrlWithParams, '_blank', 'width=600,height=400'); + }; + + // 新增:从路由参数打开模板详情 + const openTemplateFromRoute = async (templateId) => { + try { + // 如果模板列表为空,先加载模板数据 + if (inspirationItems.value.length === 0) { + await loadInspirationData(); + } + + if (showTemplateDetailModal.value && selectedTemplate.value?.task_id === templateId) { + console.log('模板详情已打开,不重复打开'); + return; + } + + // 查找模板 + const template = inspirationItems.value.find(t => t.task_id === templateId); + if (template) { + selectedTemplate.value = template; + previewTemplateDetail(template); + } else { + // 如果模板不在当前列表中,尝试从API获取 + showAlert(t('templateNotFound'), 'error'); + router.push({ name: 'Inspirations' }); + } + } catch (error) { + console.error('打开模板失败:', error); + showAlert(t('openTemplateFailed'), 'error'); + router.push({ name: 'Inspirations' }); + } + }; + + // 精选模版相关数据 + const featuredTemplates = ref([]); + const featuredTemplatesLoading = ref(false); + + // 主题管理 + const theme = ref('dark'); // 'light', 'dark' - 默认深色模式 + + // 初始化主题 + const initTheme = () => { + const savedTheme = localStorage.getItem('theme') || 'dark'; // 默认深色模式 + theme.value = savedTheme; + applyTheme(savedTheme); + }; + + // 应用主题(优化版本,减少延迟) + const applyTheme = (newTheme) => { + const html = document.documentElement; + + // 使用 requestAnimationFrame 优化 DOM 操作 + requestAnimationFrame(() => { + // 临时禁用过渡动画以提高切换速度 + html.classList.add('theme-transitioning'); + + if (newTheme === 'dark') { + html.classList.add('dark'); + html.style.colorScheme = 'dark'; + } else { + html.classList.remove('dark'); + html.style.colorScheme = 'light'; + } + + // 短暂延迟后移除过渡禁用类,恢复平滑过渡 + setTimeout(() => { + html.classList.remove('theme-transitioning'); + }, 50); + }); + }; + + // 切换主题(优化版本) + const toggleTheme = () => { + const themes = ['light', 'dark']; + const currentIndex = themes.indexOf(theme.value); + const nextIndex = (currentIndex + 1) % themes.length; + const nextTheme = themes[nextIndex]; + + // 立即更新状态 + theme.value = nextTheme; + + // 异步保存到 localStorage,不阻塞 UI + if (window.requestIdleCallback) { + requestIdleCallback(() => { + localStorage.setItem('theme', nextTheme); + }, { timeout: 100 }); + } else { + // 回退方案:使用 setTimeout + setTimeout(() => { + localStorage.setItem('theme', nextTheme); + }, 0); + } + + // 立即应用主题 + applyTheme(nextTheme); + + // 延迟显示提示,避免阻塞主题切换 + const themeNames = { + 'light': '浅色模式', + 'dark': '深色模式' + }; + setTimeout(() => { + showAlert(`已切换到${themeNames[nextTheme]}`, 'info'); + }, 100); + }; + + // 获取主题图标 + const getThemeIcon = () => { + const iconMap = { + 'light': 'fas fa-sun', + 'dark': 'fas fa-moon' + }; + return iconMap[theme.value] || 'fas fa-moon'; + }; + + // 不需要认证的API调用(用于获取模版数据) + const publicApiCall = async (endpoint, options = {}) => { + const url = `${endpoint}`; + const headers = { + 'Content-Type': 'application/json', + ...options.headers + }; + + const response = await fetch(url, { + ...options, + headers + }); + + if (response.status === 400) { + const error = await response.json(); + showAlert(error.message, 'danger'); + throw new Error(error.message); + } + + // 添加50ms延迟,防止触发服务端频率限制 + await new Promise(resolve => setTimeout(resolve, 50)); + + return response; + }; + + // 获取精选模版数据 + const loadFeaturedTemplates = async (forceRefresh = false) => { + try { + featuredTemplatesLoading.value = true; + + // 构建缓存键 + const cacheKey = `featured_templates_cache`; + + if (!forceRefresh) { + const cachedData = loadFromCache(cacheKey, TEMPLATES_CACHE_EXPIRY); + if (cachedData && cachedData.templates) { + console.log('从缓存加载精选模版数据:', cachedData.templates.length, '个'); + featuredTemplates.value = cachedData.templates; + featuredTemplatesLoading.value = false; + return; + } + } + + // 从API获取精选模版数据(不需要认证) + const params = new URLSearchParams(); + params.append('category', '精选'); + params.append('page_size', '50'); // 获取更多数据用于随机选择 + + const apiUrl = `/api/v1/template/tasks?${params.toString()}`; + const response = await publicApiCall(apiUrl); + + if (response.ok) { + const data = await response.json(); + const templates = data.templates || []; + + // 缓存数据 + saveToCache(cacheKey, { + templates: templates, + timestamp: Date.now() + }); + + featuredTemplates.value = templates; + console.log('成功加载精选模版数据:', templates.length, '个模版'); + } else { + console.warn('加载精选模版数据失败'); + featuredTemplates.value = []; + } + } catch (error) { + console.warn('加载精选模版数据失败:', error); + featuredTemplates.value = []; + } finally { + featuredTemplatesLoading.value = false; + } + }; + + // 获取随机精选模版 + const getRandomFeaturedTemplates = async (count = 10) => { + try { + featuredTemplatesLoading.value = true; + + // 如果当前没有数据,先加载 + if (featuredTemplates.value.length === 0) { + await loadFeaturedTemplates(); + } + + // 如果数据仍然为空,返回空数组 + if (featuredTemplates.value.length === 0) { + return []; + } + + // 随机选择指定数量的模版 + const shuffled = [...featuredTemplates.value].sort(() => 0.5 - Math.random()); + const randomTemplates = shuffled.slice(0, count); + + return randomTemplates; + } catch (error) { + console.error('获取随机精选模版失败:', error); + return []; + } finally { + featuredTemplatesLoading.value = false; + } + }; + const removeTtsHistoryEntry = (entryId) => { + if (!entryId) return; + const currentHistory = loadTtsHistory().filter(entry => entry.id !== entryId); + saveTtsHistory(currentHistory); + }; + + const loadTtsHistory = () => { + try { + const stored = localStorage.getItem('ttsHistory'); + if (!stored) return []; + const parsed = JSON.parse(stored); + ttsHistory.value = Array.isArray(parsed) ? parsed : []; + return ttsHistory.value; + } catch (error) { + console.error('加载TTS历史失败:', error); + ttsHistory.value = []; + return []; + } + }; + + const saveTtsHistory = (historyList) => { + try { + localStorage.setItem('ttsHistory', JSON.stringify(historyList)); + ttsHistory.value = historyList; + } catch (error) { + console.error('保存TTS历史失败:', error); + } + }; + + const addTtsHistoryEntry = (text = '', instruction = '') => { + const trimmedText = (text || '').trim(); + const trimmedInstruction = (instruction || '').trim(); + + if (!trimmedText && !trimmedInstruction) { + return; + } + + const currentHistory = loadTtsHistory(); + + const existingIndex = currentHistory.findIndex(entry => + entry.text === trimmedText && entry.instruction === trimmedInstruction + ); + + const timestamp = new Date().toISOString(); + + if (existingIndex !== -1) { + const existingEntry = currentHistory.splice(existingIndex, 1)[0]; + existingEntry.timestamp = timestamp; + currentHistory.unshift(existingEntry); + } else { + currentHistory.unshift({ + id: Date.now(), + text: trimmedText, + instruction: trimmedInstruction, + timestamp + }); + } + + if (currentHistory.length > 20) { + currentHistory.length = 20; + } + + saveTtsHistory(currentHistory); + }; + + const clearTtsHistory = () => { + ttsHistory.value = []; + localStorage.removeItem('ttsHistory'); + }; + +export { + // 任务类型下拉菜单 + showTaskTypeMenu, + showModelMenu, + isLoggedIn, + loading, + loginLoading, + initLoading, + downloadLoading, + downloadLoadingMessage, + isLoading, + isPageLoading, + + // 录音相关 + isRecording, + recordingDuration, + startRecording, + stopRecording, + formatRecordingDuration, + + loginWithGitHub, + loginWithGoogle, + // 短信登录相关 + phoneNumber, + verifyCode, + smsCountdown, + showSmsForm, + sendSmsCode, + loginWithSms, + handlePhoneEnter, + handleVerifyCodeEnter, + toggleSmsLogin, + submitting, + templateLoading, + templateLoadingMessage, + taskSearchQuery, + currentUser, + models, + tasks, + alert, + showErrorDetails, + showFailureDetails, + confirmDialog, + showConfirmDialog, + showTaskDetailModal, + modalTask, + showVoiceTTSModal, + showPodcastModal, + currentTask, + t2vForm, + i2vForm, + s2vForm, + getCurrentForm, + i2vImagePreview, + s2vImagePreview, + s2vAudioPreview, + getCurrentImagePreview, + getCurrentAudioPreview, + getCurrentVideoPreview, + setCurrentImagePreview, + setCurrentAudioPreview, + setCurrentVideoPreview, + updateUploadedContentStatus, + availableTaskTypes, + availableModelClasses, + currentTaskHints, + currentHintIndex, + startHintRotation, + stopHintRotation, + filteredTasks, + selectedTaskId, + selectedTask, + selectedModel, + selectedTaskFiles, + loadingTaskFiles, + statusFilter, + pagination, + paginationInfo, + currentTaskPage, + taskPageSize, + taskPageInput, + paginationKey, + taskMenuVisible, + toggleTaskMenu, + closeAllTaskMenus, + handleClickOutside, + showAlert, + setLoading, + apiCall, + logout, + login, + loadModels, + sidebarCollapsed, + sidebarWidth, + showExpandHint, + showGlow, + isDefaultStateHidden, + hideDefaultState, + showDefaultState, + isCreationAreaExpanded, + hasUploadedContent, + isContracting, + expandCreationArea, + contractCreationArea, + taskFileCache, + taskFileCacheLoaded, + templateFileCache, + templateFileCacheLoaded, + loadTaskFiles, + downloadFile, + handleDownloadFile, + viewFile, + handleImageUpload, + detectFacesInImage, + faceDetecting, + audioSeparating, + cropFaceImage, + updateFaceRoleName, + toggleFaceEditing, + saveFaceRoleName, + selectTask, + selectModel, + resetForm, + triggerImageUpload, + triggerAudioUpload, + removeImage, + removeAudio, + removeVideo, + handleAudioUpload, + handleVideoUpload, + separateAudioTracks, + updateSeparatedAudioRole, + updateSeparatedAudioName, + toggleSeparatedAudioEditing, + saveSeparatedAudioName, + loadImageAudioTemplates, + selectImageTemplate, + selectAudioTemplate, + previewAudioTemplate, + stopAudioPlayback, + setAudioStopCallback, + getTemplateFile, + imageTemplates, + audioTemplates, + mergedTemplates, + showImageTemplates, + showAudioTemplates, + mediaModalTab, + templatePagination, + templatePaginationInfo, + templateCurrentPage, + templatePageSize, + templatePageInput, + templatePaginationKey, + imageHistory, + audioHistory, + showTemplates, + showHistory, + showPromptModal, + promptModalTab, + submitTask, + fileToBase64, + formatTime, + refreshTasks, + goToPage, + jumpToPage, + getVisiblePages, + goToTemplatePage, + jumpToTemplatePage, + getVisibleTemplatePages, + goToInspirationPage, + jumpToInspirationPage, + getVisibleInspirationPages, + preloadTaskFilesUrl, + preloadTemplateFilesUrl, + loadTaskFilesFromCache, + saveTaskFilesToCache, + getTaskFileFromCache, + setTaskFileToCache, + getTaskFileUrlFromApi, + getTaskFileUrlSync, + // Podcast 音频缓存 + podcastAudioCache, + podcastAudioCacheLoaded, + loadPodcastAudioFromCache, + savePodcastAudioToCache, + getPodcastAudioFromCache, + setPodcastAudioToCache, + getPodcastAudioUrlFromApi, + getTemplateFileUrlFromApi, + getTemplateFileUrl, + getTemplateFileUrlAsync, + createTemplateFileUrlRef, + createTaskFileUrlRef, + loadTemplateFilesFromCache, + saveTemplateFilesToCache, + loadFromCache, + saveToCache, + clearAllCache, + getStatusBadgeClass, + viewSingleResult, + cancelTask, + resumeTask, + deleteTask, + startPollingTask, + stopPollingTask, + reuseTask, + showTaskCreator, + toggleSidebar, + clearPrompt, + getTaskItemClass, + getStatusIndicatorClass, + getTaskTypeBtnClass, + getModelBtnClass, + getTaskTypeIcon, + getTaskTypeName, + getPromptPlaceholder, + getStatusTextClass, + getImagePreview, + getTaskInputUrl, + getTaskInputImage, + getTaskInputAudio, + getTaskFileUrl, + getHistoryImageUrl, + getUserAvatarUrl, + getCurrentImagePreviewUrl, + getCurrentAudioPreviewUrl, + getCurrentVideoPreviewUrl, + handleThumbnailError, + handleImageError, + handleImageLoad, + handleAudioError, + handleAudioLoad, + getTaskStatusDisplay, + getTaskStatusColor, + getTaskStatusIcon, + getTaskDuration, + getRelativeTime, + getTaskHistory, + getActiveTasks, + getOverallProgress, + getProgressTitle, + getProgressInfo, + getSubtaskProgress, + getSubtaskStatusText, + formatEstimatedTime, + formatDuration, + searchTasks, + filterTasksByStatus, + filterTasksByType, + getAlertClass, + getAlertBorderClass, + getAlertTextClass, + getAlertIcon, + getAlertIconBgClass, + getPromptTemplates, + selectPromptTemplate, + promptHistory, + getPromptHistory, + addTaskToHistory, + getLocalTaskHistory, + selectPromptHistory, + clearPromptHistory, + getImageHistory, + getAudioHistory, + selectImageHistory, + selectAudioHistory, + previewAudioHistory, + clearImageHistory, + clearAudioHistory, + clearLocalStorage, + getAudioMimeType, + getAuthHeaders, + startResize, + sidebar, + switchToCreateView, + switchToProjectsView, + switchToInspirationView, + switchToLoginView, + openTaskDetailModal, + closeTaskDetailModal, + generateShareUrl, + copyShareLink, + shareToSocial, + openTaskFromRoute, + generateTemplateShareUrl, + copyTemplateShareLink, + shareTemplateToSocial, + openTemplateFromRoute, + // 灵感广场相关 + inspirationSearchQuery, + selectedInspirationCategory, + inspirationItems, + InspirationCategories, + loadInspirationData, + selectInspirationCategory, + handleInspirationSearch, + loadMoreInspiration, + inspirationPagination, + inspirationPaginationInfo, + // 精选模版相关 + featuredTemplates, + featuredTemplatesLoading, + loadFeaturedTemplates, + getRandomFeaturedTemplates, + inspirationCurrentPage, + inspirationPageSize, + inspirationPageInput, + inspirationPaginationKey, + // 工具函数 + formatDate, + // 模板详情弹窗相关 + showTemplateDetailModal, + selectedTemplate, + previewTemplateDetail, + closeTemplateDetailModal, + useTemplate, + // 图片放大弹窗相关 + showImageZoomModal, + zoomedImageUrl, + showImageZoom, + closeImageZoomModal, + // 模板素材应用相关 + applyTemplateImage, + applyTemplateAudio, + applyTemplatePrompt, + copyPrompt, + // 视频播放控制 + playVideo, + pauseVideo, + toggleVideoPlay, + pauseAllVideos, + updateVideoIcon, + onVideoLoaded, + onVideoError, + onVideoEnded, + applyMobileStyles, + handleLoginCallback, + init, + validateToken, + pollingInterval, + pollingTasks, + apiRequest, + // 主题相关 + theme, + initTheme, + toggleTheme, + getThemeIcon, + loadTtsHistory, + removeTtsHistoryEntry, + ttsHistory, + addTtsHistoryEntry, + saveTtsHistory, + clearTtsHistory, + }; diff --git a/lightx2v/deploy/server/frontend/src/views/404.vue b/lightx2v/deploy/server/frontend/src/views/404.vue new file mode 100644 index 0000000000000000000000000000000000000000..a868dbe3d88302b87bf1a163dbb504377822dce1 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/views/404.vue @@ -0,0 +1,56 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/views/Layout.vue b/lightx2v/deploy/server/frontend/src/views/Layout.vue new file mode 100644 index 0000000000000000000000000000000000000000..e6044d18d7e5e75896cb7c6a91fd90df959329aa --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/views/Layout.vue @@ -0,0 +1,85 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/views/Login.vue b/lightx2v/deploy/server/frontend/src/views/Login.vue new file mode 100644 index 0000000000000000000000000000000000000000..d73f74f652c1e0e1a45245fbb1ca5fe942227320 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/views/Login.vue @@ -0,0 +1,108 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/views/PodcastGenerate.vue b/lightx2v/deploy/server/frontend/src/views/PodcastGenerate.vue new file mode 100644 index 0000000000000000000000000000000000000000..eee9da8aadb89f68f9d997ba6ac38af7c38346f9 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/views/PodcastGenerate.vue @@ -0,0 +1,3893 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/src/views/Share.vue b/lightx2v/deploy/server/frontend/src/views/Share.vue new file mode 100644 index 0000000000000000000000000000000000000000..189da49b0c9050d9899074c6f974da9bb32bd159 --- /dev/null +++ b/lightx2v/deploy/server/frontend/src/views/Share.vue @@ -0,0 +1,511 @@ + + + + + diff --git a/lightx2v/deploy/server/frontend/vite.config.js b/lightx2v/deploy/server/frontend/vite.config.js new file mode 100644 index 0000000000000000000000000000000000000000..4815a12918a6d10bc79dd5833bcc4de5824e446f --- /dev/null +++ b/lightx2v/deploy/server/frontend/vite.config.js @@ -0,0 +1,8 @@ +import { defineConfig } from 'vite' +import vue from '@vitejs/plugin-vue' +import tailwindcss from '@tailwindcss/vite' + +// https://vite.dev/config/ +export default defineConfig({ + plugins: [vue(), tailwindcss()], +}) diff --git a/lightx2v/deploy/server/metrics.py b/lightx2v/deploy/server/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..62ad629521c28d2706b2b57beb030df0b93a787a --- /dev/null +++ b/lightx2v/deploy/server/metrics.py @@ -0,0 +1,65 @@ +from loguru import logger +from prometheus_client import Counter, Gauge, Summary, generate_latest +from prometheus_client.core import CollectorRegistry + +from lightx2v.deploy.task_manager import ActiveStatus, FinishedStatus, TaskStatus + +REGISTRY = CollectorRegistry() + + +class MetricMonitor: + def __init__(self): + self.task_all = Counter("task_all_total", "Total count of all tasks", ["task_type", "model_cls", "stage"], registry=REGISTRY) + self.task_end = Counter("task_end_total", "Total count of ended tasks", ["task_type", "model_cls", "stage", "status"], registry=REGISTRY) + self.task_active = Gauge("task_active_size", "Current count of active tasks", ["task_type", "model_cls", "stage"], registry=REGISTRY) + self.task_elapse = Summary("task_elapse_seconds", "Elapse time of tasks", ["task_type", "model_cls", "stage", "end_status"], registry=REGISTRY) + self.subtask_all = Counter("subtask_all_total", "Total count of all subtasks", ["queue"], registry=REGISTRY) + self.subtask_end = Counter("subtask_end_total", "Total count of ended subtasks", ["queue", "status"], registry=REGISTRY) + self.subtask_active = Gauge("subtask_active_size", "Current count of active subtasks", ["queue", "status"], registry=REGISTRY) + self.subtask_elapse = Summary("subtask_elapse_seconds", "Elapse time of subtasks", ["queue", "elapse_key"], registry=REGISTRY) + + def record_task_start(self, task): + self.task_all.labels(task["task_type"], task["model_cls"], task["stage"]).inc() + self.task_active.labels(task["task_type"], task["model_cls"], task["stage"]).inc() + logger.info(f"Metrics task_all + 1, task_active +1") + + def record_task_end(self, task, status, elapse): + self.task_end.labels(task["task_type"], task["model_cls"], task["stage"], status.name).inc() + self.task_active.labels(task["task_type"], task["model_cls"], task["stage"]).dec() + self.task_elapse.labels(task["task_type"], task["model_cls"], task["stage"], status.name).observe(elapse) + logger.info(f"Metrics task_end + 1, task_active -1, task_elapse observe {elapse}") + + def record_subtask_change(self, subtask, old_status, new_status, elapse_key, elapse): + if old_status in ActiveStatus and new_status in FinishedStatus: + self.subtask_end.labels(subtask["queue"], elapse_key).inc() + logger.info(f"Metrics subtask_end + 1") + if old_status in ActiveStatus: + self.subtask_active.labels(subtask["queue"], old_status.name).dec() + logger.info(f"Metrics subtask_active {old_status.name} -1") + if new_status in ActiveStatus: + self.subtask_active.labels(subtask["queue"], new_status.name).inc() + logger.info(f"Metrics subtask_active {new_status.name} + 1") + if new_status == TaskStatus.CREATED: + self.subtask_all.labels(subtask["queue"]).inc() + logger.info(f"Metrics subtask_all + 1") + if elapse and elapse_key: + self.subtask_elapse.labels(subtask["queue"], elapse_key).observe(elapse) + logger.info(f"Metrics subtask_elapse observe {elapse}") + + # restart server, we should recover active tasks in data_manager + def record_task_recover(self, tasks): + for task in tasks: + if task["status"] in ActiveStatus: + self.record_task_start(task) + + # restart server, we should recover active tasks in data_manager + def record_subtask_recover(self, subtasks): + for subtask in subtasks: + if subtask["status"] in ActiveStatus: + self.subtask_all.labels(subtask["queue"]).inc() + self.subtask_active.labels(subtask["queue"], subtask["status"].name).inc() + logger.info(f"Metrics subtask_active {subtask['status'].name} + 1") + logger.info(f"Metrics subtask_all + 1") + + def get_metrics(self): + return generate_latest(REGISTRY) diff --git a/lightx2v/deploy/server/monitor.py b/lightx2v/deploy/server/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6528024705ff25945714e7c8c9febda6006bcd --- /dev/null +++ b/lightx2v/deploy/server/monitor.py @@ -0,0 +1,375 @@ +import asyncio +import time +from enum import Enum + +from loguru import logger + +from lightx2v.deploy.common.utils import class_try_catch_async +from lightx2v.deploy.task_manager import TaskStatus + + +class WorkerStatus(Enum): + FETCHING = 1 + FETCHED = 2 + DISCONNECT = 3 + REPORT = 4 + PING = 5 + + +class CostWindow: + def __init__(self, window): + self.window = window + self.costs = [] + self.avg = None + + def append(self, cost): + self.costs.append(cost) + if len(self.costs) > self.window: + self.costs.pop(0) + self.avg = sum(self.costs) / len(self.costs) + + +class WorkerClient: + def __init__(self, queue, identity, infer_timeout, offline_timeout, avg_window, ping_timeout, fetching_timeout): + self.queue = queue + self.identity = identity + self.status = None + self.update_t = time.time() + self.fetched_t = None + self.infer_cost = CostWindow(avg_window) + self.offline_cost = CostWindow(avg_window) + self.infer_timeout = infer_timeout + self.offline_timeout = offline_timeout + self.ping_timeout = ping_timeout + self.fetching_timeout = fetching_timeout + + # FETCHING -> FETCHED -> PING * n -> REPORT -> FETCHING + # FETCHING -> DISCONNECT -> FETCHING + def update(self, status: WorkerStatus): + pre_status = self.status + pre_t = self.update_t + self.status = status + self.update_t = time.time() + + if status == WorkerStatus.FETCHING: + if pre_status in [WorkerStatus.DISCONNECT, WorkerStatus.REPORT] and pre_t is not None: + cur_cost = self.update_t - pre_t + if cur_cost < self.offline_timeout: + self.offline_cost.append(max(cur_cost, 1)) + + elif status == WorkerStatus.REPORT: + if self.fetched_t is not None: + cur_cost = self.update_t - self.fetched_t + self.fetched_t = None + if cur_cost < self.infer_timeout: + self.infer_cost.append(max(cur_cost, 1)) + logger.info(f"Worker {self.identity} {self.queue} avg infer cost update: {self.infer_cost.avg:.2f} s") + + elif status == WorkerStatus.FETCHED: + self.fetched_t = time.time() + + def check(self): + # infer too long + if self.fetched_t is not None: + elapse = time.time() - self.fetched_t + if self.infer_cost.avg is not None and elapse > self.infer_cost.avg * 5: + logger.warning(f"Worker {self.identity} {self.queue} infer timeout: {elapse:.2f} s") + return False + if elapse > self.infer_timeout: + logger.warning(f"Worker {self.identity} {self.queue} infer timeout2: {elapse:.2f} s") + return False + + elapse = time.time() - self.update_t + # no ping too long + if self.status in [WorkerStatus.FETCHED, WorkerStatus.PING]: + if elapse > self.ping_timeout: + logger.warning(f"Worker {self.identity} {self.queue} ping timeout: {elapse:.2f} s") + return False + # offline too long + elif self.status in [WorkerStatus.DISCONNECT, WorkerStatus.REPORT]: + if self.offline_cost.avg is not None and elapse > self.offline_cost.avg * 5: + logger.warning(f"Worker {self.identity} {self.queue} offline timeout: {elapse:.2f} s") + return False + if elapse > self.offline_timeout: + logger.warning(f"Worker {self.identity} {self.queue} offline timeout2: {elapse:.2f} s") + return False + # fetching too long + elif self.status == WorkerStatus.FETCHING: + if elapse > self.fetching_timeout: + logger.warning(f"Worker {self.identity} {self.queue} fetching timeout: {elapse:.2f} s") + return False + return True + + +class ServerMonitor: + def __init__(self, model_pipelines, task_manager, queue_manager): + self.model_pipelines = model_pipelines + self.task_manager = task_manager + self.queue_manager = queue_manager + self.stop = False + self.worker_clients = {} + self.subtask_run_timeouts = {} + self.pending_subtasks = {} + + self.all_queues = self.model_pipelines.get_queues() + self.config = self.model_pipelines.get_monitor_config() + self.interval = self.config.get("monitor_interval", 30) + self.fetching_timeout = self.config.get("fetching_timeout", 1000) + + for queue in self.all_queues: + self.subtask_run_timeouts[queue] = self.config["subtask_running_timeouts"].get(queue, 3600) + self.subtask_created_timeout = self.config["subtask_created_timeout"] + self.subtask_pending_timeout = self.config["subtask_pending_timeout"] + self.worker_avg_window = self.config["worker_avg_window"] + self.worker_offline_timeout = self.config["worker_offline_timeout"] + self.worker_min_capacity = self.config["worker_min_capacity"] + self.task_timeout = self.config["task_timeout"] + self.ping_timeout = self.config["ping_timeout"] + + self.user_visits = {} # user_id -> last_visit_t + self.user_max_active_tasks = self.config["user_max_active_tasks"] + self.user_max_daily_tasks = self.config["user_max_daily_tasks"] + self.user_visit_frequency = self.config["user_visit_frequency"] + + assert self.worker_avg_window > 0 + assert self.worker_offline_timeout > 0 + assert self.worker_min_capacity > 0 + assert self.task_timeout > 0 + assert self.ping_timeout > 0 + assert self.user_max_active_tasks > 0 + assert self.user_max_daily_tasks > 0 + assert self.user_visit_frequency > 0 + + async def init(self): + await self.init_pending_subtasks() + + async def loop(self): + while True: + if self.stop: + break + await self.clean_workers() + await self.clean_subtasks() + await asyncio.sleep(self.interval) + logger.info("ServerMonitor stopped") + + async def close(self): + self.stop = True + self.model_pipelines = None + self.task_manager = None + self.queue_manager = None + self.worker_clients = None + + def init_worker(self, queue, identity): + if queue not in self.worker_clients: + self.worker_clients[queue] = {} + if identity not in self.worker_clients[queue]: + infer_timeout = self.subtask_run_timeouts[queue] + self.worker_clients[queue][identity] = WorkerClient(queue, identity, infer_timeout, self.worker_offline_timeout, self.worker_avg_window, self.ping_timeout, self.fetching_timeout) + return self.worker_clients[queue][identity] + + @class_try_catch_async + async def worker_update(self, queue, identity, status): + worker = self.init_worker(queue, identity) + worker.update(status) + logger.info(f"Worker {identity} {queue} update [{status}]") + + @class_try_catch_async + async def clean_workers(self): + qs = list(self.worker_clients.keys()) + for queue in qs: + idens = list(self.worker_clients[queue].keys()) + for identity in idens: + if not self.worker_clients[queue][identity].check(): + self.worker_clients[queue].pop(identity) + logger.warning(f"Worker {queue} {identity} out of contact removed, remain {self.worker_clients[queue]}") + + @class_try_catch_async + async def clean_subtasks(self): + created_end_t = time.time() - self.subtask_created_timeout + pending_end_t = time.time() - self.subtask_pending_timeout + ping_end_t = time.time() - self.ping_timeout + fails = set() + + created_tasks = await self.task_manager.list_tasks(status=TaskStatus.CREATED, subtasks=True, end_updated_t=created_end_t) + pending_tasks = await self.task_manager.list_tasks(status=TaskStatus.PENDING, subtasks=True, end_updated_t=pending_end_t) + + def fmt_subtask(t): + return f"({t['task_id']}, {t['worker_name']}, {t['queue']}, {t['worker_identity']})" + + for t in created_tasks + pending_tasks: + if t["task_id"] in fails: + continue + elapse = time.time() - t["update_t"] + logger.warning(f"Subtask {fmt_subtask(t)} CREATED / PENDING timeout: {elapse:.2f} s") + await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"CREATED / PENDING timeout: {elapse:.2f} s") + fails.add(t["task_id"]) + + running_tasks = await self.task_manager.list_tasks(status=TaskStatus.RUNNING, subtasks=True) + + for t in running_tasks: + if t["task_id"] in fails: + continue + if t["ping_t"] > 0: + ping_elapse = time.time() - t["ping_t"] + if ping_elapse >= self.ping_timeout: + logger.warning(f"Subtask {fmt_subtask(t)} PING timeout: {ping_elapse:.2f} s") + await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"PING timeout: {ping_elapse:.2f} s") + fails.add(t["task_id"]) + elapse = time.time() - t["update_t"] + limit = self.subtask_run_timeouts[t["queue"]] + if elapse >= limit: + logger.warning(f"Subtask {fmt_subtask(t)} RUNNING timeout: {elapse:.2f} s") + await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"RUNNING timeout: {elapse:.2f} s") + fails.add(t["task_id"]) + + @class_try_catch_async + async def get_avg_worker_infer_cost(self, queue): + if queue not in self.worker_clients: + self.worker_clients[queue] = {} + infer_costs = [] + for _, client in self.worker_clients[queue].items(): + if client.infer_cost.avg is not None: + infer_costs.append(client.infer_cost.avg) + if len(infer_costs) <= 0: + return self.subtask_run_timeouts[queue] + return sum(infer_costs) / len(infer_costs) + + @class_try_catch_async + async def check_user_busy(self, user_id, active_new_task=False): + # check if user visit too frequently + cur_t = time.time() + if user_id in self.user_visits: + elapse = cur_t - self.user_visits[user_id] + if elapse <= self.user_visit_frequency: + return f"User {user_id} visit too frequently, {elapse:.2f} s vs {self.user_visit_frequency:.2f} s" + self.user_visits[user_id] = cur_t + + if active_new_task: + # check if user has too many active tasks + active_statuses = [TaskStatus.RUNNING, TaskStatus.PENDING, TaskStatus.CREATED] + active_tasks = await self.task_manager.list_tasks(status=active_statuses, user_id=user_id) + if len(active_tasks) >= self.user_max_active_tasks: + return f"User {user_id} has too many active tasks, {len(active_tasks)} vs {self.user_max_active_tasks}" + + # check if user has too many daily tasks + daily_statuses = active_statuses + [TaskStatus.SUCCEED, TaskStatus.CANCEL, TaskStatus.FAILED] + daily_tasks = await self.task_manager.list_tasks(status=daily_statuses, user_id=user_id, start_created_t=cur_t - 86400, include_delete=True) + if len(daily_tasks) >= self.user_max_daily_tasks: + return f"User {user_id} has too many daily tasks, {len(daily_tasks)} vs {self.user_max_daily_tasks}" + + return True + + # check if a task can be published to queues + @class_try_catch_async + async def check_queue_busy(self, keys, queues): + wait_time = 0 + + for queue in queues: + avg_cost = await self.get_avg_worker_infer_cost(queue) + worker_cnt = await self.get_ready_worker_count(queue) + subtask_pending = await self.queue_manager.pending_num(queue) + capacity = self.task_timeout * max(worker_cnt, 1) // avg_cost + capacity = max(self.worker_min_capacity, capacity) + + if subtask_pending >= capacity: + ss = f"pending={subtask_pending}, capacity={capacity}" + logger.warning(f"Queue {queue} busy, {ss}, task {keys} cannot be publised!") + return None + wait_time += avg_cost * subtask_pending / max(worker_cnt, 1) + return wait_time + + @class_try_catch_async + async def init_pending_subtasks(self): + # query all pending subtasks in task_manager + subtasks = {} + rows = await self.task_manager.list_tasks(status=TaskStatus.PENDING, subtasks=True, sort_by_update_t=True) + for row in rows: + if row["queue"] not in subtasks: + subtasks[row["queue"]] = [] + subtasks[row["queue"]].append(row["task_id"]) + for queue in self.all_queues: + if queue not in subtasks: + subtasks[queue] = [] + + # self.pending_subtasks = {queue: {"consume_count": int, "max_count": int, subtasks: {task_id: order}} + for queue, task_ids in subtasks.items(): + pending_num = await self.queue_manager.pending_num(queue) + self.pending_subtasks[queue] = {"consume_count": 0, "max_count": pending_num, "subtasks": {}} + for i, task_id in enumerate(task_ids): + self.pending_subtasks[queue]["subtasks"][task_id] = max(pending_num - i, 1) + logger.info(f"Init pending subtasks: {self.pending_subtasks}") + + @class_try_catch_async + async def pending_subtasks_add(self, queue, task_id): + if queue not in self.pending_subtasks: + logger.warning(f"Queue {queue} not found in self.pending_subtasks") + return + max_count = self.pending_subtasks[queue]["max_count"] + self.pending_subtasks[queue]["subtasks"][task_id] = max_count + 1 + self.pending_subtasks[queue]["max_count"] = max_count + 1 + # logger.warning(f"Pending subtasks {queue} add {task_id}: {self.pending_subtasks[queue]}") + + @class_try_catch_async + async def pending_subtasks_sub(self, queue, task_id): + if queue not in self.pending_subtasks: + logger.warning(f"Queue {queue} not found in self.pending_subtasks") + return + self.pending_subtasks[queue]["consume_count"] += 1 + if task_id in self.pending_subtasks[queue]["subtasks"]: + self.pending_subtasks[queue]["subtasks"].pop(task_id) + # logger.warning(f"Pending subtasks {queue} sub {task_id}: {self.pending_subtasks[queue]}") + + @class_try_catch_async + async def pending_subtasks_get_order(self, queue, task_id): + if queue not in self.pending_subtasks: + logger.warning(f"Queue {queue} not found in self.pending_subtasks") + return None + if task_id not in self.pending_subtasks[queue]["subtasks"]: + logger.warning(f"Task {task_id} not found in self.pending_subtasks[queue]['subtasks']") + return None + order = self.pending_subtasks[queue]["subtasks"][task_id] + consume = self.pending_subtasks[queue]["consume_count"] + real_order = max(order - consume, 1) + # logger.warning(f"Pending subtasks {queue} get order {task_id}: real={real_order} order={order} consume={consume}") + return real_order + + @class_try_catch_async + async def get_ready_worker_count(self, queue): + if queue not in self.worker_clients: + self.worker_clients[queue] = {} + return len(self.worker_clients[queue]) + + @class_try_catch_async + async def format_subtask(self, subtasks): + ret = [] + for sub in subtasks: + cur = { + "status": sub["status"].name, + "worker_name": sub["worker_name"], + "fail_msg": None, + "elapses": {}, + "estimated_pending_order": None, + "estimated_pending_secs": None, + "estimated_running_secs": None, + "ready_worker_count": None, + } + if sub["status"] in [TaskStatus.PENDING, TaskStatus.RUNNING]: + cur["estimated_running_secs"] = await self.get_avg_worker_infer_cost(sub["queue"]) + cur["ready_worker_count"] = await self.get_ready_worker_count(sub["queue"]) + if sub["status"] == TaskStatus.PENDING: + order = await self.pending_subtasks_get_order(sub["queue"], sub["task_id"]) + worker_count = max(cur["ready_worker_count"], 1e-7) + if order is not None: + cur["estimated_pending_order"] = order + wait_cycle = (order - 1) // worker_count + 1 + cur["estimated_pending_secs"] = cur["estimated_running_secs"] * wait_cycle + + if isinstance(sub["extra_info"], dict): + if "elapses" in sub["extra_info"]: + cur["elapses"] = sub["extra_info"]["elapses"] + if "start_t" in sub["extra_info"]: + cur["elapses"][f"{cur['status']}-"] = time.time() - sub["extra_info"]["start_t"] + if "fail_msg" in sub["extra_info"]: + cur["fail_msg"] = sub["extra_info"]["fail_msg"] + ret.append(cur) + return ret diff --git a/lightx2v/deploy/server/redis_client.py b/lightx2v/deploy/server/redis_client.py new file mode 100644 index 0000000000000000000000000000000000000000..52808a4a2e54fb0ce9d37897964f8caba53df2a2 --- /dev/null +++ b/lightx2v/deploy/server/redis_client.py @@ -0,0 +1,247 @@ +import asyncio +import json +import traceback + +from loguru import logger +from redis import asyncio as aioredis + +from lightx2v.deploy.common.utils import class_try_catch_async + + +class RedisClient: + def __init__(self, redis_url, retry_times=3): + self.redis_url = redis_url + self.client = None + self.retry_times = retry_times + self.base_key = "lightx2v" + self.init_scriptss() + + def init_scriptss(self): + self.script_create_if_not_exists = """ + local key = KEYS[1] + local data_json = ARGV[1] + if redis.call('EXISTS', key) == 0 then + local data = cjson.decode(data_json) + for field, value in pairs(data) do + redis.call('HSET', key, field, value) + end + return 1 + else + return 0 + end + """ + self.script_increment_and_get = """ + local key = KEYS[1] + local field = ARGV[1] + local diff = tonumber(ARGV[2]) + local new_value = redis.call('HINCRBY', key, field, diff) + return new_value + """ + self.script_correct_pending_info = """ + local key = KEYS[1] + local pending_num = tonumber(ARGV[1]) + if redis.call('EXISTS', key) ~= 0 then + local consume_count = redis.call('HGET', key, 'consume_count') + local max_count = redis.call('HGET', key, 'max_count') + local redis_pending = tonumber(max_count) - tonumber(consume_count) + if redis_pending > pending_num then + redis.call('HINCRBY', key, 'consume_count', redis_pending - pending_num) + return 'consume_count added ' .. (redis_pending - pending_num) + end + if redis_pending < pending_num then + redis.call('HINCRBY', key, 'max_count', pending_num - redis_pending) + return 'max_count added ' .. (pending_num - redis_pending) + end + return 'pending equal ' .. pending_num .. ' vs ' .. redis_pending + else + return 'key not exists' + end + """ + self.script_list_push = """ + local key = KEYS[1] + local value = ARGV[1] + local limit = tonumber(ARGV[2]) + redis.call('LPUSH', key, value) + redis.call('LTRIM', key, 0, limit) + return 1 + """ + self.script_list_avg = """ + local key = KEYS[1] + local limit = tonumber(ARGV[1]) + local values = redis.call('LRANGE', key, 0, limit) + local sum = 0.0 + local count = 0.0 + for _, value in ipairs(values) do + sum = sum + tonumber(value) + count = count + 1 + end + if count == 0 then + return "-1" + end + return tostring(sum / count) + """ + + async def init(self): + for i in range(self.retry_times): + try: + self.client = aioredis.Redis.from_url(self.redis_url, protocol=3) + ret = await self.client.ping() + logger.info(f"Redis connection initialized, ping: {ret}") + assert ret, "Redis connection failed" + break + except Exception: + logger.warning(f"Redis connection failed, retry {i + 1}/{self.retry_times}: {traceback.format_exc()}") + await asyncio.sleep(1) + + def fmt_key(self, key): + return f"{self.base_key}:{key}" + + @class_try_catch_async + async def correct_pending_info(self, key, pending_num): + key = self.fmt_key(key) + script = self.client.register_script(self.script_correct_pending_info) + result = await script(keys=[key], args=[pending_num]) + logger.warning(f"Redis correct pending info {key} with {pending_num}: {result}") + return result + + @class_try_catch_async + async def create_if_not_exists(self, key, value): + key = self.fmt_key(key) + script = self.client.register_script(self.script_create_if_not_exists) + result = await script(keys=[key], args=[json.dumps(value)]) + if result == 1: + logger.info(f"Redis key '{key}' created successfully.") + else: + logger.warning(f"Redis key '{key}' already exists, not set.") + + @class_try_catch_async + async def increment_and_get(self, key, field, diff): + key = self.fmt_key(key) + script = self.client.register_script(self.script_increment_and_get) + result = await script(keys=[key], args=[field, diff]) + return result + + @class_try_catch_async + async def hset(self, key, field, value): + key = self.fmt_key(key) + return await self.client.hset(key, field, value) + + @class_try_catch_async + async def hget(self, key, field): + key = self.fmt_key(key) + result = await self.client.hget(key, field) + return result + + @class_try_catch_async + async def hgetall(self, key): + key = self.fmt_key(key) + result = await self.client.hgetall(key) + return result + + @class_try_catch_async + async def hdel(self, key, field): + key = self.fmt_key(key) + return await self.client.hdel(key, field) + + @class_try_catch_async + async def hlen(self, key): + key = self.fmt_key(key) + result = await self.client.hlen(key) + return result + + @class_try_catch_async + async def set(self, key, value, nx=False): + key = self.fmt_key(key) + result = await self.client.set(key, value, nx=nx) + if result is not True: + logger.warning(f"redis set {key} = {value} failed") + return result + + @class_try_catch_async + async def get(self, key): + key = self.fmt_key(key) + result = await self.client.get(key) + return result + + @class_try_catch_async + async def delete_key(self, key): + key = self.fmt_key(key) + return await self.client.delete(key) + + @class_try_catch_async + async def list_push(self, key, value, limit): + key = self.fmt_key(key) + script = self.client.register_script(self.script_list_push) + result = await script(keys=[key], args=[value, limit]) + return result + + @class_try_catch_async + async def list_avg(self, key, limit): + key = self.fmt_key(key) + script = self.client.register_script(self.script_list_avg) + result = await script(keys=[key], args=[limit]) + return float(result) + + async def close(self): + try: + if self.client: + await self.client.aclose() + logger.info("Redis connection closed") + except Exception: + logger.warning(f"Error closing Redis connection: {traceback.format_exc()}") + + +async def main(): + redis_url = "redis://user:password@localhost:6379/1?decode_responses=True&socket_timeout=5" + r = RedisClient(redis_url) + await r.init() + + v1 = await r.set("test2", "1", nx=True) + logger.info(f"set test2=1: {v1}, {await r.get('test2')}") + v2 = await r.set("test2", "2", nx=True) + logger.info(f"set test2=2: {v2}, {await r.get('test2')}") + + await r.create_if_not_exists("test", {"a": 1, "b": 2}) + logger.info(f"create test: {await r.hgetall('test')}") + await r.create_if_not_exists("test", {"a": 2, "b": 3}) + logger.info(f"create test again: {await r.hgetall('test')}") + logger.info(f"hlen test: {await r.hlen('test')}") + + v = await r.increment_and_get("test", "a", 1) + logger.info(f"a+1: {v}, a={await r.hget('test', 'a')}") + v = await r.increment_and_get("test", "b", 3) + logger.info(f"b+3: {v}, b={await r.hget('test', 'b')}") + + await r.hset("test", "a", 233) + logger.info(f"set a=233: a={await r.hget('test', 'a')}") + await r.hset("test", "c", 456) + logger.info(f"set c=456: c={await r.hget('test', 'c')}") + logger.info(f"all: {await r.hgetall('test')}") + logger.info(f"hlen test: {await r.hlen('test')}") + logger.info(f"get unknown key: {await r.hget('test', 'unknown')}") + + await r.list_push("test_list", 1, 20) + logger.info(f"list push 1: {await r.list_avg('test_list', 20)}") + await r.list_push("test_list", 2, 20) + logger.info(f"list push 2: {await r.list_avg('test_list', 20)}") + await r.list_push("test_list", 3, 20) + logger.info(f"list push 3: {await r.list_avg('test_list', 20)}") + + await r.delete_key("test_list") + logger.info(f"delete test_list: {await r.list_avg('test_list', 20)}") + + await r.delete_key("test2") + logger.info(f"delete test2: {await r.get('test2')}") + + await r.hdel("test", "a") + logger.info(f"hdel test a: {await r.hgetall('test')}") + + await r.delete_key("test") + logger.info(f"delete test: {await r.hgetall('test')}") + logger.info(f"hlen test: {await r.hlen('test')}") + + await r.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightx2v/deploy/server/redis_monitor.py b/lightx2v/deploy/server/redis_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..9298f553049a593a07906f1b5452413dd9001be6 --- /dev/null +++ b/lightx2v/deploy/server/redis_monitor.py @@ -0,0 +1,166 @@ +import asyncio +import json +import time + +from loguru import logger + +from lightx2v.deploy.common.utils import class_try_catch_async +from lightx2v.deploy.server.monitor import ServerMonitor, WorkerStatus +from lightx2v.deploy.server.redis_client import RedisClient + + +class RedisServerMonitor(ServerMonitor): + def __init__(self, model_pipelines, task_manager, queue_manager, redis_url): + super().__init__(model_pipelines, task_manager, queue_manager) + self.redis_url = redis_url + self.redis_client = RedisClient(redis_url) + self.last_correct = None + self.correct_interval = 60 * 60 * 24 + + async def init(self): + await self.redis_client.init() + await self.init_pending_subtasks() + + async def loop(self): + while True: + if self.stop: + break + if self.last_correct is None or time.time() - self.last_correct > self.correct_interval: + self.last_correct = time.time() + await self.correct_pending_info() + await self.clean_workers() + await self.clean_subtasks() + await asyncio.sleep(self.interval) + logger.info("RedisServerMonitor stopped") + + async def close(self): + await super().close() + await self.redis_client.close() + + @class_try_catch_async + async def worker_update(self, queue, identity, status): + status = status.name + key = f"workers:{queue}:workers" + infer_key = f"workers:{queue}:infer_cost" + + update_t = time.time() + worker = await self.redis_client.hget(key, identity) + if worker is None: + worker = {"status": "", "fetched_t": 0, "update_t": update_t} + await self.redis_client.hset(key, identity, json.dumps(worker)) + else: + worker = json.loads(worker) + + pre_status = worker["status"] + pre_fetched_t = float(worker["fetched_t"]) + worker["status"] = status + worker["update_t"] = update_t + + if status == WorkerStatus.REPORT.name and pre_fetched_t > 0: + cur_cost = update_t - pre_fetched_t + worker["fetched_t"] = 0.0 + if cur_cost < self.subtask_run_timeouts[queue]: + await self.redis_client.list_push(infer_key, max(cur_cost, 1), self.worker_avg_window) + logger.info(f"Worker {identity} {queue} avg infer cost update: {cur_cost:.2f} s") + + elif status == WorkerStatus.FETCHED.name: + worker["fetched_t"] = update_t + + await self.redis_client.hset(key, identity, json.dumps(worker)) + logger.info(f"Worker {identity} {queue} update [{status}]") + + @class_try_catch_async + async def clean_workers(self): + for queue in self.all_queues: + key = f"workers:{queue}:workers" + workers = await self.redis_client.hgetall(key) + + for identity, worker in workers.items(): + worker = json.loads(worker) + fetched_t = float(worker["fetched_t"]) + update_t = float(worker["update_t"]) + status = worker["status"] + # logger.warning(f"{queue} avg infer cost {infer_avg:.2f} s, worker: {worker}") + + # infer too long + if fetched_t > 0: + elapse = time.time() - fetched_t + if elapse > self.subtask_run_timeouts[queue]: + logger.warning(f"Worker {identity} {queue} infer timeout2: {elapse:.2f} s") + await self.redis_client.hdel(key, identity) + continue + + elapse = time.time() - update_t + # no ping too long + if status in [WorkerStatus.FETCHED.name, WorkerStatus.PING.name]: + if elapse > self.ping_timeout: + logger.warning(f"Worker {identity} {queue} ping timeout: {elapse:.2f} s") + await self.redis_client.hdel(key, identity) + continue + + # offline too long + elif status in [WorkerStatus.DISCONNECT.name, WorkerStatus.REPORT.name]: + if elapse > self.worker_offline_timeout: + logger.warning(f"Worker {identity} {queue} offline timeout2: {elapse:.2f} s") + await self.redis_client.hdel(key, identity) + continue + + # fetching too long + elif status == WorkerStatus.FETCHING.name: + if elapse > self.fetching_timeout: + logger.warning(f"Worker {identity} {queue} fetching timeout: {elapse:.2f} s") + await self.redis_client.hdel(key, identity) + continue + + async def get_ready_worker_count(self, queue): + key = f"workers:{queue}:workers" + worker_count = await self.redis_client.hlen(key) + return worker_count + + async def get_avg_worker_infer_cost(self, queue): + infer_key = f"workers:{queue}:infer_cost" + infer_cost = await self.redis_client.list_avg(infer_key, self.worker_avg_window) + if infer_cost < 0: + return self.subtask_run_timeouts[queue] + return infer_cost + + async def correct_pending_info(self): + for queue in self.all_queues: + pending_num = await self.queue_manager.pending_num(queue) + await self.redis_client.correct_pending_info(f"pendings:{queue}:info", pending_num) + + @class_try_catch_async + async def init_pending_subtasks(self): + await super().init_pending_subtasks() + # save to redis if not exists + for queue, v in self.pending_subtasks.items(): + subtasks = v.pop("subtasks", {}) + await self.redis_client.create_if_not_exists(f"pendings:{queue}:info", v) + for task_id, order_id in subtasks.items(): + await self.redis_client.set(f"pendings:{queue}:subtasks:{task_id}", order_id, nx=True) + self.pending_subtasks = None + logger.info(f"Inited pending subtasks to redis") + + @class_try_catch_async + async def pending_subtasks_add(self, queue, task_id): + max_count = await self.redis_client.increment_and_get(f"pendings:{queue}:info", "max_count", 1) + await self.redis_client.set(f"pendings:{queue}:subtasks:{task_id}", max_count) + # logger.warning(f"Redis pending subtasks {queue} add {task_id}: {max_count}") + + @class_try_catch_async + async def pending_subtasks_sub(self, queue, task_id): + consume_count = await self.redis_client.increment_and_get(f"pendings:{queue}:info", "consume_count", 1) + await self.redis_client.delete_key(f"pendings:{queue}:subtasks:{task_id}") + # logger.warning(f"Redis pending subtasks {queue} sub {task_id}: {consume_count}") + + @class_try_catch_async + async def pending_subtasks_get_order(self, queue, task_id): + order = await self.redis_client.get(f"pendings:{queue}:subtasks:{task_id}") + if order is None: + return None + consume = await self.redis_client.hget(f"pendings:{queue}:info", "consume_count") + if consume is None: + return None + real_order = max(int(order) - int(consume), 1) + # logger.warning(f"Redis pending subtasks {queue} get order {task_id}: real={real_order} order={order} consume={consume}") + return real_order diff --git a/lightx2v/deploy/server/static/assets b/lightx2v/deploy/server/static/assets new file mode 120000 index 0000000000000000000000000000000000000000..da5653399cf608aa1097bd055ebe3a4d2a0af31a --- /dev/null +++ b/lightx2v/deploy/server/static/assets @@ -0,0 +1 @@ +../frontend/dist/assets \ No newline at end of file diff --git a/lightx2v/deploy/server/static/icon/logoblack.png b/lightx2v/deploy/server/static/icon/logoblack.png new file mode 100644 index 0000000000000000000000000000000000000000..92915c40ee569e203a15818bdd78d098bb4c0b91 Binary files /dev/null and b/lightx2v/deploy/server/static/icon/logoblack.png differ diff --git a/lightx2v/deploy/server/static/icon/seko_logo.svg b/lightx2v/deploy/server/static/icon/seko_logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..36a854b5ba5dd75ff01748ac48ce85965cf3e65c --- /dev/null +++ b/lightx2v/deploy/server/static/icon/seko_logo.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/lightx2v/deploy/server/static/icon/seko_logo_nobg.png b/lightx2v/deploy/server/static/icon/seko_logo_nobg.png new file mode 100644 index 0000000000000000000000000000000000000000..dab3fe97a146e90d63ebec882b2cd8b0ba24606d Binary files /dev/null and b/lightx2v/deploy/server/static/icon/seko_logo_nobg.png differ diff --git a/lightx2v/deploy/server/static/icon/seko_logo_white.svg b/lightx2v/deploy/server/static/icon/seko_logo_white.svg new file mode 100644 index 0000000000000000000000000000000000000000..ad79938dce365de73204f9707e4c04cd7ff5fd89 --- /dev/null +++ b/lightx2v/deploy/server/static/icon/seko_logo_white.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/lightx2v/deploy/server/static/index.html b/lightx2v/deploy/server/static/index.html new file mode 120000 index 0000000000000000000000000000000000000000..c867cf2b7234c797489a5d85b162f8ccd7857f2f --- /dev/null +++ b/lightx2v/deploy/server/static/index.html @@ -0,0 +1 @@ +../frontend/dist/index.html \ No newline at end of file diff --git a/lightx2v/deploy/server/static/index_old.html b/lightx2v/deploy/server/static/index_old.html new file mode 100644 index 0000000000000000000000000000000000000000..3002f261f4bf8c69e7a27e922def98284612467a --- /dev/null +++ b/lightx2v/deploy/server/static/index_old.html @@ -0,0 +1,1124 @@ + + + + + + LightX2V 文生视频服务 + + + + + +
+ + + + +
+ + + + +
+ +
+
+
+
+
+ + 可用模型 +
+
+
+
+
+
+
+
任务类型:{{ model.task }}
+

模型名称:{{ model.model_cls }}

+ 推理模式: {{ model.stage }} +
+
+
+
+ 暂无可用模型 +
+
+
+
+
+
+ + +
+
+
+
+
+ + 提交新任务 +
+
+
+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ 预览图片 +
+
+
+ + +
+ +
+
+
+ +
+
+
+
+
+ + +
+
+
+
+
+ + 任务列表 +
+
+ + + +
+
+
+
+ 暂无任务记录 +
+
+
+
+ +
+
+
+
+
+ + {{ task.task_id }} +
+
+ + + {{ task.task_type }} + + + + {{ task.model_cls }} + + + + {{ task.stage }} + +
+
+
+
+
+

+ + + 提示词: + + + {{ task.params.prompt.length > 50 ? task.params.prompt.substring(0, 50) + '...' : task.params.prompt }} + +

+

+ + + 种子值: + + {{ task.params.seed }} +

+
+
+

+ + + 创建时间: + + {{ formatTime(task.create_t) }} +

+

+ + + 输出: + + + + {{ key }} + + + + +

+
+
+
+
+
+ + + {{ task.status }} + +
+ + + +
+
+
+
+ + +
+
+
+
+ + 完整提示词 +
+
+

{{ task.params.prompt || '无提示词' }}

+
+
+
+
+ + 任务参数 +
+
+
+ 任务类型: + + + {{ task.task_type }} + +
+
+ 模型名称: + + + {{ task.model_cls }} + +
+
+ 推理模式: + + + {{ task.stage }} + +
+
+ 种子值: + {{ task.params.seed }} +
+
+
+
+
+
+
+ + 时间信息 +
+
+
    +
  • 创建时间: {{ formatTime(task.create_t) }}
  • +
  • 更新时间: {{ formatTime(task.update_t) }}
  • +
+
+
+
+
+ + 输入文件/输出结果 +
+
+
+ 输入文件: + {{ key }} +
+
+ 输出结果: + {{ key }} +
+
+
+
+
+
+
+
+ + +
+
+ 显示第 {{ (pagination.page - 1) * pagination.page_size + 1 }} - + {{ Math.min(pagination.page * pagination.page_size, pagination.total) }} 条, + 共 {{ pagination.total }} 条记录 +
+ + + +
+ 每页显示: + +
+
+
+
+
+
+
+
+ + +
+
+ 加载中... +
+
+ + +
+
+ {{ alert.message }} + +
+
+
+ + + + + + + diff --git a/lightx2v/deploy/task_manager/__init__.py b/lightx2v/deploy/task_manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f8deb85f021491468dff1d9e012879c3b15a10ea --- /dev/null +++ b/lightx2v/deploy/task_manager/__init__.py @@ -0,0 +1,313 @@ +import uuid +from enum import Enum +from re import T + +from loguru import logger + +from lightx2v.deploy.common.utils import class_try_catch, current_time, data_name + + +class TaskStatus(Enum): + CREATED = 1 + PENDING = 2 + RUNNING = 3 + SUCCEED = 4 + FAILED = 5 + CANCEL = 6 + + +ActiveStatus = [TaskStatus.CREATED, TaskStatus.PENDING, TaskStatus.RUNNING] +FinishedStatus = [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL] + + +class BaseTaskManager: + def __init__(self): + pass + + async def init(self): + pass + + async def close(self): + pass + + async def insert_user_if_not_exists(self, user_info): + raise NotImplementedError + + async def query_user(self, user_id): + raise NotImplementedError + + async def insert_task(self, task, subtasks): + raise NotImplementedError + + async def list_tasks(self, **kwargs): + raise NotImplementedError + + async def query_task(self, task_id, user_id=None, only_task=True): + raise NotImplementedError + + async def next_subtasks(self, task_id): + raise NotImplementedError + + async def run_subtasks(self, subtasks, worker_identity): + raise NotImplementedError + + async def ping_subtask(self, task_id, worker_name, worker_identity): + raise NotImplementedError + + async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False): + raise NotImplementedError + + async def cancel_task(self, task_id, user_id=None): + raise NotImplementedError + + async def resume_task(self, task_id, all_subtask=False, user_id=None): + raise NotImplementedError + + async def delete_task(self, task_id, user_id=None): + raise NotImplementedError + + async def insert_share(self, share_info): + raise NotImplementedError + + async def query_share(self, share_id): + raise NotImplementedError + + async def insert_podcast(self, podcast): + raise NotImplementedError + + async def query_podcast(self, session_id, user_id=None): + raise NotImplementedError + + async def list_podcasts(self, **kwargs): + raise NotImplementedError + + async def delete_podcast(self, session_id, user_id): + raise NotImplementedError + + def fmt_dict(self, data): + for k in ["status"]: + if k in data: + data[k] = data[k].name + + def parse_dict(self, data): + for k in ["status"]: + if k in data: + data[k] = TaskStatus[data[k]] + + def align_extra_inputs(self, task, subtask): + if "extra_inputs" in task.get("params", {}): + for inp, fs in task["params"]["extra_inputs"].items(): + if inp in subtask["inputs"]: + for f in fs: + subtask["inputs"][f] = task["inputs"][f] + logger.info(f"Align extra input: {f} for subtask {subtask['task_id']} {subtask['worker_name']}") + + async def create_share(self, task_id, user_id, share_type, valid_days, auth_type, auth_value): + assert share_type in ["task", "template"], f"do not support {share_type} share type!" + assert auth_type in ["public", "login", "user_id"], f"do not support {auth_type} auth type!" + assert valid_days > 0, f"valid_days must be greater than 0!" + share_id = str(uuid.uuid4()) + cur_t = current_time() + share_info = { + "share_id": share_id, + "task_id": task_id, + "user_id": user_id, + "share_type": share_type, + "create_t": cur_t, + "update_t": cur_t, + "valid_days": valid_days, + "valid_t": cur_t + valid_days * 24 * 3600, + "auth_type": auth_type, + "auth_value": auth_value, + "extra_info": "", + "tag": "", + } + assert await self.insert_share(share_info), f"create share {share_info} failed" + return share_id + + async def create_user(self, user_info): + assert user_info["source"] in ["github", "google", "phone"], f"do not support {user_info['source']} user!" + cur_t = current_time() + user_id = f"{user_info['source']}_{user_info['id']}" + data = { + "user_id": user_id, + "source": user_info["source"], + "id": user_info["id"], + "username": user_info["username"], + "email": user_info["email"], + "homepage": user_info["homepage"], + "avatar_url": user_info["avatar_url"], + "create_t": cur_t, + "update_t": cur_t, + "extra_info": "", + "tag": "", + } + assert await self.insert_user_if_not_exists(data), f"create user {data} failed" + return user_id + + async def create_task(self, worker_keys, workers, params, inputs, outputs, user_id): + task_type, model_cls, stage = worker_keys + cur_t = current_time() + task_id = str(uuid.uuid4()) + extra_inputs = [] + for fs in params.get("extra_inputs", {}).values(): + extra_inputs.extend(fs) + task = { + "task_id": task_id, + "task_type": task_type, + "model_cls": model_cls, + "stage": stage, + "params": params, + "create_t": cur_t, + "update_t": cur_t, + "status": TaskStatus.CREATED, + "extra_info": "", + "tag": "", + "inputs": {x: data_name(x, task_id) for x in inputs + extra_inputs}, + "outputs": {x: data_name(x, task_id) for x in outputs}, + "user_id": user_id, + } + records = [] + self.mark_task_start(records, task) + subtasks = [] + for worker_name, worker_item in workers.items(): + subtasks.append( + { + "task_id": task_id, + "worker_name": worker_name, + "inputs": {x: data_name(x, task_id) for x in worker_item["inputs"]}, + "outputs": {x: data_name(x, task_id) for x in worker_item["outputs"]}, + "queue": worker_item["queue"], + "previous": worker_item["previous"], + "status": TaskStatus.CREATED, + "worker_identity": "", + "result": "", + "fail_time": 0, + "extra_info": "", + "create_t": cur_t, + "update_t": cur_t, + "ping_t": 0.0, + "infer_cost": -1.0, + } + ) + self.mark_subtask_change(records, subtasks[-1], None, TaskStatus.CREATED) + ret = await self.insert_task(task, subtasks) + assert ret, f"create task {task_id} failed" + self.metrics_commit(records) + return task_id + + async def create_podcast(self, session_id, user_id, user_input, audio_path, rounds): + cur_t = current_time() + podcast = { + "session_id": session_id, + "user_id": user_id, + "user_input": user_input, + "create_t": cur_t, + "update_t": cur_t, + "has_audio": True, + "audio_path": audio_path, + "metadata_path": "", + "rounds": rounds, + "subtitles": [], + "extra_info": {}, + "tag": "", + } + assert await self.insert_podcast(podcast), f"create podcast {podcast} failed" + + async def mark_server_restart(self): + pass + # only for start server with active tasks + # if self.metrics_monitor: + # tasks = await self.list_tasks(status=ActiveStatus) + # subtasks = await self.list_tasks(status=ActiveStatus, subtasks=True) + # logger.warning(f"Mark system restart, {len(tasks)} tasks, {len(subtasks)} subtasks") + # self.metrics_monitor.record_task_recover(tasks) + # self.metrics_monitor.record_subtask_recover(subtasks) + + def mark_task_start(self, records, task): + t = current_time() + if not isinstance(task["extra_info"], dict): + task["extra_info"] = {} + if "active_elapse" in task["extra_info"]: + del task["extra_info"]["active_elapse"] + task["extra_info"]["start_t"] = t + logger.info(f"Task {task['task_id']} active start") + if self.metrics_monitor: + records.append( + [ + self.metrics_monitor.record_task_start, + [task], + ] + ) + + def mark_task_end(self, records, task, end_status): + if "start_t" not in task["extra_info"]: + logger.warning(f"Task {task} has no start time") + else: + elapse = current_time() - task["extra_info"]["start_t"] + task["extra_info"]["active_elapse"] = elapse + del task["extra_info"]["start_t"] + + logger.info(f"Task {task['task_id']} active end with [{end_status}], elapse: {elapse}") + if self.metrics_monitor: + records.append( + [ + self.metrics_monitor.record_task_end, + [task, end_status, elapse], + ] + ) + + def mark_subtask_change(self, records, subtask, old_status, new_status, fail_msg=None): + t = current_time() + if not isinstance(subtask["extra_info"], dict): + subtask["extra_info"] = {} + if isinstance(fail_msg, str) and len(fail_msg) > 0: + subtask["extra_info"]["fail_msg"] = fail_msg + elif "fail_msg" in subtask["extra_info"]: + del subtask["extra_info"]["fail_msg"] + + if old_status == new_status: + logger.warning(f"Subtask {subtask} update same status: {old_status} vs {new_status}") + return + + elapse, elapse_key = None, None + if old_status in ActiveStatus: + if "start_t" not in subtask["extra_info"]: + logger.warning(f"Subtask {subtask} has no start time, status: {old_status}") + else: + elapse = t - subtask["extra_info"]["start_t"] + elapse_key = f"{old_status.name}-{new_status.name}" + if "elapses" not in subtask["extra_info"]: + subtask["extra_info"]["elapses"] = {} + subtask["extra_info"]["elapses"][elapse_key] = elapse + del subtask["extra_info"]["start_t"] + + if new_status in ActiveStatus: + subtask["extra_info"]["start_t"] = t + if new_status == TaskStatus.CREATED and "elapses" in subtask["extra_info"]: + del subtask["extra_info"]["elapses"] + + logger.info( + f"Subtask {subtask['task_id']} {subtask['worker_name']} status changed: \ + [{old_status}] -> [{new_status}], {elapse_key}: {elapse}, fail_msg: {fail_msg}" + ) + + if self.metrics_monitor: + records.append( + [ + self.metrics_monitor.record_subtask_change, + [subtask, old_status, new_status, elapse_key, elapse], + ] + ) + + @class_try_catch + def metrics_commit(self, records): + for func, args in records: + func(*args) + + +# Import task manager implementations +from .local_task_manager import LocalTaskManager # noqa +from .sql_task_manager import PostgresSQLTaskManager # noqa + +__all__ = ["BaseTaskManager", "LocalTaskManager", "PostgresSQLTaskManager"] diff --git a/lightx2v/deploy/task_manager/local_task_manager.py b/lightx2v/deploy/task_manager/local_task_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f92b977d3a61e2bc2dd29b6c49bc0c4b142b64be --- /dev/null +++ b/lightx2v/deploy/task_manager/local_task_manager.py @@ -0,0 +1,476 @@ +import asyncio +import json +import os + +from lightx2v.deploy.common.utils import class_try_catch_async, current_time, str2time, time2str +from lightx2v.deploy.task_manager import ActiveStatus, BaseTaskManager, FinishedStatus, TaskStatus + + +class LocalTaskManager(BaseTaskManager): + def __init__(self, local_dir, metrics_monitor=None): + self.local_dir = local_dir + if not os.path.exists(self.local_dir): + os.makedirs(self.local_dir) + self.metrics_monitor = metrics_monitor + + def get_task_filename(self, task_id): + return os.path.join(self.local_dir, f"task_{task_id}.json") + + def get_user_filename(self, user_id): + return os.path.join(self.local_dir, f"user_{user_id}.json") + + def get_podcast_filename(self, session_id): + return os.path.join(self.local_dir, f"podcast_{session_id}.json") + + def fmt_dict(self, data): + super().fmt_dict(data) + for k in ["create_t", "update_t", "ping_t", "valid_t"]: + if k in data: + data[k] = time2str(data[k]) + + def parse_dict(self, data): + super().parse_dict(data) + for k in ["create_t", "update_t", "ping_t", "valid_t"]: + if k in data: + data[k] = str2time(data[k]) + + def save(self, task, subtasks, with_fmt=True): + info = {"task": task, "subtasks": subtasks} + if with_fmt: + self.fmt_dict(info["task"]) + [self.fmt_dict(x) for x in info["subtasks"]] + out_name = self.get_task_filename(task["task_id"]) + with open(out_name, "w") as fout: + fout.write(json.dumps(info, indent=4, ensure_ascii=False)) + + def load(self, task_id, user_id=None, only_task=False): + fpath = self.get_task_filename(task_id) + info = json.load(open(fpath)) + task, subtasks = info["task"], info["subtasks"] + if user_id is not None and task["user_id"] != user_id: + raise Exception(f"Task {task_id} is not belong to user {user_id}") + if task["tag"] == "delete": + raise Exception(f"Task {task_id} is deleted") + self.parse_dict(task) + if only_task: + return task + for sub in subtasks: + self.parse_dict(sub) + return task, subtasks + + def save_podcast(self, podcast, with_fmt=True): + if with_fmt: + self.fmt_dict(podcast) + out_name = self.get_podcast_filename(podcast["session_id"]) + with open(out_name, "w") as fout: + fout.write(json.dumps(podcast, indent=4, ensure_ascii=False)) + + def load_podcast(self, session_id, user_id=None): + fpath = self.get_podcast_filename(session_id) + data = json.load(open(fpath)) + if user_id is not None and data.get("user_id") != user_id: + raise Exception(f"Podcast {session_id} is not belong to user {user_id}") + if data["tag"] == "delete": + raise Exception(f"Podcast {session_id} is deleted") + self.parse_dict(data) + return data + + @class_try_catch_async + async def insert_task(self, task, subtasks): + self.save(task, subtasks) + return True + + @class_try_catch_async + async def list_tasks(self, **kwargs): + tasks = [] + fs = [os.path.join(self.local_dir, f) for f in os.listdir(self.local_dir)] + for f in os.listdir(self.local_dir): + if not f.startswith("task_"): + continue + fpath = os.path.join(self.local_dir, f) + info = json.load(open(fpath)) + if kwargs.get("subtasks", False): + items = info["subtasks"] + assert "user_id" not in kwargs, "user_id is not allowed when subtasks is True" + else: + items = [info["task"]] + for task in items: + self.parse_dict(task) + if "user_id" in kwargs and task["user_id"] != kwargs["user_id"]: + continue + if "status" in kwargs: + if isinstance(kwargs["status"], list) and task["status"] not in kwargs["status"]: + continue + elif kwargs["status"] != task["status"]: + continue + if "start_created_t" in kwargs and kwargs["start_created_t"] > task["create_t"]: + continue + if "end_created_t" in kwargs and kwargs["end_created_t"] < task["create_t"]: + continue + if "start_updated_t" in kwargs and kwargs["start_updated_t"] > task["update_t"]: + continue + if "end_updated_t" in kwargs and kwargs["end_updated_t"] < task["update_t"]: + continue + if "start_ping_t" in kwargs and kwargs["start_ping_t"] > task["ping_t"]: + continue + if "end_ping_t" in kwargs and kwargs["end_ping_t"] < task["ping_t"]: + continue + if not kwargs.get("include_delete", False) and task.get("tag", "") == "delete": + continue + + # 如果不是查询子任务,则添加子任务信息到任务中 + if not kwargs.get("subtasks", False): + task["subtasks"] = info.get("subtasks", []) + + tasks.append(task) + if "count" in kwargs: + return len(tasks) + + sort_key = "update_t" if kwargs.get("sort_by_update_t", False) else "create_t" + tasks = sorted(tasks, key=lambda x: x[sort_key], reverse=True) + + if "offset" in kwargs: + tasks = tasks[kwargs["offset"] :] + if "limit" in kwargs: + tasks = tasks[: kwargs["limit"]] + return tasks + + @class_try_catch_async + async def query_task(self, task_id, user_id=None, only_task=True): + return self.load(task_id, user_id, only_task) + + @class_try_catch_async + async def next_subtasks(self, task_id): + records = [] + task, subtasks = self.load(task_id) + if task["status"] not in ActiveStatus: + return [] + succeeds = set() + for sub in subtasks: + if sub["status"] == TaskStatus.SUCCEED: + succeeds.add(sub["worker_name"]) + nexts = [] + for sub in subtasks: + if sub["status"] == TaskStatus.CREATED: + dep_ok = True + for prev in sub["previous"]: + if prev not in succeeds: + dep_ok = False + break + if dep_ok: + self.mark_subtask_change(records, sub, sub["status"], TaskStatus.PENDING) + sub["params"] = task["params"] + sub["status"] = TaskStatus.PENDING + sub["update_t"] = current_time() + self.align_extra_inputs(task, sub) + nexts.append(sub) + if len(nexts) > 0: + task["status"] = TaskStatus.PENDING + task["update_t"] = current_time() + self.save(task, subtasks) + self.metrics_commit(records) + return nexts + + @class_try_catch_async + async def run_subtasks(self, cands, worker_identity): + records = [] + valids = [] + for cand in cands: + task_id = cand["task_id"] + worker_name = cand["worker_name"] + task, subtasks = self.load(task_id) + if task["status"] in [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL]: + continue + for sub in subtasks: + if sub["worker_name"] == worker_name: + self.mark_subtask_change(records, sub, sub["status"], TaskStatus.RUNNING) + sub["status"] = TaskStatus.RUNNING + sub["worker_identity"] = worker_identity + sub["update_t"] = current_time() + task["status"] = TaskStatus.RUNNING + task["update_t"] = current_time() + task["ping_t"] = current_time() + self.save(task, subtasks) + valids.append(cand) + break + self.metrics_commit(records) + return valids + + @class_try_catch_async + async def ping_subtask(self, task_id, worker_name, worker_identity): + task, subtasks = self.load(task_id) + for sub in subtasks: + if sub["worker_name"] == worker_name: + pre = sub["worker_identity"] + assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}" + sub["ping_t"] = current_time() + self.save(task, subtasks) + return True + return False + + @class_try_catch_async + async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False): + records = [] + task, subtasks = self.load(task_id) + subs = subtasks + + if worker_name: + subs = [sub for sub in subtasks if sub["worker_name"] == worker_name] + assert len(subs) >= 1, f"no worker task_id={task_id}, name={worker_name}" + + if worker_identity: + pre = subs[0]["worker_identity"] + assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}" + + assert status in [TaskStatus.SUCCEED, TaskStatus.FAILED], f"invalid finish status: {status}" + for sub in subs: + if sub["status"] not in FinishedStatus: + if should_running and sub["status"] != TaskStatus.RUNNING: + print(f"task {task_id} is not running, skip finish subtask: {sub}") + continue + self.mark_subtask_change(records, sub, sub["status"], status, fail_msg=fail_msg) + sub["status"] = status + sub["update_t"] = current_time() + + if task["status"] == TaskStatus.CANCEL: + self.save(task, subtasks) + self.metrics_commit(records) + return TaskStatus.CANCEL + + running_subs = [] + failed_sub = False + for sub in subtasks: + if sub["status"] not in FinishedStatus: + running_subs.append(sub) + if sub["status"] == TaskStatus.FAILED: + failed_sub = True + + # some subtask failed, we should fail all other subtasks + if failed_sub: + if task["status"] != TaskStatus.FAILED: + self.mark_task_end(records, task, TaskStatus.FAILED) + task["status"] = TaskStatus.FAILED + task["update_t"] = current_time() + for sub in running_subs: + self.mark_subtask_change(records, sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed") + sub["status"] = TaskStatus.FAILED + sub["update_t"] = current_time() + self.save(task, subtasks) + self.metrics_commit(records) + return TaskStatus.FAILED + + # all subtasks finished and all succeed + elif len(running_subs) == 0: + if task["status"] != TaskStatus.SUCCEED: + self.mark_task_end(records, task, TaskStatus.SUCCEED) + task["status"] = TaskStatus.SUCCEED + task["update_t"] = current_time() + self.save(task, subtasks) + self.metrics_commit(records) + return TaskStatus.SUCCEED + + self.save(task, subtasks) + self.metrics_commit(records) + return None + + @class_try_catch_async + async def cancel_task(self, task_id, user_id=None): + records = [] + task, subtasks = self.load(task_id, user_id) + if task["status"] not in ActiveStatus: + return f"Task {task_id} is not in active status (current status: {task['status']}). Only tasks with status CREATED, PENDING, or RUNNING can be cancelled." + + for sub in subtasks: + if sub["status"] not in FinishedStatus: + self.mark_subtask_change(records, sub, sub["status"], TaskStatus.CANCEL) + sub["status"] = TaskStatus.CANCEL + sub["update_t"] = current_time() + self.mark_task_end(records, task, TaskStatus.CANCEL) + task["status"] = TaskStatus.CANCEL + task["update_t"] = current_time() + self.save(task, subtasks) + self.metrics_commit(records) + return True + + @class_try_catch_async + async def resume_task(self, task_id, all_subtask=False, user_id=None): + records = [] + task, subtasks = self.load(task_id, user_id) + # the task is not finished + if task["status"] not in FinishedStatus: + return "Active task cannot be resumed" + # the task is no need to resume + if not all_subtask and task["status"] == TaskStatus.SUCCEED: + return "Succeed task cannot be resumed" + for sub in subtasks: + if all_subtask or sub["status"] != TaskStatus.SUCCEED: + self.mark_subtask_change(records, sub, None, TaskStatus.CREATED) + sub["status"] = TaskStatus.CREATED + sub["update_t"] = current_time() + sub["ping_t"] = 0.0 + self.mark_task_start(records, task) + task["status"] = TaskStatus.CREATED + task["update_t"] = current_time() + self.save(task, subtasks) + self.metrics_commit(records) + return True + + @class_try_catch_async + async def delete_task(self, task_id, user_id=None): + task, subtasks = self.load(task_id, user_id) + # only allow to delete finished tasks + if task["status"] not in FinishedStatus: + return False + # delete task file + task["tag"] = "delete" + task["update_t"] = current_time() + self.save(task, subtasks) + return True + + def get_share_filename(self, share_id): + return os.path.join(self.local_dir, f"share_{share_id}.json") + + @class_try_catch_async + async def insert_share(self, share_info): + fpath = self.get_share_filename(share_info["share_id"]) + self.fmt_dict(share_info) + with open(fpath, "w") as fout: + fout.write(json.dumps(share_info, indent=4, ensure_ascii=False)) + return True + + @class_try_catch_async + async def query_share(self, share_id): + fpath = self.get_share_filename(share_id) + if not os.path.exists(fpath): + return None + data = json.load(open(fpath)) + self.parse_dict(data) + if data["tag"] == "delete": + raise Exception(f"Share {share_id} is deleted") + if data["valid_t"] < current_time(): + raise Exception(f"Share {share_id} has expired") + return data + + @class_try_catch_async + async def insert_user_if_not_exists(self, user_info): + fpath = self.get_user_filename(user_info["user_id"]) + if os.path.exists(fpath): + return True + self.fmt_dict(user_info) + with open(fpath, "w") as fout: + fout.write(json.dumps(user_info, indent=4, ensure_ascii=False)) + return True + + @class_try_catch_async + async def query_user(self, user_id): + fpath = self.get_user_filename(user_id) + if not os.path.exists(fpath): + return None + data = json.load(open(fpath)) + self.parse_dict(data) + return data + + @class_try_catch_async + async def insert_podcast(self, podcast): + self.save_podcast(podcast) + return True + + @class_try_catch_async + async def query_podcast(self, session_id, user_id=None): + fpath = self.get_podcast_filename(session_id) + if not os.path.exists(fpath): + return None + data = json.load(open(fpath)) + self.parse_dict(data) + return data + + @class_try_catch_async + async def list_podcasts(self, **kwargs): + sessions = [] + for f in os.listdir(self.local_dir): + if not f.startswith("podcast_"): + continue + fpath = os.path.join(self.local_dir, f) + session = json.load(open(fpath)) + self.parse_dict(session) + if "user_id" in kwargs and session["user_id"] != kwargs["user_id"]: + continue + if "has_audio" in kwargs and session["has_audio"] != kwargs["has_audio"]: + continue + if not kwargs.get("include_delete", False) and session.get("tag", "") == "delete": + continue + sessions.append(session) + if "count" in kwargs: + return len(sessions) + sort_key = "update_t" if kwargs.get("sort_by_update_t", False) else "create_t" + sessions = sorted(sessions, key=lambda x: x[sort_key], reverse=True) + if "offset" in kwargs: + sessions = sessions[kwargs["offset"] :] + if "limit" in kwargs: + sessions = sessions[: kwargs["limit"]] + return sessions + + +async def test(): + from lightx2v.deploy.common.pipeline import Pipeline + + p = Pipeline("/data/nvme1/liuliang1/lightx2v/configs/model_pipeline.json") + m = LocalTaskManager("/data/nvme1/liuliang1/lightx2v/local_task") + await m.init() + + keys = ["t2v", "wan2.1", "multi_stage"] + workers = p.get_workers(keys) + inputs = p.get_inputs(keys) + outputs = p.get_outputs(keys) + params = { + "prompt": "fake input prompts", + "resolution": { + "height": 233, + "width": 456, + }, + } + + user_info = { + "source": "github", + "id": "test-id-233", + "username": "test-username-233", + "email": "test-email-233@test.com", + "homepage": "https://test.com", + "avatar_url": "https://test.com/avatar.png", + } + user_id = await m.create_user(user_info) + print(" - create_user:", user_id) + + user = await m.query_user(user_id) + print(" - query_user:", user) + + task_id = await m.create_task(keys, workers, params, inputs, outputs, user_id) + print(" - create_task:", task_id) + + tasks = await m.list_tasks() + print(" - list_tasks:", tasks) + + task = await m.query_task(task_id) + print(" - query_task:", task) + + subtasks = await m.next_subtasks(task_id) + print(" - next_subtasks:", subtasks) + + await m.run_subtasks(subtasks, "fake-worker") + await m.finish_subtasks(task_id, TaskStatus.FAILED) + await m.cancel_task(task_id) + await m.resume_task(task_id) + for sub in subtasks: + await m.finish_subtasks(sub["task_id"], TaskStatus.SUCCEED, worker_name=sub["worker_name"], worker_identity="fake-worker") + + subtasks = await m.next_subtasks(task_id) + print(" - final next_subtasks:", subtasks) + + task = await m.query_task(task_id) + print(" - final task:", task) + + await m.close() + + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/lightx2v/deploy/task_manager/sql_task_manager.py b/lightx2v/deploy/task_manager/sql_task_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..35b0ee52d69f3eed5361ab486762550b0c319600 --- /dev/null +++ b/lightx2v/deploy/task_manager/sql_task_manager.py @@ -0,0 +1,1112 @@ +import asyncio +import json +import traceback +from datetime import datetime + +import asyncpg +from loguru import logger + +from lightx2v.deploy.common.utils import class_try_catch_async +from lightx2v.deploy.task_manager import ActiveStatus, BaseTaskManager, FinishedStatus, TaskStatus + + +class PostgresSQLTaskManager(BaseTaskManager): + def __init__(self, db_url, metrics_monitor=None): + self.db_url = db_url + self.table_tasks = "tasks" + self.table_subtasks = "subtasks" + self.table_users = "users" + self.table_versions = "versions" + self.table_shares = "shares" + self.table_podcasts = "podcasts" + self.pool = None + self.metrics_monitor = metrics_monitor + self.time_keys = ["create_t", "update_t", "ping_t", "valid_t"] + self.json_keys = ["params", "extra_info", "inputs", "outputs", "previous", "rounds", "subtitles"] + + async def init(self): + await self.upgrade_db() + + async def close(self): + if self.pool: + await self.pool.close() + + def fmt_dict(self, data): + super().fmt_dict(data) + for k in self.time_keys: + if k in data and isinstance(data[k], float): + data[k] = datetime.fromtimestamp(data[k]) + for k in self.json_keys: + if k in data: + data[k] = json.dumps(data[k], ensure_ascii=False) + + def parse_dict(self, data): + super().parse_dict(data) + for k in self.json_keys: + if k in data: + data[k] = json.loads(data[k]) + for k in self.time_keys: + if k in data: + data[k] = data[k].timestamp() + + async def get_conn(self): + if self.pool is None: + self.pool = await asyncpg.create_pool(self.db_url) + return await self.pool.acquire() + + async def release_conn(self, conn): + await self.pool.release(conn) + + async def query_version(self): + conn = await self.get_conn() + try: + row = await conn.fetchrow(f"SELECT version FROM {self.table_versions} ORDER BY create_t DESC LIMIT 1") + row = dict(row) + return row["version"] if row else 0 + except: # noqa + logger.error(f"query_version error: {traceback.format_exc()}") + return 0 + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def upgrade_db(self): + versions = [ + (1, "Init tables", self.upgrade_v1), + (2, "Add shares table", self.upgrade_v2), + (3, "Add podcasts table", self.upgrade_v3), + ] + logger.info(f"upgrade_db: {self.db_url}") + cur_ver = await self.query_version() + for ver, description, func in versions: + if cur_ver < ver: + logger.info(f"Upgrade to version {ver}: {description}") + if not await func(ver, description): + logger.error(f"Upgrade to version {ver}: {description} func failed") + return False + cur_ver = ver + logger.info(f"upgrade_db: {self.db_url} done") + return True + + async def upgrade_v1(self, version, description): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + # create users table + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_users} ( + user_id VARCHAR(256) PRIMARY KEY, + source VARCHAR(32), + id VARCHAR(200), + username VARCHAR(256), + email VARCHAR(256), + homepage VARCHAR(256), + avatar_url VARCHAR(256), + create_t TIMESTAMPTZ, + update_t TIMESTAMPTZ, + extra_info JSONB, + tag VARCHAR(64) + ) + """) + # create tasks table + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_tasks} ( + task_id VARCHAR(128) PRIMARY KEY, + task_type VARCHAR(64), + model_cls VARCHAR(64), + stage VARCHAR(64), + params JSONB, + create_t TIMESTAMPTZ, + update_t TIMESTAMPTZ, + status VARCHAR(64), + extra_info JSONB, + tag VARCHAR(64), + inputs JSONB, + outputs JSONB, + user_id VARCHAR(256), + FOREIGN KEY (user_id) REFERENCES {self.table_users}(user_id) ON DELETE CASCADE + ) + """) + # create subtasks table + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_subtasks} ( + task_id VARCHAR(128), + worker_name VARCHAR(128), + inputs JSONB, + outputs JSONB, + queue VARCHAR(128), + previous JSONB, + status VARCHAR(64), + worker_identity VARCHAR(128), + result VARCHAR(128), + fail_time INTEGER, + extra_info JSONB, + create_t TIMESTAMPTZ, + update_t TIMESTAMPTZ, + ping_t TIMESTAMPTZ, + infer_cost FLOAT, + PRIMARY KEY (task_id, worker_name), + FOREIGN KEY (task_id) REFERENCES {self.table_tasks}(task_id) ON DELETE CASCADE + ) + """) + # create versions table + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_versions} ( + version INTEGER PRIMARY KEY, + description VARCHAR(255), + create_t TIMESTAMPTZ NOT NULL + ) + """) + # create indexes + await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_users}_source ON {self.table_users}(source)") + await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_users}_id ON {self.table_users}(id)") + await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_tasks}_status ON {self.table_tasks}(status)") + await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_tasks}_create_t ON {self.table_tasks}(create_t)") + await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_tasks}_tag ON {self.table_tasks}(tag)") + await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_subtasks}_task_id ON {self.table_subtasks}(task_id)") + await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_subtasks}_worker_name ON {self.table_subtasks}(worker_name)") + await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_subtasks}_status ON {self.table_subtasks}(status)") + + # update version + await conn.execute(f"INSERT INTO {self.table_versions} (version, description, create_t) VALUES ($1, $2, $3)", version, description, datetime.now()) + return True + except: # noqa + logger.error(f"upgrade_v1 error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + async def upgrade_v2(self, version, description): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + # create shares table + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_shares} ( + share_id VARCHAR(128) PRIMARY KEY, + task_id VARCHAR(128), + user_id VARCHAR(256), + share_type VARCHAR(32), + create_t TIMESTAMPTZ, + update_t TIMESTAMPTZ, + valid_days INTEGER, + valid_t TIMESTAMPTZ, + auth_type VARCHAR(32), + auth_value VARCHAR(256), + extra_info JSONB, + tag VARCHAR(64), + FOREIGN KEY (user_id) REFERENCES {self.table_users}(user_id) ON DELETE CASCADE + ) + """) + + # update version + await conn.execute(f"INSERT INTO {self.table_versions} (version, description, create_t) VALUES ($1, $2, $3)", version, description, datetime.now()) + return True + except: # noqa + logger.error(f"upgrade_v2 error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + async def upgrade_v3(self, version, description): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + # create shares table + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_podcasts} ( + session_id VARCHAR(128) PRIMARY KEY, + user_id VARCHAR(256) NOT NULL, + user_input TEXT, + create_t TIMESTAMPTZ NOT NULL, + update_t TIMESTAMPTZ NOT NULL, + has_audio BOOLEAN DEFAULT FALSE, + audio_path TEXT, + metadata_path TEXT, + rounds JSONB, + subtitles JSONB, + extra_info JSONB, + tag VARCHAR(64), + FOREIGN KEY (user_id) REFERENCES {self.table_users}(user_id) ON DELETE CASCADE + ) + """) + + # update version + await conn.execute(f"INSERT INTO {self.table_versions} (version, description, create_t) VALUES ($1, $2, $3)", version, description, datetime.now()) + return True + except: # noqa + logger.error(f"upgrade_v3 error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + async def load(self, conn, task_id, user_id=None, only_task=False, worker_name=None): + query = f"SELECT * FROM {self.table_tasks} WHERE task_id = $1 AND tag != 'delete'" + params = [task_id] + if user_id is not None: + query += " AND user_id = $2" + params.append(user_id) + row = await conn.fetchrow(query, *params) + task = dict(row) + assert task, f"query_task: task not found: {task_id} {user_id}" + self.parse_dict(task) + if only_task: + return task + query2 = f"SELECT * FROM {self.table_subtasks} WHERE task_id = $1" + params2 = [task_id] + if worker_name is not None: + query2 += " AND worker_name = $2" + params2.append(worker_name) + rows = await conn.fetch(query2, *params2) + subtasks = [] + for row in rows: + sub = dict(row) + self.parse_dict(sub) + subtasks.append(sub) + return task, subtasks + + def check_update_valid(self, ret, prefix, query, params): + if ret.startswith("UPDATE "): + count = int(ret.split(" ")[1]) + assert count > 0, f"{prefix}: no row changed: {query} {params}" + return count + else: + logger.warning(f"parse postsql update ret failed: {ret}") + return 0 + + async def update_task(self, conn, task_id, **kwargs): + query = f"UPDATE {self.table_tasks} SET " + conds = ["update_t = $1"] + params = [datetime.now()] + param_idx = 1 + if "status" in kwargs: + param_idx += 1 + conds.append(f"status = ${param_idx}") + params.append(kwargs["status"].name) + if "extra_info" in kwargs: + param_idx += 1 + conds.append(f"extra_info = ${param_idx}") + params.append(json.dumps(kwargs["extra_info"], ensure_ascii=False)) + + limit_conds = [f"task_id = ${param_idx + 1}"] + param_idx += 1 + params.append(task_id) + + if "src_status" in kwargs: + param_idx += 1 + limit_conds.append(f"status = ${param_idx}") + params.append(kwargs["src_status"].name) + + query += " ,".join(conds) + " WHERE " + " AND ".join(limit_conds) + ret = await conn.execute(query, *params) + return self.check_update_valid(ret, "update_task", query, params) + + async def update_subtask(self, conn, task_id, worker_name, **kwargs): + query = f"UPDATE {self.table_subtasks} SET " + conds = [] + params = [] + param_idx = 0 + if kwargs.get("update_t", True): + param_idx += 1 + conds.append(f"update_t = ${param_idx}") + params.append(datetime.now()) + if kwargs.get("ping_t", False): + param_idx += 1 + conds.append(f"ping_t = ${param_idx}") + params.append(datetime.now()) + if kwargs.get("reset_ping_t", False): + param_idx += 1 + conds.append(f"ping_t = ${param_idx}") + params.append(datetime.fromtimestamp(0)) + if "status" in kwargs: + param_idx += 1 + conds.append(f"status = ${param_idx}") + params.append(kwargs["status"].name) + if "worker_identity" in kwargs: + param_idx += 1 + conds.append(f"worker_identity = ${param_idx}") + params.append(kwargs["worker_identity"]) + if "infer_cost" in kwargs: + param_idx += 1 + conds.append(f"infer_cost = ${param_idx}") + params.append(kwargs["infer_cost"]) + if "extra_info" in kwargs: + param_idx += 1 + conds.append(f"extra_info = ${param_idx}") + params.append(json.dumps(kwargs["extra_info"], ensure_ascii=False)) + + limit_conds = [f"task_id = ${param_idx + 1}", f"worker_name = ${param_idx + 2}"] + param_idx += 2 + params.extend([task_id, worker_name]) + + if "src_status" in kwargs: + param_idx += 1 + limit_conds.append(f"status = ${param_idx}") + params.append(kwargs["src_status"].name) + + query += " ,".join(conds) + " WHERE " + " AND ".join(limit_conds) + ret = await conn.execute(query, *params) + return self.check_update_valid(ret, "update_subtask", query, params) + + @class_try_catch_async + async def insert_task(self, task, subtasks): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + self.fmt_dict(task) + await conn.execute( + f""" + INSERT INTO {self.table_tasks} + (task_id, task_type, model_cls, stage, params, create_t, + update_t, status, extra_info, tag, inputs, outputs, user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + """, + task["task_id"], + task["task_type"], + task["model_cls"], + task["stage"], + task["params"], + task["create_t"], + task["update_t"], + task["status"], + task["extra_info"], + task["tag"], + task["inputs"], + task["outputs"], + task["user_id"], + ) + for sub in subtasks: + self.fmt_dict(sub) + await conn.execute( + f""" + INSERT INTO {self.table_subtasks} + (task_id, worker_name, inputs, outputs, queue, previous, status, + worker_identity, result, fail_time, extra_info, create_t, update_t, + ping_t, infer_cost) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + """, + sub["task_id"], + sub["worker_name"], + sub["inputs"], + sub["outputs"], + sub["queue"], + sub["previous"], + sub["status"], + sub["worker_identity"], + sub["result"], + sub["fail_time"], + sub["extra_info"], + sub["create_t"], + sub["update_t"], + sub["ping_t"], + sub["infer_cost"], + ) + return True + except: # noqa + logger.error(f"insert_task error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def list_tasks(self, **kwargs): + conn = await self.get_conn() + try: + count = kwargs.get("count", False) + query = f"SELECT * FROM " + if count: + query = f"SELECT COUNT(*) FROM " + assert "limit" not in kwargs, "limit is not allowed when count is True" + assert "offset" not in kwargs, "offset is not allowed when count is True" + params = [] + conds = [] + param_idx = 0 + if kwargs.get("subtasks", False): + query += self.table_subtasks + assert "user_id" not in kwargs, "user_id is not allowed when subtasks is True" + else: + query += self.table_tasks + if not kwargs.get("include_delete", False): + param_idx += 1 + conds.append(f"tag != ${param_idx}") + params.append("delete") + + if "status" in kwargs: + param_idx += 1 + if isinstance(kwargs["status"], list): + next_idx = param_idx + len(kwargs["status"]) + placeholders = ",".join([f"${i}" for i in range(param_idx, next_idx)]) + conds.append(f"status IN ({placeholders})") + params.extend([x.name for x in kwargs["status"]]) + param_idx = next_idx - 1 + else: + conds.append(f"status = ${param_idx}") + params.append(kwargs["status"].name) + + if "start_created_t" in kwargs: + param_idx += 1 + conds.append(f"create_t >= ${param_idx}") + params.append(datetime.fromtimestamp(kwargs["start_created_t"])) + + if "end_created_t" in kwargs: + param_idx += 1 + conds.append(f"create_t <= ${param_idx}") + params.append(datetime.fromtimestamp(kwargs["end_created_t"])) + + if "start_updated_t" in kwargs: + param_idx += 1 + conds.append(f"update_t >= ${param_idx}") + params.append(datetime.fromtimestamp(kwargs["start_updated_t"])) + + if "end_updated_t" in kwargs: + param_idx += 1 + conds.append(f"update_t <= ${param_idx}") + params.append(datetime.fromtimestamp(kwargs["end_updated_t"])) + + if "start_ping_t" in kwargs: + param_idx += 1 + conds.append(f"ping_t >= ${param_idx}") + params.append(datetime.fromtimestamp(kwargs["start_ping_t"])) + + if "end_ping_t" in kwargs: + param_idx += 1 + conds.append(f"ping_t <= ${param_idx}") + params.append(datetime.fromtimestamp(kwargs["end_ping_t"])) + + if "user_id" in kwargs: + param_idx += 1 + conds.append(f"user_id = ${param_idx}") + params.append(kwargs["user_id"]) + + if conds: + query += " WHERE " + " AND ".join(conds) + + if not count: + sort_key = "update_t" if kwargs.get("sort_by_update_t", False) else "create_t" + query += f" ORDER BY {sort_key} DESC" + + if "limit" in kwargs: + param_idx += 1 + query += f" LIMIT ${param_idx}" + params.append(kwargs["limit"]) + + if "offset" in kwargs: + param_idx += 1 + query += f" OFFSET ${param_idx}" + params.append(kwargs["offset"]) + + rows = await conn.fetch(query, *params) + if count: + return rows[0]["count"] + + # query subtasks with task + subtasks = {} + if not kwargs.get("subtasks", False): + subtask_query = f"SELECT {self.table_subtasks}.* FROM ({query}) AS t \ + JOIN {self.table_subtasks} ON t.task_id = {self.table_subtasks}.task_id" + subtask_rows = await conn.fetch(subtask_query, *params) + for row in subtask_rows: + sub = dict(row) + self.parse_dict(sub) + if sub["task_id"] not in subtasks: + subtasks[sub["task_id"]] = [] + subtasks[sub["task_id"]].append(sub) + + tasks = [] + for row in rows: + task = dict(row) + self.parse_dict(task) + if not kwargs.get("subtasks", False): + task["subtasks"] = subtasks.get(task["task_id"], []) + tasks.append(task) + return tasks + except: # noqa + logger.error(f"list_tasks error: {traceback.format_exc()}") + return [] + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def query_task(self, task_id, user_id=None, only_task=True): + conn = await self.get_conn() + try: + return await self.load(conn, task_id, user_id, only_task=only_task) + except: # noqa + logger.error(f"query_task error: {traceback.format_exc()}") + return None + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def next_subtasks(self, task_id): + conn = await self.get_conn() + records = [] + try: + async with conn.transaction(isolation="read_uncommitted"): + task, subtasks = await self.load(conn, task_id) + if task["status"] not in ActiveStatus: + return [] + succeeds = set() + for sub in subtasks: + if sub["status"] == TaskStatus.SUCCEED: + succeeds.add(sub["worker_name"]) + nexts = [] + for sub in subtasks: + if sub["status"] == TaskStatus.CREATED: + dep_ok = True + for prev in sub["previous"]: + if prev not in succeeds: + dep_ok = False + break + if dep_ok: + sub["params"] = task["params"] + self.mark_subtask_change(records, sub, sub["status"], TaskStatus.PENDING) + await self.update_subtask( + conn, + task_id, + sub["worker_name"], + status=TaskStatus.PENDING, + extra_info=sub["extra_info"], + src_status=sub["status"], + ) + self.align_extra_inputs(task, sub) + nexts.append(sub) + if len(nexts) > 0: + await self.update_task(conn, task_id, status=TaskStatus.PENDING) + self.metrics_commit(records) + return nexts + except: # noqa + logger.error(f"next_subtasks error: {traceback.format_exc()}") + return None + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def run_subtasks(self, cands, worker_identity): + conn = await self.get_conn() + records = [] + try: + async with conn.transaction(isolation="read_uncommitted"): + valids = [] + for cand in cands: + task_id = cand["task_id"] + worker_name = cand["worker_name"] + task, subs = await self.load(conn, task_id, worker_name=worker_name) + assert len(subs) == 1, f"task {task_id} has multiple subtasks: {subs} with worker_name: {worker_name}" + if task["status"] in [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL]: + continue + + self.mark_subtask_change(records, subs[0], subs[0]["status"], TaskStatus.RUNNING) + await self.update_subtask( + conn, + task_id, + worker_name, + status=TaskStatus.RUNNING, + worker_identity=worker_identity, + ping_t=True, + extra_info=subs[0]["extra_info"], + src_status=subs[0]["status"], + ) + await self.update_task(conn, task_id, status=TaskStatus.RUNNING) + valids.append(cand) + break + self.metrics_commit(records) + return valids + except: # noqa + logger.error(f"run_subtasks error: {traceback.format_exc()}") + return [] + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def ping_subtask(self, task_id, worker_name, worker_identity): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + task, subtasks = await self.load(conn, task_id) + for sub in subtasks: + if sub["worker_name"] == worker_name: + pre = sub["worker_identity"] + assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}" + await self.update_subtask(conn, task_id, worker_name, ping_t=True, update_t=False) + return True + return False + except: # noqa + logger.error(f"ping_subtask error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False): + conn = await self.get_conn() + records = [] + try: + async with conn.transaction(isolation="read_uncommitted"): + task, subtasks = await self.load(conn, task_id) + subs = subtasks + if worker_name: + subs = [sub for sub in subtasks if sub["worker_name"] == worker_name] + assert len(subs) >= 1, f"no worker task_id={task_id}, name={worker_name}" + + if worker_identity: + pre = subs[0]["worker_identity"] + assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}" + + assert status in [TaskStatus.SUCCEED, TaskStatus.FAILED], f"invalid finish status: {status}" + for sub in subs: + if sub["status"] not in FinishedStatus: + if should_running and sub["status"] != TaskStatus.RUNNING: + logger.warning(f"task {task_id} is not running, skip finish subtask: {sub}") + continue + self.mark_subtask_change(records, sub, sub["status"], status, fail_msg=fail_msg) + await self.update_subtask( + conn, + task_id, + sub["worker_name"], + status=status, + extra_info=sub["extra_info"], + src_status=sub["status"], + ) + sub["status"] = status + + if task["status"] == TaskStatus.CANCEL: + self.metrics_commit(records) + return TaskStatus.CANCEL + + running_subs = [] + failed_sub = False + for sub in subtasks: + if sub["status"] not in FinishedStatus: + running_subs.append(sub) + if sub["status"] == TaskStatus.FAILED: + failed_sub = True + + # some subtask failed, we should fail all other subtasks + if failed_sub: + if task["status"] != TaskStatus.FAILED: + self.mark_task_end(records, task, TaskStatus.FAILED) + await self.update_task( + conn, + task_id, + status=TaskStatus.FAILED, + extra_info=task["extra_info"], + src_status=task["status"], + ) + for sub in running_subs: + self.mark_subtask_change(records, sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed") + await self.update_subtask( + conn, + task_id, + sub["worker_name"], + status=TaskStatus.FAILED, + extra_info=sub["extra_info"], + src_status=sub["status"], + ) + self.metrics_commit(records) + return TaskStatus.FAILED + + # all subtasks finished and all succeed + elif len(running_subs) == 0: + if task["status"] != TaskStatus.SUCCEED: + self.mark_task_end(records, task, TaskStatus.SUCCEED) + await self.update_task( + conn, + task_id, + status=TaskStatus.SUCCEED, + extra_info=task["extra_info"], + src_status=task["status"], + ) + self.metrics_commit(records) + return TaskStatus.SUCCEED + + self.metrics_commit(records) + return None + except: # noqa + logger.error(f"finish_subtasks error: {traceback.format_exc()}") + return None + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def cancel_task(self, task_id, user_id=None): + conn = await self.get_conn() + records = [] + try: + async with conn.transaction(isolation="read_uncommitted"): + task, subtasks = await self.load(conn, task_id, user_id) + if task["status"] not in ActiveStatus: + return f"Task {task_id} is not in active status (current status: {task['status']}). \ + Only tasks with status CREATED, PENDING, or RUNNING can be cancelled." + + for sub in subtasks: + if sub["status"] not in FinishedStatus: + self.mark_subtask_change(records, sub, sub["status"], TaskStatus.CANCEL) + await self.update_subtask( + conn, + task_id, + sub["worker_name"], + status=TaskStatus.CANCEL, + extra_info=sub["extra_info"], + src_status=sub["status"], + ) + + self.mark_task_end(records, task, TaskStatus.CANCEL) + await self.update_task( + conn, + task_id, + status=TaskStatus.CANCEL, + extra_info=task["extra_info"], + src_status=task["status"], + ) + self.metrics_commit(records) + return True + except: # noqa + logger.error(f"cancel_task error: {traceback.format_exc()}") + return "unknown cancel error" + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def resume_task(self, task_id, all_subtask=False, user_id=None): + conn = await self.get_conn() + records = [] + try: + async with conn.transaction(isolation="read_uncommitted"): + task, subtasks = await self.load(conn, task_id, user_id) + # the task is not finished + if task["status"] not in FinishedStatus: + return "Active task cannot be resumed" + # the task is no need to resume + if not all_subtask and task["status"] == TaskStatus.SUCCEED: + return "Succeed task cannot be resumed" + + for sub in subtasks: + if all_subtask or sub["status"] != TaskStatus.SUCCEED: + self.mark_subtask_change(records, sub, None, TaskStatus.CREATED) + await self.update_subtask( + conn, + task_id, + sub["worker_name"], + status=TaskStatus.CREATED, + reset_ping_t=True, + extra_info=sub["extra_info"], + src_status=sub["status"], + ) + + self.mark_task_start(records, task) + await self.update_task( + conn, + task_id, + status=TaskStatus.CREATED, + extra_info=task["extra_info"], + src_status=task["status"], + ) + self.metrics_commit(records) + return True + except: # noqa + logger.error(f"resume_task error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def delete_task(self, task_id, user_id=None): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + task = await self.load(conn, task_id, user_id, only_task=True) + + # only allow to delete finished tasks + if task["status"] not in FinishedStatus: + logger.warning(f"Cannot delete task {task_id} with status {task['status']}, only finished tasks can be deleted") + return False + + # delete task record + await conn.execute(f"UPDATE {self.table_tasks} SET tag = 'delete', update_t = $1 WHERE task_id = $2", datetime.now(), task_id) + logger.info(f"Task {task_id} deleted successfully") + return True + + except: # noqa + logger.error(f"delete_task error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def insert_share(self, share_info): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + self.fmt_dict(share_info) + await conn.execute( + f"""INSERT INTO {self.table_shares} + (share_id, task_id, user_id, share_type, create_t, update_t, + valid_days, valid_t, auth_type, auth_value, extra_info, tag) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + """, + share_info["share_id"], + share_info["task_id"], + share_info["user_id"], + share_info["share_type"], + share_info["create_t"], + share_info["update_t"], + share_info["valid_days"], + share_info["valid_t"], + share_info["auth_type"], + share_info["auth_value"], + share_info["extra_info"], + share_info["tag"], + ) + return True + except: # noqa + logger.error(f"create_share_link error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def query_share(self, share_id): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + row = await conn.fetchrow(f"SELECT * FROM {self.table_shares} WHERE share_id = $1 AND tag != 'delete' AND valid_t >= $2", share_id, datetime.now()) + share = dict(row) + self.parse_dict(share) + return share + except: # noqa + logger.error(f"query_share error: {traceback.format_exc()}") + return None + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def insert_user_if_not_exists(self, user_info): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + row = await conn.fetchrow(f"SELECT * FROM {self.table_users} WHERE user_id = $1", user_info["user_id"]) + if row: + logger.info(f"user already exists: {user_info['user_id']}") + return True + self.fmt_dict(user_info) + await conn.execute( + f""" + INSERT INTO {self.table_users} + (user_id, source, id, username, email, homepage, + avatar_url, create_t, update_t, extra_info, tag) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + """, + user_info["user_id"], + user_info["source"], + user_info["id"], + user_info["username"], + user_info["email"], + user_info["homepage"], + user_info["avatar_url"], + user_info["create_t"], + user_info["update_t"], + user_info["extra_info"], + user_info["tag"], + ) + return True + except: # noqa + logger.error(f"insert_user_if_not_exists error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def query_user(self, user_id): + conn = await self.get_conn() + try: + row = await conn.fetchrow(f"SELECT * FROM {self.table_users} WHERE user_id = $1", user_id) + user = dict(row) + self.parse_dict(user) + return user + except: # noqa + logger.error(f"query_user error: {traceback.format_exc()}") + return None + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def insert_podcast(self, podcast): + conn = await self.get_conn() + try: + async with conn.transaction(isolation="read_uncommitted"): + self.fmt_dict(podcast) + await conn.execute( + f"""INSERT INTO {self.table_podcasts} + (session_id, user_id, user_input, create_t, update_t, has_audio, + audio_path, metadata_path, rounds, subtitles, extra_info, tag) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + """, + podcast["session_id"], + podcast["user_id"], + podcast["user_input"], + podcast["create_t"], + podcast["update_t"], + podcast["has_audio"], + podcast["audio_path"], + podcast["metadata_path"], + podcast["rounds"], + podcast["subtitles"], + podcast["extra_info"], + podcast["tag"], + ) + return True + except: # noqa + logger.error(f"insert_podcast error: {traceback.format_exc()}") + return False + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def query_podcast(self, session_id, user_id=None): + conn = await self.get_conn() + try: + query = f"SELECT * FROM {self.table_podcasts} WHERE session_id = $1 AND tag != 'delete'" + params = [session_id] + if user_id is not None: + query += " AND user_id = $2" + params.append(user_id) + row = await conn.fetchrow(query, *params) + if row is None: + return None + podcast = dict(row) + self.parse_dict(podcast) + return podcast + except: # noqa + logger.error(f"query_podcast error: {traceback.format_exc()}") + return None + finally: + await self.release_conn(conn) + + @class_try_catch_async + async def list_podcasts(self, **kwargs): + conn = await self.get_conn() + try: + count = kwargs.get("count", False) + query = f"SELECT * FROM " + if count: + query = f"SELECT COUNT(*) FROM " + assert "limit" not in kwargs, "limit is not allowed when count is True" + assert "offset" not in kwargs, "offset is not allowed when count is True" + params = [] + conds = [] + param_idx = 0 + query += self.table_podcasts + + if not kwargs.get("include_delete", False): + param_idx += 1 + conds.append(f"tag != ${param_idx}") + params.append("delete") + + if "has_audio" in kwargs: + param_idx += 1 + conds.append(f"has_audio = ${param_idx}") + params.append(kwargs["has_audio"]) + + if "user_id" in kwargs: + param_idx += 1 + conds.append(f"user_id = ${param_idx}") + params.append(kwargs["user_id"]) + + if conds: + query += " WHERE " + " AND ".join(conds) + + if not count: + sort_key = "update_t" if kwargs.get("sort_by_update_t", False) else "create_t" + query += f" ORDER BY {sort_key} DESC" + + if "limit" in kwargs: + param_idx += 1 + query += f" LIMIT ${param_idx}" + params.append(kwargs["limit"]) + + if "offset" in kwargs: + param_idx += 1 + query += f" OFFSET ${param_idx}" + params.append(kwargs["offset"]) + + rows = await conn.fetch(query, *params) + if count: + return rows[0]["count"] + + podcasts = [] + for row in rows: + podcast = dict(row) + self.parse_dict(podcast) + podcasts.append(podcast) + return podcasts + except: # noqa + logger.error(f"list_podcasts error: {traceback.format_exc()}") + return [] + finally: + await self.release_conn(conn) + + +async def test(): + from lightx2v.deploy.common.pipeline import Pipeline + + p = Pipeline("/data/nvme1/liuliang1/lightx2v/configs/model_pipeline.json") + m = PostgresSQLTaskManager("postgresql://test:test@127.0.0.1:5432/lightx2v_test") + await m.init() + + keys = ["t2v", "wan2.1", "multi_stage"] + workers = p.get_workers(keys) + inputs = p.get_inputs(keys) + outputs = p.get_outputs(keys) + params = { + "prompt": "fake input prompts", + "resolution": { + "height": 233, + "width": 456, + }, + } + + user_info = { + "source": "github", + "id": "4566", + "username": "test-username-233", + "email": "test-email-233@test.com", + "homepage": "https://test.com", + "avatar_url": "https://test.com/avatar.png", + } + user_id = await m.create_user(user_info) + print(" - create_user:", user_id) + + user = await m.query_user(user_id) + print(" - query_user:", user) + + task_id = await m.create_task(keys, workers, params, inputs, outputs, user_id) + print(" - create_task:", task_id) + + tasks = await m.list_tasks() + print(" - list_tasks:", tasks) + + task = await m.query_task(task_id) + print(" - query_task:", task) + + subtasks = await m.next_subtasks(task_id) + print(" - next_subtasks:", subtasks) + + await m.run_subtasks(subtasks, "fake-worker") + await m.finish_subtasks(task_id, TaskStatus.FAILED) + await m.cancel_task(task_id) + await m.resume_task(task_id) + for sub in subtasks: + await m.finish_subtasks(sub["task_id"], TaskStatus.SUCCEED, worker_name=sub["worker_name"], worker_identity="fake-worker") + + subtasks = await m.next_subtasks(task_id) + print(" - final next_subtasks:", subtasks) + + task = await m.query_task(task_id) + print(" - final task:", task) + + await m.close() + + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/lightx2v/deploy/worker/__init__.py b/lightx2v/deploy/worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/deploy/worker/__main__.py b/lightx2v/deploy/worker/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..718406accd15bf004f6f1db7577f4e53aa11a9c2 --- /dev/null +++ b/lightx2v/deploy/worker/__main__.py @@ -0,0 +1,370 @@ +import argparse +import asyncio +import json +import os +import signal +import sys +import traceback +import uuid + +import aiohttp +import torch +import torch.distributed as dist +from loguru import logger + +from lightx2v.deploy.data_manager import LocalDataManager, S3DataManager +from lightx2v.deploy.task_manager import TaskStatus +from lightx2v.deploy.worker.hub import DiTWorker, ImageEncoderWorker, PipelineWorker, SegmentDiTWorker, TextEncoderWorker, VaeDecoderWorker, VaeEncoderWorker +from lightx2v.server.metrics import metrics + +RUNNER_MAP = { + "pipeline": PipelineWorker, + "text_encoder": TextEncoderWorker, + "image_encoder": ImageEncoderWorker, + "vae_encoder": VaeEncoderWorker, + "vae_decoder": VaeDecoderWorker, + "dit": DiTWorker, + "segment_dit": SegmentDiTWorker, +} + +# {task_id: {"server": xx, "worker_name": xx, "identity": xx}} +RUNNING_SUBTASKS = {} +WORKER_SECRET_KEY = os.getenv("WORKER_SECRET_KEY", "worker-secret-key-change-in-production") +HEADERS = {"Authorization": f"Bearer {WORKER_SECRET_KEY}", "Content-Type": "application/json"} +STOPPED = False +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) +RANK = int(os.environ.get("RANK", 0)) +TARGET_RANK = int(os.getenv("WORKER_RANK", "0")) % WORLD_SIZE + + +async def ping_life(server_url, worker_identity, keys): + url = server_url + "/api/v1/worker/ping/life" + params = {"worker_identity": worker_identity, "worker_keys": keys} + while True: + try: + logger.info(f"{worker_identity} pinging life ...") + async with aiohttp.ClientSession() as session: + async with session.post(url, data=json.dumps(params), headers=HEADERS) as ret: + if ret.status == 200: + ret = await ret.json() + logger.info(f"{worker_identity} ping life: {ret}") + if ret["msg"] == "delete": + logger.warning(f"{worker_identity} deleted") + # asyncio.create_task(shutdown(asyncio.get_event_loop())) + return + await asyncio.sleep(10) + else: + error_text = await ret.text() + raise Exception(f"{worker_identity} ping life fail: [{ret.status}], error: {error_text}") + except asyncio.CancelledError: + logger.warning("Ping life cancelled, shutting down...") + raise asyncio.CancelledError + except: # noqa + logger.warning(f"Ping life failed: {traceback.format_exc()}") + await asyncio.sleep(10) + + +async def ping_subtask(server_url, worker_identity, task_id, worker_name, queue, running_task, ping_interval): + url = server_url + "/api/v1/worker/ping/subtask" + params = { + "worker_identity": worker_identity, + "task_id": task_id, + "worker_name": worker_name, + "queue": queue, + } + while True: + try: + logger.info(f"{worker_identity} pinging subtask {task_id} {worker_name} ...") + async with aiohttp.ClientSession() as session: + async with session.post(url, data=json.dumps(params), headers=HEADERS) as ret: + if ret.status == 200: + ret = await ret.json() + logger.info(f"{worker_identity} ping subtask {task_id} {worker_name}: {ret}") + if ret["msg"] == "delete": + logger.warning(f"{worker_identity} subtask {task_id} {worker_name} deleted") + running_task.cancel() + return + await asyncio.sleep(ping_interval) + else: + error_text = await ret.text() + raise Exception(f"{worker_identity} ping subtask fail: [{ret.status}], error: {error_text}") + except asyncio.CancelledError: + logger.warning(f"Ping subtask {task_id} {worker_name} cancelled") + raise asyncio.CancelledError + except: # noqa + logger.warning(f"Ping subtask failed: {traceback.format_exc()}") + await asyncio.sleep(10) + + +async def fetch_subtasks(server_url, worker_keys, worker_identity, max_batch, timeout): + url = server_url + "/api/v1/worker/fetch" + params = { + "worker_keys": worker_keys, + "worker_identity": worker_identity, + "max_batch": max_batch, + "timeout": timeout, + } + try: + logger.info(f"{worker_identity} fetching {worker_keys} with timeout: {timeout}s ...") + async with aiohttp.ClientSession() as session: + async with session.post(url, data=json.dumps(params), headers=HEADERS, timeout=timeout + 1) as ret: + if ret.status == 200: + ret = await ret.json() + subtasks = ret["subtasks"] + for sub in subtasks: + sub["server_url"] = server_url + sub["worker_identity"] = worker_identity + RUNNING_SUBTASKS[sub["task_id"]] = sub + logger.info(f"{worker_identity} fetch {worker_keys} ok: {subtasks}") + return subtasks + else: + error_text = await ret.text() + logger.warning(f"{worker_identity} fetch {worker_keys} fail: [{ret.status}], error: {error_text}") + return None + except asyncio.CancelledError: + logger.warning("Fetch subtasks cancelled, shutting down...") + raise asyncio.CancelledError + except: # noqa + logger.warning(f"Fetch subtasks failed: {traceback.format_exc()}") + await asyncio.sleep(10) + + +async def report_task(server_url, task_id, worker_name, status, worker_identity, queue, **kwargs): + url = server_url + "/api/v1/worker/report" + params = { + "task_id": task_id, + "worker_name": worker_name, + "status": status, + "worker_identity": worker_identity, + "queue": queue, + "fail_msg": "" if status == TaskStatus.SUCCEED.name else "worker failed", + } + try: + async with aiohttp.ClientSession() as session: + async with session.post(url, data=json.dumps(params), headers=HEADERS) as ret: + if ret.status == 200: + RUNNING_SUBTASKS.pop(task_id) + ret = await ret.json() + logger.info(f"{worker_identity} report {task_id} {worker_name} {status} ok") + return True + else: + error_text = await ret.text() + logger.warning(f"{worker_identity} report {task_id} {worker_name} {status} fail: [{ret.status}], error: {error_text}") + return False + except asyncio.CancelledError: + logger.warning("Report task cancelled, shutting down...") + raise asyncio.CancelledError + except: # noqa + logger.warning(f"Report task failed: {traceback.format_exc()}") + + +async def boradcast_subtasks(subtasks): + subtasks = [] if subtasks is None else subtasks + if WORLD_SIZE <= 1: + return subtasks + try: + if RANK == TARGET_RANK: + subtasks_data = json.dumps(subtasks, ensure_ascii=False).encode("utf-8") + subtasks_tensor = torch.frombuffer(bytearray(subtasks_data), dtype=torch.uint8).to(device="cuda") + data_size = subtasks_tensor.shape[0] + size_tensor = torch.tensor([data_size], dtype=torch.int32).to(device="cuda") + logger.info(f"rank {RANK} send subtasks: {subtasks_tensor.shape}, {size_tensor}") + else: + size_tensor = torch.zeros(1, dtype=torch.int32, device="cuda") + + dist.broadcast(size_tensor, src=TARGET_RANK) + if RANK != TARGET_RANK: + subtasks_tensor = torch.zeros(size_tensor.item(), dtype=torch.uint8, device="cuda") + dist.broadcast(subtasks_tensor, src=TARGET_RANK) + + if RANK != TARGET_RANK: + subtasks_data = subtasks_tensor.cpu().numpy().tobytes() + subtasks = json.loads(subtasks_data.decode("utf-8")) + logger.info(f"rank {RANK} recv subtasks: {subtasks}") + return subtasks + + except: # noqa + logger.error(f"Broadcast subtasks failed: {traceback.format_exc()}") + return [] + + +async def sync_subtask(): + if WORLD_SIZE <= 1: + return + try: + logger.info(f"Sync subtask {RANK}/{WORLD_SIZE} wait barrier") + dist.barrier() + logger.info(f"Sync subtask {RANK}/{WORLD_SIZE} ok") + except: # noqa + logger.error(f"Sync subtask failed: {traceback.format_exc()}") + + +async def main(args): + if args.model_name == "": + args.model_name = args.model_cls + if args.task_name == "": + args.task_name = args.task + worker_keys = [args.task_name, args.model_name, args.stage, args.worker] + + metrics.server_process(args.metric_port) + + data_manager = None + if args.data_url.startswith("/"): + data_manager = LocalDataManager(args.data_url, None) + elif args.data_url.startswith("{"): + data_manager = S3DataManager(args.data_url, None) + else: + raise NotImplementedError + await data_manager.init() + + if WORLD_SIZE > 1: + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank()) + logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}") + + runner = RUNNER_MAP[args.worker](args) + if WORLD_SIZE > 1: + dist.barrier() + # asyncio.create_task(ping_life(args.server, args.identity, worker_keys)) + + while True: + subtasks = None + if RANK == TARGET_RANK: + subtasks = await fetch_subtasks(args.server, worker_keys, args.identity, args.max_batch, args.timeout) + subtasks = await boradcast_subtasks(subtasks) + + for sub in subtasks: + status = TaskStatus.FAILED.name + ping_task = None + try: + run_task = asyncio.create_task(runner.run(sub["inputs"], sub["outputs"], sub["params"], data_manager)) + if RANK == TARGET_RANK: + ping_task = asyncio.create_task(ping_subtask(args.server, sub["worker_identity"], sub["task_id"], sub["worker_name"], sub["queue"], run_task, args.ping_interval)) + ret = await run_task + if ret is True: + status = TaskStatus.SUCCEED.name + + except asyncio.CancelledError: + if STOPPED: + logger.warning("Main loop cancelled, already stopped, should exit") + return + logger.warning("Main loop cancelled, do not shut down") + + finally: + try: + if ping_task: + ping_task.cancel() + await sync_subtask() + except Exception: + logger.warning(f"Sync subtask failed: {traceback.format_exc()}") + if RANK == TARGET_RANK and sub["task_id"] in RUNNING_SUBTASKS: + try: + await report_task(status=status, **sub) + except Exception: + logger.warning(f"Report failed: {traceback.format_exc()}") + + +async def shutdown(loop): + logger.warning("Received kill signal") + global STOPPED + STOPPED = True + + for t in asyncio.all_tasks(): + if t is not asyncio.current_task(): + logger.warning(f"Cancel async task {t} ...") + t.cancel() + + # Report remaining running subtasks failed + if RANK == TARGET_RANK: + task_ids = list(RUNNING_SUBTASKS.keys()) + for task_id in task_ids: + try: + s = RUNNING_SUBTASKS[task_id] + logger.warning(f"Report {task_id} {s['worker_name']} {TaskStatus.FAILED.name} ...") + await report_task(status=TaskStatus.FAILED.name, **s) + except: # noqa + logger.warning(f"Report task {task_id} failed: {traceback.format_exc()}") + + if WORLD_SIZE > 1: + dist.destroy_process_group() + + # Force exit after a short delay to ensure cleanup + def force_exit(): + logger.warning("Force exiting process...") + sys.exit(0) + + loop.call_later(2, force_exit) + + +# align args like infer.py +def align_args(args): + args.seed = 42 + args.sf_model_path = args.sf_model_path if args.sf_model_path else "" + args.use_prompt_enhancer = False + args.prompt = "" + args.negative_prompt = "" + args.image_path = "" + args.last_frame_path = "" + args.audio_path = "" + args.src_pose_path = None + args.src_face_path = None + args.src_bg_path = None + args.src_mask_path = None + args.src_ref_images = None + args.src_video = None + args.src_mask = None + args.save_result_path = "" + args.return_result_tensor = False + args.is_live = True + + +# ========================= +# Main Entry +# ========================= + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + base_dir = os.path.abspath(os.path.join(cur_dir, "../../..")) + dft_data_url = os.path.join(base_dir, "local_data") + + parser.add_argument("--task", type=str, required=True) + parser.add_argument("--task_name", type=str, default="") + parser.add_argument("--model_cls", type=str, required=True) + parser.add_argument("--model_name", type=str, default="") + parser.add_argument("--stage", type=str, required=True) + parser.add_argument("--worker", type=str, required=True) + parser.add_argument("--identity", type=str, default="") + parser.add_argument("--max_batch", type=int, default=1) + parser.add_argument("--timeout", type=int, default=300) + parser.add_argument("--ping_interval", type=int, default=10) + + parser.add_argument("--metric_port", type=int, default=8001) + + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--sf_model_path", type=str, default="") + parser.add_argument("--config_json", type=str, required=True) + + parser.add_argument("--server", type=str, default="http://127.0.0.1:8080") + parser.add_argument("--data_url", type=str, default=dft_data_url) + + args = parser.parse_args() + align_args(args) + if args.identity == "": + # TODO: spec worker instance identity by k8s env + args.identity = "worker-" + str(uuid.uuid4())[:8] + logger.info(f"args: {args}") + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + for s in [signal.SIGINT, signal.SIGTERM]: + loop.add_signal_handler(s, lambda: asyncio.create_task(shutdown(loop))) + + try: + loop.create_task(main(args), name="main") + loop.run_forever() + finally: + loop.close() + logger.warning("Event loop closed") diff --git a/lightx2v/deploy/worker/hub.py b/lightx2v/deploy/worker/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..2863a9e45c2fa6116608be5e570a87229b898db0 --- /dev/null +++ b/lightx2v/deploy/worker/hub.py @@ -0,0 +1,498 @@ +import asyncio +import ctypes +import gc +import json +import os +import sys +import tempfile +import threading +import traceback + +import torch +import torch.distributed as dist +from loguru import logger + +import lightx2v +from lightx2v.deploy.common.utils import class_try_catch_async +from lightx2v.infer import init_runner # noqa +from lightx2v.utils.input_info import set_input_info +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.set_config import set_config, set_parallel_config +from lightx2v.utils.utils import seed_all + + +def init_tools_preprocess(): + preprocess_path = os.path.abspath(os.path.join(lightx2v.__path__[0], "..", "tools", "preprocess")) + assert os.path.exists(preprocess_path), f"lightx2v tools preprocess path not found: {preprocess_path}" + sys.path.append(preprocess_path) + + +class BaseWorker: + @ProfilingContext4DebugL1("Init Worker Worker Cost:") + def __init__(self, args): + config = set_config(args) + logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") + seed_all(args.seed) + self.rank = 0 + self.world_size = 1 + if config["parallel"]: + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + set_parallel_config(config) + # same as va_recorder rank + self.out_video_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size + torch.set_grad_enabled(False) + self.runner = RUNNER_REGISTER[config["model_cls"]](config) + self.input_info = set_input_info(args) + + def update_input_info(self, kwargs): + for k, v in kwargs.items(): + setattr(self.input_info, k, v) + + def set_inputs(self, params): + self.input_info.prompt = params["prompt"] + self.input_info.negative_prompt = params.get("negative_prompt", "") + self.input_info.image_path = params.get("image_path", "") + self.input_info.save_result_path = params.get("save_result_path", "") + self.input_info.seed = params.get("seed", self.input_info.seed) + self.input_info.audio_path = params.get("audio_path", "") + for k, v in params.get("processed_video_paths", {}).items(): + logger.info(f"set {k} to {v}") + setattr(self.input_info, k, v) + + async def prepare_input_image(self, params, inputs, tmp_dir, data_manager): + input_image_path = inputs.get("input_image", "") + tmp_image_path = os.path.join(tmp_dir, input_image_path) + + # prepare tmp image + if "image_path" in self.input_info.__dataclass_fields__: + img_data = await data_manager.load_bytes(input_image_path) + with open(tmp_image_path, "wb") as fout: + fout.write(img_data) + params["image_path"] = tmp_image_path + + async def prepare_input_video(self, params, inputs, tmp_dir, data_manager): + if not self.is_animate_model(): + return + init_tools_preprocess() + from preprocess_data import get_preprocess_parser, process_input_video + + result_paths = {} + if self.rank == 0: + tmp_image_path = params.get("image_path", "") + assert os.path.exists(tmp_image_path), f"input_image should be save by prepare_input_image but not valid: {tmp_image_path}" + + # prepare tmp input video + input_video_path = inputs.get("input_video", "") + tmp_video_path = os.path.join(tmp_dir, input_video_path) + processed_video_path = os.path.join(tmp_dir, "processe_results") + video_data = await data_manager.load_bytes(input_video_path) + with open(tmp_video_path, "wb") as fout: + fout.write(video_data) + + # prepare preprocess args + pre_args = get_preprocess_parser().parse_args([]) + pre_args.ckpt_path = self.runner.config["model_path"] + "/process_checkpoint" + pre_args.video_path = tmp_video_path + pre_args.refer_path = tmp_image_path + pre_args.save_path = processed_video_path + pre_args.replace_flag = self.runner.config.get("replace_flag", False) + pre_config = self.runner.config.get("preprocess_config", {}) + pre_keys = ["resolution_area", "fps", "replace_flag", "retarget_flag", "use_flux", "iterations", "k", "w_len", "h_len"] + for k in pre_keys: + if k in pre_config: + setattr(pre_args, k, pre_config[k]) + + process_input_video(pre_args) + result_paths = { + "src_pose_path": os.path.join(processed_video_path, "src_pose.mp4"), + "src_face_path": os.path.join(processed_video_path, "src_face.mp4"), + "src_ref_images": os.path.join(processed_video_path, "src_ref.png"), + } + if pre_args.replace_flag: + result_paths["src_bg_path"] = os.path.join(processed_video_path, "src_bg.mp4") + result_paths["src_mask_path"] = os.path.join(processed_video_path, "src_mask.mp4") + + # for dist, broadcast the video processed result to all ranks + result_paths = await self.broadcast_data(result_paths, 0) + for p in result_paths.values(): + assert os.path.exists(p), f"Input video processed result not found: {p}!" + params["processed_video_paths"] = result_paths + + async def prepare_input_audio(self, params, inputs, tmp_dir, data_manager): + input_audio_path = inputs.get("input_audio", "") + tmp_audio_path = os.path.join(tmp_dir, input_audio_path) + + # for stream audio input, value is dict + stream_audio_path = params.get("input_audio", None) + if stream_audio_path is not None: + tmp_audio_path = stream_audio_path + + if input_audio_path and self.is_audio_model() and isinstance(tmp_audio_path, str): + extra_audio_inputs = params.get("extra_inputs", {}).get("input_audio", []) + + # for multi-person audio directory input + if len(extra_audio_inputs) > 0: + os.makedirs(tmp_audio_path, exist_ok=True) + for inp in extra_audio_inputs: + tmp_path = os.path.join(tmp_dir, inputs[inp]) + inp_data = await data_manager.load_bytes(inputs[inp]) + with open(tmp_path, "wb") as fout: + fout.write(inp_data) + else: + audio_data = await data_manager.load_bytes(input_audio_path) + with open(tmp_audio_path, "wb") as fout: + fout.write(audio_data) + + params["audio_path"] = tmp_audio_path + + def prepare_output_video(self, params, outputs, tmp_dir, data_manager): + output_video_path = outputs.get("output_video", "") + tmp_video_path = os.path.join(tmp_dir, output_video_path) + if data_manager.name == "local": + tmp_video_path = os.path.join(data_manager.local_dir, output_video_path) + # for stream video output, value is dict + stream_video_path = params.get("output_video", None) + if stream_video_path is not None: + tmp_video_path = stream_video_path + + params["save_result_path"] = tmp_video_path + return tmp_video_path, output_video_path + + async def prepare_dit_inputs(self, inputs, data_manager): + device = torch.device("cuda", self.rank) + text_out = inputs["text_encoder_output"] + text_encoder_output = await data_manager.load_object(text_out, device) + image_encoder_output = None + + if "image_path" in self.input_info.__dataclass_fields__: + clip_path = inputs["clip_encoder_output"] + vae_path = inputs["vae_encoder_output"] + clip_encoder_out = await data_manager.load_object(clip_path, device) + vae_encoder_out = await data_manager.load_object(vae_path, device) + image_encoder_output = { + "clip_encoder_out": clip_encoder_out, + "vae_encoder_out": vae_encoder_out["vals"], + } + # apploy the config changes by vae encoder + self.update_input_info(vae_encoder_out["kwargs"]) + + self.runner.inputs = { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output, + } + + if self.is_audio_model(): + audio_segments, expected_frames = self.runner.read_audio_input() + self.runner.inputs["audio_segments"] = audio_segments + self.runner.inputs["expected_frames"] = expected_frames + + async def save_output_video(self, tmp_video_path, output_video_path, data_manager): + # save output video + if data_manager.name != "local" and self.rank == self.out_video_rank and isinstance(tmp_video_path, str): + video_data = open(tmp_video_path, "rb").read() + await data_manager.save_bytes(video_data, output_video_path) + + def is_audio_model(self): + return "audio" in self.runner.config["model_cls"] or "seko_talk" in self.runner.config["model_cls"] + + def is_animate_model(self): + return self.runner.config.get("task") == "animate" + + async def broadcast_data(self, data, src_rank=0): + if self.world_size <= 1: + return data + + if self.rank == src_rank: + val = json.dumps(data, ensure_ascii=False).encode("utf-8") + T = torch.frombuffer(bytearray(val), dtype=torch.uint8).to(device="cuda") + S = torch.tensor([T.shape[0]], dtype=torch.int32).to(device="cuda") + logger.info(f"hub rank {self.rank} send data: {data}") + else: + S = torch.zeros(1, dtype=torch.int32, device="cuda") + + dist.broadcast(S, src=src_rank) + if self.rank != src_rank: + T = torch.zeros(S.item(), dtype=torch.uint8, device="cuda") + dist.broadcast(T, src=src_rank) + + if self.rank != src_rank: + val = T.cpu().numpy().tobytes() + data = json.loads(val.decode("utf-8")) + logger.info(f"hub rank {self.rank} recv data: {data}") + return data + + +class RunnerThread(threading.Thread): + def __init__(self, loop, future, run_func, rank, *args, **kwargs): + super().__init__(daemon=True) + self.loop = loop + self.future = future + self.run_func = run_func + self.args = args + self.kwargs = kwargs + self.rank = rank + + def run(self): + try: + # cuda device bind for each thread + torch.cuda.set_device(self.rank) + res = self.run_func(*self.args, **self.kwargs) + status = True + except: # noqa + logger.error(f"RunnerThread run failed: {traceback.format_exc()}") + res = None + status = False + finally: + + async def set_future_result(): + self.future.set_result((status, res)) + + # add the task of setting future to the loop queue + asyncio.run_coroutine_threadsafe(set_future_result(), self.loop) + + def stop(self): + if self.is_alive(): + try: + logger.warning(f"Force terminate thread {self.ident} ...") + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(self.ident), ctypes.py_object(SystemExit)) + except Exception as e: + logger.error(f"Force terminate thread failed: {e}") + + +def class_try_catch_async_with_thread(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except asyncio.CancelledError: + logger.warning(f"RunnerThread inside {func.__name__} cancelled") + if hasattr(self, "thread"): + # self.thread.stop() + self.runner.stop_signal = True + self.thread.join() + raise asyncio.CancelledError + except Exception: + logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:") + traceback.print_exc() + return None + + return wrapper + + +class PipelineWorker(BaseWorker): + def __init__(self, args): + super().__init__(args) + self.runner.init_modules() + self.run_func = self.runner.run_pipeline + + @class_try_catch_async_with_thread + async def run(self, inputs, outputs, params, data_manager): + with tempfile.TemporaryDirectory() as tmp_dir: + await self.prepare_input_image(params, inputs, tmp_dir, data_manager) + await self.prepare_input_audio(params, inputs, tmp_dir, data_manager) + await self.prepare_input_video(params, inputs, tmp_dir, data_manager) + tmp_video_path, output_video_path = self.prepare_output_video(params, outputs, tmp_dir, data_manager) + logger.info(f"run params: {params}, {inputs}, {outputs}") + + self.set_inputs(params) + self.runner.stop_signal = False + + future = asyncio.Future() + self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_func, self.rank, input_info=self.input_info) + self.thread.start() + status, _ = await future + if not status: + return False + await self.save_output_video(tmp_video_path, output_video_path, data_manager) + return True + + +class TextEncoderWorker(BaseWorker): + def __init__(self, args): + super().__init__(args) + self.runner.text_encoders = self.runner.load_text_encoder() + + @class_try_catch_async + async def run(self, inputs, outputs, params, data_manager): + logger.info(f"run params: {params}, {inputs}, {outputs}") + input_image_path = inputs.get("input_image", "") + + self.set_inputs(params) + prompt = self.runner.config["prompt"] + img = None + + if self.runner.config["use_prompt_enhancer"]: + prompt = self.runner.config["prompt_enhanced"] + + if self.runner.config["task"] == "i2v" and not self.is_audio_model(): + img = await data_manager.load_image(input_image_path) + img = self.runner.read_image_input(img) + if isinstance(img, tuple): + img = img[0] + + out = self.runner.run_text_encoder(prompt, img) + if self.rank == 0: + await data_manager.save_object(out, outputs["text_encoder_output"]) + + del out + torch.cuda.empty_cache() + gc.collect() + return True + + +class ImageEncoderWorker(BaseWorker): + def __init__(self, args): + super().__init__(args) + self.runner.image_encoder = self.runner.load_image_encoder() + + @class_try_catch_async + async def run(self, inputs, outputs, params, data_manager): + logger.info(f"run params: {params}, {inputs}, {outputs}") + self.set_inputs(params) + + img = await data_manager.load_image(inputs["input_image"]) + img = self.runner.read_image_input(img) + if isinstance(img, tuple): + img = img[0] + out = self.runner.run_image_encoder(img) + if self.rank == 0: + await data_manager.save_object(out, outputs["clip_encoder_output"]) + + del out + torch.cuda.empty_cache() + gc.collect() + return True + + +class VaeEncoderWorker(BaseWorker): + def __init__(self, args): + super().__init__(args) + self.runner.vae_encoder, vae_decoder = self.runner.load_vae() + del vae_decoder + + @class_try_catch_async + async def run(self, inputs, outputs, params, data_manager): + logger.info(f"run params: {params}, {inputs}, {outputs}") + self.set_inputs(params) + img = await data_manager.load_image(inputs["input_image"]) + # could change config.lat_h, lat_w, tgt_h, tgt_w + img = self.runner.read_image_input(img) + if isinstance(img, tuple): + img = img[1] if self.runner.vae_encoder_need_img_original else img[0] + # run vae encoder changed the config, we use kwargs pass changes + vals = self.runner.run_vae_encoder(img) + out = {"vals": vals, "kwargs": {}} + + for key in ["original_shape", "resized_shape", "latent_shape", "target_shape"]: + if hasattr(self.input_info, key): + out["kwargs"][key] = getattr(self.input_info, key) + + if self.rank == 0: + await data_manager.save_object(out, outputs["vae_encoder_output"]) + + del out, img, vals + torch.cuda.empty_cache() + gc.collect() + return True + + +class DiTWorker(BaseWorker): + def __init__(self, args): + super().__init__(args) + self.runner.model = self.runner.load_transformer() + + @class_try_catch_async_with_thread + async def run(self, inputs, outputs, params, data_manager): + logger.info(f"run params: {params}, {inputs}, {outputs}") + self.set_inputs(params) + + await self.prepare_dit_inputs(inputs, data_manager) + self.runner.stop_signal = False + future = asyncio.Future() + self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_dit, self.rank) + self.thread.start() + status, out = await future + if not status: + return False + + if self.rank == 0: + await data_manager.save_tensor(out, outputs["latents"]) + + del out + torch.cuda.empty_cache() + gc.collect() + return True + + def run_dit(self): + self.runner.init_run() + assert self.runner.video_segment_num == 1, "DiTWorker only support single segment" + latents = self.runner.run_segment() + self.runner.end_run() + return latents + + +class VaeDecoderWorker(BaseWorker): + def __init__(self, args): + super().__init__(args) + vae_encoder, self.runner.vae_decoder = self.runner.load_vae() + self.runner.vfi_model = self.runner.load_vfi_model() if "video_frame_interpolation" in self.runner.config else None + del vae_encoder + + @class_try_catch_async + async def run(self, inputs, outputs, params, data_manager): + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_video_path, output_video_path = self.prepare_output_video(params, outputs, tmp_dir, data_manager) + logger.info(f"run params: {params}, {inputs}, {outputs}") + self.set_inputs(params) + + device = torch.device("cuda", self.rank) + latents = await data_manager.load_tensor(inputs["latents"], device) + self.runner.gen_video = self.runner.run_vae_decoder(latents) + self.runner.process_images_after_vae_decoder(save_video=True) + + await self.save_output_video(tmp_video_path, output_video_path, data_manager) + + del latents + torch.cuda.empty_cache() + gc.collect() + return True + + +class SegmentDiTWorker(BaseWorker): + def __init__(self, args): + super().__init__(args) + self.runner.model = self.runner.load_transformer() + self.runner.vae_encoder, self.runner.vae_decoder = self.runner.load_vae() + self.runner.vfi_model = self.runner.load_vfi_model() if "video_frame_interpolation" in self.runner.config else None + if self.is_audio_model(): + self.runner.audio_encoder = self.runner.load_audio_encoder() + self.runner.audio_adapter = self.runner.load_audio_adapter() + self.runner.model.set_audio_adapter(self.runner.audio_adapter) + + @class_try_catch_async_with_thread + async def run(self, inputs, outputs, params, data_manager): + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_video_path, output_video_path = self.prepare_output_video(params, outputs, tmp_dir, data_manager) + await self.prepare_input_audio(params, inputs, tmp_dir, data_manager) + logger.info(f"run params: {params}, {inputs}, {outputs}") + self.set_inputs(params) + + await self.prepare_dit_inputs(inputs, data_manager) + self.runner.stop_signal = False + future = asyncio.Future() + self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_dit, self.rank) + self.thread.start() + status, _ = await future + if not status: + return False + + await self.save_output_video(tmp_video_path, output_video_path, data_manager) + + torch.cuda.empty_cache() + gc.collect() + return True + + def run_dit(self): + self.runner.run_main() + self.runner.process_images_after_vae_decoder(save_video=True) diff --git a/lightx2v/infer.py b/lightx2v/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb13264e4ac97dc7ec616a73ee35793c9620b77 --- /dev/null +++ b/lightx2v/infer.py @@ -0,0 +1,146 @@ +import argparse + +import torch +import torch.distributed as dist +from loguru import logger + +from lightx2v.common.ops import * +from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 +from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 +from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner # noqa: F401 +from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401 +from lightx2v.utils.envs import * +from lightx2v.utils.input_info import set_input_info +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.set_config import print_config, set_config, set_parallel_config +from lightx2v.utils.utils import seed_all +from lightx2v_platform.base.global_var import AI_DEVICE +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + + +def init_runner(config): + torch.set_grad_enabled(False) + runner = RUNNER_REGISTER[config["model_cls"]](config) + runner.init_modules() + return runner + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42, help="The seed for random generator") + parser.add_argument( + "--model_cls", + type=str, + required=True, + choices=[ + "wan2.1", + "wan2.1_distill", + "wan2.1_vace", + "wan2.1_sf", + "wan2.1_sf_mtxg2", + "seko_talk", + "wan2.2_moe", + "wan2.2", + "wan2.2_moe_audio", + "wan2.2_audio", + "wan2.2_moe_distill", + "qwen_image", + "wan2.2_animate", + "hunyuan_video_1.5", + "hunyuan_video_1.5_distill", + ], + default="wan2.1", + ) + + parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate", "s2v"], default="t2v") + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--sf_model_path", type=str, required=False) + parser.add_argument("--config_json", type=str, required=True) + parser.add_argument("--use_prompt_enhancer", action="store_true") + + parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation") + parser.add_argument("--negative_prompt", type=str, default="") + + parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task") + parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") + parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task") + + # [Warning] For vace task, need refactor. + parser.add_argument( + "--src_ref_images", + type=str, + default=None, + help="The file list of the source reference images. Separated by ','. Default None.", + ) + parser.add_argument( + "--src_video", + type=str, + default=None, + help="The file of the source video. Default None.", + ) + parser.add_argument( + "--src_mask", + type=str, + default=None, + help="The file of the source mask. Default None.", + ) + parser.add_argument( + "--src_pose_path", + type=str, + default=None, + help="The file of the source pose. Default None.", + ) + parser.add_argument( + "--src_face_path", + type=str, + default=None, + help="The file of the source face. Default None.", + ) + parser.add_argument( + "--src_bg_path", + type=str, + default=None, + help="The file of the source background. Default None.", + ) + parser.add_argument( + "--src_mask_path", + type=str, + default=None, + help="The file of the source mask. Default None.", + ) + parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file") + parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)") + args = parser.parse_args() + + seed_all(args.seed) + + # set config + config = set_config(args) + + if config["parallel"]: + platform_device = PLATFORM_DEVICE_REGISTER.get(AI_DEVICE, None) + platform_device.init_parallel_env() + set_parallel_config(config) + + print_config(config) + + with ProfilingContext4DebugL1("Total Cost"): + runner = init_runner(config) + input_info = set_input_info(args) + runner.run_pipeline(input_info) + + # Clean up distributed process group + if dist.is_initialized(): + dist.destroy_process_group() + logger.info("Distributed process group cleaned up") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/models/__init__.py b/lightx2v/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/__init__.py b/lightx2v/models/input_encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/__init__.py b/lightx2v/models/input_encoders/hf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/animate/__init__.py b/lightx2v/models/input_encoders/hf/animate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/animate/face_encoder.py b/lightx2v/models/input_encoders/hf/animate/face_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..991b1e44fcd522060cbf1ee774342249e99185b7 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/animate/face_encoder.py @@ -0,0 +1,171 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +try: + from flash_attn import flash_attn_func, flash_attn_qkvpacked_func # noqa: F401 +except ImportError: + flash_attn_func = None + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="flash", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local diff --git a/lightx2v/models/input_encoders/hf/animate/motion_encoder.py b/lightx2v/models/input_encoders/hf/animate/motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab2cf9cc734a002b506f3e58eff1aacea0bec4c --- /dev/null +++ b/lightx2v/models/input_encoders/hf/animate/motion_encoder.py @@ -0,0 +1,300 @@ +# Modified from ``https://github.com/wyhsirius/LIA`` +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype == torch.bfloat16: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[ + :, + :, + max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0), + ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor**2) + + self.register_buffer("kernel", kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size**2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}, {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = {4: 512, 8: 512, 16: 512, 32: 512, 64: 256, 128: 128, 256: 64, 512: 32, 1024: 16} + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + motion_feat = self.enc.enc_motion(img) + # motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + with torch.amp.autocast("cuda", dtype=torch.float32): + motion = self.dec.direction(motion_feat) + return motion diff --git a/lightx2v/models/input_encoders/hf/hunyuan15/byt5/__init__.py b/lightx2v/models/input_encoders/hf/hunyuan15/byt5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/hunyuan15/byt5/format_prompt.py b/lightx2v/models/input_encoders/hf/hunyuan15/byt5/format_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..335b4f0cabc89be07b4e04d5913c9c94f25929f2 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/hunyuan15/byt5/format_prompt.py @@ -0,0 +1,68 @@ +import json + + +def closest_color(requested_color): + import webcolors + + min_colors = {} + for key, name in webcolors.CSS3_HEX_TO_NAMES.items(): + r_c, g_c, b_c = webcolors.hex_to_rgb(key) + rd = (r_c - requested_color[0]) ** 2 + gd = (g_c - requested_color[1]) ** 2 + bd = (b_c - requested_color[2]) ** 2 + min_colors[(rd + gd + bd)] = name + return min_colors[min(min_colors.keys())] + + +def convert_rgb_to_names(rgb_tuple): + try: + import webcolors + + color_name = webcolors.rgb_to_name(rgb_tuple) + except ValueError: + color_name = closest_color(rgb_tuple) + return color_name + + +class MultilingualPromptFormat: + def __init__( + self, + font_path: str = "assets/glyph_sdxl_assets/multilingual_10-lang_idx.json", + color_path: str = "assets/glyph_sdxl_assets/color_idx.json", + ): + with open(font_path, "r") as f: + self.font_dict = json.load(f) + with open(color_path, "r") as f: + self.color_dict = json.load(f) + + def format_prompt(self, texts, styles): + """ + Text "{text}" in {color}, {type}. + """ + + prompt = "" + for text, style in zip(texts, styles): + text_prompt = f'Text "{text}"' + + attr_list = [] + + # format color + if style["color"] is not None: + import webcolors + + hex_color = style["color"] + rgb_color = webcolors.hex_to_rgb(hex_color) + color_name = convert_rgb_to_names(rgb_color) + attr_list.append(f"") + + # format font + if style["font-family"] is not None: + attr_list.append(f"<{style['font-family'][:2]}-font-{self.font_dict[style['font-family']]}>") + attr_suffix = ", ".join(attr_list) + text_prompt += " in " + attr_suffix + text_prompt += ". " + else: + text_prompt += ". " + + prompt = prompt + text_prompt + return prompt diff --git a/lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py b/lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd0724cb49056c14c153258d721f13d5de75e501 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py @@ -0,0 +1,369 @@ +import glob +import json +import os +import re + +import torch +import torch.nn as nn +from safetensors import safe_open +from transformers import AutoTokenizer, T5ForConditionalGeneration + +from lightx2v_platform.base.global_var import AI_DEVICE + +from .format_prompt import MultilingualPromptFormat + + +def add_special_token( + tokenizer, + text_encoder, + add_color, + add_font, + color_ann_path, + font_ann_path, + multilingual=False, +): + """ + Add special tokens for color and font to tokenizer and text encoder. + + Args: + tokenizer: Huggingface tokenizer. + text_encoder: Huggingface T5 encoder. + add_color (bool): Whether to add color tokens. + add_font (bool): Whether to add font tokens. + color_ann_path (str): Path to color annotation JSON. + font_ann_path (str): Path to font annotation JSON. + multilingual (bool): Whether to use multilingual font tokens. + """ + with open(font_ann_path, "r") as f: + idx_font_dict = json.load(f) + with open(color_ann_path, "r") as f: + idx_color_dict = json.load(f) + + if multilingual: + font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict] + else: + font_token = [f"" for i in range(len(idx_font_dict))] + color_token = [f"" for i in range(len(idx_color_dict))] + additional_special_tokens = [] + if add_color: + additional_special_tokens += color_token + if add_font: + additional_special_tokens += font_token + + tokenizer.add_tokens(additional_special_tokens, special_tokens=True) + # Set mean_resizing=False to avoid PyTorch LAPACK dependency + text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) + + +def load_byt5_and_byt5_tokenizer( + byt5_name="google/byt5-small", + special_token=False, + color_special_token=False, + font_special_token=False, + color_ann_path="assets/color_idx.json", + font_ann_path="assets/font_idx_512.json", + huggingface_cache_dir=None, + multilingual=False, + device=None, +): + """ + Load ByT5 encoder and tokenizer from Huggingface, and add special tokens if needed. + + Args: + byt5_name (str): Model name or path. + special_token (bool): Whether to add special tokens. + color_special_token (bool): Whether to add color tokens. + font_special_token (bool): Whether to add font tokens. + color_ann_path (str): Path to color annotation JSON. + font_ann_path (str): Path to font annotation JSON. + huggingface_cache_dir (str): Huggingface cache directory. + multilingual (bool): Whether to use multilingual font tokens. + device (str or torch.device): Device to load the model onto. + + Returns: + tuple: (byt5_text_encoder, byt5_tokenizer) + """ + byt5_tokenizer = AutoTokenizer.from_pretrained( + byt5_name, + cache_dir=huggingface_cache_dir, + ) + byt5_text_encoder = T5ForConditionalGeneration.from_pretrained( + byt5_name, + cache_dir=huggingface_cache_dir, + ).get_encoder() + + if "cuda" not in str(device): + device = torch.device(device) + else: + device = torch.device(device) + byt5_text_encoder = byt5_text_encoder.to(device) + + if special_token: + add_special_token( + byt5_tokenizer, + byt5_text_encoder, + add_color=color_special_token, + add_font=font_special_token, + color_ann_path=color_ann_path, + font_ann_path=font_ann_path, + multilingual=multilingual, + ) + return byt5_text_encoder, byt5_tokenizer + + +class ByT5Mapper(nn.Module): + """ + ByT5Mapper: Maps ByT5 encoder outputs to a new space, with optional residual connection. + + Args: + in_dim (int): Input dimension (must equal out_dim if use_residual). + out_dim (int): Output dimension after second linear layer. + hidden_dim (int): Hidden dimension for intermediate layer. + out_dim1 (int): Final output dimension. + use_residual (bool): Whether to use residual connection (default: True). + """ + + def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.fc3 = nn.Linear(out_dim, out_dim1) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + """ + Forward pass for ByT5Mapper. + + Args: + x (Tensor): Input tensor of shape (..., in_dim). + + Returns: + Tensor: Output tensor of shape (..., out_dim1). + """ + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x2 = self.act_fn(x) + x2 = self.fc3(x2) + if self.use_residual: + x2 = x2 + residual + return x2 + + +class ByT5TextEncoder: + def __init__( + self, + config, + device=torch.device("cpu"), + checkpoint_path=None, + byt5_max_length=256, + cpu_offload=False, + ): + self.cpu_offload = cpu_offload + self.config = config + self.byt5_max_length = byt5_max_length + self.enable_cfg = config.get("enable_cfg", False) + byT5_google_path = os.path.join(checkpoint_path, "text_encoder", "byt5-small") + byT5_ckpt_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "checkpoints/byt5_model.pt") + multilingual_prompt_format_color_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "assets/color_idx.json") + multilingual_prompt_format_font_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "assets/multilingual_10-lang_idx.json") + byt5_args = dict( + byT5_google_path=byT5_google_path, + byT5_ckpt_path=byT5_ckpt_path, + multilingual_prompt_format_color_path=multilingual_prompt_format_color_path, + multilingual_prompt_format_font_path=multilingual_prompt_format_font_path, + byt5_max_length=byt5_max_length, + ) + self.byt5_tokenizer, self.byt5_model, self.byt5_max_length = self.create_byt5(byt5_args, device) + self.byt5_model = self.byt5_model.to(device=device) + self.prompt_format = MultilingualPromptFormat(font_path=multilingual_prompt_format_font_path, color_path=multilingual_prompt_format_color_path) + + self.byt5_mapper = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.config["hidden_size"], use_residual=False).to(torch.bfloat16) + + byt5_mapper_model_path = os.path.join(checkpoint_path, "transformer", self.config["transformer_model_name"]) + safetensors_files = glob.glob(os.path.join(byt5_mapper_model_path, "*.safetensors")) + byt5_mapper_state_dict = {} + for safetensor_path in safetensors_files: + with safe_open(safetensor_path, framework="pt", device="cpu") as f: + byt5_mapper_state_dict.update({key.replace("byt5_in.", ""): f.get_tensor(key).to(torch.bfloat16) for key in f.keys() if "byt5_in" in key}) + + self.byt5_mapper.load_state_dict(byt5_mapper_state_dict) + self.byt5_mapper.to(device=device) + + def create_byt5(self, args, device): + """ + Create ByT5 tokenizer and encoder, load weights if provided. + + Args: + args (dict): Configuration dictionary. + device (str or torch.device): Device to load the model onto. + + Returns: + tuple: (byt5_tokenizer, byt5_model, byt5_max_length) + """ + byt5_max_length = args["byt5_max_length"] + byt5_config = dict( + byt5_name=args["byT5_google_path"], + special_token=True, + color_special_token=True, + font_special_token=True, + color_ann_path=args["multilingual_prompt_format_color_path"], + font_ann_path=args["multilingual_prompt_format_font_path"], + multilingual=True, + ) + huggingface_cache_dir = None + byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer( + **byt5_config, + huggingface_cache_dir=huggingface_cache_dir, + device=device, + ) + + # Load custom checkpoint if provided + if args["byT5_ckpt_path"] is not None: + if "cuda" not in str(device): + byt5_state_dict = torch.load(args["byT5_ckpt_path"], map_location=device) + else: + byt5_state_dict = torch.load(args["byT5_ckpt_path"], map_location=device) + if "state_dict" in byt5_state_dict: + sd = byt5_state_dict["state_dict"] + newsd = {} + for k, v in sd.items(): + if k.startswith("module.text_tower.encoder."): + newsd[k[len("module.text_tower.encoder.") :]] = v + byt5_state_dict = newsd + byt5_model.load_state_dict(byt5_state_dict) + byt5_model.requires_grad_(False) + return byt5_tokenizer, byt5_model, byt5_max_length + + def _extract_glyph_texts(self, prompt): + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r"\"(.*?)\"|“(.*?)”" + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + return result + + def get_byt5_text_tokens(self, byt5_tokenizer, byt5_max_length, text_prompt): + """ + Tokenize text prompt for byT5 model. + + Args: + byt5_tokenizer: The byT5 tokenizer + byt5_max_length: Maximum sequence length + text_prompt: Text prompt to tokenize + + Returns: + Tuple of (input_ids, attention_mask) + """ + byt5_text_inputs = byt5_tokenizer( + text_prompt, + padding="max_length", + max_length=byt5_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + + return byt5_text_inputs.input_ids, byt5_text_inputs.attention_mask + + def _process_single_byt5_prompt(self, prompt_text, device): + """ + Process a single prompt for byT5 encoding. + + Args: + prompt_text: The prompt text to process + device: Target device for tensors + + Returns: + Tuple of (byt5_embeddings, byt5_mask) + """ + byt5_embeddings = torch.zeros((1, self.byt5_max_length, 1472), device=device) + byt5_mask = torch.zeros((1, self.byt5_max_length), device=device, dtype=torch.int64) + + glyph_texts = self._extract_glyph_texts(prompt_text) + + if len(glyph_texts) > 0: + text_styles = [{"color": None, "font-family": None} for _ in range(len(glyph_texts))] + formatted_text = self.prompt_format.format_prompt(glyph_texts, text_styles) + + text_ids, text_mask = self.get_byt5_text_tokens(self.byt5_tokenizer, self.byt5_max_length, formatted_text) + text_ids = text_ids.to(device) + text_mask = text_mask.to(device) + + byt5_outputs = self.byt5_model(text_ids, attention_mask=text_mask.float()) + byt5_embeddings = byt5_outputs[0] + byt5_mask = text_mask + + return byt5_embeddings, byt5_mask + + def _prepare_byt5_embeddings(self, prompts): + if isinstance(prompts, str): + prompt_list = [prompts] + elif isinstance(prompts, list): + prompt_list = prompts + else: + raise ValueError("prompts must be str or list of str") + + positive_embeddings = [] + positive_masks = [] + negative_embeddings = [] + negative_masks = [] + + for prompt in prompt_list: + pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, AI_DEVICE) + positive_embeddings.append(pos_emb) + positive_masks.append(pos_mask) + + if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行 + neg_emb, neg_mask = self._process_single_byt5_prompt("", AI_DEVICE) + negative_embeddings.append(neg_emb) + negative_masks.append(neg_mask) + + byt5_positive = torch.cat(positive_embeddings, dim=0) + byt5_positive_mask = torch.cat(positive_masks, dim=0) + + if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行 + byt5_negative = torch.cat(negative_embeddings, dim=0) + byt5_negative_mask = torch.cat(negative_masks, dim=0) + + byt5_embeddings = torch.cat([byt5_negative, byt5_positive], dim=0) + byt5_masks = torch.cat([byt5_negative_mask, byt5_positive_mask], dim=0) + else: + byt5_embeddings = byt5_positive + byt5_masks = byt5_positive_mask + + return byt5_embeddings, byt5_masks + + @torch.no_grad() + def infer(self, prompts): + if self.cpu_offload: + self.byt5_model = self.byt5_model.to(AI_DEVICE) + self.byt5_mapper = self.byt5_mapper.to(AI_DEVICE) + byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts) + byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16)) + if self.cpu_offload: + self.byt5_model = self.byt5_model.to("cpu") + self.byt5_mapper = self.byt5_mapper.to("cpu") + return byt5_features, byt5_masks + + +if __name__ == "__main__": + byt5 = ByT5TextEncoder(config={"transformer_model_name": "480p_t2v", "hidden_size": 2048}, device="cuda", checkpoint_path="/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5") + prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." + byt5_features, byt5_masks = byt5.infer(prompt) + print(byt5_features.shape, byt5_features.sum()) + print(byt5_masks.shape, byt5_masks.sum()) diff --git a/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/__init__.py b/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py b/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py new file mode 100644 index 0000000000000000000000000000000000000000..32150f8acf50da71396ccad0c318c105070463ef --- /dev/null +++ b/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py @@ -0,0 +1,641 @@ +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +import gc +import sys +from copy import deepcopy +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import loguru +import torch +import torch.nn as nn +from accelerate import init_empty_weights +from safetensors.torch import load_file +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, +) +from transformers.utils import ModelOutput + +current_dir = Path(__file__).resolve().parent +project_root = current_dir.parent.parent.parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402 + Q8FQuantLinearFp8, # noqa E402 + Q8FQuantLinearInt8, # noqa E402 + SglQuantLinearFp8, # noqa E402 + TorchaoQuantLinearInt8, # noqa E402 + VllmQuantLinearInt8, # noqa E402 +) +from lightx2v_platform.base.global_var import AI_DEVICE # noqa E402 + +torch_device_module = getattr(torch, AI_DEVICE) + + +def use_default(value, default): + """Utility: return value if not None, else default.""" + return value if value is not None else default + + +# Prompt templates for different models and tasks + + +__all__ = [ + "C_SCALE", + "PROMPT_TEMPLATE", + "MODEL_BASE", +] + +# =================== Constant Values ===================== +# Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid +# overflow error when tensorboard logging values. +C_SCALE = 1_000_000_000_000_000 + +PROMPT_TEMPLATE_ENCODE_IMAGE_JSON = [ + { + "role": "system", + "content": "You are a helpful assistant. Describe the image by detailing the following aspects: \ + 1. The main content and theme of the image. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. The background environment, light, style and atmosphere.", + }, + {"role": "user", "content": "{}"}, +] + +PROMPT_TEMPLATE_ENCODE_VIDEO_JSON = [ + { + "role": "system", + "content": "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + }, + {"role": "user", "content": "{}"}, +] + +PROMPT_TEMPLATE = { + "li-dit-encode-image-json": {"template": PROMPT_TEMPLATE_ENCODE_IMAGE_JSON, "crop_start": -1}, # auto-calculate crop_start + "li-dit-encode-video-json": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO_JSON, "crop_start": -1}, # auto-calculate crop_start +} + + +MODEL_BASE = os.getenv("MODEL_BASE", "") +TEXT_ENCODER_PATH = { + "qwen-2.5vl-7b": f"{MODEL_BASE}/Qwen2.5-VL-7B-Instruct", +} +TOKENIZER_PATH = { + "qwen-2.5vl-7b": f"{MODEL_BASE}/Qwen2.5-VL-7B-Instruct", +} + +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def replace_linear(module, new_linear_cls): + for name, child in list(module.named_children()): + if isinstance(child, nn.Linear): + new_linear = new_linear_cls(child.in_features, child.out_features, bias=(child.bias is not None)) + new_linear.to(device=next(child.parameters(), None).device if any(True for _ in child.parameters()) else torch.device("cpu")) + setattr(module, name, new_linear) + else: + replace_linear(child, new_linear_cls) + + +def load_text_encoder( + text_encoder_type, text_encoder_precision=None, text_encoder_path=None, logger=None, device=None, text_encoder_quantized=False, text_encoder_quant_scheme=None, text_encoder_quant_ckpt=None +): + if text_encoder_path is None: + if text_encoder_type not in TEXT_ENCODER_PATH: + raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") + text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type] + + if text_encoder_quantized: + config = AutoConfig.from_pretrained(text_encoder_path) + with init_empty_weights(): + text_encoder = AutoModel.from_config(config) + text_encoder = text_encoder.language_model + + if text_encoder_quant_scheme in ["int8", "int8-vllm"]: + linear_cls = VllmQuantLinearInt8 + elif text_encoder_quant_scheme in ["fp8", "fp8-sgl"]: + linear_cls = SglQuantLinearFp8 + elif text_encoder_quant_scheme == "int8-torchao": + linear_cls = TorchaoQuantLinearInt8 + elif text_encoder_quant_scheme == "int8-q8f": + linear_cls = Q8FQuantLinearInt8 + elif text_encoder_quant_scheme == "fp8-q8f": + linear_cls = Q8FQuantLinearFp8 + else: + NotImplementedError(f"Unsupported Qwen25_vl quant scheme: {text_encoder_quant_scheme}") + + replace_linear(text_encoder.layers, linear_cls) + + weight_dict = load_file(text_encoder_quant_ckpt, device=str(device)) + new_w_dict = {} + for key in weight_dict.keys(): + if key == "lm_head.weight": + continue + new_w_dict[key.replace("model.", "")] = weight_dict[key] + del weight_dict + + torch_device_module.empty_cache() + gc.collect() + text_encoder.load_state_dict(new_w_dict, assign=True) + + else: + text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True) + text_encoder = text_encoder.language_model + + text_encoder.final_layer_norm = text_encoder.norm + + # from_pretrained will ensure that the model is in eval mode. + if text_encoder_precision is not None: + text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision]) + + text_encoder.requires_grad_(False) + + if device is not None: + text_encoder = text_encoder.to(device) + + return text_encoder, text_encoder_path + + +def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right", logger=None): + processor = None + if tokenizer_path is None: + if tokenizer_type not in TOKENIZER_PATH: + raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") + tokenizer_path = TOKENIZER_PATH[tokenizer_type] + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side=padding_side) + + return tokenizer, tokenizer_path, processor + + +@dataclass +class TextEncoderModelOutput(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + text_outputs (`list`, *optional*, returned when `return_texts=True` is passed): + List of decoded texts. + """ + + hidden_state: torch.FloatTensor = None + attention_mask: Optional[torch.LongTensor] = None + hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None + text_outputs: Optional[list] = None + image_features: Optional[list] = None + + +class TextEncoder(nn.Module): + def __init__( + self, + text_encoder_type: str, + max_length: int, + text_encoder_precision: Optional[str] = None, + text_encoder_path: Optional[str] = None, + tokenizer_type: Optional[str] = None, + tokenizer_path: Optional[str] = None, + output_key: Optional[str] = None, + use_attention_mask: bool = True, + prompt_template: Optional[dict] = None, + prompt_template_video: Optional[dict] = None, + hidden_state_skip_layer: Optional[int] = None, + apply_final_norm: bool = False, + reproduce: bool = False, + logger=None, + device=None, + qwen25vl_quantized=False, + qwen25vl_quant_scheme=None, + qwen25vl_quant_ckpt=None, + ): + super().__init__() + self.text_encoder_type = text_encoder_type + self.max_length = max_length + self.precision = text_encoder_precision + self.model_path = text_encoder_path + self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type + self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path + self.use_attention_mask = use_attention_mask + if prompt_template_video is not None: + assert use_attention_mask is True, "Attention mask is True required when training videos." + self.prompt_template = prompt_template + self.prompt_template_video = prompt_template_video + self.hidden_state_skip_layer = hidden_state_skip_layer + self.apply_final_norm = apply_final_norm + self.reproduce = reproduce + self.logger = logger + + self.use_template = self.prompt_template is not None + if self.use_template: + assert isinstance(self.prompt_template, dict) and "template" in self.prompt_template, f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}" + assert "{}" in str(self.prompt_template["template"]), f"`prompt_template['template']` must contain a placeholder `{{}}` for the input text, got {self.prompt_template['template']}" + + self.use_video_template = self.prompt_template_video is not None + if self.use_video_template: + if self.prompt_template_video is not None: + assert isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video, ( + f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}" + ) + assert "{}" in str(self.prompt_template_video["template"]), ( + f"`prompt_template_video['template']` must contain a placeholder `{{}}` for the input text, got {self.prompt_template_video['template']}" + ) + + if text_encoder_type != "qwen-2.5vl-7b": + raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") + self.output_key = output_key or "last_hidden_state" + + self.model, self.model_path = load_text_encoder( + text_encoder_type=self.text_encoder_type, + text_encoder_precision=self.precision, + text_encoder_path=self.model_path, + logger=self.logger, + device=device, + text_encoder_quantized=qwen25vl_quantized, + text_encoder_quant_scheme=qwen25vl_quant_scheme, + text_encoder_quant_ckpt=qwen25vl_quant_ckpt, + ) + + self.tokenizer, self.tokenizer_path, self.processor = load_tokenizer( + tokenizer_type=self.tokenizer_type, + tokenizer_path=self.tokenizer_path, + padding_side="right", + logger=self.logger, + ) + + # pre-calculate crop_start for image and video + if self.use_template and self.prompt_template is not None: + self.text2tokens("a photo of a cat", data_type="image") + # self.logger.info(f"crop_start for image: {self.prompt_template['crop_start']}") + if self.use_video_template and self.prompt_template_video is not None: + self.text2tokens("a photo of a cat", data_type="video") + # self.logger.info(f"crop_start for video: {self.prompt_template_video['crop_start']}") + + @property + def dtype(self): + return self.model.dtype + + @property + def device(self): + return self.model.device + + def __repr__(self): + return f"{self.text_encoder_type} ({self.precision} - {self.model_path})" + + @staticmethod + def apply_text_to_template(text, template, prevent_empty_text=True): + """ + Apply text to template. + + Args: + text (str): Input text. + template (str or list): Template string or list of chat conversation. + prevent_empty_text (bool): If Ture, we will prevent the user text from being empty + by adding a space. Defaults to True. + """ + if isinstance(template, str): + # Will send string to tokenizer. Used for llm + return template.format(text) + elif isinstance(template, list): + # For JSON list template format (chat conversation) + # Create a deep copy to avoid modifying the original template + template_copy = deepcopy(template) + for item in template_copy: + if isinstance(item, dict) and "content" in item: + # Replace placeholder with text in the content field + item["content"] = item["content"].format(text if text else (" " if prevent_empty_text else "")) + return template_copy + else: + raise TypeError(f"Unsupported template type: {type(template)}") + + def calculate_crop_start(self, tokenized_input): + """ + Automatically calculate the crop_start position based on identifying user tokens. + + Args: + tokenized_input: The output from the tokenizer containing input_ids + + Returns: + int: The position where the actual prompt content begins (after user markers) + """ + input_ids = tokenized_input["input_ids"][0].tolist() # Get the first example's tokens + + # Qwen user marker + marker = "<|im_start|>user\n" + + # Tokenize just the marker to get its token IDs + marker_tokens = self.tokenizer(marker, add_special_tokens=False)["input_ids"] + + # Find the end position of the marker in the input sequence + for i in range(len(input_ids) - len(marker_tokens) + 1): + if input_ids[i : i + len(marker_tokens)] == marker_tokens: + # Return the position after the marker + # print(f"crop_start: {i + len(marker_tokens)}, {self.tokenizer.decode(tokenized_input["input_ids"][0][i:i+len(marker_tokens)+10])}") # check crop_start + return i + len(marker_tokens) + + # If marker not found, try to find based on special tokens + if hasattr(self.tokenizer, "special_tokens_map"): + # Check for user token or any other special token that might indicate user input start + for token_name, token_value in self.tokenizer.special_tokens_map.items(): + if "user" in token_name.lower(): + user_token_id = self.tokenizer.convert_tokens_to_ids(token_value) + if user_token_id in input_ids: + return input_ids.index(user_token_id) + 1 + + # Default fallback: return 0 (no cropping) + return 0 + + def text2tokens(self, text, data_type="image", max_length=300): + """ + Tokenize the input text. + + Args: + text (str or list): Input text. + """ + tokenize_input_type = "str" + if self.use_template or self.use_video_template: + if data_type == "image": + prompt_template = self.prompt_template["template"] + crop_start = self.prompt_template.get("crop_start", -1) + elif data_type == "video": + prompt_template = self.prompt_template_video["template"] + crop_start = self.prompt_template_video.get("crop_start", -1) + else: + raise ValueError(f"Unsupported data type: {data_type}") + if isinstance(text, (list, tuple)): + text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text] + if isinstance(text[0], list): + tokenize_input_type = "list" + elif isinstance(text, str): + text = self.apply_text_to_template(text, prompt_template) + if isinstance(text, list): + tokenize_input_type = "list" + else: + raise TypeError(f"Unsupported text type: {type(text)}") + + # First pass: tokenize with arbitrary max_length to find crop_start + if crop_start == -1: + # Use temporary max_length for the first pass (large enough) + temp_kwargs = dict( + truncation=True, + max_length=256, # Temporary large value + padding="max_length", + return_tensors="pt", + ) + + # First tokenization pass to calculate crop_start + if tokenize_input_type == "str": + temp_tokenized = self.tokenizer( + text, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + **temp_kwargs, + ) + elif tokenize_input_type == "list": + temp_tokenized = self.tokenizer.apply_chat_template( + text, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **temp_kwargs, + ) + + # Calculate the crop_start from this first pass + crop_start = self.calculate_crop_start(temp_tokenized) + + # Store the calculated crop_start for future use + if data_type == "image": + self.prompt_template["crop_start"] = crop_start + else: + self.prompt_template_video["crop_start"] = crop_start + else: + crop_start = 0 + + # Second pass: tokenize with the proper max_length using the found crop_start + kwargs = dict( + truncation=True, + max_length=max_length + (crop_start if crop_start > 0 else 0), + padding="max_length", + return_tensors="pt", + ) + + if tokenize_input_type == "str": + tokenized_output = self.tokenizer( + text, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + **kwargs, + ) + elif tokenize_input_type == "list": + tokenized_output = self.tokenizer.apply_chat_template( + text, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **kwargs, + ) + else: + raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}") + + return tokenized_output + + def encode( + self, + batch_encoding, + use_attention_mask=None, + output_hidden_states=False, + do_sample=None, + hidden_state_skip_layer=None, + return_texts=False, + data_type="image", + device=None, + semantic_images=None, + is_uncond=False, + ): + """ + Args: + batch_encoding (dict): Batch encoding from tokenizer. + use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask. + Defaults to None. + output_hidden_states (bool): Whether to output hidden states. If False, return the value of + self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer, + output_hidden_states will be set True. Defaults to False. + do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None. + When self.produce is False, do_sample is set to True by default. + hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer. + If None, self.output_key will be used. Defaults to None. + return_texts (bool): Whether to return the decoded texts. Defaults to False. + """ + device = self.model.device if device is None else device + use_attention_mask = use_default(use_attention_mask, self.use_attention_mask) + hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer) + do_sample = use_default(do_sample, not self.reproduce) + + attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None + outputs = self.model( + input_ids=batch_encoding["input_ids"].to(device), + attention_mask=attention_mask, + output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None, + ) + if hidden_state_skip_layer is not None: + last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] + # Real last hidden state already has layer norm applied. So here we only apply it + # for intermediate layers. + if hidden_state_skip_layer > 0 and self.apply_final_norm: + last_hidden_state = self.model.final_layer_norm(last_hidden_state) + else: + last_hidden_state = outputs[self.output_key] + + # Remove hidden states of instruction tokens, only keep prompt tokens. + if self.use_template: + if data_type == "image": + crop_start = self.prompt_template.get("crop_start", 0) + elif data_type == "video": + crop_start = self.prompt_template_video.get("crop_start", 0) + else: + raise ValueError(f"Unsupported data type: {data_type}") + if crop_start > 0: + last_hidden_state = last_hidden_state[:, crop_start:] + attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None + + if output_hidden_states: + return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states) + return TextEncoderModelOutput(last_hidden_state, attention_mask) + + def forward( + self, + text, + use_attention_mask=None, + output_hidden_states=False, + do_sample=False, + hidden_state_skip_layer=None, + return_texts=False, + ): + batch_encoding = self.text2tokens(text, max_length=self.max_length) + return self.encode( + batch_encoding, + use_attention_mask=use_attention_mask, + output_hidden_states=output_hidden_states, + do_sample=do_sample, + hidden_state_skip_layer=hidden_state_skip_layer, + return_texts=return_texts, + ) + + +class Qwen25VL_TextEncoder: + def __init__( + self, + text_len=1000, + dtype=torch.float16, + device=torch.device("cpu"), + checkpoint_path=None, + cpu_offload=False, + qwen25vl_quantized=False, + qwen25vl_quant_scheme=None, + qwen25vl_quant_ckpt=None, + ): + self.text_len = text_len + self.dtype = dtype + self.cpu_offload = cpu_offload + self.qwen25vl_quantized = qwen25vl_quantized + self.qwen25vl_quant_scheme = qwen25vl_quant_scheme + if self.qwen25vl_quantized: + assert self.qwen25vl_quant_scheme is not None + self.qwen25vl_quant_ckpt = qwen25vl_quant_ckpt + self.num_videos_per_prompt = 1 + + self.text_encoder = TextEncoder( + text_encoder_type="qwen-2.5vl-7b", # TODO: 不要用 qwen, 改成 llm + tokenizer_type="qwen-2.5vl-7b", + text_encoder_path=checkpoint_path, + max_length=text_len, + text_encoder_precision="fp16", + prompt_template=PROMPT_TEMPLATE["li-dit-encode-image-json"], + prompt_template_video=PROMPT_TEMPLATE["li-dit-encode-video-json"], + hidden_state_skip_layer=2, + apply_final_norm=False, + reproduce=False, + logger=loguru.logger, + device=device, + qwen25vl_quantized=qwen25vl_quantized, + qwen25vl_quant_scheme=qwen25vl_quant_scheme, + qwen25vl_quant_ckpt=qwen25vl_quant_ckpt, + ) + + def infer(self, texts): + if self.cpu_offload: + self.text_encoder = self.text_encoder.to(AI_DEVICE) + text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len) + prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device=AI_DEVICE) + if self.cpu_offload: + self.text_encoder = self.text_encoder.to("cpu") + prompt_embeds = prompt_outputs.hidden_state + attention_mask = prompt_outputs.attention_mask + + if attention_mask is not None: + attention_mask = attention_mask.to(AI_DEVICE) + _, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt) + attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len) + prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=AI_DEVICE) + + seq_len = prompt_embeds.shape[1] + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, self.num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(self.num_videos_per_prompt, seq_len, -1) + return prompt_embeds, attention_mask + + +if __name__ == "__main__": + text_encoder_path = "/data/nvme0/models/hy1118/ckpts/hunyuanvideo-1.5/text_encoder/llm" + device = "cuda" + import torch.nn.functional as F + + prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." + negative_prompt = "" + + model = Qwen25VL_TextEncoder( + text_len=1000, + dtype=torch.float16, + device=device, + checkpoint_path=text_encoder_path, + cpu_offload=False, + qwen25vl_quantized=True, + qwen25vl_quant_scheme="int8-q8f", + qwen25vl_quant_ckpt="/data/nvme0/models/hy1118/quant_ckpts/qwen25vl-llm-int8.safetensors", + ) + + prompt_embeds, attention_mask = model.infer([prompt]) + print(f"prompt_embeds: {prompt_embeds}, {prompt_embeds.shape}") + a = torch.load("prompt_embeds.pth") + # print(f"attention_mask: {attention_mask}, {attention_mask.sum()}, {attention_mask.shape}") + print(F.cosine_similarity(prompt_embeds.flatten().unsqueeze(0), a.flatten().unsqueeze(0), dim=1)) + + negative_prompt_embeds, negative_attention_mask = model.infer([negative_prompt]) + print(f"negative_prompt_embeds: {negative_prompt_embeds}, {negative_prompt_embeds.shape}") + b = torch.load("negative_prompt_embeds.pth") + print(F.cosine_similarity(negative_prompt_embeds.flatten().unsqueeze(0), b.flatten().unsqueeze(0), dim=1)) + +# print(f"negative_attention_mask: {negative_attention_mask}, {negative_attention_mask.sum()}, {negative_attention_mask.shape}") diff --git a/lightx2v/models/input_encoders/hf/hunyuan15/siglip/__init__.py b/lightx2v/models/input_encoders/hf/hunyuan15/siglip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py b/lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..baa87ae344d1a99ddc3ea2daa6302181f66fe289 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py @@ -0,0 +1,303 @@ +import glob +import os +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from safetensors.torch import safe_open +from transformers import SiglipImageProcessor, SiglipVisionModel +from transformers.utils import ModelOutput + +from lightx2v_platform.base.global_var import AI_DEVICE + +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +VISION_ENCODER_PATH = {} + + +def use_default(value, default): + return value if value is not None else default + + +def load_vision_encoder( + vision_encoder_type, + vision_encoder_precision=None, + vision_encoder_path=None, + logger=None, + device=None, +): + if vision_encoder_path is None: + vision_encoder_path = VISION_ENCODER_PATH[vision_encoder_type] + + if vision_encoder_type == "siglip": + vision_encoder = SiglipVisionModel.from_pretrained(vision_encoder_path, subfolder="image_encoder") + else: + raise ValueError(f"Unsupported vision encoder type: {vision_encoder_type}") + + # from_pretrained will ensure that the model is in eval mode. + if vision_encoder_precision is not None: + vision_encoder = vision_encoder.to(dtype=PRECISION_TO_TYPE[vision_encoder_precision]) + + vision_encoder.requires_grad_(False) + + if device is not None: + vision_encoder = vision_encoder.to(device) + + return vision_encoder, vision_encoder_path + + +def load_image_processor(processor_type, processor_path=None, logger=None): + if processor_path is None: + processor_path = VISION_ENCODER_PATH[processor_type] + + if processor_type == "siglip": + processor = SiglipImageProcessor.from_pretrained(processor_path, subfolder="feature_extractor") + else: + raise ValueError(f"Unsupported processor type: {processor_type}") + + return processor, processor_path + + +@dataclass +class VisionEncoderModelOutput(ModelOutput): + """ + Base class for vision encoder model's outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*): + Last layer hidden-state of the first token of the sequence (classification token) + after further processing through the layers used for the auxiliary pretraining task. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class VisionEncoder(nn.Module): + def __init__( + self, + vision_encoder_type: str, + vision_encoder_precision: Optional[str] = None, + vision_encoder_path: Optional[str] = None, + processor_type: Optional[str] = None, + processor_path: Optional[str] = None, + output_key: Optional[str] = None, + logger=None, + device=None, + cpu_offload=False, + ): + super().__init__() + self.cpu_offload = cpu_offload + self.vision_encoder_type = vision_encoder_type + self.precision = vision_encoder_precision + self.model_path = vision_encoder_path + self.processor_type = processor_type if processor_type is not None else vision_encoder_type + self.processor_path = processor_path if processor_path is not None else vision_encoder_path + self.logger = logger + + if "siglip" in vision_encoder_type: + self.output_key = output_key or "last_hidden_state" + else: + raise ValueError(f"Unsupported vision encoder type: {vision_encoder_type}") + + self.model, self.model_path = load_vision_encoder( + vision_encoder_type=self.vision_encoder_type, + vision_encoder_precision=self.precision, + vision_encoder_path=self.model_path, + logger=self.logger, + device=device, + ) + self.dtype = self.model.dtype + self.device = self.model.device + + self.processor, self.processor_path = load_image_processor( + processor_type=self.processor_type, + processor_path=self.processor_path, + logger=self.logger, + ) + + def __repr__(self): + return f"{self.vision_encoder_type} ({self.precision} - {self.model_path})" + + def encode_latents_to_images(self, latents, vae, reorg_token=False): + """ + Convert latents to images using VAE decoder. + + Args: + latents: Input latents tensor + vae: VAE model for decoding + reorg_token: Whether to reorg the token + Returns: + images: Decoded images as numpy array + """ + # Handle both 4D and 5D latents (for video, take first frame) + first_image_latents = latents[:, :, 0, ...] if len(latents.shape) == 5 else latents + first_image_latents = 1 / vae.config.scaling_factor * first_image_latents + + first_image = vae.decode(first_image_latents.unsqueeze(2).to(vae.dtype), return_dict=False)[0].cpu() + + first_image = first_image[:, :, 0, :, :] + first_image = (first_image / 2 + 0.5).clamp(0, 1) + first_image = (first_image * 255.0).clamp(0, 255.0) + first_image = first_image.to(torch.uint8).numpy() + first_image = first_image.transpose(0, 2, 3, 1) + + assert isinstance(first_image, np.ndarray) + assert first_image.ndim == 4 and first_image.shape[3] == 3 + assert first_image.dtype == np.uint8 + + return first_image + + def encode_images(self, images): + """ + Encode images using the vision encoder. + + Args: + images: Input images (numpy array or preprocessed tensor) + + Returns: + VisionEncoderModelOutput with encoded features + """ + if self.cpu_offload: + self.model = self.model.to(AI_DEVICE) + self.processor = self.processor.to(AI_DEVICE) + + if isinstance(images, np.ndarray): + # Preprocess images if they're numpy arrays + preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device=AI_DEVICE, dtype=self.model.dtype) + else: + # Assume already preprocessed + preprocessed = images + + outputs = self.model(**preprocessed) + + if self.cpu_offload: + self.model = self.model.to("cpu") + self.processor = self.processor.to("cpu") + + return VisionEncoderModelOutput( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output if hasattr(outputs, "pooler_output") else None, + hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None, + ) + + def encode_latents(self, latents, vae, reorg_token=False): + """ + Encode latents by first converting to images, then encoding. + This is the main function that replaces sigclip_vision_encode. + + Args: + latents: Input latent tensors + vae: VAE model for decoding latents to images + + Returns: + Encoded image features + """ + # Convert latents to images + images = self.encode_latents_to_images(latents, vae, reorg_token) + + # Encode images + outputs = self.encode_images(images) + + return outputs.last_hidden_state + + def forward(self, images): + """ + Forward pass for direct image encoding. + + Args: + images: Input images + + Returns: + VisionEncoderModelOutput with encoded features + """ + return self.encode_images(images) + + +class SiglipVisionEncoder: + def __init__( + self, + config, + device=torch.device("cpu"), + checkpoint_path=None, + cpu_offload=False, + ): + self.config = config + self.device = device + self.cpu_offload = cpu_offload + self.vision_states_dim = 1152 + vision_encoder_path = os.path.join(checkpoint_path, "vision_encoder", "siglip") + + self.vision_encoder = VisionEncoder( + vision_encoder_type="siglip", + vision_encoder_precision="fp16", + vision_encoder_path=vision_encoder_path, + processor_type=None, + processor_path=None, + output_key=None, + logger=None, + device=self.device, + cpu_offload=self.cpu_offload, + ) + + self.vision_in = VisionProjection(in_dim=self.vision_states_dim, out_dim=self.config["hidden_size"], flf_pos_emb=False).to(torch.bfloat16) + + vision_in_model_path = os.path.join(checkpoint_path, "transformer", self.config["transformer_model_name"]) + safetensors_files = glob.glob(os.path.join(vision_in_model_path, "*.safetensors")) + vision_in_state_dict = {} + for safetensor_path in safetensors_files: + with safe_open(safetensor_path, framework="pt", device="cpu") as f: + vision_in_state_dict.update({key.replace("vision_in.", ""): f.get_tensor(key).to(torch.bfloat16) for key in f.keys() if "vision_in" in key}) + self.vision_in.load_state_dict(vision_in_state_dict) + self.vision_in.to(device=device) + + @torch.no_grad() + def infer(self, vision_states): + if self.cpu_offload: + self.vision_in = self.vision_in.to(AI_DEVICE) + vision_states = self.vision_in(vision_states) + if self.cpu_offload: + self.vision_in = self.vision_in.to("cpu") + return vision_states + + @torch.no_grad() + def encode_images(self, images): + return self.vision_encoder.encode_images(images) + + +class VisionProjection(torch.nn.Module): + """ + Projects vision embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py#L488 + """ + + def __init__(self, in_dim, out_dim, flf_pos_emb=False): + super().__init__() + + self.proj = torch.nn.Sequential(torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim)) + + if flf_pos_emb: # NOTE: we only use this for `flf2v` + self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) + + @torch.no_grad() + def forward(self, image_embeds): + if hasattr(self, "emb_pos"): + bs, n, d = image_embeds.shape + image_embeds = image_embeds.view(-1, 2 * n, d) + image_embeds = image_embeds + self.emb_pos + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens diff --git a/lightx2v/models/input_encoders/hf/q_linear.py b/lightx2v/models/input_encoders/hf/q_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ee40badbf65438a5ab580b3632d7c7566ffb00 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/q_linear.py @@ -0,0 +1,312 @@ +import torch +import torch.nn as nn + +try: + from vllm import _custom_ops as ops +except ImportError: + ops = None + +try: + import sgl_kernel +except ImportError: + sgl_kernel = None + +try: + from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax +except ImportError: + quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None + +try: + from q8_kernels.functional.linear import q8_linear +except ImportError: + q8_linear = None + +try: + from q8_kernels.functional.linear import fp8_linear +except ImportError: + fp8_linear = None + + +class VllmQuantLinearInt8(nn.Module): + def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) + + if bias: + self.register_buffer("bias", torch.empty(out_features, dtype=dtype)) + else: + self.register_buffer("bias", None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) + return input_tensor_quant, input_tensor_scale + + def forward(self, input_tensor): + input_tensor = input_tensor.squeeze(0) + shape = (input_tensor.shape[0], self.weight.shape[0]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + torch.ops._C.cutlass_scaled_mm( + output_tensor, + input_tensor_quant, + self.weight.t(), + input_tensor_scale, + self.weight_scale.float(), + self.bias, + ) + return output_tensor.unsqueeze(0) + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + def maybe_cast(t): + if t is not None and t.device != fn(t).device: + return fn(t) + return t + + self.weight = maybe_cast(self.weight) + self.weight_scale = maybe_cast(self.weight_scale) + self.bias = maybe_cast(self.bias) + return self + + +class VllmQuantLinearFp8(nn.Module): + def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) + self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) + if bias: + self.register_buffer("bias", torch.empty(out_features, dtype=dtype)) + else: + self.register_buffer("bias", None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) + return input_tensor_quant, input_tensor_scale + + def forward(self, input_tensor): + input_tensor = input_tensor.squeeze(0) + shape = (input_tensor.shape[0], self.weight.shape[0]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + torch.ops._C.cutlass_scaled_mm( + output_tensor, + input_tensor_quant, + self.weight.t(), + input_tensor_scale, + self.weight_scale.float(), + self.bias, + ) + + return output_tensor.unsqueeze(0) + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + def maybe_cast(t): + if t is not None and t.device != fn(t).device: + return fn(t) + return t + + self.weight = maybe_cast(self.weight) + self.weight_scale = maybe_cast(self.weight_scale) + self.bias = maybe_cast(self.bias) + return self + + +class SglQuantLinearFp8(nn.Module): + def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) + self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) + if bias: + self.register_buffer("bias", torch.empty(out_features, dtype=dtype)) + else: + self.register_buffer("bias", None) + + def act_quant_func(self, x): + m, k = x.shape + input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False) + input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False) + sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale) + return input_tensor_quant, input_tensor_scale + + def forward(self, input_tensor): + input_tensor = input_tensor.squeeze(0) + shape = (input_tensor.shape[0], self.weight.shape[0]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = sgl_kernel.fp8_scaled_mm( + input_tensor_quant, + self.weight.t(), + input_tensor_scale, + self.weight_scale.float(), + dtype, + bias=self.bias, + ) + + return output_tensor.unsqueeze(0) + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + def maybe_cast(t): + if t is not None and t.device != fn(t).device: + return fn(t) + return t + + self.weight = maybe_cast(self.weight) + self.weight_scale = maybe_cast(self.weight_scale) + self.bias = maybe_cast(self.bias) + return self + + +class TorchaoQuantLinearInt8(nn.Module): + def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) + + if bias: + self.register_buffer("bias", torch.empty(out_features, dtype=dtype)) + else: + self.register_buffer("bias", None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x) + return input_tensor_quant, input_tensor_scale + + def forward(self, input_tensor): + input_tensor = input_tensor.squeeze(0) + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16) + if self.bias is not None: + output_tensor = output_tensor + self.bias + + return output_tensor.unsqueeze(0) + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + def maybe_cast(t): + if t is not None and t.device != fn(t).device: + return fn(t) + return t + + self.weight = maybe_cast(self.weight) + self.weight_scale = maybe_cast(self.weight_scale) + self.bias = maybe_cast(self.bias) + return self + + +class Q8FQuantLinearInt8(nn.Module): + def __init__(self, in_features, out_features, bias=True, dtype=torch.float32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) + + if bias: + self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32)) + else: + self.register_buffer("bias", None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) + return input_tensor_quant, input_tensor_scale + + def forward(self, x): + input_tensor_quant, input_tensor_scale = self.act_quant_func(x) + output_tensor = q8_linear( + input_tensor_quant, + self.weight, + self.bias.float() if self.bias is not None else None, + input_tensor_scale, + self.weight_scale.float(), + fuse_gelu=False, + out_dtype=torch.bfloat16, + ) + return output_tensor + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + def maybe_cast(t): + if t is not None and t.device != fn(t).device: + return fn(t) + return t + + self.weight = maybe_cast(self.weight) + self.weight_scale = maybe_cast(self.weight_scale) + self.bias = maybe_cast(self.bias) + return self + + +class Q8FQuantLinearFp8(nn.Module): + def __init__(self, in_features, out_features, bias=True, dtype=torch.float32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) + self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) + + if bias: + self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32)) + else: + self.register_buffer("bias", None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x.squeeze(0), None, scale_ub=None, use_per_token_if_dynamic=True) + return input_tensor_quant, input_tensor_scale + + def forward(self, x): + input_tensor_quant, input_tensor_scale = self.act_quant_func(x) + output_tensor = fp8_linear( + input_tensor_quant, + self.weight, + self.bias.float() if self.bias is not None else None, + input_tensor_scale, + self.weight_scale.float(), + out_dtype=torch.bfloat16, + ) + return output_tensor + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + def maybe_cast(t): + if t is not None and t.device != fn(t).device: + return fn(t) + return t + + self.weight = maybe_cast(self.weight) + self.weight_scale = maybe_cast(self.weight_scale) + self.bias = maybe_cast(self.bias) + return self diff --git a/lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py b/lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py new file mode 100644 index 0000000000000000000000000000000000000000..f752023eb6eeedb85e473e579dcdfd7ffd10b8e2 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py @@ -0,0 +1,190 @@ +import gc +import math +import os + +import torch + +try: + from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration +except ImportError: + Qwen2Tokenizer = None + Qwen2_5_VLForConditionalGeneration = None + +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + +try: + from diffusers.image_processor import VaeImageProcessor + from transformers import Qwen2VLProcessor +except ImportError: + VaeImageProcessor = None + Qwen2VLProcessor = None + +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +class Qwen25_VLForConditionalGeneration_TextEncoder: + def __init__(self, config): + self.config = config + self.tokenizer_max_length = 1024 + self.prompt_template_encode = config["prompt_template_encode"] + self.prompt_template_encode_start_idx = config["prompt_template_encode_start_idx"] + """ + for Qwen-Image-Edit model, CONDITION_IMAGE_SIZE = 1024 * 1024 + for Qwen-Image-Edit-2509 model, CONDITION_IMAGE_SIZE = 384 * 384 + """ + self.CONDITION_IMAGE_SIZE = config.get("CONDITION_IMAGE_SIZE", 384 * 384) + self.USE_IMAGE_ID_IN_PROMPT = config.get("USE_IMAGE_ID_IN_PROMPT", True) + self.VAE_IMAGE_SIZE = 1024 * 1024 + + self.cpu_offload = config.get("cpu_offload", False) + self.dtype = torch.bfloat16 + self.load() + + def load(self): + self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16) + if not self.cpu_offload: + self.text_encoder = self.text_encoder.to(AI_DEVICE) + + self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config["model_path"], "tokenizer")) + if self.config["task"] == "i2i": + self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2) + self.processor = Qwen2VLProcessor.from_pretrained(os.path.join(self.config["model_path"], "processor")) + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def preprocess_image(self, image): + image_width, image_height = image.size + condition_width, condition_height = calculate_dimensions(self.CONDITION_IMAGE_SIZE, image_width / image_height) + vae_width, vae_height = calculate_dimensions(self.VAE_IMAGE_SIZE, image_width / image_height) + condition_image = self.image_processor.resize(image, condition_height, condition_width) + vae_image = self.image_processor.preprocess(image, vae_height, vae_width).unsqueeze(2) + + return condition_image, vae_image, (condition_height, condition_width), (vae_height, vae_width) + + @torch.no_grad() + def infer(self, text, image_list=None): + if self.cpu_offload: + self.text_encoder.to(AI_DEVICE) + + if image_list is not None: + condition_image_list = [] + vae_image_list = [] + condition_image_info_list = [] + vae_image_info_list = [] + if self.USE_IMAGE_ID_IN_PROMPT: + base_img_prompt = "" + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + for i, image in enumerate(image_list): + base_img_prompt += img_prompt_template.format(i + 1) + condition_image, vae_image, condition_image_info, vae_image_info = self.preprocess_image(image) + condition_image_list.append(condition_image) + vae_image_list.append(vae_image) + condition_image_info_list.append(condition_image_info) + vae_image_info_list.append(vae_image_info) + else: + base_img_prompt = "<|vision_start|><|image_pad|><|vision_end|>" + for i, image in enumerate(image_list): + condition_image, vae_image, condition_image_info, vae_image_info = self.preprocess_image(image) + condition_image_list.append(condition_image) + vae_image_list.append(vae_image) + condition_image_info_list.append(condition_image_info) + vae_image_info_list.append(vae_image_info) + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in text] + + model_inputs = self.processor( + text=txt, + images=condition_image_list, + padding=True, + return_tensors="pt", + ).to(AI_DEVICE) + + encoder_hidden_states = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + image_info = { + "condition_image_list": condition_image_list, + "vae_image_list": vae_image_list, + "condition_image_info_list": condition_image_info_list, + "vae_image_info_list": vae_image_info_list, + } + + else: + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in text] + + image_info = {} + model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(AI_DEVICE) + encoder_hidden_states = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + output_hidden_states=True, + ) + + hidden_states = encoder_hidden_states.hidden_states[-1] + + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + + prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=AI_DEVICE) + prompt_embeds_mask = encoder_attention_mask + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, self.config["num_images_per_prompt"], 1) + prompt_embeds = prompt_embeds.view(self.config["batchsize"] * self.config["num_images_per_prompt"], seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config["num_images_per_prompt"], 1) + prompt_embeds_mask = prompt_embeds_mask.view(self.config["batchsize"] * self.config["num_images_per_prompt"], seq_len) + + if self.cpu_offload: + self.text_encoder.to(torch.device("cpu")) + torch_device_module.empty_cache() + gc.collect() + + return prompt_embeds, prompt_embeds_mask, image_info diff --git a/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py b/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..c783dd8fff1acce85020725a8330ac38a228bac2 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py @@ -0,0 +1,311 @@ +try: + import flash_attn +except ModuleNotFoundError: + flash_attn = None + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from einops import rearrange + +from lightx2v_platform.base.global_var import AI_DEVICE + + +def linear_interpolation(features, output_len: int): + features = features.transpose(1, 2) + output_features = F.interpolate(features, size=output_len, align_corners=False, mode="linear") + return output_features.transpose(1, 2) + + +@torch.compiler.disable +def get_max_int(q_lens, k_lens): + max_seqlen_q = int(q_lens.max().item()) + max_seqlen_k = int(k_lens.max().item()) + return max_seqlen_q, max_seqlen_k + + +def get_qk_lens_audio_range( + n_tokens_per_rank: torch.Tensor, + n_query_tokens: torch.Tensor, + n_tokens_per_frame: torch.Tensor, + sp_rank: torch.Tensor, + num_tokens_x4, +): + device = n_tokens_per_rank.device + dtype = torch.int32 + + if n_query_tokens == 0: + q_lens = torch.ones(1, dtype=dtype, device=device) + t0 = torch.tensor(0, device=device) + t1 = torch.tensor(1, device=device) + k_lens = torch.full((t1 - t0,), num_tokens_x4, dtype=dtype, device=device) + max_seqlen_q, max_seqlen_k = get_max_int(q_lens, k_lens) + return q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 + + idx0 = n_tokens_per_rank * sp_rank + + first_length = n_tokens_per_frame - idx0 % n_tokens_per_frame + first_length = torch.minimum(first_length, n_query_tokens) + + n_frames = torch.div(n_query_tokens - first_length, n_tokens_per_frame, rounding_mode="floor") + + last_length = n_query_tokens - n_frames * n_tokens_per_frame - first_length + + first_tensor = first_length.unsqueeze(0) # [1] + frame_tensor = n_tokens_per_frame.repeat(n_frames) # [n_frames] + last_tensor = last_length.unsqueeze(0) # [1] + + q_lens_all = torch.cat([first_tensor, frame_tensor, last_tensor]) + q_lens = q_lens_all[q_lens_all > 0].to(dtype) + + t0 = idx0 // n_tokens_per_frame + t1 = t0 + q_lens.numel() + + k_lens = torch.full((t1 - t0,), num_tokens_x4, dtype=dtype, device=device) + + assert q_lens.shape == k_lens.shape + max_seqlen_q, max_seqlen_k = get_max_int(q_lens, k_lens) + + return q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 + + +def calculate_n_query_tokens(hidden_states, person_mask_latens, sp_rank, sp_size, n_tokens_per_rank, n_tokens): + tail_length = n_tokens_per_rank * sp_size - n_tokens + n_unused_ranks = tail_length // n_tokens_per_rank + + if sp_rank > sp_size - n_unused_ranks - 1: + n_query_tokens = 0 + elif sp_rank == sp_size - n_unused_ranks - 1: + val = n_tokens_per_rank - (tail_length % n_tokens_per_rank) + n_query_tokens = val + else: + n_query_tokens = n_tokens_per_rank + + if n_query_tokens > 0: + hidden_states_aligned = hidden_states[:n_query_tokens] + hidden_states_tail = hidden_states[n_query_tokens:] + if person_mask_latens is not None: + person_mask_aligned = person_mask_latens[:, :n_query_tokens] + else: + person_mask_aligned = None + else: + # for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works. + hidden_states_aligned = hidden_states[:1] + hidden_states_tail = hidden_states[1:] + if person_mask_latens is not None: + person_mask_aligned = person_mask_latens[:, :1] + else: + person_mask_aligned = None + + return n_query_tokens, hidden_states_aligned, hidden_states_tail, person_mask_aligned + + +''' +class PerceiverAttentionCA(nn.Module): + def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False, quantized=False, quant_scheme=None): + super().__init__() + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + kv_dim = inner_dim if kv_dim is None else kv_dim + self.norm_kv = nn.LayerNorm(kv_dim) + self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN) + + if quantized: + if quant_scheme == "fp8": + self.to_q = SglQuantLinearFp8(inner_dim, inner_dim) + self.to_kv = nn.Linear(kv_dim, inner_dim * 2) + self.to_out = SglQuantLinearFp8(inner_dim, inner_dim) + else: + raise ValueError(f"Unsupported quant_scheme: {quant_scheme}") + else: + self.to_q = nn.Linear(inner_dim, inner_dim) + self.to_kv = nn.Linear(kv_dim, inner_dim * 2) + self.to_out = nn.Linear(inner_dim, inner_dim) + if adaLN: + self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5) + else: + shift_scale_gate = torch.zeros((1, 3, inner_dim)) + shift_scale_gate[:, 2] = 1 + self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False) + + def forward(self, x, latents, t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k): + """x shape (batchsize, latent_frame, audio_tokens_per_latent, + model_dim) latents (batchsize, length, model_dim)""" + batchsize = len(x) + x = self.norm_kv(x) + shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1) + norm_q = self.norm_q(latents) + if scale.shape[0] != norm_q.shape[0]: + scale = scale.transpose(0, 1) # (1, 5070, 3072) + shift = shift.transpose(0, 1) + gate = gate.transpose(0, 1) + latents = norm_q * (1 + scale) + shift + q = self.to_q(latents) + k, v = self.to_kv(x).chunk(2, dim=-1) + q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads) + k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads) + v = rearrange(v, "B T L (H C) -> (B T L) H C", H=self.heads) + + out = flash_attn.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + ) + out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize) + return self.to_out(out) * gate +''' + + +class AudioProjection(nn.Module): + def __init__( + self, + audio_feature_dim: int = 768, + n_neighbors: tuple = (2, 2), + num_tokens: int = 32, + mlp_dims: tuple = (1024, 1024, 32 * 768), + transformer_layers: int = 4, + ): + super().__init__() + mlp = [] + self.left, self.right = n_neighbors + self.audio_frames = sum(n_neighbors) + 1 + in_dim = audio_feature_dim * self.audio_frames + for i, out_dim in enumerate(mlp_dims): + mlp.append(nn.Linear(in_dim, out_dim)) + if i != len(mlp_dims) - 1: + mlp.append(nn.ReLU()) + in_dim = out_dim + self.mlp = nn.Sequential(*mlp) + self.norm = nn.LayerNorm(mlp_dims[-1] // num_tokens) + self.num_tokens = num_tokens + if transformer_layers > 0: + decoder_layer = nn.TransformerDecoderLayer(d_model=audio_feature_dim, nhead=audio_feature_dim // 64, dim_feedforward=4 * audio_feature_dim, dropout=0.0, batch_first=True) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, + num_layers=transformer_layers, + ) + else: + self.transformer_decoder = None + + def forward(self, audio_feature, latent_frame): + video_frame = (latent_frame - 1) * 4 + 1 + audio_feature_ori = audio_feature + audio_feature = linear_interpolation(audio_feature_ori, video_frame) + if self.transformer_decoder is not None: + audio_feature = self.transformer_decoder(audio_feature, audio_feature_ori) + audio_feature = F.pad(audio_feature, pad=(0, 0, self.left, self.right), mode="replicate") + audio_feature = audio_feature.unfold(dimension=1, size=self.audio_frames, step=1) + audio_feature = rearrange(audio_feature, "B T C W -> B T (W C)") + audio_feature = self.mlp(audio_feature) # (B, video_frame, C) + audio_feature = rearrange(audio_feature, "B T (N C) -> B T N C", N=self.num_tokens) # (B, video_frame, num_tokens, C) + return self.norm(audio_feature) + + +class TimeEmbedding(nn.Module): + def __init__(self, dim, time_freq_dim, time_proj_dim): + super().__init__() + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + + def forward(self, timestep: torch.Tensor): + # Project timestep + if timestep.dim() == 2: + timestep = self.timesteps_proj(timestep.squeeze(0)).unsqueeze(0) + else: + timestep = self.timesteps_proj(timestep) + + # Match dtype with time_embedder (except int8) + target_dtype = next(self.time_embedder.parameters()).dtype + if timestep.dtype != target_dtype and target_dtype != torch.int8: + timestep = timestep.to(target_dtype) + + # Time embedding projection + temb = self.time_embedder(timestep) + timestep_proj = self.time_proj(self.act_fn(temb)) + + return timestep_proj.squeeze(0) if timestep_proj.dim() == 3 else timestep_proj + + +class AudioAdapter(nn.Module): + def __init__( + self, + attention_head_dim=64, + num_attention_heads=40, + base_num_layers=30, + interval=1, + audio_feature_dim: int = 768, + num_tokens: int = 32, + mlp_dims: tuple = (1024, 1024, 32 * 768), + time_freq_dim: int = 256, + projection_transformer_layers: int = 4, + quantized: bool = False, + quant_scheme: str = None, + cpu_offload: bool = False, + ): + super().__init__() + self.cpu_offload = cpu_offload + self.audio_proj = AudioProjection( + audio_feature_dim=audio_feature_dim, + n_neighbors=(2, 2), + num_tokens=num_tokens, + mlp_dims=mlp_dims, + transformer_layers=projection_transformer_layers, + ) + # self.num_tokens = num_tokens * 4 + self.num_tokens_x4 = num_tokens * 4 + self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02) + # ca_num = math.ceil(base_num_layers / interval) + self.base_num_layers = base_num_layers + self.interval = interval + """ + self.ca = nn.ModuleList( + [ + PerceiverAttentionCA( + dim_head=attention_head_dim, + heads=num_attention_heads, + kv_dim=mlp_dims[-1] // num_tokens, + adaLN=time_freq_dim > 0, + quantized=quantized, + quant_scheme=quant_scheme, + ) + for _ in range(ca_num) + ] + ) + """ + self.dim = attention_head_dim * num_attention_heads + if time_freq_dim > 0: + self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3) + else: + self.time_embedding = None + + def rearange_audio_features(self, audio_feature: torch.Tensor): + # audio_feature (B, video_frame, num_tokens, C) + audio_feature_0 = audio_feature[:, :1] + audio_feature_0 = torch.repeat_interleave(audio_feature_0, repeats=4, dim=1) + audio_feature = torch.cat([audio_feature_0, audio_feature[:, 1:]], dim=1) # (B, 4 * latent_frame, num_tokens, C) + audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4) + return audio_feature + + @torch.no_grad() + def forward_audio_proj(self, audio_feat, latent_frame): + if self.cpu_offload: + self.audio_proj.to(AI_DEVICE) + x = self.audio_proj(audio_feat, latent_frame) + x = self.rearange_audio_features(x) + x = x + self.audio_pe.to(AI_DEVICE) + if self.cpu_offload: + self.audio_proj.to("cpu") + return x diff --git a/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py b/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5b727af4437d845cd8a216ba4df12848a3c45584 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py @@ -0,0 +1,40 @@ +import torch +from transformers import AutoFeatureExtractor, AutoModel + +from lightx2v.utils.envs import * +from lightx2v_platform.base.global_var import AI_DEVICE + + +class SekoAudioEncoderModel: + def __init__(self, model_path, audio_sr, cpu_offload): + self.model_path = model_path + self.audio_sr = audio_sr + self.cpu_offload = cpu_offload + if self.cpu_offload: + self.device = torch.device("cpu") + else: + self.device = torch.device(AI_DEVICE) + self.load() + + def load(self): + self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_path) + self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path) + self.audio_feature_encoder.to(self.device) + self.audio_feature_encoder.eval() + self.audio_feature_encoder.to(GET_DTYPE()) + + def to_cpu(self): + self.audio_feature_encoder = self.audio_feature_encoder.to("cpu") + + def to_cuda(self): + self.audio_feature_encoder = self.audio_feature_encoder.to(AI_DEVICE) + + @torch.no_grad() + def infer(self, audio_segment): + audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(AI_DEVICE).to(dtype=GET_DTYPE()) + if self.cpu_offload: + self.audio_feature_encoder = self.audio_feature_encoder.to(AI_DEVICE) + audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state + if self.cpu_offload: + self.audio_feature_encoder = self.audio_feature_encoder.to("cpu") + return audio_feat diff --git a/lightx2v/models/input_encoders/hf/vace/vace_processor.py b/lightx2v/models/input_encoders/hf/vace/vace_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..f1006493fba14fe3b7f4424934204e790d72a9c0 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/vace/vace_processor.py @@ -0,0 +1,173 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import numpy as np +import torch +import torch.nn.functional as F + + +class VaceVideoProcessor(object): + def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): + self.downsample = downsample + self.min_area = min_area + self.max_area = max_area + self.min_fps = min_fps + self.max_fps = max_fps + self.zero_start = zero_start + self.keep_last = keep_last + self.seq_len = seq_len + assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) + + def set_area(self, area): + self.min_area = area + self.max_area = area + + def set_seq_len(self, seq_len): + self.seq_len = seq_len + + @staticmethod + def resize_crop(video: torch.Tensor, oh: int, ow: int): + """ + Resize, center crop and normalize for decord loaded video (torch.Tensor type) + + Parameters: + video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) + oh - target height (int) + ow - target width (int) + + Returns: + The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) + + Raises: + """ + # permute ([t, h, w, c] -> [t, c, h, w]) + video = video.permute(0, 3, 1, 2) + + # resize and crop + ih, iw = video.shape[2:] + if ih != oh or iw != ow: + # resize + scale = max(ow / iw, oh / ih) + video = F.interpolate(video, size=(round(scale * ih), round(scale * iw)), mode="bicubic", antialias=True) + assert video.size(3) >= ow and video.size(2) >= oh + + # center crop + x1 = (video.size(3) - ow) // 2 + y1 = (video.size(2) - oh) // 2 + video = video[:, :, y1 : y1 + oh, x1 : x1 + ow] + + # permute ([t, c, h, w] -> [c, t, h, w]) and normalize + video = video.transpose(0, 1).float().div_(127.5).sub_(1.0) + return video + + def _video_preprocess(self, video, oh, ow): + return self.resize_crop(video, oh, ow) + + def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): + target_fps = min(fps, self.max_fps) + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min((int(duration * target_fps) - 1) // df + 1, int(self.seq_len / area_z)) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = of / target_fps + begin = 0.0 if self.zero_start else rng.uniform(0, duration - target_duration) + timestamps = np.linspace(begin, begin + target_duration, of) + frame_ids = np.argmax(np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], timestamps[:, None] < frame_timestamps[None, :, 1]), axis=1).tolist() + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min((len(frame_timestamps) - 1) // df + 1, int(self.seq_len / area_z)) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = duration + target_fps = of / target_duration + timestamps = np.linspace(0.0, target_duration, of) + frame_ids = np.argmax(np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], timestamps[:, None] <= frame_timestamps[None, :, 1]), axis=1).tolist() + # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): + if self.keep_last: + return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) + else: + return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) + + def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): + rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) + # read video + import decord + + decord.bridge.set_bridge("torch") + readers = [] + for data_k in data_key_batch: + reader = decord.VideoReader(data_k) + readers.append(reader) + + fps = readers[0].get_avg_fps() + length = min([len(r) for r in readers]) + frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] + frame_timestamps = np.array(frame_timestamps, dtype=np.float32) + h, w = readers[0].next().shape[:2] + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) + + # preprocess video + videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] + videos = [self._video_preprocess(video, oh, ow) for video in videos] + return *videos, frame_ids, (oh, ow), fps + # return videos if len(videos) > 1 else videos[0] + + +def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_video is None and sub_src_mask is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + for j, ref_img in enumerate(ref_images): + if ref_img is not None and ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode="bilinear", align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top : top + new_height, left : left + new_width] = resized_image + src_ref_images[i][j] = white_canvas + return src_video, src_mask, src_ref_images diff --git a/lightx2v/models/input_encoders/hf/wan/matrix_game2/__init__.py b/lightx2v/models/input_encoders/hf/wan/matrix_game2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py b/lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..3008c9b79aef0aa607c32ce580f2a93ad24f5302 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py @@ -0,0 +1,332 @@ +# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from diffusers.models import ModelMixin + +from lightx2v.models.input_encoders.hf.wan.matrix_game2.tokenizers import HuggingfaceTokenizer +from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import VisionTransformer + + +class SelfAttention(nn.Module): + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + self.type_embedding(torch.zeros_like(ids)) + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +class XLMRobertaWithHead(XLMRoberta): + def __init__(self, **kwargs): + self.out_dim = kwargs.pop("out_dim") + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + def __init__( + self, + dtype=torch.float16, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + vision_pre_norm=True, + vision_post_norm=False, + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5, + quantized=False, + quant_scheme=None, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + ): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + dtype=dtype, + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps, + quantized=quantized, + quant_scheme=quant_scheme, + ) + self.textual = XLMRobertaWithHead( + vocab_size=vocab_size, + max_seq_len=max_text_len, + type_size=type_size, + pad_id=pad_id, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + post_norm=text_post_norm, + dropout=text_dropout, + ) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + +def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs): + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if "siglip" in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std)]) + output += (transforms,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + ) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class CLIPModel(ModelMixin): + def __init__(self, checkpoint_path, tokenizer_path): + super().__init__() + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False, + ) + self.model = self.model.eval().requires_grad_(False) + logging.info(f"loading {checkpoint_path}") + self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace") + + def encode_video(self, video): + # preprocess + b, c, t, h, w = video.shape + video = video.transpose(1, 2) + video = video.reshape(b * t, c, h, w) + size = (self.model.image_size,) * 2 + video = F.interpolate(video, size=size, mode="bicubic", align_corners=False) + + video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5)) + + # forward + with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type): + out = self.model.visual(video, use_31_block=True) + + return out + + def forward(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + with torch.amp.autocast("cuda", dtype=self.dtype): + out = self.model.visual(videos, use_31_block=True) + return out diff --git a/lightx2v/models/input_encoders/hf/wan/matrix_game2/conditions.py b/lightx2v/models/input_encoders/hf/wan/matrix_game2/conditions.py new file mode 100644 index 0000000000000000000000000000000000000000..ca99055d34844b679b104e58f790226de5cbef01 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/wan/matrix_game2/conditions.py @@ -0,0 +1,203 @@ +import random + +import torch + + +def combine_data(data, num_frames=57, keyboard_dim=6, mouse=True): + assert num_frames % 4 == 1 + keyboard_condition = torch.zeros((num_frames, keyboard_dim)) + if mouse: + mouse_condition = torch.zeros((num_frames, 2)) + + current_frame = 0 + selections = [12] + + while current_frame < num_frames: + rd_frame = selections[random.randint(0, len(selections) - 1)] + rd = random.randint(0, len(data) - 1) + k = data[rd]["keyboard_condition"] + if mouse: + m = data[rd]["mouse_condition"] + + if current_frame == 0: + keyboard_condition[:1] = k[:1] + if mouse: + mouse_condition[:1] = m[:1] + current_frame = 1 + else: + rd_frame = min(rd_frame, num_frames - current_frame) + repeat_time = rd_frame // 4 + keyboard_condition[current_frame : current_frame + rd_frame] = k.repeat(repeat_time, 1) + if mouse: + mouse_condition[current_frame : current_frame + rd_frame] = m.repeat(repeat_time, 1) + current_frame += rd_frame + if mouse: + return {"keyboard_condition": keyboard_condition, "mouse_condition": mouse_condition} + return {"keyboard_condition": keyboard_condition} + + +def Bench_actions_universal(num_frames, num_samples_per_action=4): + actions_single_action = [ + "forward", + # "back", + "left", + "right", + ] + actions_double_action = [ + "forward_left", + "forward_right", + # "back_left", + # "back_right", + ] + + actions_single_camera = [ + "camera_l", + "camera_r", + # "camera_ur", + # "camera_ul", + # "camera_dl", + # "camera_dr" + # "camera_up", + # "camera_down", + ] + actions_to_test = actions_double_action * 5 + actions_single_camera * 5 + actions_single_action * 5 + for action in actions_single_action + actions_double_action: + for camera in actions_single_camera: + double_action = f"{action}_{camera}" + actions_to_test.append(double_action) + + # print("length of actions: ", len(actions_to_test)) + base_action = actions_single_action + actions_single_camera + + KEYBOARD_IDX = {"forward": 0, "back": 1, "left": 2, "right": 3} + + CAM_VALUE = 0.1 + CAMERA_VALUE_MAP = { + "camera_up": [CAM_VALUE, 0], + "camera_down": [-CAM_VALUE, 0], + "camera_l": [0, -CAM_VALUE], + "camera_r": [0, CAM_VALUE], + "camera_ur": [CAM_VALUE, CAM_VALUE], + "camera_ul": [CAM_VALUE, -CAM_VALUE], + "camera_dr": [-CAM_VALUE, CAM_VALUE], + "camera_dl": [-CAM_VALUE, -CAM_VALUE], + } + + data = [] + + for action_name in actions_to_test: + keyboard_condition = [[0, 0, 0, 0] for _ in range(num_samples_per_action)] + mouse_condition = [[0, 0] for _ in range(num_samples_per_action)] + + for sub_act in base_action: + if sub_act not in action_name: # 只处理action_name包含的动作 + continue + # print(f"action name: {action_name} sub_act: {sub_act}") + if sub_act in CAMERA_VALUE_MAP: + mouse_condition = [CAMERA_VALUE_MAP[sub_act] for _ in range(num_samples_per_action)] + + elif sub_act in KEYBOARD_IDX: + col = KEYBOARD_IDX[sub_act] + for row in keyboard_condition: + row[col] = 1 + + data.append({"keyboard_condition": torch.tensor(keyboard_condition), "mouse_condition": torch.tensor(mouse_condition)}) + return combine_data(data, num_frames, keyboard_dim=4, mouse=True) + + +def Bench_actions_gta_drive(num_frames, num_samples_per_action=4): + actions_single_action = [ + "forward", + "back", + ] + + actions_single_camera = [ + "camera_l", + "camera_r", + ] + actions_to_test = actions_single_camera * 2 + actions_single_action * 2 + for action in actions_single_action: + for camera in actions_single_camera: + double_action = f"{action}_{camera}" + actions_to_test.append(double_action) + + # print("length of actions: ", len(actions_to_test)) + base_action = actions_single_action + actions_single_camera + + KEYBOARD_IDX = {"forward": 0, "back": 1} + + CAM_VALUE = 0.1 + CAMERA_VALUE_MAP = { + "camera_l": [0, -CAM_VALUE], + "camera_r": [0, CAM_VALUE], + } + + data = [] + + for action_name in actions_to_test: + keyboard_condition = [[0, 0] for _ in range(num_samples_per_action)] + mouse_condition = [[0, 0] for _ in range(num_samples_per_action)] + + for sub_act in base_action: + if sub_act not in action_name: # 只处理action_name包含的动作 + continue + # print(f"action name: {action_name} sub_act: {sub_act}") + if sub_act in CAMERA_VALUE_MAP: + mouse_condition = [CAMERA_VALUE_MAP[sub_act] for _ in range(num_samples_per_action)] + + elif sub_act in KEYBOARD_IDX: + col = KEYBOARD_IDX[sub_act] + for row in keyboard_condition: + row[col] = 1 + + data.append({"keyboard_condition": torch.tensor(keyboard_condition), "mouse_condition": torch.tensor(mouse_condition)}) + return combine_data(data, num_frames, keyboard_dim=2, mouse=True) + + +def Bench_actions_templerun(num_frames, num_samples_per_action=4): + actions_single_action = ["jump", "slide", "leftside", "rightside", "turnleft", "turnright", "nomove"] + + actions_to_test = actions_single_action + + base_action = actions_single_action + + KEYBOARD_IDX = {"nomove": 0, "jump": 1, "slide": 2, "turnleft": 3, "turnright": 4, "leftside": 5, "rightside": 6} + + data = [] + + for action_name in actions_to_test: + keyboard_condition = [[0, 0, 0, 0, 0, 0, 0] for _ in range(num_samples_per_action)] + + for sub_act in base_action: + if sub_act not in action_name: # 只处理action_name包含的动作 + continue + # print(f"action name: {action_name} sub_act: {sub_act}") + elif sub_act in KEYBOARD_IDX: + col = KEYBOARD_IDX[sub_act] + for row in keyboard_condition: + row[col] = 1 + + data.append({"keyboard_condition": torch.tensor(keyboard_condition)}) + return combine_data(data, num_frames, keyboard_dim=7, mouse=False) + + +class MatrixGame2_Bench: + def __init__(self): + self.deivce = torch.device("cuda") + self.weight_dtype = torch.bfloat16 + + def get_conditions(self, mode, num_frames): + conditional_dict = {} + if mode == "universal": + cond_data = Bench_actions_universal(num_frames) + mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype) + conditional_dict["mouse_cond"] = mouse_condition + elif mode == "gta_drive": + cond_data = Bench_actions_gta_drive(num_frames) + mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype) + conditional_dict["mouse_cond"] = mouse_condition + else: + cond_data = Bench_actions_templerun(num_frames) + keyboard_condition = cond_data["keyboard_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype) + conditional_dict["keyboard_cond"] = keyboard_condition + return conditional_dict diff --git a/lightx2v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py b/lightx2v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c66ecd6df8c0a4e8bf1f951cca5a181658cdcf --- /dev/null +++ b/lightx2v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py @@ -0,0 +1,75 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ["HuggingfaceTokenizer"] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class HuggingfaceTokenizer: + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, "whitespace", "lower", "canonicalize") + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop("return_mask", False) + + # arguments + _kwargs = {"return_tensors": "pt"} + if self.seq_len is not None: + _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == "whitespace": + text = whitespace_clean(basic_clean(text)) + elif self.clean == "lower": + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == "canonicalize": + text = canonicalize(basic_clean(text)) + return text diff --git a/lightx2v/models/input_encoders/hf/wan/t5/__init__.py b/lightx2v/models/input_encoders/hf/wan/t5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/wan/t5/model.py b/lightx2v/models/input_encoders/hf/wan/t5/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a3813878bdff773ab055a0b2e972bc0620274f9d --- /dev/null +++ b/lightx2v/models/input_encoders/hf/wan/t5/model.py @@ -0,0 +1,849 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import math +import os +import sys +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger + +current_dir = Path(__file__).resolve().parent +project_root = current_dir.parent.parent.parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList # noqa E402 +from lightx2v.common.offload.manager import WeightAsyncStreamManager # noqa E402 +from lightx2v.common.ops import * # noqa E402 +from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402 + Q8FQuantLinearFp8, # noqa E402 + Q8FQuantLinearInt8, # noqa E402 + SglQuantLinearFp8, # noqa E402 + TorchaoQuantLinearInt8, # noqa E402 + VllmQuantLinearInt8, # noqa E402, + VllmQuantLinearFp8, # noqa E402 +) +from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8 # noqa E402 +from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402 +from lightx2v.utils.envs import * # noqa E402 +from lightx2v.utils.registry_factory import ( # noqa E402 + EMBEDDING_WEIGHT_REGISTER, # noqa E402 + MM_WEIGHT_REGISTER, # noqa E402 + RMS_WEIGHT_REGISTER, # noqa E402 +) +from lightx2v.utils.utils import load_weights # noqa E402 +from lightx2v_platform.base.global_var import AI_DEVICE # noqa E402 + +__all__ = [ + "T5Model", + "T5Encoder", + "T5Decoder", + "T5EncoderModel", +] + + +class T5OffloadBlocksWeights(WeightModule): + def __init__(self, block_nums, mm_type): + super().__init__() + self.block_nums = block_nums + self.offload_block_buffers = WeightModuleList([T5OffloadSelfAttention(i, mm_type, create_cuda_buffer=True) for i in range(2)]) + self.blocks = WeightModuleList([T5OffloadSelfAttention(i, mm_type) for i in range(block_nums)]) + self.add_module("offload_block_buffers", self.offload_block_buffers) + self.add_module("blocks", self.blocks) + + +class T5OffloadSelfAttention(WeightModule): + def __init__(self, block_index, mm_type, block_prefix="blocks", create_cuda_buffer=False): + super().__init__() + self.block_index = block_index + if mm_type is None: + mm_type = "Default" + self.mm_type = mm_type + + self.add_module( + "norm1", + RMS_WEIGHT_REGISTER["sgl-kernel"](f"{block_prefix}.{self.block_index}.norm1.weight", create_cuda_buffer), + ) + self.add_module( + "norm2", + RMS_WEIGHT_REGISTER["sgl-kernel"](f"{block_prefix}.{self.block_index}.norm2.weight", create_cuda_buffer), + ) + self.add_module( + "pos_embedding", + EMBEDDING_WEIGHT_REGISTER["Default"](f"{block_prefix}.{self.block_index}.pos_embedding.embedding.weight", create_cuda_buffer), + ) + + self.compute_phases = WeightModuleList( + [ + T5OffloadAttention(block_index, block_prefix, mm_type, create_cuda_buffer), + T5OffloadFeedForward(block_index, block_prefix, mm_type, create_cuda_buffer), + ] + ) + self.add_module("compute_phases", self.compute_phases) + + +class T5OffloadAttention(WeightModule): + def __init__(self, block_index, block_prefix, mm_type, create_cuda_buffer=False): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.add_module( + "attn_q", + MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.q.weight", None, create_cuda_buffer), + ) + self.add_module( + "attn_k", + MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.k.weight", None, create_cuda_buffer), + ) + self.add_module( + "attn_v", + MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.v.weight", None, create_cuda_buffer), + ) + self.add_module( + "attn_o", + MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.o.weight", None, create_cuda_buffer), + ) + + +class T5OffloadFeedForward(WeightModule): + def __init__(self, block_index, block_prefix, mm_type, create_cuda_buffer=False): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + + self.add_module( + "ffn_fc1", + MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc1.weight", None, create_cuda_buffer), + ) + self.add_module( + "ffn_fc2", + MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc2.weight", None, create_cuda_buffer), + ) + self.add_module( + "ffn_gate_0", + MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.gate.0.weight", None, create_cuda_buffer), + ) + self.gelu = GELU() + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) + + +class GELU(nn.Module): + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-6, dtype=torch.float16): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=dtype)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + def __init__( + self, + dim, + dim_attn, + num_heads, + dropout=0.1, + quantized=False, + quant_scheme=None, + dtype=torch.bfloat16, + ): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + if quantized: + if quant_scheme in ["int8", "int8-vllm"]: + linear_cls = VllmQuantLinearInt8 + elif quant_scheme in ["fp8", "fp8-sgl"]: + linear_cls = SglQuantLinearFp8 + elif quant_scheme == "fp8-vllm": + linear_cls = VllmQuantLinearFp8 + elif quant_scheme == "int8-torchao": + linear_cls = TorchaoQuantLinearInt8 + elif quant_scheme == "int8-q8f": + linear_cls = Q8FQuantLinearInt8 + elif quant_scheme == "fp8-q8f": + linear_cls = Q8FQuantLinearFp8 + elif quant_scheme == "int8-tmo": + linear_cls = MluQuantLinearInt8 + else: + NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") + else: + linear_cls = nn.Linear + + # layers + self.q = linear_cls(dim, dim_attn, bias=False, dtype=dtype) + self.k = linear_cls(dim, dim_attn, bias=False, dtype=dtype) + self.v = linear_cls(dim, dim_attn, bias=False, dtype=dtype) + self.o = linear_cls(dim_attn, dim, bias=False, dtype=dtype) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + x = x.reshape(b, -1, n * c) + x = self.o(x) + + return x + + +class T5FeedForward(nn.Module): + def __init__( + self, + dim, + dim_ffn, + dropout=0.1, + quantized=False, + quant_scheme=None, + dtype=torch.bfloat16, + ): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + if quantized: + if quant_scheme in ["int8", "int8-vllm"]: + linear_cls = VllmQuantLinearInt8 + elif quant_scheme in ["fp8", "fp8-sgl"]: + linear_cls = SglQuantLinearFp8 + elif quant_scheme == "fp8-vllm": + linear_cls = VllmQuantLinearFp8 + elif quant_scheme == "int8-torchao": + linear_cls = TorchaoQuantLinearInt8 + elif quant_scheme == "int8-q8f": + linear_cls = Q8FQuantLinearInt8 + elif quant_scheme == "fp8-q8f": + linear_cls = Q8FQuantLinearFp8 + elif quant_scheme == "int8-tmo": + linear_cls = MluQuantLinearInt8 + else: + NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") + else: + linear_cls = nn.Linear + # layers + self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False, dtype=dtype), GELU()) + + self.fc1 = linear_cls(dim, dim_ffn, bias=False, dtype=dtype) + self.fc2 = linear_cls(dim_ffn, dim, bias=False, dtype=dtype) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + def __init__( + self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1, + quantized=False, + quant_scheme=None, + dtype=torch.bfloat16, + ): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim, dtype=dtype) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme, dtype) + self.norm2 = T5LayerNorm(dim, dtype=dtype) + self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme, dtype=dtype) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + + return x + + +class T5CrossAttention(nn.Module): + def __init__( + self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1, + ): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) + + def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + def __init__(self, num_buckets, num_heads, bidirectional, dtype=torch.bfloat16, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads, dtype=dtype) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + + rel_pos_embeds = self.embedding(rel_pos) + + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long() + rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + def __init__( + self, + dtype, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1, + cpu_offload=False, + quantized=False, + quant_scheme=None, + ): + super(T5Encoder, self).__init__() + self.cpu_offload = cpu_offload + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + self.quant_scheme = quant_scheme + + # layers + self.token_embedding = vocab.to(dtype) if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, dtype=dtype) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None + self.dropout = nn.Dropout(dropout) + + if cpu_offload: + self.offload_manager = WeightAsyncStreamManager(offload_granularity="block") + self.blocks_weights = T5OffloadBlocksWeights(num_layers, quant_scheme) + self.offload_manager.init_cuda_buffer(self.blocks_weights.offload_block_buffers, None) + self.blocks = self.blocks_weights.blocks + else: + self.blocks = nn.ModuleList( + [ + T5SelfAttention( + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos, + dropout, + quantized, + quant_scheme, + dtype, + ) + for _ in range(num_layers) + ] + ) + + self.norm = T5LayerNorm(dim, dtype=dtype) + + def forward_without_offload(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + + for i, block in enumerate(self.blocks): + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x.to(GET_DTYPE()) + + def forword_attn_with_offload(self, x, attn_phase, context=None, mask=None, pos_bias=None): + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.dim_attn // self.num_heads + # compute query, key, value + q = attn_phase.attn_q.apply(x.squeeze(0)).view(b, -1, n, c) + k = attn_phase.attn_k.apply(context.squeeze(0)).view(b, -1, n, c) + v = attn_phase.attn_v.apply(context.squeeze(0)).view(b, -1, n, c) + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + x = x.reshape(b, -1, n * c) + x = attn_phase.attn_o.apply(x.squeeze(0)).unsqueeze(0) + return x + + def forward_ffn_with_offload(self, x, ffn_phase): + x = x.squeeze(0) + x = ffn_phase.ffn_fc1.apply(x) * ffn_phase.gelu(ffn_phase.ffn_gate_0.apply(x)) + x = ffn_phase.ffn_fc2.apply(x) + return x.unsqueeze(0) + + def forward_block_with_offload(self, block, x, mask=None, pos_bias=None): + if self.shared_pos: + e = pos_bias + else: + lq, lk = x.size(1), x.size(1) + rel_pos = torch.arange(lk, device=AI_DEVICE).unsqueeze(0) - torch.arange(lq, device=AI_DEVICE).unsqueeze(1) + num_buckets = block.pos_embedding.weight.shape[0] // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(128 / max_exact) * (num_buckets - max_exact)).long() + rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + e = block.pos_embedding.apply(rel_buckets).permute(2, 0, 1).unsqueeze(0).contiguous() + + norm1_out = block.norm1.apply(x) + x = fp16_clamp(x + self.forword_attn_with_offload(norm1_out, block.compute_phases[0], mask=mask, pos_bias=e)) + norm2_out = block.norm2.apply(x) + x = fp16_clamp(x + self.forward_ffn_with_offload(norm2_out, block.compute_phases[1])) + return x + + def forward_with_offload(self, ids, mask=None): + self.token_embedding = self.token_embedding.to(AI_DEVICE) + self.pos_embedding = self.pos_embedding.to(AI_DEVICE) if self.pos_embedding is not None else None + + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + self.norm = self.norm.to(AI_DEVICE) + + for block_idx in range(len(self.blocks)): + self.block_idx = block_idx + self.offload_manager.cuda_buffers[0].load_state_dict( + self.blocks[block_idx].state_dict(), + block_idx, + ) + x = self.forward_block_with_offload(self.offload_manager.cuda_buffers[0], x, mask, pos_bias=e) + + x = self.norm(x) + x = self.dropout(x) + return x.to(GET_DTYPE()) + + def forward(self, ids, mask=None): + if self.cpu_offload: + return self.forward_with_offload(ids, mask) + else: + return self.forward_without_offload(ids, mask) + + +class T5Decoder(nn.Module): + def __init__( + self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1, + ): + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + def __init__( + self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1, + ): + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder( + self.token_embedding, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + num_buckets, + shared_pos, + dropout, + ) + self.decoder = T5Decoder( + self.token_embedding, + dim, + dim_attn, + dim_ffn, + num_heads, + decoder_layers, + num_buckets, + shared_pos, + dropout, + ) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5( + name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device="cpu", + **kwargs, +): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("encoder_layers") + _ = kwargs.pop("decoder_layers") + elif decoder_only: + model_cls = T5Decoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("decoder_layers") + _ = kwargs.pop("encoder_layers") + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(dtype=dtype, **kwargs) + + # set device + model = model.to(device=device) + return model + + +def split_block_weights(weights): + block_weights = {} + all_keys = list(weights.keys()) + for key in all_keys: + if key.startswith(("blocks.")): + block_weights[key] = weights.pop(key) + return block_weights + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1, + ) + cfg.update(**kwargs) + return _t5("umt5-xxl", **cfg) + + +class T5EncoderModel: + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.device("cuda"), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + cpu_offload=False, + t5_quantized=False, + t5_quantized_ckpt=None, + quant_scheme=None, + load_from_rank0=False, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + if t5_quantized_ckpt is not None and t5_quantized: + self.checkpoint_path = t5_quantized_ckpt + else: + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # sync cpu offload + self.cpu_offload = cpu_offload + + model = ( + umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device, + cpu_offload=cpu_offload, + quantized=t5_quantized, + quant_scheme=quant_scheme, + ) + .eval() + .requires_grad_(False) + ) + + weights_dict = load_weights( + self.checkpoint_path, + cpu_offload=cpu_offload, + load_from_rank0=load_from_rank0, + ) + + if cpu_offload: + block_weights_dict = split_block_weights(weights_dict) + model.blocks_weights.load(block_weights_dict) + del block_weights_dict + gc.collect() + + model.load_state_dict(weights_dict) + del weights_dict + gc.collect() + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") + + def infer(self, texts): + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.to(AI_DEVICE) + mask = mask.to(AI_DEVICE) + seq_lens = mask.gt(0).sum(dim=1).long() + + with torch.no_grad(): + context = self.model(ids, mask) + + return [u[:v] for u, v in zip(context, seq_lens)] + + +if __name__ == "__main__": + import time + + checkpoint_dir = "" + t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" + t5_tokenizer = "google/umt5-xxl" + + cpu_offload = False + if cpu_offload: + device = torch.device("cpu") + else: + device = torch.device("cuda") + + model = T5EncoderModel( + text_len=512, + dtype=torch.bfloat16, + device=device, + checkpoint_path=os.path.join(checkpoint_dir, t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, t5_tokenizer), + shard_fn=None, + cpu_offload=cpu_offload, + ) + text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + + torch.cuda.synchronize() + s_t = time.time() + outputs = model.infer(text) + + torch.cuda.synchronize() + e_t = time.time() + + logger.info(e_t - s_t) + logger.info(outputs) diff --git a/lightx2v/models/input_encoders/hf/wan/t5/tokenizer.py b/lightx2v/models/input_encoders/hf/wan/t5/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fe937821b4a09505998268cc42723942213007f7 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/wan/t5/tokenizer.py @@ -0,0 +1,81 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ["HuggingfaceTokenizer"] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class HuggingfaceTokenizer: + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, "whitespace", "lower", "canonicalize") + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop("return_mask", False) + + # arguments + _kwargs = {"return_tensors": "pt"} + if self.seq_len is not None: + _kwargs.update( + { + "padding": "max_length", + "truncation": True, + "max_length": self.seq_len, + } + ) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == "whitespace": + text = whitespace_clean(basic_clean(text)) + elif self.clean == "lower": + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == "canonicalize": + text = canonicalize(basic_clean(text)) + return text diff --git a/lightx2v/models/input_encoders/hf/wan/xlm_roberta/__init__.py b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef108635940090264c8547f5696f912998e6462 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py @@ -0,0 +1,473 @@ +# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from loguru import logger + +# from lightx2v.attentions import attention +from lightx2v.common.ops.attn import TorchSDPAWeight +from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8 +from lightx2v.utils.utils import load_weights +from lightx2v_platform.base.global_var import AI_DEVICE +from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8 + +__all__ = [ + "XLMRobertaCLIP", + "clip_xlm_roberta_vit_h_14", + "CLIPModel", +] + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat( + [ + pos[:, :n], + F.interpolate(pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2), size=(tar_grid, tar_grid), mode="bicubic", align_corners=False).flatten(2).transpose(1, 2), + ], + dim=1, + ) + + +class QuickGELU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None, dtype=None): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + if quantized: + if quant_scheme in ["int8", "int8-vllm"]: + linear_cls = VllmQuantLinearInt8 + elif quant_scheme in ["fp8", "fp8-sgl"]: + linear_cls = SglQuantLinearFp8 + elif quant_scheme == "fp8-vllm": + linear_cls = VllmQuantLinearFp8 + elif quant_scheme == "int8-torchao": + linear_cls = TorchaoQuantLinearInt8 + elif quant_scheme == "int8-q8f": + linear_cls = Q8FQuantLinearInt8 + elif quant_scheme == "fp8-q8f": + linear_cls = Q8FQuantLinearFp8 + elif quant_scheme == "int8-tmo": + linear_cls = MluQuantLinearInt8 + else: + NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}") + else: + linear_cls = nn.Linear + + self.to_qkv = linear_cls(dim, dim * 3, dtype=dtype) + self.proj = linear_cls(dim, dim, dtype=dtype) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + + # compute attention + x = TorchSDPAWeight().apply(q=q, k=k, v=v) + x = x.reshape(b, s, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation="quick_gelu", + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5, + quantized=False, + quant_scheme=None, + dtype=torch.float16, + ): + assert activation in ["quick_gelu", "gelu", "swi_glu"] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + if quantized: + if quant_scheme in ["int8", "int8-vllm"]: + linear_cls = VllmQuantLinearInt8 + elif quant_scheme in ["fp8", "fp8-sgl"]: + linear_cls = SglQuantLinearFp8 + elif quant_scheme == "fp8-vllm": + linear_cls = VllmQuantLinearFp8 + elif quant_scheme == "int8-torchao": + linear_cls = TorchaoQuantLinearInt8 + elif quant_scheme == "int8-q8f": + linear_cls = Q8FQuantLinearInt8 + elif quant_scheme == "fp8-q8f": + linear_cls = Q8FQuantLinearFp8 + elif quant_scheme == "int8-tmo": + linear_cls = MluQuantLinearInt8 + else: + NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") + else: + linear_cls = nn.Linear + + self.norm1 = LayerNorm(dim, eps=norm_eps, dtype=dtype) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme, dtype) + self.norm2 = LayerNorm(dim, eps=norm_eps, dtype=dtype) + if activation == "swi_glu": + self.mlp = SwiGLU(dim, int(dim * mlp_ratio), dtype=dtype) + else: + self.mlp = nn.Sequential( + linear_cls(dim, int(dim * mlp_ratio), dtype=dtype), + QuickGELU() if activation == "quick_gelu" else nn.GELU(), + linear_cls(int(dim * mlp_ratio), dim, dtype=dtype), + nn.Dropout(proj_dropout), + ) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5, dtype=torch.float16): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim, dtype=dtype) + self.to_kv = nn.Linear(dim, dim * 2, dtype=dtype) + self.proj = nn.Linear(dim, dim, dtype=dtype) + self.norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio), dtype=dtype), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim, dtype=dtype), nn.Dropout(proj_dropout) + ) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) + k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + + # compute attention + x = attention(q=q, k=k, v=v, attention_type="torch_sdpa") + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + def __init__( + self, + dtype=torch.float16, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type="token", + pre_norm=True, + post_norm=False, + activation="quick_gelu", + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5, + quantized=False, + quant_scheme=None, + ): + if image_size % patch_size != 0: + logger.info("[WARNING] image_size is not divisible by patch_size", flush=True) + assert pool_type in ("token", "token_fc", "attn_pool") + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size) ** 2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm, dtype=dtype) + if pool_type in ("token", "token_fc"): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim, dtype=dtype)) + self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim, dtype=dtype)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) if pre_norm else None + self.transformer = nn.Sequential( + *[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme, dtype) for _ in range(num_layers)] + ) + self.post_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) + + # head + if pool_type == "token": + self.head = nn.Parameter(gain * torch.randn(dim, out_dim, dtype=dtype)) + elif pool_type == "token_fc": + self.head = nn.Linear(dim, out_dim, dtype=dtype) + elif pool_type == "attn_pool": + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps, dtype=dtype) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x.type(self.patch_embedding.weight.type())).flatten(2).permute(0, 2, 1) + if self.pool_type in ("token", "token_fc"): + x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class XLMRobertaCLIP(nn.Module): + def __init__( + self, + dtype=torch.float16, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + vision_pre_norm=True, + vision_post_norm=False, + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5, + quantized=False, + quant_scheme=None, + ): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + dtype=dtype, + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps, + quantized=quantized, + quant_scheme=quant_scheme, + ) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + +def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs): + # init a model on device + with torch.device(device): + model = model_cls(dtype=dtype, **kwargs) + + model = model.to(device=device) + + output = (model,) + # init transforms + if return_transforms: + # mean and std + if "siglip" in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std)]) + output += (transforms,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + ) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class CLIPModel: + def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False): + self.dtype = dtype + self.quantized = clip_quantized + self.cpu_offload = cpu_offload + self.use_31_block = use_31_block + + if self.quantized: + self.checkpoint_path = clip_quantized_ckpt + else: + self.checkpoint_path = checkpoint_path + + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme + ) + self.model = self.model.eval().requires_grad_(False) + weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual", load_from_rank0=load_from_rank0) + self.model.load_state_dict(weight_dict) + + def visual(self, videos): + if self.cpu_offload: + self.to_cuda() + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([F.interpolate(u, size=size, mode="bicubic", align_corners=False) for u in videos]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + # forward + with torch.amp.autocast("cuda", dtype=self.dtype): + out = self.model.visual(videos, use_31_block=self.use_31_block) + + if self.cpu_offload: + self.to_cpu() + return out + + def to_cuda(self): + self.model = self.model.to(AI_DEVICE) + + def to_cpu(self): + self.model = self.model.cpu() diff --git a/lightx2v/models/networks/__init__.py b/lightx2v/models/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/hunyuan_video/__init__.py b/lightx2v/models/networks/hunyuan_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py b/lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py new file mode 100644 index 0000000000000000000000000000000000000000..c5409f66aa079262ca28943bb126957ad04fc0fc --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py @@ -0,0 +1,129 @@ +import torch +from einops import rearrange +from loguru import logger + +try: + from flash_attn import flash_attn_varlen_qkvpacked_func +except ImportError: + flash_attn_varlen_qkvpacked_func = None + logger.info("flash_attn_varlen_qkvpacked_func not available") +try: + from flash_attn.bert_padding import pad_input, unpad_input +except ImportError: + pad_input = None + unpad_input = None + logger.info("flash_attn.bert_padding not available") + +try: + from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 +except ImportError: + flash_attn_varlen_func_v3 = None + logger.info("flash_attn_varlen_func_v3 not available") + +if torch.cuda.is_available() and torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]: + try: + from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn + except ImportError: + logger.info("sageattn not found, please install sageattention first") + sageattn = None +else: + try: + from sageattention import sageattn + except ImportError: + logger.info("sageattn not found, please install sageattention first") + sageattn = None + +try: + from sageattn3 import sageattn3_blackwell +except ImportError: + logger.info("sageattn3 not found, please install sageattention first") + sageattn3_blackwell = None + + +def flash_attn_no_pad(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False): + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input(x, key_padding_mask) + + x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, + cu_seqlens, + max_s, + dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + ) + output = rearrange( + pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), + "b s (h d) -> b s h d", + h=nheads, + ) + return output + + +def flash_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False): + if flash_attn_varlen_func_v3 is None: + raise ImportError("FlashAttention V3 backend not available") + + batch_size, seqlen, _, nheads, head_dim = qkv.shape + query, key, value = qkv.unbind(dim=2) + + query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) + key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) + value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) + + query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads) + key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads) + value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads) + + output_unpad = flash_attn_varlen_func_v3( + query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_q, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic + ) + + output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads) + return output + + +def sage_attn_no_pad_v2(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False): + batch_size, seqlen, _, nheads, head_dim = qkv.shape + query, key, value = qkv.unbind(dim=2) + + query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) + key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) + value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) + + query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads) + key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads) + value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads) + + output_unpad = sageattn( + query_unpad.unsqueeze(0), + key_unpad.unsqueeze(0), + value_unpad.unsqueeze(0), + tensor_layout="NHD", + ).squeeze(0) + + output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads) + return output + + +def sage_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False): + batch_size, seqlen, _, nheads, head_dim = qkv.shape + query, key, value = qkv.unbind(dim=2) + + query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) + key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) + value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) + + query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads) + key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads) + value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads) + + output_unpad = sageattn3_blackwell(query_unpad.unsqueeze(0).transpose(1, 2), key_unpad.unsqueeze(0).transpose(1, 2), value_unpad.unsqueeze(0).transpose(1, 2)).transpose(1, 2).squeeze(0) + + output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads) + return output diff --git a/lightx2v/models/networks/hunyuan_video/infer/feature_caching/__init__.py b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..56041ec8b1f3ba3f1c91ff59a1f3e9e0d9fb5565 --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py @@ -0,0 +1,229 @@ +import gc +import json + +import numpy as np +import torch +import torch.nn.functional as F + +from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer +from lightx2v_platform.base.global_var import AI_DEVICE + + +class HunyuanVideo15TransformerInferMagCaching(HunyuanVideo15OffloadTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.magcache_thresh = config.get("magcache_thresh", 0.2) + self.K = config.get("magcache_K", 6) + self.retention_ratio = config.get("magcache_retention_ratio", 0.2) + self.mag_ratios = np.array(config.get("magcache_ratios", [])) + self.enable_magcache_calibration = config.get("magcache_calibration", True) + # {True: cond_param, False: uncond_param} + self.accumulated_err = {True: 0.0, False: 0.0} + self.accumulated_steps = {True: 0, False: 0} + self.accumulated_ratio = {True: 1.0, False: 1.0} + self.residual_cache = {True: None, False: None} + self.residual_cache_txt = {True: None, False: None} + # calibration args + self.norm_ratio = [[1.0], [1.0]] # mean of magnitude ratio + self.norm_std = [[0.0], [0.0]] # std of magnitude ratio + self.cos_dis = [[0.0], [0.0]] # cosine distance of residual features + + @torch.no_grad() + def infer(self, weights, infer_module_out): + skip_forward = False + step_index = self.scheduler.step_index + infer_condition = self.scheduler.infer_condition + + if self.enable_magcache_calibration: + skip_forward = False + else: + if step_index >= int(self.config["infer_steps"] * self.retention_ratio): + # conditional and unconditional in one list + cur_mag_ratio = self.mag_ratios[0][step_index] if infer_condition else self.mag_ratios[1][step_index] + # magnitude ratio between current step and the cached step + self.accumulated_ratio[infer_condition] = self.accumulated_ratio[infer_condition] * cur_mag_ratio + self.accumulated_steps[infer_condition] += 1 # skip steps plus 1 + # skip error of current steps + cur_skip_err = np.abs(1 - self.accumulated_ratio[infer_condition]) + # accumulated error of multiple steps + self.accumulated_err[infer_condition] += cur_skip_err + + if self.accumulated_err[infer_condition] < self.magcache_thresh and self.accumulated_steps[infer_condition] <= self.K: + skip_forward = True + else: + self.accumulated_err[infer_condition] = 0 + self.accumulated_steps[infer_condition] = 0 + self.accumulated_ratio[infer_condition] = 1.0 + + if not skip_forward: + self.infer_calculating(weights, infer_module_out) + else: + self.infer_using_cache(infer_module_out) + + x = self.infer_final_layer(weights, infer_module_out) + + return x + + def infer_calculating(self, weights, infer_module_out): + step_index = self.scheduler.step_index + infer_condition = self.scheduler.infer_condition + + ori_img = infer_module_out.img.clone() + ori_txt = infer_module_out.txt.clone() + self.infer_func(weights, infer_module_out) + + previous_residual = infer_module_out.img - ori_img + previous_residual_txt = infer_module_out.txt - ori_txt + + if self.config["cpu_offload"]: + previous_residual = previous_residual.cpu() + previous_residual_txt = previous_residual_txt.cpu() + + if self.enable_magcache_calibration and step_index >= 1: + norm_ratio = ((previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).mean()).item() + norm_std = (previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).std().item() + cos_dis = (1 - F.cosine_similarity(previous_residual, self.residual_cache[infer_condition], dim=-1, eps=1e-8)).mean().item() + _index = int(not infer_condition) + self.norm_ratio[_index].append(round(norm_ratio, 5)) + self.norm_std[_index].append(round(norm_std, 5)) + self.cos_dis[_index].append(round(cos_dis, 5)) + print(f"time: {step_index}, infer_condition: {infer_condition}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}") + + self.residual_cache[infer_condition] = previous_residual + self.residual_cache_txt[infer_condition] = previous_residual_txt + + if self.config["cpu_offload"]: + ori_img = ori_img.to("cpu") + ori_txt = ori_txt.to("cpu") + del ori_img, ori_txt + torch.cuda.empty_cache() + gc.collect() + + def infer_using_cache(self, infer_module_out): + residual_img = self.residual_cache[self.scheduler.infer_condition] + residual_txt = self.residual_cache_txt[self.scheduler.infer_condition] + infer_module_out.img.add_(residual_img.to(AI_DEVICE)) + infer_module_out.txt.add_(residual_txt.to(AI_DEVICE)) + + def clear(self): + self.accumulated_err = {True: 0.0, False: 0.0} + self.accumulated_steps = {True: 0, False: 0} + self.accumulated_ratio = {True: 1.0, False: 1.0} + self.residual_cache = {True: None, False: None} + self.residual_cache_txt = {True: None, False: None} + if self.enable_magcache_calibration: + print("norm ratio") + print(self.norm_ratio) + print("norm std") + print(self.norm_std) + print("cos_dis") + print(self.cos_dis) + + def save_json(filename, obj_list): + with open(filename + ".json", "w") as f: + json.dump(obj_list, f) + + save_json("mag_ratio", self.norm_ratio) + save_json("mag_std", self.norm_std) + save_json("cos_dis", self.cos_dis) + + +class HunyuanTransformerInferTeaCaching(HunyuanVideo15OffloadTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.teacache_thresh = self.config["teacache_thresh"] + self.coefficients = self.config["coefficients"] + + self.accumulated_rel_l1_distance_odd = 0 + self.previous_modulated_input_odd = None + self.previous_residual_odd = None + + self.accumulated_rel_l1_distance_even = 0 + self.previous_modulated_input_even = None + self.previous_residual_even = None + + def calculate_should_calc(self, img, vec, block): + inp = img.clone() + vec_ = vec.clone() + img_mod_layer = block.img_branch.img_mod + if self.config["cpu_offload"]: + img_mod_layer.to_cuda() + + img_mod1_shift, img_mod1_scale, _, _, _, _ = img_mod_layer.apply(vec_).chunk(6, dim=-1) + inp = inp.squeeze(0) + normed_inp = torch.nn.functional.layer_norm(inp, (inp.shape[1],), None, None, 1e-6) + modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift + + del normed_inp, inp, vec_ + + if self.scheduler.step_index == 0 or self.scheduler.step_index == self.scheduler.infer_steps - 1: + should_calc = True + if self.scheduler.infer_condition: + self.accumulated_rel_l1_distance_odd = 0 + self.previous_modulated_input_odd = modulated_inp + else: + self.accumulated_rel_l1_distance_even = 0 + self.previous_modulated_input_even = modulated_inp + else: + rescale_func = np.poly1d(self.coefficients) + if self.scheduler.infer_condition: + self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_modulated_input_odd).abs().mean() / self.previous_modulated_input_odd.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance_odd = 0 + self.previous_modulated_input_odd = modulated_inp + else: + self.accumulated_rel_l1_distance_even += rescale_func( + ((modulated_inp - self.previous_modulated_input_even).abs().mean() / self.previous_modulated_input_even.abs().mean()).cpu().item() + ) + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance_even = 0 + + self.previous_modulated_input_even = modulated_inp + del modulated_inp + + return should_calc + + def infer(self, weights, infer_module_out): + should_calc = self.calculate_should_calc(infer_module_out.img, infer_module_out.vec, weights.double_blocks[0]) + if not should_calc: + if self.scheduler.infer_condition: + infer_module_out.img += self.previous_residual_odd + else: + infer_module_out.img += self.previous_residual_even + else: + ori_img = infer_module_out.img.clone() + + self.infer_func(weights, infer_module_out) + + if self.scheduler.infer_condition: + self.previous_residual_odd = infer_module_out.img - ori_img + else: + self.previous_residual_even = infer_module_out.img - ori_img + + x = self.infer_final_layer(weights, infer_module_out) + return x + + def clear(self): + if self.previous_residual_odd is not None: + self.previous_residual_odd = self.previous_residual_odd.cpu() + + if self.previous_modulated_input_odd is not None: + self.previous_modulated_input_odd = self.previous_modulated_input_odd.cpu() + + if self.previous_residual_even is not None: + self.previous_residual_even = self.previous_residual_even.cpu() + + if self.previous_modulated_input_even is not None: + self.previous_modulated_input_even = self.previous_modulated_input_even.cpu() + + self.previous_modulated_input_odd = None + self.previous_residual_odd = None + self.previous_modulated_input_even = None + self.previous_residual_even = None + torch.cuda.empty_cache() diff --git a/lightx2v/models/networks/hunyuan_video/infer/module_io.py b/lightx2v/models/networks/hunyuan_video/infer/module_io.py new file mode 100644 index 0000000000000000000000000000000000000000..199922e76f0d137f0f68a40533fdf4091b3cc5e6 --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/infer/module_io.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass + +import torch + + +@dataclass +class HunyuanVideo15InferModuleOutput: + img: torch.Tensor + txt: torch.Tensor + vec: torch.Tensor + grid_sizes: tuple + + +@dataclass +class HunyuanVideo15ImgBranchOutput: + img_mod1_gate: torch.Tensor + img_mod2_shift: torch.Tensor + img_mod2_scale: torch.Tensor + img_mod2_gate: torch.Tensor + + +@dataclass +class HunyuanVideo15TxtBranchOutput: + txt_mod1_gate: torch.Tensor + txt_mod2_shift: torch.Tensor + txt_mod2_scale: torch.Tensor + txt_mod2_gate: torch.Tensor diff --git a/lightx2v/models/networks/hunyuan_video/infer/offload/__init__.py b/lightx2v/models/networks/hunyuan_video/infer/offload/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/hunyuan_video/infer/offload/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/offload/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..4700e30b0523e0d701d05f71f0b5ca3139f6ab89 --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/infer/offload/transformer_infer.py @@ -0,0 +1,34 @@ +import torch + +from lightx2v.common.offload.manager import WeightAsyncStreamManager +from lightx2v.models.networks.hunyuan_video.infer.transformer_infer import HunyuanVideo15TransformerInfer +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +class HunyuanVideo15OffloadTransformerInfer(HunyuanVideo15TransformerInfer): + def __init__(self, config): + super().__init__(config) + if self.config.get("cpu_offload", False): + offload_granularity = self.config.get("offload_granularity", "block") + if offload_granularity == "block": + self.infer_func = self.infer_with_blocks_offload + elif offload_granularity == "model": + self.infer_func = self.infer_without_offload + else: + raise NotImplementedError + if offload_granularity != "model": + self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity) + + @torch.no_grad() + def infer_with_blocks_offload(self, weights, infer_module_out): + for block_idx in range(self.double_blocks_num): + self.block_idx = block_idx + if block_idx == 0: + self.offload_manager.init_first_buffer(weights.double_blocks) + if block_idx < self.double_blocks_num - 1: + self.offload_manager.prefetch_weights(block_idx + 1, weights.double_blocks) + with torch_device_module.stream(self.offload_manager.compute_stream): + infer_module_out.img, infer_module_out.txt = self.infer_double_block(self.offload_manager.cuda_buffers[0], infer_module_out) + self.offload_manager.swap_blocks() diff --git a/lightx2v/models/networks/hunyuan_video/infer/post_infer.py b/lightx2v/models/networks/hunyuan_video/infer/post_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..a34545f5aa2375bbc608351807a07daf9af299f3 --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/infer/post_infer.py @@ -0,0 +1,39 @@ +import torch + +from lightx2v.utils.envs import * + + +class HunyuanVideo15PostInfer: + def __init__(self, config): + self.config = config + self.unpatchify_channels = config["out_channels"] + self.patch_size = config["patch_size"] # (1, 1, 1) + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + @torch.no_grad() + def infer(self, x, pre_infer_out): + x = self.unpatchify(x, pre_infer_out.grid_sizes[0], pre_infer_out.grid_sizes[1], pre_infer_out.grid_sizes[2]) + return x + + def unpatchify(self, x, t, h, w): + """ + Unpatchify a tensorized input back to frame format. + + Args: + x (Tensor): Input tensor of shape (N, T, patch_size**2 * C) + t (int): Number of time steps + h (int): Height in patch units + w (int): Width in patch units + + Returns: + Tensor: Output tensor of shape (N, C, t * pt, h * ph, w * pw) + """ + c = self.unpatchify_channels + pt, ph, pw = self.patch_size + x = x[:, : t * h * w] # remove padding from seq parallel + x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) + x = torch.einsum("nthwcopq->nctohpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs diff --git a/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py b/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..feb62f5c4663489708a562728a265d53c33109bd --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py @@ -0,0 +1,240 @@ +import math +from typing import Optional + +import torch +from einops import rearrange + +from lightx2v.utils.envs import * +from lightx2v_platform.base.global_var import AI_DEVICE + +from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, sage_attn_no_pad_v2 +from .module_io import HunyuanVideo15InferModuleOutput + + +def apply_gate(x, gate=None, tanh=False): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +@torch.compiler.disable +def attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, drop_rate: float = 0.0, attn_mask: Optional[torch.Tensor] = None, causal: bool = False, attn_type: str = "flash_attn2" +) -> torch.Tensor: + """ + Compute attention using flash_attn_no_pad. + + Args: + q: Query tensor of shape [B, L, H, D] + k: Key tensor of shape [B, L, H, D] + v: Value tensor of shape [B, L, H, D] + drop_rate: Dropout rate for attention weights. + attn_mask: Optional attention mask of shape [B, L]. + causal: Whether to apply causal masking. + + Returns: + Output tensor after attention of shape [B, L, H*D] + """ + qkv = torch.stack([q, k, v], dim=2) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.bool() + if attn_type == "flash_attn2": + x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None) + elif attn_type == "flash_attn3": + x = flash_attn_no_pad_v3(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None) + elif attn_type == "sage_attn2": + x = sage_attn_no_pad_v2(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class HunyuanVideo15PreInfer: + def __init__(self, config): + self.config = config + self.patch_size = config["patch_size"] + self.heads_num = config["heads_num"] + self.frequency_embedding_size = 256 + self.max_period = 10000 + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + @torch.no_grad() + def infer(self, weights, inputs): + latents = self.scheduler.latents + grid_sizes_t, grid_sizes_h, grid_sizes_w = latents.shape[2:] + + timesteps = self.scheduler.timesteps + t = timesteps[self.scheduler.step_index] + + if self.scheduler.infer_condition: + txt, text_mask = inputs["text_encoder_output"]["context"][0], inputs["text_encoder_output"]["context"][1] + else: + txt, text_mask = inputs["text_encoder_output"]["context_null"][0], inputs["text_encoder_output"]["context_null"][1] + + byt5_txt, byt5_text_mask = inputs["text_encoder_output"]["byt5_features"], inputs["text_encoder_output"]["byt5_masks"] + siglip_output, siglip_mask = inputs["image_encoder_output"]["siglip_output"], inputs["image_encoder_output"]["siglip_mask"] + txt = txt.to(torch.bfloat16) + + if self.config["is_sr_running"]: + if t < 1000 * self.scheduler.noise_scale: + condition = self.scheduler.zero_condition + else: + condition = self.scheduler.condition + + img = x = latent_model_input = torch.concat([latents, condition], dim=1) + else: + cond_latents_concat = self.scheduler.cond_latents_concat + mask_concat = self.scheduler.mask_concat + img = x = latent_model_input = torch.concat([latents, cond_latents_concat, mask_concat], dim=1) + + img = img.to(torch.bfloat16) + + t_expand = t.repeat(latent_model_input.shape[0]) + guidance_expand = None + + img = weights.img_in.apply(img) + img = img.flatten(2).transpose(1, 2) + + t_freq = self.timestep_embedding(t_expand, self.frequency_embedding_size, self.max_period).to(torch.bfloat16) + vec = weights.time_in_0.apply(t_freq) + vec = torch.nn.functional.silu(vec) + vec = weights.time_in_2.apply(vec) + + if self.config["is_sr_running"]: + use_meanflow = self.config.get("video_super_resolution", {}).get("use_meanflow", False) + if use_meanflow: + if self.scheduler.step_index == len(timesteps) - 1: + timesteps_r = torch.tensor([0.0], device=latent_model_input.device) + else: + timesteps_r = timesteps[self.scheduler.step_index + 1] + timesteps_r = timesteps_r.repeat(latent_model_input.shape[0]) + else: + timesteps_r = None + + if timesteps_r is not None: + t_freq = self.timestep_embedding(timesteps_r, self.frequency_embedding_size, self.max_period).to(torch.bfloat16) + vec_res = weights.time_r_in_0.apply(t_freq) + vec_res = torch.nn.functional.silu(vec_res) + vec_res = weights.time_r_in_2.apply(vec_res) + vec = vec + vec_res + + t_freq = self.timestep_embedding(t_expand, self.frequency_embedding_size, self.max_period).to(torch.bfloat16) + timestep_aware_representations = weights.txt_in_t_embedder_0.apply(t_freq) + timestep_aware_representations = torch.nn.functional.silu(timestep_aware_representations) + timestep_aware_representations = weights.txt_in_t_embedder_2.apply(timestep_aware_representations) + + mask_float = text_mask.float().unsqueeze(-1) + context_aware_representations = (txt * mask_float).sum(dim=1) / mask_float.sum(dim=1) + context_aware_representations = context_aware_representations.to(torch.bfloat16) + context_aware_representations = weights.txt_in_c_embedder_0.apply(context_aware_representations) + context_aware_representations = torch.nn.functional.silu(context_aware_representations) + context_aware_representations = weights.txt_in_c_embedder_2.apply(context_aware_representations) + + c = timestep_aware_representations + context_aware_representations + out = weights.txt_in_input_embedder.apply(txt[0].to(torch.bfloat16)) + txt = self.run_individual_token_refiner(weights, out, text_mask, c) + + # TODO: 可以删除这段计算 + txt = txt.unsqueeze(0) + txt = txt + weights.cond_type_embedding.apply(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long)) + byt5_txt = byt5_txt + weights.cond_type_embedding.apply(torch.ones_like(byt5_txt[:, :, 0], device=byt5_txt.device, dtype=torch.long)) + txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=True) + + siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=AI_DEVICE)) + txt, text_mask = self.reorder_txt_token(siglip_output, txt, siglip_mask, text_mask) + txt = txt[:, : text_mask.sum(), :] + + return HunyuanVideo15InferModuleOutput( + img=img.contiguous(), + txt=txt.contiguous(), + vec=torch.nn.functional.silu(vec), + grid_sizes=(grid_sizes_t, grid_sizes_h, grid_sizes_w), + ) + + def run_individual_token_refiner(self, weights, out, mask, c): + mask = mask.clone().bool() + mask[:, 0] = True # Prevent attention weights from becoming NaN + for block in weights.individual_token_refiner: # block num = 2 + gate_msa, gate_mlp = self.adaLN_modulation(block, c) + norm_x = block.norm1.apply(out.unsqueeze(0)).squeeze(0) + qkv = block.self_attn_qkv.apply(norm_x).unsqueeze(0) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + attn = attention(q, k, v, attn_mask=mask, attn_type="flash_attn2").squeeze(0) + out = out + apply_gate(block.self_attn_proj.apply(attn).unsqueeze(0), gate_msa).squeeze(0) + tmp = block.mlp_fc1.apply(block.norm2.apply(out)) + tmp = torch.nn.functional.silu(tmp) + tmp = block.mlp_fc2.apply(tmp) + out = out + apply_gate(tmp.unsqueeze(0), gate_mlp).squeeze(0) + return out + + def adaLN_modulation(self, weights, c): + c = torch.nn.functional.silu(c) + gate_msa, gate_mlp = weights.adaLN_modulation.apply(c).chunk(2, dim=1) + return gate_msa, gate_mlp + + def timestep_embedding(self, t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def reorder_txt_token(self, byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=False, is_reorder=True): + if is_reorder: + reorder_txt = [] + reorder_mask = [] + for i in range(text_mask.shape[0]): + byt5_text_mask_i = byt5_text_mask[i].bool() + text_mask_i = text_mask[i].bool() + + byt5_txt_i = byt5_txt[i] + txt_i = txt[i] + if zero_feat: + # When using block mask with approximate computation, set pad to zero to reduce error + pad_byt5 = torch.zeros_like(byt5_txt_i[~byt5_text_mask_i]) + pad_text = torch.zeros_like(txt_i[~text_mask_i]) + reorder_txt_i = torch.cat([byt5_txt_i[byt5_text_mask_i], txt_i[text_mask_i], pad_byt5, pad_text], dim=0) + else: + reorder_txt_i = torch.cat([byt5_txt_i[byt5_text_mask_i], txt_i[text_mask_i], byt5_txt_i[~byt5_text_mask_i], txt_i[~text_mask_i]], dim=0) + reorder_mask_i = torch.cat([byt5_text_mask_i[byt5_text_mask_i], text_mask_i[text_mask_i], byt5_text_mask_i[~byt5_text_mask_i], text_mask_i[~text_mask_i]], dim=0) + + reorder_txt.append(reorder_txt_i) + reorder_mask.append(reorder_mask_i) + + reorder_txt = torch.stack(reorder_txt) + reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64) + else: + reorder_txt = torch.concat([byt5_txt, txt], dim=1) + reorder_mask = torch.concat([byt5_text_mask, text_mask], dim=1).to(dtype=torch.int64) + + return reorder_txt, reorder_mask diff --git a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..869f99d8b310abdb8b11ab5248d19f6ba9cdc1a5 --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py @@ -0,0 +1,265 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange + +try: + from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace +except Exception as e: + apply_rope_with_cos_sin_cache_inplace = None + +from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer +from lightx2v_platform.base.global_var import AI_DEVICE + +from .module_io import HunyuanVideo15ImgBranchOutput, HunyuanVideo15TxtBranchOutput +from .triton_ops import fuse_scale_shift_kernel + + +def modulate(x, scale, shift): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +def apply_hunyuan_rope_with_flashinfer( + xq: torch.Tensor, + xk: torch.Tensor, + cos_sin_cache: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, L, H, D = xq.shape + + query = xq.reshape(B * L, H * D).contiguous() + key = xk.reshape(B * L, H * D).contiguous() + + positions = torch.arange(B * L, device=xq.device, dtype=torch.long) + + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=D, + cos_sin_cache=cos_sin_cache, + is_neox=False, + ) + + xq_out = query.view(B, L, H, D) + xk_out = key.view(B, L, H, D) + return xq_out, xk_out + + +def apply_hunyuan_rope_with_torch( + xq: torch.Tensor, + xk: torch.Tensor, + cos_sin_cache: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, L, H, D = xq.shape + + cos = cos_sin_cache[:, : D // 2] + sin = cos_sin_cache[:, D // 2 :] + + def _apply_rope(x: torch.Tensor) -> torch.Tensor: + x_flat = x.view(B * L, H, D) + x1 = x_flat[..., ::2] + x2 = x_flat[..., 1::2] + + cos_ = cos.unsqueeze(1) + sin_ = sin.unsqueeze(1) + + o1 = x1.float() * cos_ - x2.float() * sin_ + o2 = x2.float() * cos_ + x1.float() * sin_ + + out = torch.empty_like(x_flat) + out[..., ::2] = o1 + out[..., 1::2] = o2 + return out.view(B, L, H, D) + + xq_out = _apply_rope(xq) + xk_out = _apply_rope(xk) + return xq_out, xk_out + + +class HunyuanVideo15TransformerInfer(BaseTransformerInfer): + def __init__(self, config): + self.config = config + self.double_blocks_num = config["mm_double_blocks_depth"] + self.heads_num = config["heads_num"] + if self.config["seq_parallel"]: + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False) + else: + self.seq_p_group = None + self.seq_p_fp8_comm = False + self.infer_func = self.infer_without_offload + if self.config.get("modulate_type", "triton") == "triton": + self.modulate_func = fuse_scale_shift_kernel + else: + self.modulate_func = modulate + if self.config.get("rope_type", "flashinfer") == "flashinfer": + self.apply_rope_func = apply_hunyuan_rope_with_flashinfer + else: + self.apply_rope_func = apply_hunyuan_rope_with_torch + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + self.scheduler.transformer_infer = self + + @torch.no_grad() + def infer(self, weights, infer_module_out): + self.infer_func(weights, infer_module_out) + x = self.infer_final_layer(weights, infer_module_out) + return x + + @torch.no_grad() + def infer_without_offload(self, weights, infer_module_out): + for i in range(self.double_blocks_num): + infer_module_out.img, infer_module_out.txt = self.infer_double_block(weights.double_blocks[i], infer_module_out) + + @torch.no_grad() + def infer_final_layer(self, weights, infer_module_out): + x = torch.cat((infer_module_out.img, infer_module_out.txt), 1) + img = x[:, : infer_module_out.img.shape[1], ...] + shift, scale = weights.final_layer.adaLN_modulation.apply(infer_module_out.vec).chunk(2, dim=1) + img = self.modulate_func(weights.final_layer.norm_final.apply(img.squeeze(0)), scale=scale, shift=shift).squeeze(0) + img = weights.final_layer.linear.apply(img) + return img.unsqueeze(0) + + @torch.no_grad() + def infer_double_block(self, weights, infer_module_out): + img_q, img_k, img_v, img_branch_out = self._infer_img_branch_before_attn(weights, infer_module_out) + txt_q, txt_k, txt_v, txt_branch_out = self._infer_txt_branch_before_attn(weights, infer_module_out) + img_attn, txt_attn = self._infer_attn(weights, img_q, img_k, img_v, txt_q, txt_k, txt_v) + img = self._infer_img_branch_after_attn(weights, img_attn, infer_module_out.img, img_branch_out) + txt = self._infer_txt_branch_after_attn(weights, txt_attn, infer_module_out.txt, txt_branch_out) + return img, txt + + @torch.no_grad() + def _infer_img_branch_before_attn(self, weights, infer_module_out): + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = weights.img_branch.img_mod.apply(infer_module_out.vec).chunk(6, dim=-1) + img_modulated = weights.img_branch.img_norm1.apply(infer_module_out.img.squeeze(0)) + img_modulated = self.modulate_func(img_modulated, scale=img_mod1_scale, shift=img_mod1_shift).squeeze(0) + img_q = weights.img_branch.img_attn_q.apply(img_modulated) + img_k = weights.img_branch.img_attn_k.apply(img_modulated) + img_v = weights.img_branch.img_attn_v.apply(img_modulated) + img_q = rearrange(img_q, "L (H D) -> L H D", H=self.heads_num) + img_k = rearrange(img_k, "L (H D) -> L H D", H=self.heads_num) + img_v = rearrange(img_v, "L (H D) -> L H D", H=self.heads_num) + img_q = weights.img_branch.img_attn_q_norm.apply(img_q) + img_k = weights.img_branch.img_attn_k_norm.apply(img_k) + img_q, img_k = self.apply_rope_func(img_q.unsqueeze(0), img_k.unsqueeze(0), cos_sin_cache=self.scheduler.cos_sin) + return ( + img_q, + img_k, + img_v.unsqueeze(0), + HunyuanVideo15ImgBranchOutput( + img_mod1_gate=img_mod1_gate, + img_mod2_shift=img_mod2_shift, + img_mod2_scale=img_mod2_scale, + img_mod2_gate=img_mod2_gate, + ), + ) + + @torch.no_grad() + def _infer_txt_branch_before_attn(self, weights, infer_module_out): + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = weights.txt_branch.txt_mod.apply(infer_module_out.vec).chunk(6, dim=-1) + txt_modulated = weights.txt_branch.txt_norm1.apply(infer_module_out.txt.squeeze(0)) + txt_modulated = self.modulate_func(txt_modulated, scale=txt_mod1_scale, shift=txt_mod1_shift).squeeze(0) + txt_q = weights.txt_branch.txt_attn_q.apply(txt_modulated) + txt_k = weights.txt_branch.txt_attn_k.apply(txt_modulated) + txt_v = weights.txt_branch.txt_attn_v.apply(txt_modulated) + txt_q = rearrange(txt_q, "L (H D) -> L H D", H=self.heads_num) + txt_k = rearrange(txt_k, "L (H D) -> L H D", H=self.heads_num) + txt_v = rearrange(txt_v, "L (H D) -> L H D", H=self.heads_num) + txt_q = weights.txt_branch.txt_attn_q_norm.apply(txt_q).to(txt_v) + txt_k = weights.txt_branch.txt_attn_k_norm.apply(txt_k).to(txt_v) + return ( + txt_q.unsqueeze(0), + txt_k.unsqueeze(0), + txt_v.unsqueeze(0), + HunyuanVideo15TxtBranchOutput( + txt_mod1_gate=txt_mod1_gate, + txt_mod2_shift=txt_mod2_shift, + txt_mod2_scale=txt_mod2_scale, + txt_mod2_gate=txt_mod2_gate, + ), + ) + + @torch.no_grad() + def _infer_attn(self, weights, img_q, img_k, img_v, txt_q, txt_k, txt_v): + img_seqlen = img_q.shape[1] + query = torch.cat([img_q, txt_q], dim=1) + key = torch.cat([img_k, txt_k], dim=1) + value = torch.cat([img_v, txt_v], dim=1) + seqlen = query.shape[1] + cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + + if self.config["seq_parallel"]: + attn_out = weights.self_attention_parallel.apply( + q=query, + k=key, + v=value, + img_qkv_len=img_seqlen, + cu_seqlens_qkv=cu_seqlens_qkv, + attention_module=weights.self_attention, + seq_p_group=self.seq_p_group, + use_fp8_comm=self.seq_p_fp8_comm, + model_cls=self.config["model_cls"], + ) + else: + attn_out = weights.self_attention.apply( + q=query, k=key, v=value, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=seqlen, max_seqlen_kv=seqlen, model_cls=self.config["model_cls"] + ) + + img_attn, txt_attn = attn_out[:img_seqlen], attn_out[img_seqlen:] + return img_attn, txt_attn + + @torch.no_grad() + def _infer_img_branch_after_attn(self, weights, img_attn, img, img_branch_out): + img = img + apply_gate(weights.img_branch.img_attn_proj.apply(img_attn).unsqueeze(0), gate=img_branch_out.img_mod1_gate) + out = weights.img_branch.img_mlp_fc1.apply( + self.modulate_func(weights.img_branch.img_norm2.apply(img.squeeze(0)), scale=img_branch_out.img_mod2_scale, shift=img_branch_out.img_mod2_shift).squeeze(0) + ) + out = weights.img_branch.img_mlp_fc2.apply(F.gelu(out, approximate="tanh")) + img = img + apply_gate(out.unsqueeze(0), gate=img_branch_out.img_mod2_gate) + return img + + @torch.no_grad() + def _infer_txt_branch_after_attn(self, weights, txt_attn, txt, txt_branch_out): + txt = txt + apply_gate(weights.txt_branch.txt_attn_proj.apply(txt_attn).unsqueeze(0), gate=txt_branch_out.txt_mod1_gate) + out = weights.txt_branch.txt_mlp_fc1.apply( + self.modulate_func(weights.txt_branch.txt_norm2.apply(txt.squeeze(0)), scale=txt_branch_out.txt_mod2_scale, shift=txt_branch_out.txt_mod2_shift).squeeze(0) + ) + out = weights.txt_branch.txt_mlp_fc2.apply(F.gelu(out, approximate="tanh")) + txt = txt + apply_gate(out.unsqueeze(0), gate=txt_branch_out.txt_mod2_gate) + return txt diff --git a/lightx2v/models/networks/hunyuan_video/infer/triton_ops.py b/lightx2v/models/networks/hunyuan_video/infer/triton_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdd71c62f746d45bfe7707fb61870ca9c74f50f --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/infer/triton_ops.py @@ -0,0 +1,902 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang + +# TODO: for temporary usage, expecting a refactor +from typing import Optional + +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from torch import Tensor + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64}, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_N": 1024}, num_warps=8), + ], + key=["inner_dim"], +) +@triton.jit +def _fused_scale_shift_4d_kernel( + output_ptr, + normalized_ptr, + scale_ptr, + shift_ptr, + rows, + inner_dim, + seq_len, + num_frames, + frame_seqlen, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N) + mask = col_offsets < inner_dim + + # Pointers for normalized and output + row_base = pid_row * inner_dim + norm_ptrs = normalized_ptr + row_base + col_offsets + out_ptrs = output_ptr + row_base + col_offsets + + # Pointers for scale and shift for 4D + b_idx = pid_row // seq_len + t_idx = pid_row % seq_len + frame_idx_in_batch = t_idx // frame_seqlen + + scale_row_idx = b_idx * num_frames + frame_idx_in_batch + scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets + shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets + + normalized = tl.load(norm_ptrs, mask=mask, other=0.0) + scale = tl.load(scale_ptrs, mask=mask, other=0.0) + shift = tl.load(shift_ptrs, mask=mask, other=0.0) + + one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype) + output = normalized * (one + scale) + shift + + tl.store(out_ptrs, output, mask=mask) + + +@triton.jit +def fuse_scale_shift_kernel_blc_opt( + x_ptr, + shift_ptr, + scale_ptr, + y_ptr, + B, + L, + C, + stride_x_b, + stride_x_l, + stride_x_c, + stride_s_b, + stride_s_l, + stride_s_c, + stride_sc_b, + stride_sc_l, + stride_sc_c, + SCALE_IS_SCALAR: tl.constexpr, + SHIFT_IS_SCALAR: tl.constexpr, + BLOCK_L: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_l = tl.program_id(0) + pid_c = tl.program_id(1) + pid_b = tl.program_id(2) + + l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_l = l_offsets < L + mask_c = c_offsets < C + mask = mask_l[:, None] & mask_c[None, :] + + x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c + x = tl.load(x_ptr + x_off, mask=mask, other=0) + + if SHIFT_IS_SCALAR: + shift_val = tl.load(shift_ptr) + shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype) + else: + s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c + shift = tl.load(shift_ptr + s_off, mask=mask, other=0) + + if SCALE_IS_SCALAR: + scale_val = tl.load(scale_ptr) + scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype) + else: + sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c + scale = tl.load(scale_ptr + sc_off, mask=mask, other=0) + + y = x * (1 + scale) + shift + tl.store(y_ptr + x_off, y, mask=mask) + + +def fuse_scale_shift_kernel( + x: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + block_l: int = 128, + block_c: int = 128, +): + assert x.is_cuda and scale.is_cuda + assert x.is_contiguous() + if x.dim() == 2: + x = x.unsqueeze(0) + + B, L, C = x.shape + output = torch.empty_like(x) + + if scale.dim() == 4: + # scale/shift: [B, F, 1, C] + rows = B * L + x_2d = x.view(rows, C) + output_2d = output.view(rows, C) + grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa + num_frames = scale.shape[1] + assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift" + frame_seqlen = L // num_frames + + # Compact [B, F, C] without the singleton dim into [B*F, C] + scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous() + shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous() + + _fused_scale_shift_4d_kernel[grid]( + output_2d, + x_2d, + scale_reshaped, + shift_reshaped, + rows, + C, + L, + num_frames, + frame_seqlen, + ) + else: + # 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L + # 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C]) + # Also support scalar (0D or 1-element) + if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1): + scale_blc = scale.reshape(1) + elif scale.dim() == 2: + scale_blc = scale[:, None, :] + elif scale.dim() == 3: + scale_blc = scale + else: + raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D") + + if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1): + shift_blc = shift.reshape(1) + elif shift.dim() == 2: + shift_blc = shift[:, None, :] + elif shift.dim() == 3: + shift_blc = shift + else: + # broadcast later via expand if possible + shift_blc = shift + + need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1 + need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1 + + if not need_scale_scalar: + scale_exp = scale_blc.expand(B, L, C) + s_sb, s_sl, s_sc = scale_exp.stride() + else: + s_sb = s_sl = s_sc = 0 + + if not need_shift_scalar: + shift_exp = shift_blc.expand(B, L, C) + sh_sb, sh_sl, sh_sc = shift_exp.stride() + else: + sh_sb = sh_sl = sh_sc = 0 + + # If both scalars and both zero, copy fast-path + if need_scale_scalar and need_shift_scalar: + if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0): + output.copy_(x) + return output + + grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) + fuse_scale_shift_kernel_blc_opt[grid]( + x, + shift_blc if need_shift_scalar else shift_exp, + scale_blc if need_scale_scalar else scale_exp, + output, + B, + L, + C, + x.stride(0), + x.stride(1), + x.stride(2), + sh_sb, + sh_sl, + sh_sc, + s_sb, + s_sl, + s_sc, + SCALE_IS_SCALAR=need_scale_scalar, + SHIFT_IS_SCALAR=need_shift_scalar, + BLOCK_L=block_l, + BLOCK_C=block_c, + num_warps=4, + num_stages=2, + ) + return output + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2), + triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8), + ], + key=["head_size", "interleaved"], +) +@triton.jit +def _rotary_embedding_kernel( + output_ptr, + x_ptr, + cos_ptr, + sin_ptr, + num_heads, + head_size, + num_tokens, + stride_x_row, + stride_cos_row, + stride_sin_row, + interleaved: tl.constexpr, + BLOCK_HS_HALF: tl.constexpr, +): + row_idx = tl.program_id(0) + token_idx = (row_idx // num_heads) % num_tokens + + x_row_ptr = x_ptr + row_idx * stride_x_row + cos_row_ptr = cos_ptr + token_idx * stride_cos_row + sin_row_ptr = sin_ptr + token_idx * stride_sin_row + output_row_ptr = output_ptr + row_idx * stride_x_row + + # half size for x1 and x2 + head_size_half = head_size // 2 + + for block_start in range(0, head_size_half, BLOCK_HS_HALF): + offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF) + mask = offsets_half < head_size_half + + cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0) + sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0) + + offsets_x1 = 2 * offsets_half + offsets_x2 = 2 * offsets_half + 1 + + x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0) + x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0) + + x1_fp32 = x1_vals.to(tl.float32) + x2_fp32 = x2_vals.to(tl.float32) + cos_fp32 = cos_vals.to(tl.float32) + sin_fp32 = sin_vals.to(tl.float32) + o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32) + o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32) + + tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask) + tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask) + + +def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + output = torch.empty_like(x) + + if x.dim() > 3: + bsz, num_tokens, num_heads, head_size = x.shape + else: + num_tokens, num_heads, head_size = x.shape + bsz = 1 + + assert head_size % 2 == 0, "head_size must be divisible by 2" + + x_reshaped = x.view(-1, head_size) + output_reshaped = output.view(-1, head_size) + + # num_tokens per head, 1 token per block + grid = (bsz * num_tokens * num_heads,) + + if interleaved and cos.shape[-1] == head_size: + cos = cos[..., ::2].contiguous() + sin = sin[..., ::2].contiguous() + else: + cos = cos.contiguous() + sin = sin.contiguous() + + _rotary_embedding_kernel[grid]( + output_reshaped, + x_reshaped, + cos, + sin, + num_heads, + head_size, + num_tokens, + x_reshaped.stride(0), + cos.stride(0), + sin.stride(0), + interleaved, + ) + + return output + + +# RMSNorm-fp32 +def maybe_contiguous_lastdim(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def maybe_contiguous(x): + return x.contiguous() if x is not None else None + + +def triton_autotune_configs(): + if not torch.cuda.is_available(): + return [] + # Return configs with a valid warp count for the current device + configs = [] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block = 1024 + # Default to warp size 32 if not defined by device + warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block] + # return [triton.Config({}, num_warps=8)] + + +# Copied from flash-attn +@triton.autotune( + configs=triton_autotune_configs(), + key=[ + "N", + "HAS_RESIDUAL", + "STORE_RESIDUAL_OUT", + "IS_RMS_NORM", + "HAS_BIAS", + "HAS_WEIGHT", + "HAS_X1", + "HAS_W1", + "HAS_B1", + ], +) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + DROPOUT_MASK1, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + y = x_hat * w + b if HAS_BIAS else x_hat * w + else: + y = x_hat + b if HAS_BIAS else x_hat + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + out: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None: + residual_dtype = residual.dtype + if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None): + residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype) + else: + residual_out = None + y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( + x, + weight, + bias, + eps, + out, + residual=residual, + x1=x1, + weight1=weight1, + bias1=bias1, + dropout_p=dropout_p, + rowscale=rowscale, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + residual_out=residual_out, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 + + +# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema +# since we're returning a tuple of tensors +def _layer_norm_fwd_impl( + x: Tensor, + weight: Optional[Tensor], + bias: Tensor, + eps: float, + out: Tensor, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + assert out.shape == x.shape + assert out.stride(-1) == 1 + if residual_out is not None: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) + if x1 is not None: + dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask1 = None + else: + dropout_mask, dropout_mask1 = None, None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( + x, + out, + weight if weight is not None else x, # unused when HAS_WEIGHT == False + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + dropout_mask1, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + HAS_X1=x1 is not None, + HAS_W1=weight1 is not None, + HAS_B1=bias1 is not None, + ) + return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 + + +class LayerNormFn: + @staticmethod + def forward( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) + if residual is not None: + assert residual.shape == x_shape_og + residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) + # weight can be None when elementwise_affine=False for LayerNorm + if weight is not None: + weight = weight.contiguous() + bias = maybe_contiguous(bias) + weight1 = maybe_contiguous(weight1) + bias1 = maybe_contiguous(bias1) + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + out_dtype=out_dtype, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out, + ) + y = y.reshape(x_shape_og) + return y + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) + + +@triton.jit +def _norm_infer_kernel( + X, + Y, + W, + B, + stride_x_row, + stride_y_row, + M, + N, + eps, + IS_RMS_NORM: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_WEIGHT: + W += 0 + if HAS_BIAS: + B += 0 + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32) + y = x_hat * w + else: + y = x_hat + if HAS_BIAS: + b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32) + y += b + tl.store(Y + cols, y, mask=cols < N) + + +def norm_infer( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, + is_rms_norm: bool = False, + out: Optional[Tensor] = None, +): + M, N = x.shape + assert x.stride(-1) == 1 + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.shape == (N,) + assert bias.stride(-1) == 1 + if out is None: + out = torch.empty_like(x) + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + num_warps = min(max(BLOCK_N // 256, 1), 8) + _norm_infer_kernel[(M,)]( + x, + out, + weight if weight is not None else x, # dummy when HAS_WEIGHT=False + bias if bias is not None else x, # dummy when HAS_BIAS=False + x.stride(0), + out.stride(0), + M, + N, + eps, + IS_RMS_NORM=is_rms_norm, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + ) + return out + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) diff --git a/lightx2v/models/networks/hunyuan_video/model.py b/lightx2v/models/networks/hunyuan_video/model.py new file mode 100644 index 0000000000000000000000000000000000000000..69a2515f211e0adcf77ea47b4918937444a579ef --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/model.py @@ -0,0 +1,279 @@ +import gc +import glob +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from loguru import logger +from safetensors import safe_open + +from lightx2v.models.networks.hunyuan_video.infer.feature_caching.transformer_infer import HunyuanTransformerInferTeaCaching, HunyuanVideo15TransformerInferMagCaching +from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer +from lightx2v.models.networks.hunyuan_video.infer.post_infer import HunyuanVideo15PostInfer +from lightx2v.models.networks.hunyuan_video.infer.pre_infer import HunyuanVideo15PreInfer +from lightx2v.models.networks.hunyuan_video.infer.transformer_infer import HunyuanVideo15TransformerInfer +from lightx2v.models.networks.hunyuan_video.weights.post_weights import HunyuanVideo15PostWeights +from lightx2v.models.networks.hunyuan_video.weights.pre_weights import HunyuanVideo15PreWeights +from lightx2v.models.networks.hunyuan_video.weights.transformer_weights import HunyuanVideo15TransformerWeights +from lightx2v.utils.custom_compiler import CompiledMethodsMixin +from lightx2v.utils.envs import * + + +class HunyuanVideo15Model(CompiledMethodsMixin): + def __init__(self, model_path, config, device): + super().__init__() + self.model_path = model_path + self.config = config + self.device = device + if self.config["seq_parallel"]: + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + self.cpu_offload = self.config.get("cpu_offload", False) + self.offload_granularity = self.config.get("offload_granularity", "block") + self.remove_keys = [] + self.remove_keys.extend(["byt5_in", "vision_in"]) + self.dit_quantized = self.config.get("dit_quantized", False) + if self.dit_quantized: + assert self.config.get("dit_quant_scheme", "Default") in [ + "Default-Force-FP32", + "fp8-vllm", + "int8-vllm", + "fp8-q8f", + "int8-q8f", + "fp8-b128-deepgemm", + "fp8-sgl", + "int8-sgl", + "int8-torchao", + "nvfp4", + "mxfp4", + "mxfp6-mxfp8", + "mxfp8", + ] + self._init_infer_class() + self._init_weights() + self._init_infer() + + def _init_infer_class(self): + self.pre_infer_class = HunyuanVideo15PreInfer + self.post_infer_class = HunyuanVideo15PostInfer + if self.config["feature_caching"] == "NoCaching": + self.transformer_infer_class = HunyuanVideo15TransformerInfer if not self.cpu_offload else HunyuanVideo15OffloadTransformerInfer + elif self.config["feature_caching"] == "Mag": + self.transformer_infer_class = HunyuanVideo15TransformerInferMagCaching + elif self.config["feature_caching"] == "Tea": + self.transformer_infer_class = HunyuanTransformerInferTeaCaching + else: + raise NotImplementedError + + def _init_weights(self): + unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE() + sensitive_layer = {} + if not self.dit_quantized: + weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) + else: + weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) + + self.original_weight_dict = weight_dict + self.pre_weight = HunyuanVideo15PreWeights(self.config) + self.transformer_weights = HunyuanVideo15TransformerWeights(self.config) + self.post_weight = HunyuanVideo15PostWeights(self.config) + self._apply_weights() + + def _apply_weights(self, weight_dict=None): + if weight_dict is not None: + self.original_weight_dict = weight_dict + del weight_dict + gc.collect() + # Load weights into containers + self.pre_weight.load(self.original_weight_dict) + self.transformer_weights.load(self.original_weight_dict) + + del self.original_weight_dict + torch.cuda.empty_cache() + gc.collect() + + def _init_infer(self): + self.pre_infer = self.pre_infer_class(self.config) + self.transformer_infer = self.transformer_infer_class(self.config) + self.post_infer = self.post_infer_class(self.config) + if hasattr(self.transformer_infer, "offload_manager"): + self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers) + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + self.pre_infer.set_scheduler(scheduler) + self.transformer_infer.set_scheduler(scheduler) + self.post_infer.set_scheduler(scheduler) + + def _load_quant_ckpt(self, unified_dtype, sensitive_layer): + remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] + + if self.config.get("dit_quantized_ckpt", None): + safetensors_path = self.config["dit_quantized_ckpt"] + else: + safetensors_path = self.model_path + + if os.path.isdir(safetensors_path): + safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) + else: + safetensors_files = [safetensors_path] + safetensors_path = os.path.dirname(safetensors_path) + + weight_dict = {} + for safetensor_path in safetensors_files: + if self.config.get("adapter_model_path", None) is not None: + if self.config["adapter_model_path"] == safetensor_path: + continue + with safe_open(safetensor_path, framework="pt") as f: + logger.info(f"Loading weights from {safetensor_path}") + for k in f.keys(): + if any(remove_key in k for remove_key in remove_keys): + continue + if f.get_tensor(k).dtype in [ + torch.float16, + torch.bfloat16, + torch.float, + ]: + if unified_dtype or all(s not in k for s in sensitive_layer): + weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device) + else: + weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device) + else: + weight_dict[k] = f.get_tensor(k).to(self.device) + + if self.config.get("dit_quant_scheme", "Default") == "nvfp4": + calib_path = os.path.join(safetensors_path, "calib.pt") + logger.info(f"[CALIB] Loaded calibration data from: {calib_path}") + calib_data = torch.load(calib_path, map_location="cpu") + for k, v in calib_data["absmax"].items(): + weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device) + + return weight_dict + + def _load_ckpt(self, unified_dtype, sensitive_layer): + if self.config.get("dit_original_ckpt", None): + safetensors_path = self.config["dit_original_ckpt"] + else: + safetensors_path = self.config["transformer_model_path"] + + if os.path.isdir(safetensors_path): + safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) + else: + safetensors_files = [safetensors_path] + + weight_dict = {} + for file_path in safetensors_files: + if self.config.get("adapter_model_path", None) is not None: + if self.config["adapter_model_path"] == file_path: + continue + logger.info(f"Loading weights from {file_path}") + file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) + weight_dict.update(file_weights) + + return weight_dict + + def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): + remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] + + if self.device.type != "cpu" and dist.is_initialized(): + device = dist.get_rank() + else: + device = str(self.device) + + with safe_open(file_path, framework="pt", device=device) as f: + return { + key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) + for key in f.keys() + if not any(remove_key in key for remove_key in remove_keys) + } + + def to_cpu(self): + self.pre_weight.to_cpu() + self.transformer_weights.to_cpu() + + def to_cuda(self): + self.pre_weight.to_cuda() + self.transformer_weights.to_cuda() + + @torch.no_grad() + def infer(self, inputs): + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]: + self.to_cuda() + elif self.offload_granularity != "model": + self.pre_weight.to_cuda() + self.transformer_weights.non_block_weights_to_cuda() + + if self.config["enable_cfg"]: + if self.config["cfg_parallel"]: + # ==================== CFG Parallel Processing ==================== + cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p") + assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2" + cfg_p_rank = dist.get_rank(cfg_p_group) + + if cfg_p_rank == 0: + noise_pred = self._infer_cond_uncond(inputs, infer_condition=True).contiguous() + else: + noise_pred = self._infer_cond_uncond(inputs, infer_condition=False).contiguous() + + noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)] + dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group) + noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0 + noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1 + else: + # ==================== CFG Processing ==================== + noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True) + noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False) + + self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) + else: + # ==================== No CFG ==================== + self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) + + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]: + self.to_cpu() + elif self.offload_granularity != "model": + self.pre_weight.to_cpu() + self.transformer_weights.non_block_weights_to_cpu() + + @torch.no_grad() + def _infer_cond_uncond(self, inputs, infer_condition=True): + self.scheduler.infer_condition = infer_condition + + pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs) + + if self.config["seq_parallel"]: + pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) + + x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out) + + if self.config["seq_parallel"]: + x = self._seq_parallel_post_process(x) + + noise_pred = self.post_infer.infer(x, pre_infer_out)[0] + + return noise_pred + + @torch.no_grad() + def _seq_parallel_pre_process(self, pre_infer_out): + seqlen = pre_infer_out.img.shape[1] + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + pre_infer_out.img = F.pad(pre_infer_out.img, (0, 0, 0, padding_size)) + + pre_infer_out.img = torch.chunk(pre_infer_out.img, world_size, dim=1)[cur_rank] + return pre_infer_out + + @torch.no_grad() + def _seq_parallel_post_process(self, x): + world_size = dist.get_world_size(self.seq_p_group) + gathered_x = [torch.empty_like(x) for _ in range(world_size)] + dist.all_gather(gathered_x, x, group=self.seq_p_group) + combined_output = torch.cat(gathered_x, dim=1) + return combined_output diff --git a/lightx2v/models/networks/hunyuan_video/weights/post_weights.py b/lightx2v/models/networks/hunyuan_video/weights/post_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..5c762263d598f68d0b587e92b63d57ab70236ac1 --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/weights/post_weights.py @@ -0,0 +1,7 @@ +from lightx2v.common.modules.weight_module import WeightModule + + +class HunyuanVideo15PostWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.config = config diff --git a/lightx2v/models/networks/hunyuan_video/weights/pre_weights.py b/lightx2v/models/networks/hunyuan_video/weights/pre_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..f5600d35a30b3c9a7987b17ef0678b67e8855cdc --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/weights/pre_weights.py @@ -0,0 +1,147 @@ +from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList +from lightx2v.utils.registry_factory import ( + CONV3D_WEIGHT_REGISTER, + EMBEDDING_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, +) + + +class HunyuanVideo15PreWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.config = config + self.mm_type = config.get("dit_quant_scheme", "Default") + self.patch_size = config["patch_size"] # (1, 1, 1) + + self.add_module( + "img_in", + CONV3D_WEIGHT_REGISTER["Default"]( + "img_in.proj.weight", + "img_in.proj.bias", + stride=self.patch_size, + ), + ) + self.add_module( + "time_in_0", + MM_WEIGHT_REGISTER["Default"]( + "time_in.mlp.0.weight", + "time_in.mlp.0.bias", + ), + ) + self.add_module( + "time_in_2", + MM_WEIGHT_REGISTER["Default"]( + "time_in.mlp.2.weight", + "time_in.mlp.2.bias", + ), + ) + if self.config["is_sr_running"]: + self.add_module( + "time_r_in_0", + MM_WEIGHT_REGISTER["Default"]( + "time_r_in.mlp.0.weight", + "time_r_in.mlp.0.bias", + ), + ) + self.add_module( + "time_r_in_2", + MM_WEIGHT_REGISTER["Default"]( + "time_r_in.mlp.2.weight", + "time_r_in.mlp.2.bias", + ), + ) + self.add_module( + "txt_in_t_embedder_0", + MM_WEIGHT_REGISTER["Default"]( + "txt_in.t_embedder.mlp.0.weight", + "txt_in.t_embedder.mlp.0.bias", + ), + ) + self.add_module( + "txt_in_t_embedder_2", + MM_WEIGHT_REGISTER["Default"]( + "txt_in.t_embedder.mlp.2.weight", + "txt_in.t_embedder.mlp.2.bias", + ), + ) + + self.add_module( + "txt_in_c_embedder_0", + MM_WEIGHT_REGISTER["Default"]( + "txt_in.c_embedder.linear_1.weight", + "txt_in.c_embedder.linear_1.bias", + ), + ) + self.add_module( + "txt_in_c_embedder_2", + MM_WEIGHT_REGISTER["Default"]( + "txt_in.c_embedder.linear_2.weight", + "txt_in.c_embedder.linear_2.bias", + ), + ) + + self.add_module( + "txt_in_input_embedder", + MM_WEIGHT_REGISTER["Default"]( + "txt_in.input_embedder.weight", + "txt_in.input_embedder.bias", + ), + ) + + self.add_module( + "individual_token_refiner", + WeightModuleList( + [ + IndividualTokenRefinerBlock( + i, + self.mm_type, + self.config, + "txt_in.individual_token_refiner.blocks", + ) + for i in range(2) # 2 blocks + ] + ), + ) + + self.add_module( + "cond_type_embedding", + EMBEDDING_WEIGHT_REGISTER["Default"]( + "cond_type_embedding.weight", + ), + ) + + +class IndividualTokenRefinerBlock(WeightModule): + def __init__(self, block_idx, mm_type, config, block_prefix): + super().__init__() + self.config = config + self.mm_type = mm_type + self.add_module( + "norm1", + LN_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.norm1.weight", f"{block_prefix}.{block_idx}.norm1.bias"), + ) + self.add_module( + "self_attn_qkv", + MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.self_attn_qkv.weight", f"{block_prefix}.{block_idx}.self_attn_qkv.bias"), + ) + self.add_module( + "self_attn_proj", + MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.self_attn_proj.weight", f"{block_prefix}.{block_idx}.self_attn_proj.bias"), + ) + self.add_module( + "norm2", + LN_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.norm2.weight", f"{block_prefix}.{block_idx}.norm2.bias"), + ) + self.add_module( + "mlp_fc1", + MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.mlp.fc1.weight", f"{block_prefix}.{block_idx}.mlp.fc1.bias"), + ) + self.add_module( + "mlp_fc2", + MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.mlp.fc2.weight", f"{block_prefix}.{block_idx}.mlp.fc2.bias"), + ) + self.add_module( + "adaLN_modulation", + MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.adaLN_modulation.1.weight", f"{block_prefix}.{block_idx}.adaLN_modulation.1.bias"), + ) diff --git a/lightx2v/models/networks/hunyuan_video/weights/transformer_weights.py b/lightx2v/models/networks/hunyuan_video/weights/transformer_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..7eba13667f853ec8cf2ac4e61ce21d8c2a75c81e --- /dev/null +++ b/lightx2v/models/networks/hunyuan_video/weights/transformer_weights.py @@ -0,0 +1,390 @@ +from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList +from lightx2v.utils.registry_factory import ( + ATTN_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, + RMS_WEIGHT_REGISTER, +) + + +class HunyuanVideo15TransformerWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.config = config + self.task = config["task"] + self.mm_type = config.get("dit_quant_scheme", "Default") + self.ln_type = config.get("ln_type", "Triton") + self.rms_type = config.get("rms_type", "sgl-kernel") + self.double_blocks_num = config["mm_double_blocks_depth"] + self.register_offload_buffers(config) + self.add_module("double_blocks", WeightModuleList([MMDoubleStreamBlock(i, self.task, self.config, block_prefix="double_blocks") for i in range(self.double_blocks_num)])) + self.add_module("final_layer", FinalLayerWeights(self.config)) + + def register_offload_buffers(self, config): + if config["cpu_offload"]: + if config.get("offload_granularity", "block") == "block": + self.offload_blocks_num = 2 + self.offload_block_cuda_buffers = WeightModuleList( + [ + MMDoubleStreamBlock( + i, + self.task, + self.config, + "double_blocks", + True, + ) + for i in range(self.offload_blocks_num) + ] + ) + self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers) + self.offload_phase_cuda_buffers = None + + def non_block_weights_to_cuda(self): + self.final_layer.to_cuda() + + def non_block_weights_to_cpu(self): + self.final_layer.to_cpu() + + +class MMDoubleStreamBlock(WeightModule): + def __init__(self, block_index, task, config, block_prefix="double_blocks", create_cuda_buffer=False, create_cpu_buffer=False): + super().__init__() + self.block_index = block_index + self.task = task + self.config = config + self.create_cuda_buffer = create_cuda_buffer + self.create_cpu_buffer = create_cpu_buffer + + self.lazy_load = False + self.lazy_load_file = None + + self.add_module( + "img_branch", + MMDoubleStreamBlockImgBranch(block_index, task, config, block_prefix, create_cuda_buffer, create_cpu_buffer), + ) + self.add_module( + "txt_branch", + MMDoubleStreamBlockTxtBranch(block_index, task, config, block_prefix, create_cuda_buffer, create_cpu_buffer), + ) + attention_weights_cls = ATTN_WEIGHT_REGISTER[self.config["attn_type"]] + self.add_module("self_attention", attention_weights_cls()) + if self.config["seq_parallel"]: + self.add_module( + "self_attention_parallel", + ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")](), + ) + + +class MMDoubleStreamBlockImgBranch(WeightModule): + def __init__(self, block_index, task, config, block_prefix="double_blocks", create_cuda_buffer=False, create_cpu_buffer=False): + super().__init__() + self.block_index = block_index + self.task = task + self.config = config + + self.lazy_load = False + self.lazy_load_file = None + + self.mm_type = config.get("dit_quant_scheme", "Default") + self.ln_type = config.get("ln_type", "Triton") + self.rms_type = config.get("rms_type", "sgl-kernel") + + self.add_module( + "img_mod", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.img_mod.linear.weight", + f"{block_prefix}.{self.block_index}.img_mod.linear.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_norm1", + LN_WEIGHT_REGISTER[self.ln_type]( + None, + None, + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_attn_q", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.img_attn_q.weight", + f"{block_prefix}.{self.block_index}.img_attn_q.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_attn_k", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.img_attn_k.weight", + f"{block_prefix}.{self.block_index}.img_attn_k.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_attn_v", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.img_attn_v.weight", + f"{block_prefix}.{self.block_index}.img_attn_v.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_attn_q_norm", + RMS_WEIGHT_REGISTER[self.rms_type]( + f"{block_prefix}.{self.block_index}.img_attn_q_norm.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_attn_k_norm", + RMS_WEIGHT_REGISTER[self.rms_type]( + f"{block_prefix}.{self.block_index}.img_attn_k_norm.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_attn_proj", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.img_attn_proj.weight", + f"{block_prefix}.{self.block_index}.img_attn_proj.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_norm2", + LN_WEIGHT_REGISTER[self.ln_type]( + None, + None, + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_mlp_fc1", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.img_mlp.fc1.weight", + f"{block_prefix}.{self.block_index}.img_mlp.fc1.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_mlp_fc2", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.img_mlp.fc2.weight", + f"{block_prefix}.{self.block_index}.img_mlp.fc2.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + +class MMDoubleStreamBlockTxtBranch(WeightModule): + def __init__(self, block_index, task, config, block_prefix="double_blocks", create_cuda_buffer=False, create_cpu_buffer=False): + super().__init__() + self.block_index = block_index + self.task = task + self.config = config + + self.lazy_load = False + self.lazy_load_file = None + + self.mm_type = config.get("dit_quant_scheme", "Default") + self.ln_type = config.get("ln_type", "Triton") + self.rms_type = config.get("rms_type", "sgl-kernel") + + self.add_module( + "txt_mod", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.txt_mod.linear.weight", + f"{block_prefix}.{self.block_index}.txt_mod.linear.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_norm1", + LN_WEIGHT_REGISTER[self.ln_type]( + None, + None, + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_attn_q", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.txt_attn_q.weight", + f"{block_prefix}.{self.block_index}.txt_attn_q.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_attn_k", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.txt_attn_k.weight", + f"{block_prefix}.{self.block_index}.txt_attn_k.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_attn_v", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.txt_attn_v.weight", + f"{block_prefix}.{self.block_index}.txt_attn_v.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_attn_q_norm", + RMS_WEIGHT_REGISTER[self.rms_type]( + f"{block_prefix}.{self.block_index}.txt_attn_q_norm.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_attn_k_norm", + RMS_WEIGHT_REGISTER[self.rms_type]( + f"{block_prefix}.{self.block_index}.txt_attn_k_norm.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_attn_proj", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.txt_attn_proj.weight", + f"{block_prefix}.{self.block_index}.txt_attn_proj.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_norm2", + LN_WEIGHT_REGISTER[self.ln_type]( + None, + None, + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_mlp_fc1", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.txt_mlp.fc1.weight", + f"{block_prefix}.{self.block_index}.txt_mlp.fc1.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_mlp_fc2", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.txt_mlp.fc2.weight", + f"{block_prefix}.{self.block_index}.txt_mlp.fc2.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + +class FinalLayerWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.config = config + self.lazy_load = False + self.lazy_load_file = None + + self.mm_type = config.get("dit_quant_scheme", "Default") + self.ln_type = config.get("ln_type", "Triton") + + self.add_module( + "adaLN_modulation", + MM_WEIGHT_REGISTER["Default"]( + "final_layer.adaLN_modulation.1.weight", + "final_layer.adaLN_modulation.1.bias", + False, + False, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "linear", + MM_WEIGHT_REGISTER["Default"]( + "final_layer.linear.weight", + "final_layer.linear.bias", + False, + False, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "norm_final", + LN_WEIGHT_REGISTER[self.ln_type]( + None, + None, + False, + False, + self.lazy_load, + self.lazy_load_file, + ), + ) diff --git a/lightx2v/models/networks/qwen_image/infer/offload/__init__.py b/lightx2v/models/networks/qwen_image/infer/offload/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py b/lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0b40ba7f3bf9d38996385a64b92b6a19f8c445 --- /dev/null +++ b/lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py @@ -0,0 +1,50 @@ +import torch + +from lightx2v.common.offload.manager import WeightAsyncStreamManager +from lightx2v.models.networks.qwen_image.infer.transformer_infer import QwenImageTransformerInfer +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.phases_num = 3 + self.num_blocks = config["num_layers"] + if self.config.get("cpu_offload", False): + if "offload_ratio" in self.config: + self.offload_ratio = self.config["offload_ratio"] + else: + self.offload_ratio = 1 + offload_granularity = self.config.get("offload_granularity", "block") + if offload_granularity == "block": + if not self.config.get("lazy_load", False): + self.infer_func = self.infer_with_blocks_offload + else: + assert NotImplementedError + else: + assert NotImplementedError + + if offload_granularity != "model": + self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity) + else: + assert NotImplementedError + + def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb): + for block_idx in range(self.num_blocks): + self.block_idx = block_idx + if block_idx == 0: + self.offload_manager.init_first_buffer(block_weights.blocks) + + if block_idx < self.num_blocks - 1: + self.offload_manager.prefetch_weights(block_idx + 1, block_weights.blocks) + + with torch_device_module.stream(self.offload_manager.compute_stream): + encoder_hidden_states, hidden_states = self.infer_block( + block_weight=self.offload_manager.cuda_buffers[0], hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb + ) + + self.offload_manager.swap_blocks() + + return encoder_hidden_states, hidden_states diff --git a/lightx2v/models/networks/qwen_image/infer/post_infer.py b/lightx2v/models/networks/qwen_image/infer/post_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..6af7e76a84ea38af1456bebf31005dbc0f29a089 --- /dev/null +++ b/lightx2v/models/networks/qwen_image/infer/post_infer.py @@ -0,0 +1,19 @@ +import torch +import torch.nn.functional as F + + +class QwenImagePostInfer: + def __init__(self, config): + self.config = config + self.cpu_offload = config.get("cpu_offload", False) + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + def infer(self, weights, hidden_states, temb): + temb1 = F.silu(temb) + temb1 = weights.norm_out_linear.apply(temb1) + scale, shift = torch.chunk(temb1, 2, dim=1) + hidden_states = weights.norm_out.apply(hidden_states) * (1 + scale) + shift + output = weights.proj_out_linear.apply(hidden_states.squeeze(0)) + return output.unsqueeze(0) diff --git a/lightx2v/models/networks/qwen_image/infer/pre_infer.py b/lightx2v/models/networks/qwen_image/infer/pre_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..c9835c2cda6f774b3304e0d53a23e68e8f4bedf4 --- /dev/null +++ b/lightx2v/models/networks/qwen_image/infer/pre_infer.py @@ -0,0 +1,188 @@ +import functools +import math +from typing import List + +import torch +from torch import nn + +from lightx2v.utils.envs import * + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int = 256, + flip_sin_to_cos: bool = True, + downscale_freq_shift: float = 0, + scale: float = 1000, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq = self.rope_cache[rope_key] + else: + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = txt_seq_lens + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenImagePreInfer: + def __init__(self, config): + self.config = config + self.attention_kwargs = {} + self.cpu_offload = config.get("cpu_offload", False) + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(config["axes_dims_rope"]), scale_rope=True) + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + def infer(self, weights, hidden_states, timestep, guidance, encoder_hidden_states_mask, encoder_hidden_states, img_shapes, txt_seq_lens, attention_kwargs): + hidden_states = hidden_states.squeeze(0) + hidden_states = weights.img_in.apply(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = encoder_hidden_states.squeeze(0) + + encoder_hidden_states = weights.txt_norm.apply(encoder_hidden_states) + encoder_hidden_states = weights.txt_in.apply(encoder_hidden_states) + timesteps_proj = get_timestep_embedding(timestep).to(torch.bfloat16) + + embed = weights.time_text_embed_timestep_embedder_linear_1.apply(timesteps_proj) + embed0 = torch.nn.functional.silu(embed) + embed0 = weights.time_text_embed_timestep_embedder_linear_2.apply(embed0) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens[0], device=hidden_states.device) + + return hidden_states, encoder_hidden_states, encoder_hidden_states_mask, (embed0, image_rotary_emb) diff --git a/lightx2v/models/networks/qwen_image/infer/transformer_infer.py b/lightx2v/models/networks/qwen_image/infer/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..ca98393d37c0409014cf9144b8c61db38ea50865 --- /dev/null +++ b/lightx2v/models/networks/qwen_image/infer/transformer_infer.py @@ -0,0 +1,224 @@ +from typing import Tuple, Union + +import torch +import torch.nn.functional as F + +from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer + + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +def calculate_q_k_len(q, k_lens): + q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) + cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) + cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) + return cu_seqlens_q, cu_seqlens_k + + +def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_emb, attn_type): + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = block_weight.attn.to_q.apply(hidden_states[0]) + img_key = block_weight.attn.to_k.apply(hidden_states[0]) + img_value = block_weight.attn.to_v.apply(hidden_states[0]) + + # Compute QKV for text stream (context projections) + txt_query = block_weight.attn.add_q_proj.apply(encoder_hidden_states[0]) + txt_key = block_weight.attn.add_k_proj.apply(encoder_hidden_states[0]) + txt_value = block_weight.attn.add_v_proj.apply(encoder_hidden_states[0]) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (block_weight.attn.heads, -1)) + img_key = img_key.unflatten(-1, (block_weight.attn.heads, -1)) + img_value = img_value.unflatten(-1, (block_weight.attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (block_weight.attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (block_weight.attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (block_weight.attn.heads, -1)) + + # Apply QK normalization + if block_weight.attn.norm_q is not None: + img_query = block_weight.attn.norm_q.apply(img_query) + if block_weight.attn.norm_k is not None: + img_key = block_weight.attn.norm_k.apply(img_key) + if block_weight.attn.norm_added_q is not None: + txt_query = block_weight.attn.norm_added_q.apply(txt_query) + if block_weight.attn.norm_added_k is not None: + txt_key = block_weight.attn.norm_added_k.apply(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs1 = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query.unsqueeze(0), img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key.unsqueeze(0), img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query.unsqueeze(0), txt_freqs1, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key.unsqueeze(0), txt_freqs1, use_real=False) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value.unsqueeze(0), img_value.unsqueeze(0)], dim=1) + + # Compute joint attention + if attn_type == "torch_sdpa": + joint_hidden_states = block_weight.attn.calculate.apply(q=joint_query, k=joint_key, v=joint_value) + + else: + joint_query = joint_query.squeeze(0) + joint_key = joint_key.squeeze(0) + joint_value = joint_value.squeeze(0) + + k_lens = torch.tensor([joint_key.size(0)], dtype=torch.int32, device=joint_key.device) + cu_seqlens_q, cu_seqlens_k = calculate_q_k_len(joint_query, k_lens=k_lens) + + joint_hidden_states = block_weight.attn.calculate.apply( + q=joint_query, k=joint_key, v=joint_value, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=joint_query.size(0), max_seqlen_kv=joint_key.size(0), model_cls="qwen_image" + ) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:seq_txt, :] # Text part + img_attn_output = joint_hidden_states[seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = block_weight.attn.to_out.apply(img_attn_output) + txt_attn_output = block_weight.attn.to_add_out.apply(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerInfer(BaseTransformerInfer): + def __init__(self, config): + self.config = config + self.infer_conditional = True + self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) + self.infer_func = self.infer_calculating + self.attn_type = config.get("attn_type", "flash_attn3") + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def infer_block(self, block_weight, hidden_states, encoder_hidden_states, temb, image_rotary_emb): + # Get modulation parameters for both streams + img_mod_params = block_weight.img_mod.apply(F.silu(temb)) + txt_mod_params = block_weight.txt_mod.apply(F.silu(temb)) + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) + + # Process image stream - norm1 + modulation + img_normed = block_weight.img_norm1.apply(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + + # Process text stream - norm1 + modulation + txt_normed = block_weight.txt_norm1.apply(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + attn_output = apply_attn( + block_weight=block_weight, + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + image_rotary_emb=image_rotary_emb, + attn_type=self.attn_type, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + img_normed2 = block_weight.img_norm2.apply(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_mlp_output = F.gelu(block_weight.img_mlp.mlp_0.apply(img_modulated2.squeeze(0)), approximate="tanh") + img_mlp_output = block_weight.img_mlp.mlp_2.apply(img_mlp_output) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + txt_normed2 = block_weight.txt_norm2.apply(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_mlp_output = F.gelu(block_weight.txt_mlp.mlp_0.apply(txt_modulated2.squeeze(0)), approximate="tanh") + txt_mlp_output = block_weight.txt_mlp.mlp_2.apply(txt_mlp_output) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + def infer_calculating(self, block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb): + for idx in range(len(block_weights.blocks)): + block_weight = block_weights.blocks[idx] + encoder_hidden_states, hidden_states = self.infer_block( + block_weight=block_weight, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb + ) + return encoder_hidden_states, hidden_states + + def infer(self, hidden_states, encoder_hidden_states, pre_infer_out, block_weights): + temb, image_rotary_emb = pre_infer_out + encoder_hidden_states, hidden_states = self.infer_func(block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb) + return encoder_hidden_states, hidden_states diff --git a/lightx2v/models/networks/qwen_image/lora_adapter.py b/lightx2v/models/networks/qwen_image/lora_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..6897a10b359fdec4e12e83f6e5d8b1e2a8513125 --- /dev/null +++ b/lightx2v/models/networks/qwen_image/lora_adapter.py @@ -0,0 +1,90 @@ +import os + +import torch +from loguru import logger +from safetensors import safe_open + +from lightx2v.utils.envs import * +from lightx2v_platform.base.global_var import AI_DEVICE + + +def fuse_lora_weights(original_weight, lora_down, lora_up, alpha): + rank = lora_down.shape[0] + lora_delta = torch.mm(lora_up, lora_down) # W_up × W_down + scaling = alpha / rank + lora_delta = lora_delta * scaling + fused_weight = original_weight + lora_delta + return fused_weight + + +class QwenImageLoraWrapper: + def __init__(self, qwenimage_model): + self.model = qwenimage_model + self.lora_metadata = {} + self.device = torch.device(AI_DEVICE) if not self.model.config.get("cpu_offload", False) else torch.device("cpu") + + def load_lora(self, lora_path, lora_name=None): + if lora_name is None: + lora_name = os.path.basename(lora_path).split(".")[0] + + if lora_name in self.lora_metadata: + logger.info(f"LoRA {lora_name} already loaded, skipping...") + return lora_name + + self.lora_metadata[lora_name] = {"path": lora_path} + logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}") + + return lora_name + + def _load_lora_file(self, file_path): + with safe_open(file_path, framework="pt") as f: + tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()).to(self.device) for key in f.keys()} + return tensor_dict + + def apply_lora(self, lora_name, alpha=1.0): + if lora_name not in self.lora_metadata: + logger.info(f"LoRA {lora_name} not found. Please load it first.") + + if not hasattr(self.model, "original_weight_dict"): + logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.") + return False + + lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"]) + + weight_dict = self.model.original_weight_dict + + weight_dict = self._apply_lora_weights(weight_dict, lora_weights, alpha) + + self.model._apply_weights(weight_dict) + + logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") + del lora_weights + return True + + @torch.no_grad() + def _apply_lora_weights(self, weight_dict, lora_weights, alpha): + lora_prefixs = [ + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_add_out", + "attn.to_out.0", + "img_mlp.net.0.proj", + "txt_mlp.net.0.proj", + "txt_mlp.net.2", + ] + + for prefix in lora_prefixs: + for idx in range(self.model.config["num_layers"]): + prefix_name = f"transformer_blocks.{idx}.{prefix}" + lora_up = lora_weights[f"{prefix_name}.lora_up.weight"] + lora_down = lora_weights[f"{prefix_name}.lora_down.weight"] + lora_alpha = lora_weights[f"{prefix_name}.alpha"] + origin = weight_dict[f"{prefix_name}.weight"] + weight_dict[f"{prefix_name}.weight"] = fuse_lora_weights(origin, lora_down, lora_up, lora_alpha) + + return weight_dict diff --git a/lightx2v/models/networks/qwen_image/model.py b/lightx2v/models/networks/qwen_image/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ec46875b633c23e1480b85f0a5ec89b2e02a4682 --- /dev/null +++ b/lightx2v/models/networks/qwen_image/model.py @@ -0,0 +1,383 @@ +import gc +import glob +import json +import os + +import torch +from safetensors import safe_open + +from lightx2v.utils.envs import * +from lightx2v.utils.utils import * + +from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer +from .infer.post_infer import QwenImagePostInfer +from .infer.pre_infer import QwenImagePreInfer +from .infer.transformer_infer import QwenImageTransformerInfer +from .weights.post_weights import QwenImagePostWeights +from .weights.pre_weights import QwenImagePreWeights +from .weights.transformer_weights import QwenImageTransformerWeights + + +class QwenImageTransformerModel: + pre_weight_class = QwenImagePreWeights + transformer_weight_class = QwenImageTransformerWeights + post_weight_class = QwenImagePostWeights + + def __init__(self, config): + self.config = config + self.model_path = os.path.join(config["model_path"], "transformer") + self.cpu_offload = config.get("cpu_offload", False) + self.offload_granularity = self.config.get("offload_granularity", "block") + self.device = torch.device("cpu") if self.cpu_offload else torch.device(AI_DEVICE) + + with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f: + transformer_config = json.load(f) + self.in_channels = transformer_config["in_channels"] + self.attention_kwargs = {} + + self.dit_quantized = self.config.get("dit_quantized", False) + + self._init_infer_class() + self._init_weights() + self._init_infer() + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + self.pre_infer.set_scheduler(scheduler) + self.transformer_infer.set_scheduler(scheduler) + self.post_infer.set_scheduler(scheduler) + + def _init_infer_class(self): + if self.config["feature_caching"] == "NoCaching": + self.transformer_infer_class = QwenImageTransformerInfer if not self.cpu_offload else QwenImageOffloadTransformerInfer + else: + assert NotImplementedError + self.pre_infer_class = QwenImagePreInfer + self.post_infer_class = QwenImagePostInfer + + def _init_weights(self, weight_dict=None): + unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE() + # Some layers run with float32 to achieve high accuracy + sensitive_layer = {} + + if weight_dict is None: + is_weight_loader = self._should_load_weights() + if is_weight_loader: + if not self.dit_quantized: + # Load original weights + weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) + else: + # Load quantized weights + if not self.config.get("lazy_load", False): + weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) + else: + weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer) + + if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False): + weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader) + + self.original_weight_dict = weight_dict + else: + self.original_weight_dict = weight_dict + + # Initialize weight containers + self.pre_weight = self.pre_weight_class(self.config) + self.transformer_weights = self.transformer_weight_class(self.config) + self.post_weight = self.post_weight_class(self.config) + if not self._should_init_empty_model(): + self._apply_weights() + + def _apply_weights(self, weight_dict=None): + if weight_dict is not None: + self.original_weight_dict = weight_dict + del weight_dict + gc.collect() + # Load weights into containers + self.pre_weight.load(self.original_weight_dict) + self.transformer_weights.load(self.original_weight_dict) + self.post_weight.load(self.original_weight_dict) + + del self.original_weight_dict + torch.cuda.empty_cache() + gc.collect() + + def _should_load_weights(self): + """Determine if current rank should load weights from disk.""" + if self.config.get("device_mesh") is None: + # Single GPU mode + return True + elif dist.is_initialized(): + if self.config.get("load_from_rank0", False): + # Multi-GPU mode, only rank 0 loads + if dist.get_rank() == 0: + logger.info(f"Loading weights from {self.model_path}") + return True + else: + return True + return False + + def _should_init_empty_model(self): + if self.config.get("lora_configs") and self.config["lora_configs"]: + return True + return False + + def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): + remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] + + if self.device.type != "cpu" and dist.is_initialized(): + device = dist.get_rank() + else: + device = str(self.device) + + with safe_open(file_path, framework="pt", device=device) as f: + return { + key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) + for key in f.keys() + if not any(remove_key in key for remove_key in remove_keys) + } + + def _load_ckpt(self, unified_dtype, sensitive_layer): + if self.config.get("dit_original_ckpt", None): + safetensors_path = self.config["dit_original_ckpt"] + else: + safetensors_path = self.model_path + + if os.path.isdir(safetensors_path): + safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) + else: + safetensors_files = [safetensors_path] + + weight_dict = {} + for file_path in safetensors_files: + logger.info(f"Loading weights from {file_path}") + file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) + weight_dict.update(file_weights) + + return weight_dict + + def _load_quant_ckpt(self, unified_dtype, sensitive_layer): + remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] + + if self.config.get("dit_quantized_ckpt", None): + safetensors_path = self.config["dit_quantized_ckpt"] + else: + safetensors_path = self.model_path + + if os.path.isdir(safetensors_path): + safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) + else: + safetensors_files = [safetensors_path] + safetensors_path = os.path.dirname(safetensors_path) + + weight_dict = {} + for safetensor_path in safetensors_files: + with safe_open(safetensor_path, framework="pt") as f: + logger.info(f"Loading weights from {safetensor_path}") + for k in f.keys(): + if any(remove_key in k for remove_key in remove_keys): + continue + if f.get_tensor(k).dtype in [ + torch.float16, + torch.bfloat16, + torch.float, + ]: + if unified_dtype or all(s not in k for s in sensitive_layer): + weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device) + else: + weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device) + else: + weight_dict[k] = f.get_tensor(k).to(self.device) + + if self.config.get("dit_quant_scheme", "Default") == "nvfp4": + calib_path = os.path.join(safetensors_path, "calib.pt") + logger.info(f"[CALIB] Loaded calibration data from: {calib_path}") + calib_data = torch.load(calib_path, map_location="cpu") + for k, v in calib_data["absmax"].items(): + weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device) + + return weight_dict + + def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite + lazy_load_model_path = self.dit_quantized_ckpt + logger.info(f"Loading splited quant model from {lazy_load_model_path}") + pre_post_weight_dict = {} + + safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors") + with safe_open(safetensor_path, framework="pt", device="cpu") as f: + for k in f.keys(): + if f.get_tensor(k).dtype in [ + torch.float16, + torch.bfloat16, + torch.float, + ]: + if unified_dtype or all(s not in k for s in sensitive_layer): + pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device) + else: + pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device) + else: + pre_post_weight_dict[k] = f.get_tensor(k).to(self.device) + + return pre_post_weight_dict + + def _load_weights_from_rank0(self, weight_dict, is_weight_loader): + logger.info("Loading distributed weights") + global_src_rank = 0 + target_device = "cpu" if self.cpu_offload else "cuda" + + if is_weight_loader: + meta_dict = {} + for key, tensor in weight_dict.items(): + meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} + + obj_list = [meta_dict] + dist.broadcast_object_list(obj_list, src=global_src_rank) + synced_meta_dict = obj_list[0] + else: + obj_list = [None] + dist.broadcast_object_list(obj_list, src=global_src_rank) + synced_meta_dict = obj_list[0] + + distributed_weight_dict = {} + for key, meta in synced_meta_dict.items(): + distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) + + if target_device == "cuda": + dist.barrier(device_ids=[torch.cuda.current_device()]) + + for key in sorted(synced_meta_dict.keys()): + if is_weight_loader: + distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True) + + if target_device == "cpu": + if is_weight_loader: + gpu_tensor = distributed_weight_dict[key].cuda() + dist.broadcast(gpu_tensor, src=global_src_rank) + distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) + del gpu_tensor + torch.cuda.empty_cache() + else: + gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda") + dist.broadcast(gpu_tensor, src=global_src_rank) + distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) + del gpu_tensor + torch.cuda.empty_cache() + + if distributed_weight_dict[key].is_pinned(): + distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True) + else: + dist.broadcast(distributed_weight_dict[key], src=global_src_rank) + + if target_device == "cuda": + torch.cuda.synchronize() + else: + for tensor in distributed_weight_dict.values(): + if tensor.is_pinned(): + tensor.copy_(tensor, non_blocking=False) + + logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") + + return distributed_weight_dict + + def _init_infer(self): + self.transformer_infer = self.transformer_infer_class(self.config) + self.pre_infer = self.pre_infer_class(self.config) + self.post_infer = self.post_infer_class(self.config) + if hasattr(self.transformer_infer, "offload_manager"): + self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers) + + def to_cpu(self): + self.pre_weight.to_cpu() + self.transformer_weights.to_cpu() + self.post_weight.to_cpu() + + def to_cuda(self): + self.pre_weight.to_cuda() + self.transformer_weights.to_cuda() + self.post_weight.to_cuda() + + @torch.no_grad() + def infer(self, inputs): + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == 0: + self.to_cuda() + elif self.offload_granularity != "model": + self.pre_weight.to_cuda() + self.post_weight.to_cuda() + + t = self.scheduler.timesteps[self.scheduler.step_index] + latents = self.scheduler.latents + if self.config["task"] == "i2i": + image_latents = torch.cat([item["image_latents"] for item in inputs["image_encoder_output"]], dim=1) + latents_input = torch.cat([latents, image_latents], dim=1) + else: + latents_input = latents + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + img_shapes = inputs["img_shapes"] + + prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"] + prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"] + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + + hidden_states, encoder_hidden_states, _, pre_infer_out = self.pre_infer.infer( + weights=self.pre_weight, + hidden_states=latents_input, + timestep=timestep / 1000, + guidance=self.scheduler.guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + ) + + encoder_hidden_states, hidden_states = self.transformer_infer.infer( + block_weights=self.transformer_weights, + hidden_states=hidden_states.unsqueeze(0), + encoder_hidden_states=encoder_hidden_states.unsqueeze(0), + pre_infer_out=pre_infer_out, + ) + + noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out[0]) + + if self.config["do_true_cfg"]: + neg_prompt_embeds = inputs["text_encoder_output"]["negative_prompt_embeds"] + neg_prompt_embeds_mask = inputs["text_encoder_output"]["negative_prompt_embeds_mask"] + + negative_txt_seq_lens = neg_prompt_embeds_mask.sum(dim=1).tolist() if neg_prompt_embeds_mask is not None else None + + neg_hidden_states, neg_encoder_hidden_states, _, neg_pre_infer_out = self.pre_infer.infer( + weights=self.pre_weight, + hidden_states=latents_input, + timestep=timestep / 1000, + guidance=self.scheduler.guidance, + encoder_hidden_states_mask=neg_prompt_embeds_mask, + encoder_hidden_states=neg_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + ) + + neg_encoder_hidden_states, neg_hidden_states = self.transformer_infer.infer( + block_weights=self.transformer_weights, + hidden_states=neg_hidden_states.unsqueeze(0), + encoder_hidden_states=neg_encoder_hidden_states.unsqueeze(0), + pre_infer_out=neg_pre_infer_out, + ) + + neg_noise_pred = self.post_infer.infer(self.post_weight, neg_hidden_states, neg_pre_infer_out[0]) + + if self.config["task"] == "i2i": + noise_pred = noise_pred[:, : latents.size(1)] + + if self.config["do_true_cfg"]: + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + self.config["true_cfg_scale"] * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + noise_pred = noise_pred[:, : latents.size(1)] + self.scheduler.noise_pred = noise_pred diff --git a/lightx2v/models/networks/qwen_image/weights/post_weights.py b/lightx2v/models/networks/qwen_image/weights/post_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..a14add998995ff66c51bee607f887e87a6162973 --- /dev/null +++ b/lightx2v/models/networks/qwen_image/weights/post_weights.py @@ -0,0 +1,49 @@ +from lightx2v.common.modules.weight_module import WeightModule +from lightx2v.utils.registry_factory import ( + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, +) + + +class QwenImagePostWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.task = config["task"] + self.config = config + self.lazy_load = self.config.get("lazy_load", False) + if self.lazy_load: + assert NotImplementedError + self.lazy_load_file = False + + # norm_out + self.add_module( + "norm_out_linear", + MM_WEIGHT_REGISTER["Default"]( + "norm_out.linear.weight", + "norm_out.linear.bias", + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module("norm_out", LN_WEIGHT_REGISTER["Default"](eps=1e-6)) + + # proj_out + self.add_module( + "proj_out_linear", + MM_WEIGHT_REGISTER["Default"]( + "proj_out.weight", + "proj_out.bias", + self.lazy_load, + self.lazy_load_file, + ), + ) + + def to_cpu(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cpu"): + module.to_cpu(non_blocking=non_blocking) + + def to_cuda(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cuda"): + module.to_cuda(non_blocking=non_blocking) diff --git a/lightx2v/models/networks/qwen_image/weights/pre_weights.py b/lightx2v/models/networks/qwen_image/weights/pre_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..80fe1b75c60e51dea3e0f932e0c7a85d38eb41fc --- /dev/null +++ b/lightx2v/models/networks/qwen_image/weights/pre_weights.py @@ -0,0 +1,40 @@ +from lightx2v.common.modules.weight_module import WeightModule +from lightx2v.utils.registry_factory import ( + MM_WEIGHT_REGISTER, + RMS_WEIGHT_REGISTER, +) + + +class QwenImagePreWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.config = config + # img_in + self.add_module( + "img_in", + MM_WEIGHT_REGISTER["Default"]("img_in.weight", "img_in.bias"), + ) + # txt_in + self.add_module( + "txt_in", + MM_WEIGHT_REGISTER["Default"]("txt_in.weight", "txt_in.bias"), + ) + # txt_norm + self.add_module("txt_norm", RMS_WEIGHT_REGISTER["fp32_variance"]("txt_norm.weight")) + # time_text_embed + self.add_module( + "time_text_embed_timestep_embedder_linear_1", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_1.weight", "time_text_embed.timestep_embedder.linear_1.bias") + ) + self.add_module( + "time_text_embed_timestep_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_2.weight", "time_text_embed.timestep_embedder.linear_2.bias") + ) + + def to_cpu(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cpu"): + module.to_cpu(non_blocking=non_blocking) + + def to_cuda(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cuda"): + module.to_cuda(non_blocking=non_blocking) diff --git a/lightx2v/models/networks/qwen_image/weights/transformer_weights.py b/lightx2v/models/networks/qwen_image/weights/transformer_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..5fec43a906449f453b83e2a73c4b8d962e4c9b90 --- /dev/null +++ b/lightx2v/models/networks/qwen_image/weights/transformer_weights.py @@ -0,0 +1,323 @@ +import os + +from safetensors import safe_open + +from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER + + +class QwenImageTransformerWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.blocks_num = config["num_layers"] + self.task = config["task"] + self.config = config + self.mm_type = config.get("dit_quant_scheme", "Default") + if self.mm_type != "Default": + assert config.get("dit_quantized") is True + blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, False, "transformer_blocks") for i in range(self.blocks_num)) + self.register_offload_buffers(config) + self.add_module("blocks", blocks) + + def register_offload_buffers(self, config): + if config["cpu_offload"]: + if config["offload_granularity"] == "block": + self.offload_blocks_num = 2 + self.offload_block_cuda_buffers = WeightModuleList( + [QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, True, False, "transformer_blocks") for i in range(self.offload_blocks_num)] + ) + self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers) + self.offload_phase_cuda_buffers = None + else: + raise NotImplementedError + + +class QwenImageTransformerAttentionBlock(WeightModule): + def __init__(self, block_index, task, mm_type, config, create_cuda_buffer=False, create_cpu_buffer=False, block_prefix="transformer_blocks"): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + self.sparge = config.get("sparge", False) + + self.lazy_load = self.config.get("lazy_load", False) + if self.lazy_load: + lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors") + self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu") + else: + self.lazy_load_file = None + + # Image processing modules + self.add_module( + "img_mod", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.img_mod.1.weight", + f"{block_prefix}.{self.block_index}.img_mod.1.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "img_norm1", + LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6), + ) + self.attn = QwenImageCrossAttention( + block_index=block_index, + block_prefix="transformer_blocks", + task=config["task"], + mm_type=mm_type, + config=config, + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + ) + self.add_module("attn", self.attn) + + self.add_module( + "img_norm2", + LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6), + ) + img_mlp = QwenImageFFN( + block_index=block_index, + block_prefix="transformer_blocks", + ffn_prefix="img_mlp", + task=config["task"], + mm_type=mm_type, + config=config, + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + ) + self.add_module("img_mlp", img_mlp) + + # Text processing modules + self.add_module( + "txt_mod", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.txt_mod.1.weight", + f"{block_prefix}.{self.block_index}.txt_mod.1.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "txt_norm1", + LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6), + ) + + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.add_module( + "txt_norm2", + LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6), + ) + txt_mlp = QwenImageFFN( + block_index=block_index, + block_prefix="transformer_blocks", + ffn_prefix="txt_mlp", + task=config["task"], + mm_type=mm_type, + config=config, + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + ) + self.add_module("txt_mlp", txt_mlp) + + +class QwenImageCrossAttention(WeightModule): + def __init__(self, block_index, block_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + self.sparge = config.get("sparge", False) + self.attn_type = config.get("attn_type", "flash_attn3") + self.heads = config["attention_out_dim"] // config["attention_dim_head"] + + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + + # norm_q + self.add_module( + "norm_q", + RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_q.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer), + ) + # norm_k + self.add_module( + "norm_k", + RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_k.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer), + ) + # to_q + self.add_module( + "to_q", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.attn.to_q.weight", + f"{block_prefix}.{self.block_index}.attn.to_q.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + # to_k + self.add_module( + "to_k", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.attn.to_k.weight", + f"{block_prefix}.{self.block_index}.attn.to_k.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + # to_v + self.add_module( + "to_v", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.attn.to_v.weight", + f"{block_prefix}.{self.block_index}.attn.to_v.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + # add_q_proj + self.add_module( + "add_q_proj", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.attn.add_q_proj.weight", + f"{block_prefix}.{self.block_index}.attn.add_q_proj.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + # add_k_proj + self.add_module( + "add_k_proj", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.attn.add_k_proj.weight", + f"{block_prefix}.{self.block_index}.attn.add_k_proj.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + # add_v_proj + self.add_module( + "add_v_proj", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.attn.add_v_proj.weight", + f"{block_prefix}.{self.block_index}.attn.add_v_proj.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + # to_out + self.add_module( + "to_out", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.attn.to_out.0.weight", + f"{block_prefix}.{self.block_index}.attn.to_out.0.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + # to_add_out + self.add_module( + "to_add_out", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.attn.to_add_out.weight", + f"{block_prefix}.{self.block_index}.attn.to_add_out.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + # norm_added_q + self.add_module( + "norm_added_q", + RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_q.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer), + ) + # norm_added_k + self.add_module( + "norm_added_k", + RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_k.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer), + ) + # attn + self.add_module("calculate", ATTN_WEIGHT_REGISTER[self.attn_type]()) + + def to_cpu(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cpu"): + module.to_cpu(non_blocking=non_blocking) + + def to_cuda(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cuda"): + module.to_cuda(non_blocking=non_blocking) + + +class QwenImageFFN(WeightModule): + def __init__(self, block_index, block_prefix, ffn_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + + self.add_module( + "mlp_0", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.weight", + f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "mlp_2", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.weight", + f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + def to_cpu(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cpu"): + module.to_cpu(non_blocking=non_blocking) + + def to_cuda(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cuda"): + module.to_cuda(non_blocking=non_blocking) diff --git a/lightx2v/models/networks/wan/animate_model.py b/lightx2v/models/networks/wan/animate_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b78555f2ad42768440d41f386f77976ac368b44d --- /dev/null +++ b/lightx2v/models/networks/wan/animate_model.py @@ -0,0 +1,22 @@ +from lightx2v.models.networks.wan.infer.animate.pre_infer import WanAnimatePreInfer +from lightx2v.models.networks.wan.infer.animate.transformer_infer import WanAnimateTransformerInfer +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.networks.wan.weights.animate.transformer_weights import WanAnimateTransformerWeights +from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights + + +class WanAnimateModel(WanModel): + pre_weight_class = WanPreWeights + transformer_weight_class = WanAnimateTransformerWeights + + def __init__(self, model_path, config, device): + self.remove_keys = ["face_encoder", "motion_encoder"] + super().__init__(model_path, config, device) + + def _init_infer_class(self): + super()._init_infer_class() + self.pre_infer_class = WanAnimatePreInfer + self.transformer_infer_class = WanAnimateTransformerInfer + + def set_animate_encoders(self, motion_encoder, face_encoder): + self.pre_infer.set_animate_encoders(motion_encoder, face_encoder) diff --git a/lightx2v/models/networks/wan/audio_model.py b/lightx2v/models/networks/wan/audio_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9813280c98d9ff3f82e3102fa2ac3fbc69cc755f --- /dev/null +++ b/lightx2v/models/networks/wan/audio_model.py @@ -0,0 +1,164 @@ +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from loguru import logger + +from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer +from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer +from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAudioTransformerWeights +from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights +from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights +from lightx2v.utils.utils import load_weights +from lightx2v_platform.base.global_var import AI_DEVICE + + +class WanAudioModel(WanModel): + pre_weight_class = WanPreWeights + post_weight_class = WanPostWeights + transformer_weight_class = WanAudioTransformerWeights + + def __init__(self, model_path, config, device): + self.config = config + self._load_adapter_ckpt() + super().__init__(model_path, config, device) + + def _load_adapter_ckpt(self): + if self.config.get("adapter_model_path", None) is None: + if self.config.get("adapter_quantized", False): + if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl"]: + adapter_model_name = "audio_adapter_model_fp8.safetensors" + elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-sgl", "int8-tmo"]: + adapter_model_name = "audio_adapter_model_int8.safetensors" + elif self.config.get("adapter_quant_scheme", None) in ["mxfp4"]: + adapter_model_name = "audio_adapter_model_mxfp4.safetensors" + elif self.config.get("adapter_quant_scheme", None) in ["mxfp6", "mxfp6-mxfp8"]: + adapter_model_name = "audio_adapter_model_mxfp6.safetensors" + elif self.config.get("adapter_quant_scheme", None) in ["mxfp8"]: + adapter_model_name = "audio_adapter_model_mxfp8.safetensors" + else: + raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}") + else: + adapter_model_name = "audio_adapter_model.safetensors" + self.config["adapter_model_path"] = os.path.join(self.config["model_path"], adapter_model_name) + + adapter_offload = self.config.get("cpu_offload", False) + load_from_rank0 = self.config.get("load_from_rank0", False) + self.adapter_weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0) + if not adapter_offload: + if not dist.is_initialized() or not load_from_rank0: + for key in self.adapter_weights_dict: + self.adapter_weights_dict[key] = self.adapter_weights_dict[key].to(torch.device(AI_DEVICE)) + + def _init_infer_class(self): + super()._init_infer_class() + self.pre_infer_class = WanAudioPreInfer + self.post_infer_class = WanAudioPostInfer + self.transformer_infer_class = WanAudioTransformerInfer + + def get_graph_name(self, shape, audio_num, with_mask=True): + return f"graph_{shape[0]}x{shape[1]}_audio_num_{audio_num}_mask_{with_mask}" + + def start_compile(self, shape, audio_num, with_mask=True): + graph_name = self.get_graph_name(shape, audio_num, with_mask) + logger.info(f"[Compile] Compile shape: {shape}, audio_num:{audio_num}, graph_name: {graph_name}") + + target_video_length = self.config.get("target_video_length", 81) + latents_length = (target_video_length - 1) // 16 * 4 + 1 + latents_h = shape[0] // self.config["vae_stride"][1] + latents_w = shape[1] // self.config["vae_stride"][2] + + new_inputs = {} + new_inputs["text_encoder_output"] = {} + new_inputs["text_encoder_output"]["context"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda() + new_inputs["text_encoder_output"]["context_null"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda() + + new_inputs["image_encoder_output"] = {} + new_inputs["image_encoder_output"]["clip_encoder_out"] = torch.randn(257, 1280, dtype=torch.bfloat16).cuda() + new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda() + + new_inputs["audio_encoder_output"] = torch.randn(audio_num, latents_length, 128, 1024, dtype=torch.bfloat16).cuda() + if with_mask: + new_inputs["person_mask_latens"] = torch.zeros(audio_num, 1, (latents_h // 2), (latents_w // 2), dtype=torch.int8).cuda() + else: + assert audio_num == 1, "audio_num must be 1 when with_mask is False" + new_inputs["person_mask_latens"] = None + + new_inputs["previmg_encoder_output"] = {} + new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda() + new_inputs["previmg_encoder_output"]["prev_mask"] = torch.randn(4, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda() + + self.scheduler.latents = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda() + self.scheduler.timestep_input = torch.tensor([600.0], dtype=torch.float32).cuda() + self.scheduler.audio_adapter_t_emb = torch.randn(1, 3, 5120, dtype=torch.bfloat16).cuda() + + self._infer_cond_uncond(new_inputs, infer_condition=True, graph_name=graph_name) + + def compile(self, compile_shapes): + self.check_compile_shapes(compile_shapes) + self.enable_compile_mode("_infer_cond_uncond") + + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]: + self.to_cuda() + elif self.offload_granularity != "model": + self.pre_weight.to_cuda() + self.transformer_weights.non_block_weights_to_cuda() + + max_audio_num_num = self.config.get("compile_max_audios", 1) + for audio_num in range(1, max_audio_num_num + 1): + for shape in compile_shapes: + self.start_compile(shape, audio_num, with_mask=True) + if audio_num == 1: + self.start_compile(shape, audio_num, with_mask=False) + + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]: + self.to_cpu() + elif self.offload_granularity != "model": + self.pre_weight.to_cpu() + self.transformer_weights.non_block_weights_to_cpu() + + self.disable_compile_mode("_infer_cond_uncond") + logger.info(f"[Compile] Compile status: {self.get_compile_status()}") + + def check_compile_shapes(self, compile_shapes): + for shape in compile_shapes: + assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]] + + def select_graph_for_compile(self, input_info): + logger.info(f"target_h, target_w : {input_info.target_shape[0]}, {input_info.target_shape[1]}, audio_num: {input_info.audio_num}") + graph_name = self.get_graph_name(input_info.target_shape, input_info.audio_num, with_mask=input_info.with_mask) + self.select_graph("_infer_cond_uncond", graph_name) + logger.info(f"[Compile] Compile status: {self.get_compile_status()}") + + @torch.no_grad() + def _seq_parallel_pre_process(self, pre_infer_out): + x = pre_infer_out.x + person_mask_latens = pre_infer_out.adapter_args["person_mask_latens"] + + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + + padding_size = (world_size - (x.shape[0] % world_size)) % world_size + if padding_size > 0: + x = F.pad(x, (0, 0, 0, padding_size)) + if person_mask_latens is not None: + person_mask_latens = F.pad(person_mask_latens, (0, padding_size)) + + pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank] + if person_mask_latens is not None: + pre_infer_out.adapter_args["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank] + + if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v"]: + embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0 + padding_size = (world_size - (embed.shape[0] % world_size)) % world_size + if padding_size > 0: + embed = F.pad(embed, (0, 0, 0, padding_size)) + embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) + pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank] + pre_infer_out.embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank] + return pre_infer_out diff --git a/lightx2v/models/networks/wan/causvid_model.py b/lightx2v/models/networks/wan/causvid_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a76dccb22da97330908fe8dc6adfb44e5294a485 --- /dev/null +++ b/lightx2v/models/networks/wan/causvid_model.py @@ -0,0 +1,58 @@ +import os + +import torch + +from lightx2v.models.networks.wan.infer.causvid.transformer_infer import ( + WanTransformerInferCausVid, +) +from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer +from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights +from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights +from lightx2v.models.networks.wan.weights.transformer_weights import ( + WanTransformerWeights, +) +from lightx2v.utils.envs import * +from lightx2v.utils.utils import find_torch_model_path + + +class WanCausVidModel(WanModel): + pre_weight_class = WanPreWeights + post_weight_class = WanPostWeights + transformer_weight_class = WanTransformerWeights + + def __init__(self, model_path, config, device): + super().__init__(model_path, config, device) + + def _init_infer_class(self): + self.pre_infer_class = WanPreInfer + self.post_infer_class = WanPostInfer + self.transformer_infer_class = WanTransformerInferCausVid + + def _load_ckpt(self, unified_dtype, sensitive_layer): + ckpt_path = find_torch_model_path(self.config, self.model_path, "causvid_model.pt") + if os.path.exists(ckpt_path): + weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + weight_dict = { + key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) + for key in weight_dict.keys() + } + return weight_dict + + return super()._load_ckpt(unified_dtype, sensitive_layer) + + @torch.no_grad() + def infer(self, inputs, kv_start, kv_end): + if self.config["cpu_offload"]: + self.pre_weight.to_cuda() + self.transformer_weights.post_weights_to_cuda() + + embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, kv_start=kv_start, kv_end=kv_end) + + x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end) + self.scheduler.noise_pred = self.post_infer.infer(x, embed, grid_sizes)[0] + + if self.config["cpu_offload"]: + self.pre_weight.to_cpu() + self.transformer_weights.post_weights_to_cpu() diff --git a/lightx2v/models/networks/wan/distill_model.py b/lightx2v/models/networks/wan/distill_model.py new file mode 100644 index 0000000000000000000000000000000000000000..96c48b369108e87a0a3a0eeb468da54f8f185626 --- /dev/null +++ b/lightx2v/models/networks/wan/distill_model.py @@ -0,0 +1,35 @@ +import os + +import torch +from loguru import logger + +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights +from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights +from lightx2v.models.networks.wan.weights.transformer_weights import ( + WanTransformerWeights, +) +from lightx2v.utils.envs import * +from lightx2v.utils.utils import * + + +class WanDistillModel(WanModel): + pre_weight_class = WanPreWeights + post_weight_class = WanPostWeights + transformer_weight_class = WanTransformerWeights + + def __init__(self, model_path, config, device, model_type="wan2.1"): + super().__init__(model_path, config, device, model_type) + + def _load_ckpt(self, unified_dtype, sensitive_layer): + # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill + ckpt_path = os.path.join(self.model_path, "distill_model.pt") + if os.path.exists(ckpt_path): + logger.info(f"Loading weights from {ckpt_path}") + weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + weight_dict = { + key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) + for key in weight_dict.keys() + } + return weight_dict + return super()._load_ckpt(unified_dtype, sensitive_layer) diff --git a/lightx2v/models/networks/wan/infer/animate/pre_infer.py b/lightx2v/models/networks/wan/infer/animate/pre_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f1b4194a4df061774fa9b2b7c9c71bce034c8f --- /dev/null +++ b/lightx2v/models/networks/wan/infer/animate/pre_infer.py @@ -0,0 +1,31 @@ +import math + +import torch + +from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer +from lightx2v.utils.envs import * + + +class WanAnimatePreInfer(WanPreInfer): + def __init__(self, config): + super().__init__(config) + self.encode_bs = 8 + + def set_animate_encoders(self, motion_encoder, face_encoder): + self.motion_encoder = motion_encoder + self.face_encoder = face_encoder + + @torch.no_grad() + def after_patch_embedding(self, weights, x, pose_latents, face_pixel_values): + pose_latents = weights.pose_patch_embedding.apply(pose_latents) + x[:, :, 1:].add_(pose_latents) + + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0] / self.encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * self.encode_bs : (i + 1) * self.encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + motion_vec = self.face_encoder(motion_vec.unsqueeze(0).to(GET_DTYPE())).squeeze(0) + pad_face = torch.zeros(1, motion_vec.shape[1], motion_vec.shape[2], dtype=motion_vec.dtype, device="cuda") + motion_vec = torch.cat([pad_face, motion_vec], dim=0) + return x, motion_vec diff --git a/lightx2v/models/networks/wan/infer/animate/transformer_infer.py b/lightx2v/models/networks/wan/infer/animate/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8edb16f23d4aa134ae2ebc021465a2cacc3d17 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/animate/transformer_infer.py @@ -0,0 +1,74 @@ +import torch +from einops import rearrange + +from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer + + +class WanAnimateTransformerInfer(WanOffloadTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.has_post_adapter = True + self.phases_num = 4 + + def infer_with_blocks_offload(self, blocks, x, pre_infer_out): + for block_idx in range(len(blocks)): + self.block_idx = block_idx + if block_idx == 0: + self.offload_manager.init_first_buffer(blocks, block_idx // 5) + if block_idx < len(blocks) - 1: + self.offload_manager.prefetch_weights(block_idx + 1, blocks, (block_idx + 1) // 5) + + with torch.cuda.stream(self.offload_manager.compute_stream): + x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out) + self.offload_manager.swap_blocks() + return x + + def infer_phases(self, block_idx, blocks, x, pre_infer_out, lazy): + for phase_idx in range(self.phases_num): + if block_idx == 0 and phase_idx == 0: + if lazy: + obj_key = (block_idx, phase_idx) + phase = self.offload_manager.pin_memory_buffer.get(obj_key) + phase.to_cuda() + self.offload_manager.cuda_buffers[0] = (obj_key, phase) + else: + self.offload_manager.init_first_buffer(blocks, block_idx // 5) + is_last_phase = block_idx == len(blocks) - 1 and phase_idx == self.phases_num - 1 + if not is_last_phase: + next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx + next_phase_idx = (phase_idx + 1) % self.phases_num + self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks, (block_idx + 1) // 5) + + with torch.cuda.stream(self.offload_manager.compute_stream): + x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out) + + self.offload_manager.swap_phases() + + return x + + @torch.no_grad() + def infer_post_adapter(self, phase, x, pre_infer_out): + if phase.is_empty() or phase.linear1_kv.weight is None: + return x + T = pre_infer_out.adapter_args["motion_vec"].shape[0] + x_motion = phase.pre_norm_motion.apply(pre_infer_out.adapter_args["motion_vec"]) + x_feat = phase.pre_norm_feat.apply(x) + kv = phase.linear1_kv.apply(x_motion.view(-1, x_motion.shape[-1])) + kv = kv.view(T, -1, kv.shape[-1]) + q = phase.linear1_q.apply(x_feat) + k, v = rearrange(kv, "L N (K H D) -> K L N H D", K=2, H=self.config["num_heads"]) + q = rearrange(q, "S (H D) -> S H D", H=self.config["num_heads"]) + + q = phase.q_norm.apply(q).view(T, q.shape[0] // T, q.shape[1], q.shape[2]) + k = phase.k_norm.apply(k) + attn = phase.adapter_attn.apply( + q=q, + k=k, + v=v, + max_seqlen_q=q.shape[1], + model_cls=self.config["model_cls"], + ) + + output = phase.linear2.apply(attn) + x = x.add_(output) + return x diff --git a/lightx2v/models/networks/wan/infer/audio/post_infer.py b/lightx2v/models/networks/wan/infer/audio/post_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..3653cddb4e02c862b48654bb718c92d870fdb47c --- /dev/null +++ b/lightx2v/models/networks/wan/infer/audio/post_infer.py @@ -0,0 +1,21 @@ +import torch + +from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer +from lightx2v.utils.envs import * + + +class WanAudioPostInfer(WanPostInfer): + def __init__(self, config): + super().__init__(config) + + @torch.no_grad() + def infer(self, x, pre_infer_out): + t, h, w = pre_infer_out.grid_sizes.tuple + grid_sizes = (t - 1, h, w) + + x = self.unpatchify(x, grid_sizes) + + if self.clean_cuda_cache: + torch.cuda.empty_cache() + + return [u.float() for u in x] diff --git a/lightx2v/models/networks/wan/infer/audio/pre_infer.py b/lightx2v/models/networks/wan/infer/audio/pre_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..9c56c4162a60fc6d400219b1c936114313fa4675 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/audio/pre_infer.py @@ -0,0 +1,124 @@ +import torch + +from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer +from lightx2v.utils.envs import * + +from ..module_io import GridOutput, WanPreInferModuleOutput +from ..utils import sinusoidal_embedding_1d + + +class WanAudioPreInfer(WanPreInfer): + def __init__(self, config): + super().__init__(config) + assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 + self.config = config + self.task = config["task"] + self.freq_dim = config["freq_dim"] + self.dim = config["dim"] + self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) + self.infer_dtype = GET_DTYPE() + self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() + + @torch.no_grad() + def infer(self, weights, inputs): + infer_condition, latents, timestep_input = self.scheduler.infer_condition, self.scheduler.latents, self.scheduler.timestep_input + prev_latents = inputs["previmg_encoder_output"]["prev_latents"] + hidden_states = latents + if self.config["model_cls"] != "wan2.2_audio": + prev_mask = inputs["previmg_encoder_output"]["prev_mask"] + hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=0) + + x = hidden_states + t = timestep_input + + if infer_condition: + context = inputs["text_encoder_output"]["context"] + else: + context = inputs["text_encoder_output"]["context_null"] + + clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] + ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(latents.dtype) + + num_channels, _, height, width = x.shape + ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape + + if ref_num_channels != num_channels: + zero_padding = torch.zeros( + (num_channels - ref_num_channels, ref_num_frames, height, width), + dtype=latents.dtype, + device=latents.device, + ) + ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=0) + y = ref_image_encoder + + # embeddings + x = weights.patch_embedding.apply(x.unsqueeze(0)) + grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] + x = x.flatten(2).transpose(1, 2).contiguous() + # seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) + + y = weights.patch_embedding.apply(y.unsqueeze(0)) + y = y.flatten(2).transpose(1, 2).contiguous() + if not self.config.get("f2v_process", False): + x = torch.cat([x, y], dim=1).squeeze(0) + else: + x = x.squeeze(0) + + ####for r2v # zero temporl component corresponding to ref embeddings + # self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0 + grid_sizes_t += 1 + + person_mask_latens = inputs["person_mask_latens"] + if person_mask_latens is not None: + person_mask_latens = person_mask_latens.expand(-1, grid_sizes_t, -1, -1) + person_mask_latens = person_mask_latens.reshape(person_mask_latens.shape[0], -1) + + embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) + if self.sensitive_layer_dtype != self.infer_dtype: + embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype)) + else: + embed = weights.time_embedding_0.apply(embed) + embed = torch.nn.functional.silu(embed) + + embed = weights.time_embedding_2.apply(embed) + embed0 = torch.nn.functional.silu(embed) + embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) + + # text embeddings + if self.sensitive_layer_dtype != self.infer_dtype: + out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype)) + else: + out = weights.text_embedding_0.apply(context.squeeze(0)) + out = torch.nn.functional.gelu(out, approximate="tanh") + context = weights.text_embedding_2.apply(out) + if self.clean_cuda_cache: + del out + torch.cuda.empty_cache() + + if self.task in ["i2v", "s2v"] and self.config.get("use_image_encoder", True): + context_clip = weights.proj_0.apply(clip_fea) + if self.clean_cuda_cache: + del clip_fea + torch.cuda.empty_cache() + context_clip = weights.proj_1.apply(context_clip) + context_clip = torch.nn.functional.gelu(context_clip, approximate="none") + context_clip = weights.proj_3.apply(context_clip) + context_clip = weights.proj_4.apply(context_clip) + if self.clean_cuda_cache: + torch.cuda.empty_cache() + context = torch.concat([context_clip, context], dim=0) + + if self.clean_cuda_cache: + if self.config.get("use_image_encoder", True): + del context_clip + torch.cuda.empty_cache() + + grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w)) + return WanPreInferModuleOutput( + embed=embed, + grid_sizes=grid_sizes, + x=x, + embed0=embed0.squeeze(0), + context=context, + adapter_args={"audio_encoder_output": inputs["audio_encoder_output"], "person_mask_latens": person_mask_latens}, + ) diff --git a/lightx2v/models/networks/wan/infer/audio/transformer_infer.py b/lightx2v/models/networks/wan/infer/audio/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0dfa0c0ebfad8966f6a9f88530d6bdc8ce9517 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/audio/transformer_infer.py @@ -0,0 +1,96 @@ +import torch +import torch.distributed as dist +from loguru import logger + +try: + import flash_attn # noqa: F401 + from flash_attn.flash_attn_interface import flash_attn_varlen_func +except ImportError: + logger.info("flash_attn_varlen_func not found, please install flash_attn2 first") + flash_attn_varlen_func = None + +from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import calculate_n_query_tokens, get_qk_lens_audio_range +from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer + + +class WanAudioTransformerInfer(WanOffloadTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.has_post_adapter = True + self.phases_num = 4 + + @torch.no_grad() + def infer_post_adapter(self, phase, x, pre_infer_out): + grid_sizes = pre_infer_out.grid_sizes.tensor + audio_encoder_output = pre_infer_out.adapter_args["audio_encoder_output"] + person_mask_latens = pre_infer_out.adapter_args["person_mask_latens"] + total_tokens = grid_sizes[0].prod() + pre_frame_tokens = grid_sizes[0][1:].prod() + n_tokens = total_tokens - pre_frame_tokens # 去掉ref image的token数 + + ori_dtype = x.dtype + device = x.device + n_tokens_per_rank = torch.tensor(x.size(0), dtype=torch.int32, device=device) + + if self.seq_p_group is not None: + sp_size = dist.get_world_size(self.seq_p_group) + sp_rank = dist.get_rank(self.seq_p_group) + else: + sp_size = 1 + sp_rank = 0 + + n_query_tokens, hidden_states_aligned, hidden_states_tail, person_mask_aligned = calculate_n_query_tokens(x, person_mask_latens, sp_rank, sp_size, n_tokens_per_rank, n_tokens) + + q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 = get_qk_lens_audio_range( + n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=pre_frame_tokens, sp_rank=sp_rank, num_tokens_x4=128 + ) + + total_residual = None + for i in range(audio_encoder_output.shape[0]): + audio_encoder = audio_encoder_output[i] + audio_encoder = audio_encoder[t0:t1].reshape(-1, audio_encoder.size(-1)) + residual = self.perceiver_attention_ca(phase, audio_encoder, hidden_states_aligned, self.scheduler.audio_adapter_t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k) + + residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入 + if n_query_tokens == 0: + residual = residual * 0.0 + if person_mask_aligned is not None: + residual = residual * person_mask_aligned[i].unsqueeze(-1) + + if total_residual is None: + total_residual = residual + else: + total_residual += residual + + x = torch.cat([hidden_states_aligned + total_residual, hidden_states_tail], dim=0) + return x + + @torch.no_grad() + def perceiver_attention_ca(self, phase, audio_encoder_output, latents, t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k): + audio_encoder_output = phase.norm_kv.apply(audio_encoder_output) + shift, scale, gate = (t_emb + phase.shift_scale_gate.tensor)[0].chunk(3, dim=0) + norm_q = phase.norm_q.apply(latents) + latents = norm_q * (1 + scale) + shift + q = phase.to_q.apply(latents) + k, v = phase.to_kv.apply(audio_encoder_output).chunk(2, dim=-1) + + q = q.view(q.size(0), self.num_heads, self.head_dim) + k = k.view(k.size(0), self.num_heads, self.head_dim) + v = v.view(v.size(0), self.num_heads, self.head_dim) + + out = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + ) + out = out.view(-1, self.num_heads * self.head_dim) + return phase.to_out.apply(out) * gate diff --git a/lightx2v/models/networks/wan/infer/causvid/__init__.py b/lightx2v/models/networks/wan/infer/causvid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/wan/infer/causvid/transformer_infer.py b/lightx2v/models/networks/wan/infer/causvid/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..17589850b64bf5ea96c37b33e29fdd3efd66de5b --- /dev/null +++ b/lightx2v/models/networks/wan/infer/causvid/transformer_infer.py @@ -0,0 +1,222 @@ +import math + +import torch + +from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer +from lightx2v.utils.envs import * + +from ..utils import apply_rotary_emb, compute_freqs_causvid + + +class WanTransformerInferCausVid(WanOffloadTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.num_frames = config["num_frames"] + self.num_frame_per_block = config["num_frame_per_block"] + self.frame_seq_length = config["frame_seq_length"] + self.text_len = config["text_len"] + self.kv_cache = None + self.crossattn_cache = None + + def _init_kv_cache(self, dtype, device): + kv_cache = [] + kv_size = self.num_frames * self.frame_seq_length + + for _ in range(self.blocks_num): + kv_cache.append( + { + "k": torch.zeros([kv_size, self.num_heads, self.head_dim], dtype=dtype, device=device), + "v": torch.zeros([kv_size, self.num_heads, self.head_dim], dtype=dtype, device=device), + } + ) + + self.kv_cache = kv_cache + + def _init_crossattn_cache(self, dtype, device): + crossattn_cache = [] + + for _ in range(self.blocks_num): + crossattn_cache.append( + { + "k": torch.zeros([self.text_len, self.num_heads, self.head_dim], dtype=dtype, device=device), + "v": torch.zeros([self.text_len, self.num_heads, self.head_dim], dtype=dtype, device=device), + "is_init": False, + } + ) + + self.crossattn_cache = crossattn_cache + + def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end): + return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end) + + def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end): + for block_idx in range(self.blocks_num): + if block_idx == 0: + self.weights_stream_mgr.active_weights[0] = weights.blocks[0] + self.weights_stream_mgr.active_weights[0].to_cuda() + + with torch.cuda.stream(self.weights_stream_mgr.compute_stream): + x = self.infer_block( + self.weights_stream_mgr.active_weights[0], + grid_sizes, + embed, + x, + embed0, + seq_lens, + freqs, + context, + block_idx, + kv_start, + kv_end, + ) + + if block_idx < self.blocks_num - 1: + self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks) + self.weights_stream_mgr.swap_weights() + + return x + + def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end): + for block_idx in range(self.blocks_num): + x = self.infer_block( + weights.blocks[block_idx], + grid_sizes, + embed, + x, + embed0, + seq_lens, + freqs, + context, + block_idx, + kv_start, + kv_end, + ) + return x + + def infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end): + norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) + norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0) + + s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim + q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d) + k = weights.self_attn_norm_k.apply(weights.self_attn_k.apply(norm1_out)).view(s, n, d) + v = weights.self_attn_v.apply(norm1_out).view(s, n, d) + + if not self.parallel_attention: + freqs_i = compute_freqs_causvid(q.size(2) // 2, grid_sizes, freqs, start_frame=kv_start // math.prod(grid_sizes[0][1:]).item()) + else: + # TODO: Implement parallel attention for causvid inference + raise NotImplementedError("Parallel attention is not implemented for causvid inference") + + q = apply_rotary_emb(q, freqs_i) + k = apply_rotary_emb(k, freqs_i) + + self.kv_cache[block_idx]["k"][kv_start:kv_end] = k + self.kv_cache[block_idx]["v"][kv_start:kv_end] = v + + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q=q, k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device)) + + if not self.parallel_attention: + attn_out = weights.self_attn_1.apply( + q=q, + k=self.kv_cache[block_idx]["k"][:kv_end], + v=self.kv_cache[block_idx]["v"][:kv_end], + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k.size(0), + model_cls=self.config["model_cls"], + ) + else: + # TODO: Implement parallel attention for causvid inference + raise NotImplementedError("Parallel attention is not implemented for causvid inference") + + y = weights.self_attn_o.apply(attn_out) + + x = x + y * embed0[2].squeeze(0) + + return x + + def infer_cross_attn(self, weights, x, context, block_idx): + norm3_out = weights.norm3.apply(x) + + if self.task in ["i2v", "s2v"]: + context_img = context[:257] + context = context[257:] + + n, d = self.num_heads, self.head_dim + q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d) + if not self.crossattn_cache[block_idx]["is_init"]: + k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d) + v = weights.cross_attn_v.apply(context).view(-1, n, d) + self.crossattn_cache[block_idx]["k"] = k + self.crossattn_cache[block_idx]["v"] = v + self.crossattn_cache[block_idx]["is_init"] = True + else: + k = self.crossattn_cache[block_idx]["k"] + v = self.crossattn_cache[block_idx]["v"] + + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) + + attn_out = weights.cross_attn_1.apply( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k.size(0), + model_cls=self.config["model_cls"], + ) + + if self.task in ["i2v", "s2v"]: + k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) + v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) + + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( + q, + k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), + ) + + img_attn_out = weights.cross_attn_2.apply( + q=q, + k=k_img, + v=v_img, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k_img.size(0), + model_cls=self.config["model_cls"], + ) + + attn_out = attn_out + img_attn_out + + attn_out = weights.cross_attn_o.apply(attn_out) + + x = x + attn_out + + return x + + def infer_ffn(self, weights, x, embed0): + norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) + y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0)) + y = torch.nn.functional.gelu(y, approximate="tanh") + y = weights.ffn_2.apply(y) + x = x + y * embed0[5].squeeze(0) + + return x + + def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end): + if embed0.dim() == 3: + modulation = weights.compute_phases[0].modulation.tensor.unsqueeze(2) # 1, 6, 1, dim + embed0 = embed0.unsqueeze(0) # + embed0 = (modulation + embed0).chunk(6, dim=1) + embed0 = [ei.squeeze(1) for ei in embed0] + elif embed0.dim() == 2: + embed0 = (weights.compute_phases[0].modulation.tensor + embed0).chunk(6, dim=1) + + x = self.infer_self_attn(weights.compute_phases[1], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end) + x = self.infer_cross_attn(weights.compute_phases[2], x, context, block_idx) + x = self.infer_ffn(weights.compute_phases[3], x, embed0) + + return x diff --git a/lightx2v/models/networks/wan/infer/feature_caching/__init__.py b/lightx2v/models/networks/wan/infer/feature_caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py b/lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f5342fde69bedf3e600cccf335bc9d3089c81b --- /dev/null +++ b/lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py @@ -0,0 +1,1104 @@ +import gc +import json + +import numpy as np +import torch +import torch.nn.functional as F + +from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer +from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer +from lightx2v_platform.base.global_var import AI_DEVICE + + +class WanTransformerInferCaching(WanOffloadTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.must_calc_steps = [] + if self.config.get("changing_resolution", False): + self.must_calc_steps = self.config["changing_resolution_steps"] + + def must_calc(self, step_index): + if step_index in self.must_calc_steps: + return True + return False + + +class WanTransformerInferTeaCaching(WanTransformerInferCaching): + def __init__(self, config): + super().__init__(config) + self.teacache_thresh = config["teacache_thresh"] + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = None + self.previous_residual_even = None + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = None + self.previous_residual_odd = None + self.use_ret_steps = config["use_ret_steps"] + if self.use_ret_steps: + self.coefficients = self.config["coefficients"][0] + self.ret_steps = 5 + self.cutoff_steps = self.config["infer_steps"] + else: + self.coefficients = self.config["coefficients"][1] + self.ret_steps = 1 + self.cutoff_steps = self.config["infer_steps"] - 1 + + # calculate should_calc + @torch.no_grad() + def calculate_should_calc(self, embed, embed0): + # 1. timestep embedding + modulated_inp = embed0 if self.use_ret_steps else embed + + # 2. L1 calculate + should_calc = False + if self.scheduler.infer_condition: + if self.scheduler.step_index < self.ret_steps or self.scheduler.step_index >= self.cutoff_steps: + should_calc = True + self.accumulated_rel_l1_distance_even = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_even += rescale_func( + ((modulated_inp - self.previous_e0_even.to(AI_DEVICE)).abs().mean() / self.previous_e0_even.to(AI_DEVICE).abs().mean()).cpu().item() + ) + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = modulated_inp.clone() + if self.config["cpu_offload"]: + self.previous_e0_even = self.previous_e0_even.cpu() + + else: + if self.scheduler.step_index < self.ret_steps or self.scheduler.step_index >= self.cutoff_steps: + should_calc = True + self.accumulated_rel_l1_distance_odd = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_e0_odd.to(AI_DEVICE)).abs().mean() / self.previous_e0_odd.to(AI_DEVICE).abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = modulated_inp.clone() + + if self.config["cpu_offload"]: + self.previous_e0_odd = self.previous_e0_odd.cpu() + + if self.config["cpu_offload"]: + modulated_inp = modulated_inp.cpu() + del modulated_inp + torch.cuda.empty_cache() + gc.collect() + + if self.clean_cuda_cache: + del embed, embed0 + torch.cuda.empty_cache() + + # 3. return the judgement + return should_calc + + def infer_main_blocks(self, weights, pre_infer_out): + if self.scheduler.infer_condition: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records + if index <= self.scheduler.infer_steps - 1: + should_calc = self.calculate_should_calc(pre_infer_out.embed, pre_infer_out.embed0) + self.scheduler.caching_records[index] = should_calc + + if caching_records[index] or self.must_calc(index): + x = self.infer_calculating(weights, pre_infer_out) + else: + x = self.infer_using_cache(pre_infer_out.x) + + else: + index = self.scheduler.step_index + caching_records_2 = self.scheduler.caching_records_2 + if index <= self.scheduler.infer_steps - 1: + should_calc = self.calculate_should_calc(pre_infer_out.embed, pre_infer_out.embed0) + self.scheduler.caching_records_2[index] = should_calc + + if caching_records_2[index] or self.must_calc(index): + x = self.infer_calculating(weights, pre_infer_out) + else: + x = self.infer_using_cache(pre_infer_out.x) + + if self.clean_cuda_cache: + del grid_sizes, embed, embed0, seq_lens, freqs, context + torch.cuda.empty_cache() + + return x + + def infer_calculating(self, weights, pre_infer_out): + ori_x = pre_infer_out.x.clone() + + x = super().infer_main_blocks(weights, pre_infer_out) + if self.scheduler.infer_condition: + self.previous_residual_even = x - ori_x + if self.config["cpu_offload"]: + self.previous_residual_even = self.previous_residual_even.cpu() + else: + self.previous_residual_odd = x - ori_x + if self.config["cpu_offload"]: + self.previous_residual_odd = self.previous_residual_odd.cpu() + + if self.config["cpu_offload"]: + ori_x = ori_x.to("cpu") + del ori_x + torch.cuda.empty_cache() + gc.collect() + return x + + def infer_using_cache(self, x): + if self.scheduler.infer_condition: + x.add_(self.previous_residual_even.to(AI_DEVICE)) + else: + x.add_(self.previous_residual_odd.to(AI_DEVICE)) + return x + + def clear(self): + if self.previous_residual_even is not None: + self.previous_residual_even = self.previous_residual_even.cpu() + if self.previous_residual_odd is not None: + self.previous_residual_odd = self.previous_residual_odd.cpu() + if self.previous_e0_even is not None: + self.previous_e0_even = self.previous_e0_even.cpu() + if self.previous_e0_odd is not None: + self.previous_e0_odd = self.previous_e0_odd.cpu() + + self.previous_residual_even = None + self.previous_residual_odd = None + self.previous_e0_even = None + self.previous_e0_odd = None + + torch.cuda.empty_cache() + + +class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCachingTransformerInfer): + def __init__(self, config): + super().__init__(config) + + self.blocks_cache_even = [{} for _ in range(self.blocks_num)] + self.blocks_cache_odd = [{} for _ in range(self.blocks_num)] + + # 1. get taylor step_diff when there is two caching_records in scheduler + def get_taylor_step_diff(self): + step_diff = 0 + if self.infer_conditional: + current_step = self.scheduler.step_index + last_calc_step = current_step - 1 + while last_calc_step >= 0 and not self.scheduler.caching_records[last_calc_step]: + last_calc_step -= 1 + step_diff = current_step - last_calc_step + else: + current_step = self.scheduler.step_index + last_calc_step = current_step - 1 + while last_calc_step >= 0 and not self.scheduler.caching_records_2[last_calc_step]: + last_calc_step -= 1 + step_diff = current_step - last_calc_step + + return step_diff + + def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + if self.infer_conditional: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records + + if caching_records[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + + else: + index = self.scheduler.step_index + caching_records_2 = self.scheduler.caching_records_2 + + if caching_records_2[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + + if self.config["enable_cfg"]: + self.switch_status() + + return x + + def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + for block_idx in range(self.blocks_num): + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(weights.blocks[block_idx].compute_phases[0], embed0) + + y_out = self.infer_self_attn(weights.blocks[block_idx].compute_phases[1], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa) + if self.infer_conditional: + self.derivative_approximation(self.blocks_cache_even[block_idx], "self_attn_out", y_out) + else: + self.derivative_approximation(self.blocks_cache_odd[block_idx], "self_attn_out", y_out) + + x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa) + if self.infer_conditional: + self.derivative_approximation(self.blocks_cache_even[block_idx], "cross_attn_out", attn_out) + else: + self.derivative_approximation(self.blocks_cache_odd[block_idx], "cross_attn_out", attn_out) + + y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) + if self.infer_conditional: + self.derivative_approximation(self.blocks_cache_even[block_idx], "ffn_out", y_out) + else: + self.derivative_approximation(self.blocks_cache_odd[block_idx], "ffn_out", y_out) + + x = self.post_process(x, y_out, c_gate_msa) + return x + + def infer_using_cache(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + for block_idx in range(self.blocks_num): + x = self.infer_block(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx) + return x + + # 1. taylor using caching + def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, i): + # 1. shift, scale, gate + _, _, gate_msa, _, _, c_gate_msa = self.infer_modulation(weights.compute_phases[0], embed0) + + # 2. residual and taylor + if self.infer_conditional: + out = self.taylor_formula(self.blocks_cache_even[i]["self_attn_out"]) + out = out * gate_msa.squeeze(0) + x = x + out + + out = self.taylor_formula(self.blocks_cache_even[i]["cross_attn_out"]) + x = x + out + + out = self.taylor_formula(self.blocks_cache_even[i]["ffn_out"]) + out = out * c_gate_msa.squeeze(0) + x = x + out + + else: + out = self.taylor_formula(self.blocks_cache_odd[i]["self_attn_out"]) + out = out * gate_msa.squeeze(0) + x = x + out + + out = self.taylor_formula(self.blocks_cache_odd[i]["cross_attn_out"]) + x = x + out + + out = self.taylor_formula(self.blocks_cache_odd[i]["ffn_out"]) + out = out * c_gate_msa.squeeze(0) + x = x + out + + return x + + def clear(self): + for cache in self.blocks_cache_even: + for key in cache: + if cache[key] is not None: + if isinstance(cache[key], torch.Tensor): + cache[key] = cache[key].cpu() + elif isinstance(cache[key], dict): + for k, v in cache[key].items(): + if isinstance(v, torch.Tensor): + cache[key][k] = v.cpu() + cache.clear() + + for cache in self.blocks_cache_odd: + for key in cache: + if cache[key] is not None: + if isinstance(cache[key], torch.Tensor): + cache[key] = cache[key].cpu() + elif isinstance(cache[key], dict): + for k, v in cache[key].items(): + if isinstance(v, torch.Tensor): + cache[key][k] = v.cpu() + cache.clear() + torch.cuda.empty_cache() + + +class WanTransformerInferAdaCaching(WanTransformerInferCaching): + def __init__(self, config): + super().__init__(config) + + # 1. fixed args + self.decisive_double_block_id = self.blocks_num // 2 + self.codebook = {0.03: 12, 0.05: 10, 0.07: 8, 0.09: 6, 0.11: 4, 1.00: 3} + + # 2. Create two instances of AdaArgs + self.args_even = AdaArgs(config) + self.args_odd = AdaArgs(config) + + def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + if self.infer_conditional: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records + + if caching_records[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + + # 1. calculate the skipped step length + if index <= self.scheduler.infer_steps - 2: + self.args_even.skipped_step_length = self.calculate_skip_step_length() + for i in range(1, self.args_even.skipped_step_length): + if (index + i) <= self.scheduler.infer_steps - 1: + self.scheduler.caching_records[index + i] = False + else: + x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + + else: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records_2 + + if caching_records[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + + # 1. calculate the skipped step length + if index <= self.scheduler.infer_steps - 2: + self.args_odd.skipped_step_length = self.calculate_skip_step_length() + for i in range(1, self.args_odd.skipped_step_length): + if (index + i) <= self.scheduler.infer_steps - 1: + self.scheduler.caching_records_2[index + i] = False + else: + x = self.infer_using_cache(xt) + + if self.config["enable_cfg"]: + self.switch_status() + + return x + + def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + ori_x = x.clone() + + for block_idx in range(self.blocks_num): + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(weights.blocks[block_idx].compute_phases[0], embed0) + + y_out = self.infer_self_attn(weights.blocks[block_idx].compute_phases[1], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa) + if block_idx == self.decisive_double_block_id: + if self.infer_conditional: + self.args_even.now_residual_tiny = y_out * gate_msa.squeeze(0) + else: + self.args_odd.now_residual_tiny = y_out * gate_msa.squeeze(0) + + x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa) + y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) + x = self.post_process(x, y_out, c_gate_msa) + + if self.infer_conditional: + self.args_even.previous_residual = x - ori_x + else: + self.args_odd.previous_residual = x - ori_x + return x + + def infer_using_cache(self, x): + if self.infer_conditional: + x += self.args_even.previous_residual + else: + x += self.args_odd.previous_residual + return x + + def calculate_skip_step_length(self): + if self.infer_conditional: + if self.args_even.previous_residual_tiny is None: + self.args_even.previous_residual_tiny = self.args_even.now_residual_tiny + return 1 + else: + cache = self.args_even.previous_residual_tiny + res = self.args_even.now_residual_tiny + norm_ord = self.args_even.norm_ord + cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord) + cache_diff = cache_diff / self.args_even.skipped_step_length + + if self.args_even.moreg_steps[0] <= self.scheduler.step_index <= self.args_even.moreg_steps[1]: + moreg = 0 + for i in self.args_even.moreg_strides: + moreg_i = (res[i * self.args_even.spatial_dim :, :] - res[: -i * self.args_even.spatial_dim, :]).norm(p=norm_ord) + moreg_i /= res[i * self.args_even.spatial_dim :, :].norm(p=norm_ord) + res[: -i * self.args_even.spatial_dim, :].norm(p=norm_ord) + moreg += moreg_i + moreg = moreg / len(self.args_even.moreg_strides) + moreg = ((1 / self.args_even.moreg_hyp[0] * moreg) ** self.args_even.moreg_hyp[1]) / self.args_even.moreg_hyp[2] + else: + moreg = 1.0 + + mograd = self.args_even.mograd_mul * (moreg - self.args_even.previous_moreg) / self.args_even.skipped_step_length + self.args_even.previous_moreg = moreg + moreg = moreg + abs(mograd) + cache_diff = cache_diff * moreg + + metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values()) + if cache_diff < metric_thres[0]: + new_rate = cache_rates[0] + elif cache_diff < metric_thres[1]: + new_rate = cache_rates[1] + elif cache_diff < metric_thres[2]: + new_rate = cache_rates[2] + elif cache_diff < metric_thres[3]: + new_rate = cache_rates[3] + elif cache_diff < metric_thres[4]: + new_rate = cache_rates[4] + else: + new_rate = cache_rates[-1] + + self.args_even.previous_residual_tiny = self.args_even.now_residual_tiny + return new_rate + + else: + if self.args_odd.previous_residual_tiny is None: + self.args_odd.previous_residual_tiny = self.args_odd.now_residual_tiny + return 1 + else: + cache = self.args_odd.previous_residual_tiny + res = self.args_odd.now_residual_tiny + norm_ord = self.args_odd.norm_ord + cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord) + cache_diff = cache_diff / self.args_odd.skipped_step_length + + if self.args_odd.moreg_steps[0] <= self.scheduler.step_index <= self.args_odd.moreg_steps[1]: + moreg = 0 + for i in self.args_odd.moreg_strides: + moreg_i = (res[i * self.args_odd.spatial_dim :, :] - res[: -i * self.args_odd.spatial_dim, :]).norm(p=norm_ord) + moreg_i /= res[i * self.args_odd.spatial_dim :, :].norm(p=norm_ord) + res[: -i * self.args_odd.spatial_dim, :].norm(p=norm_ord) + moreg += moreg_i + moreg = moreg / len(self.args_odd.moreg_strides) + moreg = ((1 / self.args_odd.moreg_hyp[0] * moreg) ** self.args_odd.moreg_hyp[1]) / self.args_odd.moreg_hyp[2] + else: + moreg = 1.0 + + mograd = self.args_odd.mograd_mul * (moreg - self.args_odd.previous_moreg) / self.args_odd.skipped_step_length + self.args_odd.previous_moreg = moreg + moreg = moreg + abs(mograd) + cache_diff = cache_diff * moreg + + metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values()) + if cache_diff < metric_thres[0]: + new_rate = cache_rates[0] + elif cache_diff < metric_thres[1]: + new_rate = cache_rates[1] + elif cache_diff < metric_thres[2]: + new_rate = cache_rates[2] + elif cache_diff < metric_thres[3]: + new_rate = cache_rates[3] + elif cache_diff < metric_thres[4]: + new_rate = cache_rates[4] + else: + new_rate = cache_rates[-1] + + self.args_odd.previous_residual_tiny = self.args_odd.now_residual_tiny + return new_rate + + def clear(self): + if self.args_even.previous_residual is not None: + self.args_even.previous_residual = self.args_even.previous_residual.cpu() + if self.args_even.previous_residual_tiny is not None: + self.args_even.previous_residual_tiny = self.args_even.previous_residual_tiny.cpu() + if self.args_even.now_residual_tiny is not None: + self.args_even.now_residual_tiny = self.args_even.now_residual_tiny.cpu() + + if self.args_odd.previous_residual is not None: + self.args_odd.previous_residual = self.args_odd.previous_residual.cpu() + if self.args_odd.previous_residual_tiny is not None: + self.args_odd.previous_residual_tiny = self.args_odd.previous_residual_tiny.cpu() + if self.args_odd.now_residual_tiny is not None: + self.args_odd.now_residual_tiny = self.args_odd.now_residual_tiny.cpu() + + self.args_even.previous_residual = None + self.args_even.previous_residual_tiny = None + self.args_even.now_residual_tiny = None + + self.args_odd.previous_residual = None + self.args_odd.previous_residual_tiny = None + self.args_odd.now_residual_tiny = None + + torch.cuda.empty_cache() + + +class AdaArgs: + def __init__(self, config): + # Cache related attributes + self.previous_residual_tiny = None + self.now_residual_tiny = None + self.norm_ord = 1 + self.skipped_step_length = 1 + self.previous_residual = None + + # Moreg related attributes + self.previous_moreg = 1.0 + self.moreg_strides = [1] + self.moreg_steps = [int(0.1 * config["infer_steps"]), int(0.9 * config["infer_steps"])] + self.moreg_hyp = [0.385, 8, 1, 2] + self.mograd_mul = 10 + self.spatial_dim = 1536 + + +class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCachingTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.cnt = 0 + self.teacache_thresh = config["teacache_thresh"] + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = None + self.previous_residual_even = None + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = None + self.previous_residual_odd = None + self.cache_even = {} + self.cache_odd = {} + self.use_ret_steps = config["use_ret_steps"] + if self.use_ret_steps: + self.coefficients = self.config["coefficients"][0] + self.ret_steps = 5 * 2 + self.cutoff_steps = self.config["infer_steps"] * 2 + else: + self.coefficients = self.config["coefficients"][1] + self.ret_steps = 1 * 2 + self.cutoff_steps = self.config["infer_steps"] * 2 - 2 + + # 1. get taylor step_diff when there is two caching_records in scheduler + def get_taylor_step_diff(self): + step_diff = 0 + if self.infer_conditional: + current_step = self.scheduler.step_index + last_calc_step = current_step - 1 + while last_calc_step >= 0 and not self.scheduler.caching_records[last_calc_step]: + last_calc_step -= 1 + step_diff = current_step - last_calc_step + else: + current_step = self.scheduler.step_index + last_calc_step = current_step - 1 + while last_calc_step >= 0 and not self.scheduler.caching_records_2[last_calc_step]: + last_calc_step -= 1 + step_diff = current_step - last_calc_step + + return step_diff + + # calculate should_calc + def calculate_should_calc(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + # 1. timestep embedding + modulated_inp = embed0 if self.use_ret_steps else embed + + # 2. L1 calculate + should_calc = False + if self.infer_conditional: + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc = True + self.accumulated_rel_l1_distance_even = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = modulated_inp.clone() + + else: + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc = True + self.accumulated_rel_l1_distance_odd = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = modulated_inp.clone() + + # 3. return the judgement + return should_calc + + def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + if self.infer_conditional: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records + if index <= self.scheduler.infer_steps - 1: + should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + self.scheduler.caching_records[index] = should_calc + + if caching_records[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x = self.infer_using_cache(x) + + else: + index = self.scheduler.step_index + caching_records_2 = self.scheduler.caching_records_2 + if index <= self.scheduler.infer_steps - 1: + should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + self.scheduler.caching_records_2[index] = should_calc + + if caching_records_2[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x = self.infer_using_cache(x) + + if self.config["enable_cfg"]: + self.switch_status() + + self.cnt += 1 + + return x + + def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + ori_x = x.clone() + + for block_idx in range(self.blocks_num): + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(weights.blocks[block_idx].compute_phases[0], embed0) + + y_out = self.infer_self_attn(weights.blocks[block_idx].compute_phases[1], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa) + x, attn_out = self.infer_cross_attn(weights.blocks[block_idx].compute_phases[2], x, context, y_out, gate_msa) + y_out = self.infer_ffn(weights.blocks[block_idx].compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) + x = self.post_process(x, y_out, c_gate_msa) + + if self.infer_conditional: + self.previous_residual_even = x - ori_x + self.derivative_approximation(self.cache_even, "previous_residual", self.previous_residual_even) + else: + self.previous_residual_odd = x - ori_x + self.derivative_approximation(self.cache_odd, "previous_residual", self.previous_residual_odd) + return x + + def infer_using_cache(self, x): + if self.infer_conditional: + x += self.taylor_formula(self.cache_even["previous_residual"]) + else: + x += self.taylor_formula(self.cache_odd["previous_residual"]) + return x + + def clear(self): + if self.previous_residual_even is not None: + self.previous_residual_even = self.previous_residual_even.cpu() + if self.previous_residual_odd is not None: + self.previous_residual_odd = self.previous_residual_odd.cpu() + if self.previous_e0_even is not None: + self.previous_e0_even = self.previous_e0_even.cpu() + if self.previous_e0_odd is not None: + self.previous_e0_odd = self.previous_e0_odd.cpu() + + for key in self.cache_even: + if self.cache_even[key] is not None and hasattr(self.cache_even[key], "cpu"): + self.cache_even[key] = self.cache_even[key].cpu() + self.cache_even.clear() + + for key in self.cache_odd: + if self.cache_odd[key] is not None and hasattr(self.cache_odd[key], "cpu"): + self.cache_odd[key] = self.cache_odd[key].cpu() + self.cache_odd.clear() + + self.previous_residual_even = None + self.previous_residual_odd = None + self.previous_e0_even = None + self.previous_e0_odd = None + + torch.cuda.empty_cache() + + +class WanTransformerInferFirstBlock(WanTransformerInferCaching): + def __init__(self, config): + super().__init__(config) + + self.residual_diff_threshold = config["residual_diff_threshold"] + self.prev_first_block_residual_even = None + self.prev_remaining_blocks_residual_even = None + self.prev_first_block_residual_odd = None + self.prev_remaining_blocks_residual_odd = None + self.downsample_factor = self.config["downsample_factor"] + + def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + ori_x = x.clone() + x = super().infer_block(weights.blocks[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context) + x_residual = x - ori_x + del ori_x + + if self.infer_conditional: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records + if index <= self.scheduler.infer_steps - 1: + should_calc = self.calculate_should_calc(x_residual) + self.scheduler.caching_records[index] = should_calc + + if caching_records[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x = self.infer_using_cache(x) + + else: + index = self.scheduler.step_index + caching_records_2 = self.scheduler.caching_records_2 + if index <= self.scheduler.infer_steps - 1: + should_calc = self.calculate_should_calc(x_residual) + self.scheduler.caching_records_2[index] = should_calc + + if caching_records_2[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x = self.infer_using_cache(x) + + if self.config["enable_cfg"]: + self.switch_status() + + return x + + def calculate_should_calc(self, x_residual): + diff = 1.0 + x_residual_downsampled = x_residual[..., :: self.downsample_factor] + if self.infer_conditional: + if self.prev_first_block_residual_even is not None: + t1 = self.prev_first_block_residual_even + t2 = x_residual_downsampled + mean_diff = (t1 - t2).abs().mean() + mean_t1 = t1.abs().mean() + diff = (mean_diff / mean_t1).item() + self.prev_first_block_residual_even = x_residual_downsampled + else: + if self.prev_first_block_residual_odd is not None: + t1 = self.prev_first_block_residual_odd + t2 = x_residual_downsampled + mean_diff = (t1 - t2).abs().mean() + mean_t1 = t1.abs().mean() + diff = (mean_diff / mean_t1).item() + self.prev_first_block_residual_odd = x_residual_downsampled + + return diff >= self.residual_diff_threshold + + def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + ori_x = x.clone() + + for block_idx in range(1, self.blocks_num): + x = super().infer_block( + weights.blocks[block_idx], + grid_sizes, + embed, + x, + embed0, + seq_lens, + freqs, + context, + ) + + if self.infer_conditional: + self.prev_remaining_blocks_residual_even = x - ori_x + else: + self.prev_remaining_blocks_residual_odd = x - ori_x + del ori_x + + return x + + def infer_using_cache(self, x): + if self.infer_conditional: + return x.add_(self.prev_remaining_blocks_residual_even) + else: + return x.add_(self.prev_remaining_blocks_residual_odd) + + def clear(self): + self.prev_first_block_residual_even = None + self.prev_remaining_blocks_residual_even = None + self.prev_first_block_residual_odd = None + self.prev_remaining_blocks_residual_odd = None + torch.cuda.empty_cache() + + +class WanTransformerInferDualBlock(WanTransformerInferCaching): + def __init__(self, config): + super().__init__(config) + + self.residual_diff_threshold = config["residual_diff_threshold"] + self.prev_front_blocks_residual_even = None + self.prev_middle_blocks_residual_even = None + self.prev_front_blocks_residual_odd = None + self.prev_middle_blocks_residual_odd = None + self.downsample_factor = self.config["downsample_factor"] + + def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + ori_x = x.clone() + for block_idx in range(0, 5): + x = super().infer_block( + weights.blocks[block_idx], + grid_sizes, + embed, + x, + embed0, + seq_lens, + freqs, + context, + ) + x_residual = x - ori_x + del ori_x + + if self.infer_conditional: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records + if index <= self.scheduler.infer_steps - 1: + should_calc = self.calculate_should_calc(x_residual) + self.scheduler.caching_records[index] = should_calc + + if caching_records[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x = self.infer_using_cache(x) + + else: + index = self.scheduler.step_index + caching_records_2 = self.scheduler.caching_records_2 + if index <= self.scheduler.infer_steps - 1: + should_calc = self.calculate_should_calc(x_residual) + self.scheduler.caching_records_2[index] = should_calc + + if caching_records_2[index] or self.must_calc(index): + x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x = self.infer_using_cache(x) + + for block_idx in range(self.blocks_num - 5, self.blocks_num): + x = super().infer_block( + weights.blocks[block_idx], + grid_sizes, + embed, + x, + embed0, + seq_lens, + freqs, + context, + ) + + if self.config["enable_cfg"]: + self.switch_status() + + return x + + def calculate_should_calc(self, x_residual): + diff = 1.0 + x_residual_downsampled = x_residual[..., :: self.downsample_factor] + if self.infer_conditional: + if self.prev_front_blocks_residual_even is not None: + t1 = self.prev_front_blocks_residual_even + t2 = x_residual_downsampled + mean_diff = (t1 - t2).abs().mean() + mean_t1 = t1.abs().mean() + diff = (mean_diff / mean_t1).item() + self.prev_front_blocks_residual_even = x_residual_downsampled + else: + if self.prev_front_blocks_residual_odd is not None: + t1 = self.prev_front_blocks_residual_odd + t2 = x_residual_downsampled + mean_diff = (t1 - t2).abs().mean() + mean_t1 = t1.abs().mean() + diff = (mean_diff / mean_t1).item() + self.prev_front_blocks_residual_odd = x_residual_downsampled + + return diff >= self.residual_diff_threshold + + def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + ori_x = x.clone() + + for block_idx in range(5, self.blocks_num - 5): + x = super().infer_block( + weights.blocks[block_idx], + grid_sizes, + embed, + x, + embed0, + seq_lens, + freqs, + context, + ) + + if self.infer_conditional: + self.prev_middle_blocks_residual_even = x - ori_x + else: + self.prev_middle_blocks_residual_odd = x - ori_x + del ori_x + + return x + + def infer_using_cache(self, x): + if self.infer_conditional: + return x.add_(self.prev_middle_blocks_residual_even) + else: + return x.add_(self.prev_middle_blocks_residual_odd) + + def clear(self): + self.prev_front_blocks_residual_even = None + self.prev_middle_blocks_residual_even = None + self.prev_front_blocks_residual_odd = None + self.prev_middle_blocks_residual_odd = None + torch.cuda.empty_cache() + + +class WanTransformerInferDynamicBlock(WanTransformerInferCaching): + def __init__(self, config): + super().__init__(config) + self.residual_diff_threshold = config["residual_diff_threshold"] + self.downsample_factor = self.config["downsample_factor"] + + self.block_in_cache_even = {i: None for i in range(self.blocks_num)} + self.block_residual_cache_even = {i: None for i in range(self.blocks_num)} + self.block_in_cache_odd = {i: None for i in range(self.blocks_num)} + self.block_residual_cache_odd = {i: None for i in range(self.blocks_num)} + + def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): + for block_idx in range(self.blocks_num): + x = self.infer_block(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx) + + return x + + def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx): + ori_x = x.clone() + + if self.infer_conditional: + if self.block_in_cache_even[block_idx] is not None: + should_calc = self.are_two_tensor_similar(self.block_in_cache_even[block_idx], x) + if should_calc or self.must_calc(block_idx): + x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x += self.block_residual_cache_even[block_idx] + + else: + x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + + self.block_in_cache_even[block_idx] = ori_x + self.block_residual_cache_even[block_idx] = x - ori_x + del ori_x + + else: + if self.block_in_cache_odd[block_idx] is not None: + should_calc = self.are_two_tensor_similar(self.block_in_cache_odd[block_idx], x) + if should_calc or self.must_calc(block_idx): + x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + else: + x += self.block_residual_cache_odd[block_idx] + + else: + x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) + + self.block_in_cache_odd[block_idx] = ori_x + self.block_residual_cache_odd[block_idx] = x - ori_x + del ori_x + + return x + + def are_two_tensor_similar(self, t1, t2): + diff = 1.0 + t1_downsampled = t1[..., :: self.downsample_factor] + t2_downsampled = t2[..., :: self.downsample_factor] + mean_diff = (t1_downsampled - t2_downsampled).abs().mean() + mean_t1 = t1_downsampled.abs().mean() + diff = (mean_diff / mean_t1).item() + + return diff >= self.residual_diff_threshold + + def clear(self): + for i in range(self.blocks_num): + self.block_in_cache_even[i] = None + self.block_residual_cache_even[i] = None + self.block_in_cache_odd[i] = None + self.block_residual_cache_odd[i] = None + torch.cuda.empty_cache() + + +class WanTransformerInferMagCaching(WanTransformerInferCaching): + def __init__(self, config): + super().__init__(config) + self.magcache_thresh = config["magcache_thresh"] + self.K = config["magcache_K"] + self.retention_ratio = config["magcache_retention_ratio"] + self.mag_ratios = np.array(config["magcache_ratios"]) + # {True: cond_param, False: uncond_param} + self.accumulated_err = {True: 0.0, False: 0.0} + self.accumulated_steps = {True: 0, False: 0} + self.accumulated_ratio = {True: 1.0, False: 1.0} + self.residual_cache = {True: None, False: None} + # calibration args + self.norm_ratio = [[1.0], [1.0]] # mean of magnitude ratio + self.norm_std = [[0.0], [0.0]] # std of magnitude ratio + self.cos_dis = [[0.0], [0.0]] # cosine distance of residual features + + def infer_main_blocks(self, weights, pre_infer_out): + skip_forward = False + step_index = self.scheduler.step_index + infer_condition = self.scheduler.infer_condition + + if self.config["magcache_calibration"]: + skip_forward = False + else: + if step_index >= int(self.config["infer_steps"] * self.retention_ratio): + # conditional and unconditional in one list + cur_mag_ratio = self.mag_ratios[0][step_index] if infer_condition else self.mag_ratios[1][step_index] + # magnitude ratio between current step and the cached step + self.accumulated_ratio[infer_condition] = self.accumulated_ratio[infer_condition] * cur_mag_ratio + self.accumulated_steps[infer_condition] += 1 # skip steps plus 1 + # skip error of current steps + cur_skip_err = np.abs(1 - self.accumulated_ratio[infer_condition]) + # accumulated error of multiple steps + self.accumulated_err[infer_condition] += cur_skip_err + + if self.accumulated_err[infer_condition] < self.magcache_thresh and self.accumulated_steps[infer_condition] <= self.K: + skip_forward = True + else: + self.accumulated_err[infer_condition] = 0 + self.accumulated_steps[infer_condition] = 0 + self.accumulated_ratio[infer_condition] = 1.0 + + if not skip_forward: + x = self.infer_calculating(weights, pre_infer_out) + else: + x = self.infer_using_cache(pre_infer_out.x) + + if self.clean_cuda_cache: + torch.cuda.empty_cache() + + return x + + def infer_calculating(self, weights, pre_infer_out): + step_index = self.scheduler.step_index + infer_condition = self.scheduler.infer_condition + + ori_x = pre_infer_out.x.clone() + + x = super().infer_main_blocks(weights, pre_infer_out) + + previous_residual = x - ori_x + if self.config["cpu_offload"]: + previous_residual = previous_residual.cpu() + + if self.config["magcache_calibration"] and step_index >= 1: + norm_ratio = ((previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).mean()).item() + norm_std = (previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).std().item() + cos_dis = (1 - F.cosine_similarity(previous_residual, self.residual_cache[infer_condition], dim=-1, eps=1e-8)).mean().item() + _index = int(not infer_condition) + self.norm_ratio[_index].append(round(norm_ratio, 5)) + self.norm_std[_index].append(round(norm_std, 5)) + self.cos_dis[_index].append(round(cos_dis, 5)) + print(f"time: {step_index}, infer_condition: {infer_condition}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}") + + self.residual_cache[infer_condition] = previous_residual + + if self.config["cpu_offload"]: + ori_x = ori_x.to("cpu") + del ori_x + torch.cuda.empty_cache() + gc.collect() + return x + + def infer_using_cache(self, x): + residual_x = self.residual_cache[self.scheduler.infer_condition] + x.add_(residual_x.to(AI_DEVICE)) + return x + + def clear(self): + self.accumulated_err = {True: 0.0, False: 0.0} + self.accumulated_steps = {True: 0, False: 0} + self.accumulated_ratio = {True: 1.0, False: 1.0} + self.residual_cache = {True: None, False: None} + if self.config["magcache_calibration"]: + print("norm ratio") + print(self.norm_ratio) + print("norm std") + print(self.norm_std) + print("cos_dis") + print(self.cos_dis) + + def save_json(filename, obj_list): + with open(filename + ".json", "w") as f: + json.dump(obj_list, f) + + save_json("wan2_1_mag_ratio", self.norm_ratio) + save_json("wan2_1_mag_std", self.norm_std) + save_json("wan2_1_cos_dis", self.cos_dis) + torch.cuda.empty_cache() diff --git a/lightx2v/models/networks/wan/infer/matrix_game2/posemb_layers.py b/lightx2v/models/networks/wan/infer/matrix_game2/posemb_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..549ea44daef992c8a45bc347015912b9d4e2c78e --- /dev/null +++ b/lightx2v/models/networks/wan/infer/matrix_game2/posemb_layers.py @@ -0,0 +1,291 @@ +from typing import List, Tuple, Union + +import torch + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32, device=torch.cuda.current_device())[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 + + +def reshape_for_broadcast( + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + x: torch.Tensor, + head_first=False, +): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Notes: + When using FlashMHAModified, head_first should be False. + When using Attention, head_first should be True. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + # assert freqs_cis[0].shape == ( + # x.shape[1], + # x.shape[-1], + # ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + # shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + shape = [1, freqs_cis[0].shape[0], 1, freqs_cis[0].shape[1]] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + head_first: bool = False, + start_offset: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + # print(freqs_cis[0].shape, xq.shape, xk.shape) + xk_out = None + assert isinstance(freqs_cis, tuple) + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = (xq.float() * cos[:, start_offset : start_offset + xq.shape[1], :, :] + rotate_half(xq.float()) * sin[:, start_offset : start_offset + xq.shape[1], :, :]).type_as(xq) + xk_out = (xk.float() * cos[:, start_offset : start_offset + xk.shape[1], :, :] + rotate_half(xk.float()) * sin[:, start_offset : start_offset + xk.shape[1], :, :]).type_as(xk) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + theta_rescale_factor: Union[float, List[float]] = 1.0, + interpolation_factor: Union[float, List[float]] = 1.0, +): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] + + if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos, device=torch.cuda.current_device()).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=torch.cuda.current_device())[: (dim // 2)].float() / dim)) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis diff --git a/lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..d828476cdc0bc1c05dff767ded5405acaac9ddb3 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py @@ -0,0 +1,98 @@ +import torch + +from lightx2v.models.networks.wan.infer.module_io import GridOutput +from lightx2v.models.networks.wan.infer.self_forcing.pre_infer import WanSFPreInfer, WanSFPreInferModuleOutput, sinusoidal_embedding_1d +from lightx2v.utils.envs import * + + +def cond_current(conditional_dict, current_start_frame, num_frame_per_block, replace=None, mode="universal"): + new_cond = {} + + new_cond["cond_concat"] = conditional_dict["image_encoder_output"]["cond_concat"][:, :, current_start_frame : current_start_frame + num_frame_per_block] + new_cond["visual_context"] = conditional_dict["image_encoder_output"]["visual_context"] + if replace: + if current_start_frame == 0: + last_frame_num = 1 + 4 * (num_frame_per_block - 1) + else: + last_frame_num = 4 * num_frame_per_block + final_frame = 1 + 4 * (current_start_frame + num_frame_per_block - 1) + if mode != "templerun": + conditional_dict["text_encoder_output"]["mouse_cond"][:, -last_frame_num + final_frame : final_frame] = replace["mouse"][None, None, :].repeat(1, last_frame_num, 1) + conditional_dict["text_encoder_output"]["keyboard_cond"][:, -last_frame_num + final_frame : final_frame] = replace["keyboard"][None, None, :].repeat(1, last_frame_num, 1) + if mode != "templerun": + new_cond["mouse_cond"] = conditional_dict["text_encoder_output"]["mouse_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)] + new_cond["keyboard_cond"] = conditional_dict["text_encoder_output"]["keyboard_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)] + + if replace: + return new_cond, conditional_dict + else: + return new_cond + + +# @amp.autocast(enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +class WanMtxg2PreInfer(WanSFPreInfer): + def __init__(self, config): + super().__init__(config) + d = config["dim"] // config["num_heads"] + self.freqs = torch.cat([rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], dim=1).to(torch.device("cuda")) + self.dim = config["dim"] + + def img_emb(self, weights, x): + x = weights.img_emb_0.apply(x) + x = weights.img_emb_1.apply(x.squeeze(0)) + x = torch.nn.functional.gelu(x, approximate="none") + x = weights.img_emb_3.apply(x) + x = weights.img_emb_4.apply(x) + x = x.unsqueeze(0) + return x + + @torch.no_grad() + def infer(self, weights, inputs, kv_start=0, kv_end=0): + x = self.scheduler.latents_input + t = self.scheduler.timestep_input + current_start_frame = self.scheduler.seg_index * self.scheduler.num_frame_per_block + + if self.config["streaming"]: + current_actions = inputs["current_actions"] + current_conditional_dict, _ = cond_current(inputs, current_start_frame, self.scheduler.num_frame_per_block, replace=current_actions, mode=self.config["mode"]) + else: + current_conditional_dict = cond_current(inputs, current_start_frame, self.scheduler.num_frame_per_block, mode=self.config["mode"]) + cond_concat = current_conditional_dict["cond_concat"] + visual_context = current_conditional_dict["visual_context"] + + x = torch.cat([x.unsqueeze(0), cond_concat], dim=1) + + # embeddings + x = weights.patch_embedding.apply(x) + grid_sizes_t, grid_sizes_h, grid_sizes_w = torch.tensor(x.shape[2:], dtype=torch.long) + grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w)) + + x = x.flatten(2).transpose(1, 2) # B FHW C' + seq_lens = torch.tensor([u.size(0) for u in x], dtype=torch.long, device=torch.device("cuda")) + assert seq_lens[0] <= 15 * 1 * 880 + + embed_tmp = sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x) # torch.Size([3, 256]) + embed = self.time_embedding(weights, embed_tmp) # torch.Size([3, 1536]) + embed0 = self.time_projection(weights, embed).unflatten(dim=0, sizes=t.shape) + + # context + context_lens = None + context = self.img_emb(weights, visual_context) + + return WanSFPreInferModuleOutput( + embed=embed, + grid_sizes=grid_sizes, + x=x.squeeze(0), + embed0=embed0.squeeze(0), + seq_lens=seq_lens, + freqs=self.freqs, + context=context[0], + conditional_dict=current_conditional_dict, + ) diff --git a/lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb1cdc60b092134502c048410fad8f5ea0cf9a2 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py @@ -0,0 +1,672 @@ +import math + +import torch +from einops import rearrange + +try: + import flash_attn_interface + + FLASH_ATTN_3_AVAILABLE = True +except ImportError: + try: + from flash_attn import flash_attn_func + + FLASH_ATTN_3_AVAILABLE = False + + except ImportError: + FLASH_ATTN_3_AVAILABLE = False + + +from lightx2v.models.networks.wan.infer.matrix_game2.posemb_layers import apply_rotary_emb, get_nd_rotary_pos_embed +from lightx2v.models.networks.wan.infer.self_forcing.transformer_infer import WanSFTransformerInfer, causal_rope_apply + + +class WanMtxg2TransformerInfer(WanSFTransformerInfer): + def __init__(self, config): + super().__init__(config) + self._initialize_kv_cache_mouse_and_keyboard(self.device, self.dtype) + self.sink_size = 0 + self.vae_time_compression_ratio = config["action_config"]["vae_time_compression_ratio"] + self.windows_size = config["action_config"]["windows_size"] + self.patch_size = config["action_config"]["patch_size"] + + self.rope_theta = config["action_config"]["rope_theta"] + self.enable_keyboard = config["action_config"]["enable_keyboard"] + self.heads_num = config["action_config"]["heads_num"] + self.hidden_size = config["action_config"]["hidden_size"] + self.img_hidden_size = config["action_config"]["img_hidden_size"] + self.keyboard_dim_in = config["action_config"]["keyboard_dim_in"] + self.keyboard_hidden_dim = config["action_config"]["keyboard_hidden_dim"] + + self.qk_norm = config["action_config"]["qk_norm"] + self.qkv_bias = config["action_config"]["qkv_bias"] + self.rope_dim_list = config["action_config"]["rope_dim_list"] + self.freqs_cos, self.freqs_sin = self.get_rotary_pos_embed(7500, self.patch_size[1], self.patch_size[2], 64, self.rope_dim_list, start_offset=0) + + self.enable_mouse = config["action_config"]["enable_mouse"] + if self.enable_mouse: + self.mouse_dim_in = config["action_config"]["mouse_dim_in"] + self.mouse_hidden_dim = config["action_config"]["mouse_hidden_dim"] + self.mouse_qk_dim_list = config["action_config"]["mouse_qk_dim_list"] + + def get_rotary_pos_embed(self, video_length, height, width, head_dim, rope_dim_list=None, start_offset=0): + target_ndim = 3 + ndim = 5 - 2 + + latents_size = [video_length + start_offset, height, width] + + if isinstance(self.patch_size, int): + assert all(s % self.patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), but got {latents_size}." + rope_sizes = [s // self.patch_size for s in latents_size] + elif isinstance(self.patch_size, list): + assert all(s % self.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), ( + f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), but got {latents_size}." + ) + rope_sizes = [s // self.patch_size[idx] for idx, s in enumerate(latents_size)] + + if len(rope_sizes) != target_ndim: + rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis + + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list, + rope_sizes, + theta=self.rope_theta, + use_real=True, + theta_rescale_factor=1, + ) + return freqs_cos[-video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] :], freqs_sin[-video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] :] + + def _initialize_kv_cache(self, dtype, device): + """ + Initialize a Per-GPU KV cache for the Wan model. + """ + kv_cache1 = [] + if self.local_attn_size != -1: + # Use the local attention size to compute the KV cache size + kv_cache_size = self.local_attn_size * self.frame_seq_length + else: + # Use the default KV cache size + kv_cache_size = 32760 + for _ in range(self.num_transformer_blocks): + kv_cache1.append( + { + "k": torch.zeros((kv_cache_size, 12, 128)).to(dtype).to(device), + "v": torch.zeros((kv_cache_size, 12, 128)).to(dtype).to(device), + "global_end_index": 0, + "local_end_index": 0, + } + ) + + self.kv_cache1_default = kv_cache1 + + def _initialize_kv_cache_mouse_and_keyboard(self, device, dtype): + """ + Initialize a Per-GPU KV cache for the Wan model. + """ + kv_cache_mouse = [] + kv_cache_keyboard = [] + if self.local_attn_size != -1: + kv_cache_size = self.local_attn_size + else: + kv_cache_size = 15 * 1 + for _ in range(self.num_transformer_blocks): + kv_cache_keyboard.append( + { + "k": torch.zeros([1, kv_cache_size, 16, 64], dtype=dtype, device=device), + "v": torch.zeros([1, kv_cache_size, 16, 64], dtype=dtype, device=device), + "global_end_index": torch.tensor([0], dtype=torch.long, device=device), + "local_end_index": torch.tensor([0], dtype=torch.long, device=device), + } + ) + kv_cache_mouse.append( + { + "k": torch.zeros([self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device), + "v": torch.zeros([self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device), + "global_end_index": torch.tensor([0], dtype=torch.long, device=device), + "local_end_index": torch.tensor([0], dtype=torch.long, device=device), + } + ) + self.kv_cache_keyboard = kv_cache_keyboard + self.kv_cache_mouse = kv_cache_mouse + + def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa): + if hasattr(phase, "smooth_norm1_weight"): + norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor + norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor + else: + norm1_weight = 1 + scale_msa.squeeze() + norm1_bias = shift_msa.squeeze() + + norm1_out = phase.norm1.apply(x) + + if self.sensitive_layer_dtype != self.infer_dtype: + norm1_out = norm1_out.to(self.sensitive_layer_dtype) + + norm1_out.mul_(norm1_weight[0:1, :]).add_(norm1_bias[0:1, :]) + + if self.sensitive_layer_dtype != self.infer_dtype: # False + norm1_out = norm1_out.to(self.infer_dtype) + + s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim + + q0 = phase.self_attn_q.apply(norm1_out) + k0 = phase.self_attn_k.apply(norm1_out) + + q = phase.self_attn_norm_q.apply(q0).view(s, n, d) + k = phase.self_attn_norm_k.apply(k0).view(s, n, d) + v = phase.self_attn_v.apply(norm1_out).view(s, n, d) + + seg_index = self.scheduler.seg_index + + frame_seqlen = math.prod(grid_sizes[0][1:]).item() + current_start = seg_index * self.num_frame_per_block * self.frame_seq_length + current_start_frame = current_start // frame_seqlen + + q = causal_rope_apply(q.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0] + k = causal_rope_apply(k.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0] + + current_end = current_start + q.shape[0] + sink_tokens = self.sink_size * frame_seqlen + + kv_cache_size = self.kv_cache1[self.block_idx]["k"].shape[0] + num_new_tokens = q.shape[0] + + if (current_end > self.kv_cache1[self.block_idx]["global_end_index"]) and (num_new_tokens + self.kv_cache1[self.block_idx]["local_end_index"] > kv_cache_size): + num_evicted_tokens = num_new_tokens + self.kv_cache1[self.block_idx]["local_end_index"] - kv_cache_size + num_rolled_tokens = self.kv_cache1[self.block_idx]["local_end_index"] - num_evicted_tokens - sink_tokens + + self.kv_cache1[self.block_idx]["k"][sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache1[self.block_idx]["k"][ + sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + self.kv_cache1[self.block_idx]["v"][sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache1[self.block_idx]["v"][ + sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + + # Insert the new keys/values at the end + local_end_index = self.kv_cache1[self.block_idx]["local_end_index"] + current_end - self.kv_cache1[self.block_idx]["global_end_index"] - num_evicted_tokens + local_start_index = local_end_index - num_new_tokens + self.kv_cache1[self.block_idx]["k"][local_start_index:local_end_index] = k + self.kv_cache1[self.block_idx]["v"][local_start_index:local_end_index] = v + else: + # Assign new keys/values directly up to current_end + local_end_index = self.kv_cache1[self.block_idx]["local_end_index"] + current_end - self.kv_cache1[self.block_idx]["global_end_index"] + local_start_index = local_end_index - num_new_tokens + self.kv_cache1[self.block_idx]["k"][local_start_index:local_end_index] = k + self.kv_cache1[self.block_idx]["v"][local_start_index:local_end_index] = v + + attn_k = self.kv_cache1[self.block_idx]["k"][max(0, local_end_index - self.max_attention_size) : local_end_index] + attn_v = self.kv_cache1[self.block_idx]["v"][max(0, local_end_index - self.max_attention_size) : local_end_index] + + self.kv_cache1[self.block_idx]["local_end_index"] = local_end_index + self.kv_cache1[self.block_idx]["global_end_index"] = current_end + + k_lens = torch.empty_like(seq_lens).fill_(attn_k.size(0)) + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens) + + if self.clean_cuda_cache: + del freqs_i, norm1_out, norm1_weight, norm1_bias + torch.cuda.empty_cache() + + if self.config["seq_parallel"]: + attn_out = phase.self_attn_1_parallel.apply( + q=q, + k=attn_k, + v=attn_v, + img_qkv_len=q.shape[0], + cu_seqlens_qkv=cu_seqlens_q, + attention_module=phase.self_attn_1, + seq_p_group=self.seq_p_group, + model_cls=self.config["model_cls"], + ) + else: + attn_out = phase.self_attn_1.apply( + q=q, + k=attn_k, + v=attn_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=attn_k.size(0), + model_cls=self.config["model_cls"], + ) + + y = phase.self_attn_o.apply(attn_out) + + if self.clean_cuda_cache: + del q, k, v, attn_out + torch.cuda.empty_cache() + + return y + + def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa): + num_frames = gate_msa.shape[0] + frame_seqlen = x.shape[0] // gate_msa.shape[0] + + x.add_((y_out.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) * gate_msa).flatten(0, 1)) + norm3_out = phase.norm3.apply(x) + + n, d = self.num_heads, self.head_dim + q = phase.cross_attn_q.apply(norm3_out) + q = phase.cross_attn_norm_q.apply(q).view(-1, n, d) + + if not self.crossattn_cache[self.block_idx]["is_init"]: + self.crossattn_cache[self.block_idx]["is_init"] = True + k = phase.cross_attn_k.apply(context) + k = phase.cross_attn_norm_k.apply(k).view(-1, n, d) + v = phase.cross_attn_v.apply(context) + v = v.view(-1, n, d) + self.crossattn_cache[self.block_idx]["k"] = k + self.crossattn_cache[self.block_idx]["v"] = v + else: + k = self.crossattn_cache[self.block_idx]["k"] + v = self.crossattn_cache[self.block_idx]["v"] + + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( + q, + k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device), + ) + + attn_out = phase.cross_attn_1.apply( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k.size(0), + model_cls=self.config["model_cls"], + ) + + attn_out = phase.cross_attn_o.apply(attn_out) + if self.clean_cuda_cache: + del q, k, v, norm3_out, context, context_img + torch.cuda.empty_cache() + + return x, attn_out + + def infer_action_model(self, phase, x, grid_sizes, seq_lens, mouse_condition=None, keyboard_condition=None, is_causal=False, use_rope_keyboard=True): + tt, th, tw = grid_sizes + current_start = self.scheduler.seg_index * self.num_frame_per_block + start_frame = current_start + B, N_frames, C = keyboard_condition.shape + assert tt * th * tw == x.shape[0] + assert ((N_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0 + N_feats = int((N_frames - 1) / self.vae_time_compression_ratio) + 1 + + # Defined freqs_cis early so it's available for both mouse and keyboard + freqs_cis = (self.freqs_cos, self.freqs_sin) + + cond1 = N_feats == tt + cond2 = is_causal and not self.kv_cache_mouse + cond3 = (N_frames - 1) // self.vae_time_compression_ratio + 1 == current_start + self.num_frame_per_block + assert (cond1 and ((cond2) or not is_causal)) or (cond3 and is_causal) + + x = x.unsqueeze(0) + if self.enable_mouse and mouse_condition is not None: + hidden_states = rearrange(x, "B (T S) C -> (B S) T C", T=tt, S=th * tw) # 65*272*480 -> 17*(272//16)*(480//16) -> 8670 + B, N_frames, C = mouse_condition.shape + else: + hidden_states = x + + pad_t = self.vae_time_compression_ratio * self.windows_size + if self.enable_mouse and mouse_condition is not None: + pad = mouse_condition[:, 0:1, :].expand(-1, pad_t, -1) + mouse_condition = torch.cat([pad, mouse_condition], dim=1) + if is_causal and self.kv_cache_mouse is not None: + mouse_condition = mouse_condition[:, self.vae_time_compression_ratio * (N_feats - self.num_frame_per_block - self.windows_size) + pad_t :, :] + group_mouse = [ + mouse_condition[:, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :] for i in range(self.num_frame_per_block) + ] + else: + group_mouse = [mouse_condition[:, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :] for i in range(N_feats)] + + group_mouse = torch.stack(group_mouse, dim=1) + + S = th * tw + group_mouse = group_mouse.unsqueeze(-1).expand(B, self.num_frame_per_block, pad_t, C, S) + group_mouse = group_mouse.permute(0, 4, 1, 2, 3).reshape(B * S, self.num_frame_per_block, pad_t * C) + + group_mouse = torch.cat([hidden_states, group_mouse], dim=-1) + + # mouse_mlp + # 注释:Batch维度不可避免,因此用 torch.nn.functional + group_mouse = torch.nn.functional.linear(group_mouse, phase.mouse_mlp_0.weight.T, phase.mouse_mlp_0.bias) + group_mouse = torch.nn.functional.gelu(group_mouse, approximate="tanh") + group_mouse = torch.nn.functional.linear(group_mouse, phase.mouse_mlp_2.weight.T, phase.mouse_mlp_2.bias) + group_mouse = torch.nn.functional.layer_norm(group_mouse, (group_mouse.shape[-1],), phase.mouse_mlp_3.weight.T, phase.mouse_mlp_3.bias, 1e-5) + + # qkvc + mouse_qkv = torch.nn.functional.linear(group_mouse, phase.t_qkv.weight.T) + + q0, k0, v = rearrange(mouse_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # BHW F H C # torch.Size([880, 3, 16, 64]) + q = q0 * torch.rsqrt(q0.pow(2).mean(dim=-1, keepdim=True) + 1e-6) + k = k0 * torch.rsqrt(k0.pow(2).mean(dim=-1, keepdim=True) + 1e-6) + + q, k = apply_rotary_emb(q, k, freqs_cis, start_offset=start_frame, head_first=False) + + ## TODO: adding cache here + if is_causal: + current_end = current_start + q.shape[1] + + assert q.shape[1] == self.num_frame_per_block + sink_size = 0 + max_attention_size = self.local_attn_size + sink_tokens = sink_size * 1 + kv_cache_size = self.kv_cache_mouse[self.block_idx]["k"].shape[1] + num_new_tokens = q.shape[1] + + if (current_end > self.kv_cache_mouse[self.block_idx]["global_end_index"].item()) and (num_new_tokens + self.kv_cache_mouse[self.block_idx]["local_end_index"].item() > kv_cache_size): + num_evicted_tokens = num_new_tokens + self.kv_cache_mouse[self.block_idx]["local_end_index"].item() - kv_cache_size + num_rolled_tokens = self.kv_cache_mouse[self.block_idx]["local_end_index"].item() - num_evicted_tokens - sink_tokens + self.kv_cache_mouse[self.block_idx]["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_mouse[self.block_idx]["k"][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + self.kv_cache_mouse[self.block_idx]["v"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_mouse[self.block_idx]["v"][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + # Insert the new keys/values at the end + local_end_index = self.kv_cache_mouse[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_mouse[self.block_idx]["global_end_index"].item() - num_evicted_tokens + local_start_index = local_end_index - num_new_tokens + else: + local_end_index = self.kv_cache_mouse[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_mouse[self.block_idx]["global_end_index"].item() + local_start_index = local_end_index - num_new_tokens + + self.kv_cache_mouse[self.block_idx]["k"][:, local_start_index:local_end_index] = k + self.kv_cache_mouse[self.block_idx]["v"][:, local_start_index:local_end_index] = v + + attn_k = self.kv_cache_mouse[self.block_idx]["k"][:, max(0, local_end_index - max_attention_size) : local_end_index] + attn_v = self.kv_cache_mouse[self.block_idx]["v"][:, max(0, local_end_index - max_attention_size) : local_end_index] + + attn = flash_attn_interface.flash_attn_func( + q, + attn_k, + attn_v, + ) + + self.kv_cache_mouse[self.block_idx]["global_end_index"].fill_(current_end) + self.kv_cache_mouse[self.block_idx]["local_end_index"].fill_(local_end_index) + else: + attn = flash_attn_func( + q, + k, + v, + ) + # Compute cu_squlens and max_seqlen for flash attention + # qk norm + attn = rearrange(attn, "(b S) T h d -> b (T S) (h d)", b=B) + hidden_states = rearrange(x, "(B S) T C -> B (T S) C", B=B) + + attn = phase.proj_mouse.apply(attn[0]).unsqueeze(0) + hidden_states = hidden_states + attn + + if self.enable_keyboard and keyboard_condition is not None: + pad = keyboard_condition[:, 0:1, :].expand(-1, pad_t, -1) + keyboard_condition = torch.cat([pad, keyboard_condition], dim=1) + if is_causal and self.kv_cache_keyboard is not None: + keyboard_condition = keyboard_condition[ + :, self.vae_time_compression_ratio * (N_feats - self.num_frame_per_block - self.windows_size) + pad_t :, : + ] # keyboard_condition[:, self.vae_time_compression_ratio*(start_frame - self.windows_size) + pad_t:start_frame * self.vae_time_compression_ratio + pad_t,:] + + keyboard_condition = phase.keyboard_embed_0.apply(keyboard_condition[0]) + keyboard_condition = torch.nn.functional.silu(keyboard_condition) + keyboard_condition = phase.keyboard_embed_2.apply(keyboard_condition).unsqueeze(0) + group_keyboard = [ + keyboard_condition[:, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :] for i in range(self.num_frame_per_block) + ] + else: + keyboard_condition = phase.keyboard_embed_0.apply(keyboard_condition[0]) + keyboard_condition = torch.nn.functional.silu(keyboard_condition) + keyboard_condition = phase.keyboard_embed_2.apply(keyboard_condition).unsqueeze(0) + group_keyboard = [keyboard_condition[:, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :] for i in range(N_feats)] + + group_keyboard = torch.stack(group_keyboard, dim=1) # B F RW C + group_keyboard = group_keyboard.reshape(shape=(group_keyboard.shape[0], group_keyboard.shape[1], -1)) + + # apply cross attn + mouse_q = phase.mouse_attn_q.apply(hidden_states[0]).unsqueeze(0) + keyboard_kv = phase.keyboard_attn_kv.apply(group_keyboard[0]).unsqueeze(0) + + B, L, HD = mouse_q.shape + D = HD // self.heads_num + q = mouse_q.view(B, L, self.heads_num, D) + + B, L, KHD = keyboard_kv.shape + k, v = keyboard_kv.view(B, L, 2, self.heads_num, D).permute(2, 0, 1, 3, 4) + + # Compute cu_squlens and max_seqlen for flash attention + # qk norm + q = q * torch.rsqrt(q.pow(2).mean(dim=-1, keepdim=True) + 1e-6) + k = k * torch.rsqrt(k.pow(2).mean(dim=-1, keepdim=True) + 1e-6) + + S = th * tw + assert S == 880 + # position embed + if use_rope_keyboard: + B, TS, H, D = q.shape + T_ = TS // S + q = q.view(B, T_, S, H, D).transpose(1, 2).reshape(B * S, T_, H, D) + q, k = apply_rotary_emb(q, k, freqs_cis, start_offset=start_frame, head_first=False) + + k1, k2, k3, k4 = k.shape + k = k.expand(S, k2, k3, k4) + v = v.expand(S, k2, k3, k4) + + if is_causal: + current_end = current_start + k.shape[1] + assert k.shape[1] == self.num_frame_per_block + sink_size = 0 + max_attention_size = self.local_attn_size + sink_tokens = sink_size * 1 + kv_cache_size = self.kv_cache_keyboard[self.block_idx]["k"].shape[1] + num_new_tokens = k.shape[1] + + if (current_end > self.kv_cache_keyboard[self.block_idx]["global_end_index"].item()) and ( + num_new_tokens + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() > kv_cache_size + ): + num_evicted_tokens = num_new_tokens + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() - kv_cache_size + num_rolled_tokens = self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() - num_evicted_tokens - sink_tokens + + self.kv_cache_keyboard[self.block_idx]["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_keyboard[self.block_idx]["k"][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + self.kv_cache_keyboard[self.block_idx]["v"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_keyboard[self.block_idx]["v"][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + + # Insert the new keys/values at the end + local_end_index = ( + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_keyboard[self.block_idx]["global_end_index"].item() - num_evicted_tokens + ) + local_start_index = local_end_index - num_new_tokens + else: + local_end_index = self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_keyboard[self.block_idx]["global_end_index"].item() + local_start_index = local_end_index - num_new_tokens + + assert k.shape[0] == 880 # BS == 1 or the cache should not be saved/ load method should be modified + self.kv_cache_keyboard[self.block_idx]["k"][:, local_start_index:local_end_index] = k[:1] + self.kv_cache_keyboard[self.block_idx]["v"][:, local_start_index:local_end_index] = v[:1] + + if FLASH_ATTN_3_AVAILABLE: + attn_k = self.kv_cache_keyboard[self.block_idx]["k"][:, max(0, local_end_index - max_attention_size) : local_end_index].repeat(S, 1, 1, 1) + attn_v = self.kv_cache_keyboard[self.block_idx]["v"][:, max(0, local_end_index - max_attention_size) : local_end_index].repeat(S, 1, 1, 1) + attn = flash_attn_interface.flash_attn_func( + q, + attn_k, + attn_v, + ) + else: + attn = flash_attn_func( + q, + self.kv_cache_keyboard[self.block_idx]["k"][max(0, local_end_index - max_attention_size) : local_end_index].repeat(S, 1, 1, 1), + self.kv_cache_keyboard[self.block_idx]["v"][max(0, local_end_index - max_attention_size) : local_end_index].repeat(S, 1, 1, 1), + ) + + self.kv_cache_keyboard[self.block_idx]["global_end_index"].fill_(current_end) + self.kv_cache_keyboard[self.block_idx]["local_end_index"].fill_(local_end_index) + else: + attn = flash_attn_func( + q, + k, + v, + causal=False, + ) + attn = rearrange(attn, "(B S) T H D -> B (T S) (H D)", S=S) + + else: + if is_causal: + current_start = start_frame + current_end = current_start + k.shape[1] + assert k.shape[1] == self.num_frame_per_block + sink_size = 0 + local_attn_size = self.local_attn_size + max_attention_size = self.local_attn_size + sink_tokens = sink_size * 1 + kv_cache_size = self.kv_cache_keyboard[self.block_idx]["k"].shape[1] + num_new_tokens = k.shape[1] + + if (current_end > self.kv_cache_keyboard[self.block_idx]["global_end_index"].item()) and ( + num_new_tokens + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() > kv_cache_size + ): + num_evicted_tokens = num_new_tokens + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() - kv_cache_size + num_rolled_tokens = self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() - num_evicted_tokens - sink_tokens + self.kv_cache_keyboard[self.block_idx]["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_keyboard[self.block_idx]["k"][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + self.kv_cache_keyboard[self.block_idx]["v"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_keyboard[self.block_idx]["v"][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + # Insert the new keys/values at the end + local_end_index = ( + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_keyboard[self.block_idx]["global_end_index"].item() - num_evicted_tokens + ) + local_start_index = local_end_index - num_new_tokens + else: + local_end_index = self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_keyboard[self.block_idx]["global_end_index"].item() + local_start_index = local_end_index - num_new_tokens + self.kv_cache_keyboard[self.block_idx]["k"][:, local_start_index:local_end_index] = k + self.kv_cache_keyboard[self.block_idx]["v"][:, local_start_index:local_end_index] = v + attn = flash_attn_func( + q, + self.kv_cache_keyboard[self.block_idx]["k"][:, max(0, local_end_index - max_attention_size) : local_end_index], + self.kv_cache_keyboard[self.block_idx]["v"][:, max(0, local_end_index - max_attention_size) : local_end_index], + ) + self.kv_cache_keyboard[self.block_idx]["global_end_index"].fill_(current_end) + self.kv_cache_keyboard[self.block_idx]["local_end_index"].fill_(local_end_index) + else: + attn = flash_attn_func( + q, + k, + v, + ) + attn = rearrange(attn, "B L H D -> B L (H D)") + attn = phase.proj_keyboard.apply(attn[0]).unsqueeze(0) + hidden_states = hidden_states + attn + hidden_states = hidden_states.squeeze(0) + + return hidden_states + + def infer_ffn(self, phase, x, c_shift_msa, c_scale_msa): + num_frames = c_shift_msa.shape[0] + frame_seqlen = x.shape[0] // c_shift_msa.shape[0] + + x = phase.norm2.apply(x).unsqueeze(0) + x = x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) + + c_scale_msa = c_scale_msa.unsqueeze(0) + c_shift_msa = c_shift_msa.unsqueeze(0) + x = x * (1 + c_scale_msa) + c_shift_msa + x = x.flatten(1, 2).squeeze(0) + + y = phase.ffn_0.apply(x) + y = torch.nn.functional.gelu(y, approximate="tanh") + y = phase.ffn_2.apply(y) + + return y + + def post_process(self, x, y, c_gate_msa, pre_infer_out=None): + x = x + y * c_gate_msa[0] + x = x.squeeze(0) + return x + + def infer_block_witch_kvcache(self, block, x, pre_infer_out): + if hasattr(block.compute_phases[0], "before_proj"): + x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process( + block.compute_phases[0].modulation, + pre_infer_out.embed0, + ) + + y_out = self.infer_self_attn_with_kvcache( + block.compute_phases[0], + pre_infer_out.grid_sizes.tensor, + x, + pre_infer_out.seq_lens, + pre_infer_out.freqs, + shift_msa, + scale_msa, + ) + + x, attn_out = self.infer_cross_attn_with_kvcache( + block.compute_phases[1], + x, + pre_infer_out.context, + y_out, + gate_msa, + ) + x = x + attn_out + + if len(block.compute_phases) == 4: + if self.config["mode"] != "templerun": + x = self.infer_action_model( + phase=block.compute_phases[2], + x=x, + grid_sizes=pre_infer_out.grid_sizes.tensor[0], + seq_lens=pre_infer_out.seq_lens, + mouse_condition=pre_infer_out.conditional_dict["mouse_cond"], + keyboard_condition=pre_infer_out.conditional_dict["keyboard_cond"], + is_causal=True, + use_rope_keyboard=True, + ) + else: + x = self.infer_action_model( + phase=block.compute_phases[2], + x=x, + grid_sizes=pre_infer_out.grid_sizes.tensor[0], + seq_lens=pre_infer_out.seq_lens, + keyboard_condition=pre_infer_out.conditional_dict["keyboard_cond"], + is_causal=True, + use_rope_keyboard=True, + ) + y = self.infer_ffn(block.compute_phases[3], x, c_shift_msa, c_scale_msa) + + elif len(block.compute_phases) == 3: + y = self.infer_ffn(block.compute_phases[2], x, c_shift_msa, c_scale_msa) + + x = self.post_process(x, y, c_gate_msa, pre_infer_out) + + return x + + def infer_non_blocks(self, weights, x, e): + num_frames = e.shape[0] + frame_seqlen = x.shape[0] // e.shape[0] + e = e.unsqueeze(0).unsqueeze(2) + + x = weights.norm.apply(x).unsqueeze(0) + x = x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) + + modulation = weights.head_modulation.tensor + e = (modulation.unsqueeze(1) + e).chunk(2, dim=2) + + x = x * (1 + e[1]) + e[0] + + x = torch.nn.functional.linear(x, weights.head.weight.T, weights.head.bias) + + if self.clean_cuda_cache: + del e + torch.cuda.empty_cache() + + return x diff --git a/lightx2v/models/networks/wan/infer/module_io.py b/lightx2v/models/networks/wan/infer/module_io.py new file mode 100644 index 0000000000000000000000000000000000000000..f0aa4162f2007fc67580657d98a77a50165a559f --- /dev/null +++ b/lightx2v/models/networks/wan/infer/module_io.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass, field +from typing import Any, Dict + +import torch + + +@dataclass +class GridOutput: + tensor: torch.Tensor + tuple: tuple + + +@dataclass +class WanPreInferModuleOutput: + embed: torch.Tensor + grid_sizes: GridOutput + x: torch.Tensor + embed0: torch.Tensor + context: torch.Tensor + adapter_args: Dict[str, Any] = field(default_factory=dict) + conditional_dict: Dict[str, Any] = field(default_factory=dict) diff --git a/lightx2v/models/networks/wan/infer/offload/__init__.py b/lightx2v/models/networks/wan/infer/offload/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/wan/infer/offload/transformer_infer.py b/lightx2v/models/networks/wan/infer/offload/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..a43a99950d8a505762f3efc956f942ff3b431cb8 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/offload/transformer_infer.py @@ -0,0 +1,154 @@ +import torch + +from lightx2v.common.offload.manager import WeightAsyncStreamManager +from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +class WanOffloadTransformerInfer(WanTransformerInfer): + def __init__(self, config): + super().__init__(config) + if self.config.get("cpu_offload", False): + offload_granularity = self.config.get("offload_granularity", "block") + if offload_granularity == "block": + self.infer_func = self.infer_with_blocks_offload + elif offload_granularity == "phase": + self.infer_func = self.infer_with_phases_offload + self.phase_params = { + "shift_msa": None, + "scale_msa": None, + "gate_msa": None, + "c_shift_msa": None, + "c_scale_msa": None, + "c_gate_msa": None, + "y_out": None, + "attn_out": None, + "y": None, + } + elif offload_granularity == "model": + self.infer_func = self.infer_without_offload + + if offload_granularity != "model": + self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity) + self.lazy_load = self.config.get("lazy_load", False) + if self.lazy_load and offload_granularity == "phase": + self.offload_manager.init_lazy_load(num_workers=self.config.get("num_disk_workers", 4)) + + def infer_with_blocks_offload(self, blocks, x, pre_infer_out): + for block_idx in range(len(blocks)): + self.block_idx = block_idx + if self.offload_manager.need_init_first_buffer: + self.offload_manager.init_first_buffer(blocks) + + self.offload_manager.prefetch_weights((block_idx + 1) % len(blocks), blocks) + with torch_device_module.stream(self.offload_manager.compute_stream): + x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out) + + self.offload_manager.swap_blocks() + + if self.clean_cuda_cache: + del ( + pre_infer_out.embed0, + pre_infer_out.context, + ) + torch_device_module.empty_cache() + + return x + + def infer_with_phases_offload(self, blocks, x, pre_infer_out): + for block_idx in range(len(blocks)): + self.block_idx = block_idx + if self.lazy_load: + next_prefetch = (block_idx + 1) % len(blocks) + self.offload_manager.start_prefetch_block(next_prefetch) + + x = self.infer_phases(block_idx, blocks, x, pre_infer_out) + if self.clean_cuda_cache: + del ( + self.phase_params["attn_out"], + self.phase_params["y_out"], + self.phase_params["y"], + ) + torch_device_module.empty_cache() + + if self.clean_cuda_cache: + self.clear_offload_params(pre_infer_out) + + return x + + def infer_phases(self, block_idx, blocks, x, pre_infer_out): + for phase_idx in range(self.phases_num): + if self.offload_manager.need_init_first_buffer: + self.offload_manager.init_first_buffer(blocks) + next_block_idx = (block_idx + 1) % len(blocks) if phase_idx == self.phases_num - 1 else block_idx + next_phase_idx = (phase_idx + 1) % self.phases_num + if self.lazy_load: + if phase_idx == self.phases_num - 1: + self.offload_manager.swap_cpu_buffers() + self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks) + with torch_device_module.stream(self.offload_manager.compute_stream): + x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out) + + self.offload_manager.swap_phases() + + return x + + def infer_phase(self, cur_phase_idx, cur_phase, x, pre_infer_out): + if cur_phase_idx == 0: + if hasattr(cur_phase, "before_proj") and cur_phase.before_proj.weight is not None: + x = cur_phase.before_proj.apply(x) + pre_infer_out.x + ( + self.phase_params["shift_msa"], + self.phase_params["scale_msa"], + self.phase_params["gate_msa"], + self.phase_params["c_shift_msa"], + self.phase_params["c_scale_msa"], + self.phase_params["c_gate_msa"], + ) = self.pre_process(cur_phase.modulation, pre_infer_out.embed0) + self.phase_params["y_out"] = self.infer_self_attn( + cur_phase, + x, + self.phase_params["shift_msa"], + self.phase_params["scale_msa"], + ) + elif cur_phase_idx == 1: + x, self.phase_params["attn_out"] = self.infer_cross_attn( + cur_phase, + x, + pre_infer_out.context, + self.phase_params["y_out"], + self.phase_params["gate_msa"], + ) + elif cur_phase_idx == 2: + self.phase_params["y"] = self.infer_ffn( + cur_phase, + x, + self.phase_params["attn_out"], + self.phase_params["c_shift_msa"], + self.phase_params["c_scale_msa"], + ) + x = self.post_process(x, self.phase_params["y"], self.phase_params["c_gate_msa"], pre_infer_out) + if hasattr(cur_phase, "after_proj"): + pre_infer_out.adapter_args["hints"].append(cur_phase.after_proj.apply(x)) + elif cur_phase_idx == 3: + x = self.infer_post_adapter(cur_phase, x, pre_infer_out) + return x + + def clear_offload_params(self, pre_infer_out): + del ( + self.phase_params["shift_msa"], + self.phase_params["scale_msa"], + self.phase_params["gate_msa"], + ) + del ( + self.phase_params["c_shift_msa"], + self.phase_params["c_scale_msa"], + self.phase_params["c_gate_msa"], + ) + del ( + pre_infer_out.embed0, + pre_infer_out.context, + ) + torch_device_module.empty_cache() diff --git a/lightx2v/models/networks/wan/infer/post_infer.py b/lightx2v/models/networks/wan/infer/post_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..2610d6a63630c6681b5f157c634ba88438d0e2fd --- /dev/null +++ b/lightx2v/models/networks/wan/infer/post_infer.py @@ -0,0 +1,31 @@ +import math + +import torch + +from lightx2v.utils.envs import * + + +class WanPostInfer: + def __init__(self, config): + self.out_dim = config["out_dim"] + self.patch_size = (1, 2, 2) + self.clean_cuda_cache = config.get("clean_cuda_cache", False) + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + @torch.no_grad() + def infer(self, x, pre_infer_out): + x = self.unpatchify(x, pre_infer_out.grid_sizes.tuple) + + if self.clean_cuda_cache: + torch.cuda.empty_cache() + + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + c = self.out_dim + x = x[: math.prod(grid_sizes)].view(*grid_sizes, *self.patch_size, c) + x = torch.einsum("fhwpqrc->cfphqwr", x) + x = x.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) + return [x] diff --git a/lightx2v/models/networks/wan/infer/pre_infer.py b/lightx2v/models/networks/wan/infer/pre_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..3587716fc330cd9f3e653781d2f9594af38bdb4a --- /dev/null +++ b/lightx2v/models/networks/wan/infer/pre_infer.py @@ -0,0 +1,124 @@ +import torch + +from lightx2v.utils.envs import * + +from .module_io import GridOutput, WanPreInferModuleOutput +from .utils import guidance_scale_embedding, sinusoidal_embedding_1d + + +class WanPreInfer: + def __init__(self, config): + assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 + self.config = config + self.clean_cuda_cache = config.get("clean_cuda_cache", False) + self.task = config["task"] + self.freq_dim = config["freq_dim"] + self.dim = config["dim"] + self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False) + self.cfg_scale = config.get("cfg_scale", 4.0) + self.infer_dtype = GET_DTYPE() + self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + @torch.no_grad() + def infer(self, weights, inputs, kv_start=0, kv_end=0): + x = self.scheduler.latents + t = self.scheduler.timestep_input + + if self.scheduler.infer_condition: + context = inputs["text_encoder_output"]["context"] + else: + context = inputs["text_encoder_output"]["context_null"] + + if self.task in ["i2v", "flf2v", "animate", "s2v"]: + if self.config.get("use_image_encoder", True): + clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] + + if self.config.get("changing_resolution", False): + image_encoder = inputs["image_encoder_output"]["vae_encoder_out"][self.scheduler.changing_resolution_index] + else: + image_encoder = inputs["image_encoder_output"]["vae_encoder_out"] + + if image_encoder is not None: + frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2) + if kv_end - kv_start >= frame_seq_length: # 如果是CausalVid, image_encoder取片段 + idx_s = kv_start // frame_seq_length + idx_e = kv_end // frame_seq_length + image_encoder = image_encoder[:, idx_s:idx_e, :, :] + y = image_encoder + x = torch.cat([x, y], dim=0) + + # embeddings + x = weights.patch_embedding.apply(x.unsqueeze(0)) + + if hasattr(self, "after_patch_embedding"): + x, motion_vec = self.after_patch_embedding(weights, x, inputs["image_encoder_output"]["pose_latents"], inputs["image_encoder_output"]["face_pixel_values"]) + else: + motion_vec = None + + grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] + x = x.flatten(2).transpose(1, 2).contiguous() + # seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) + + embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) + if self.enable_dynamic_cfg: + s = torch.tensor([self.cfg_scale], dtype=torch.float32, device=x.device) + cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32).type_as(x) + cfg_embed = weights.cfg_cond_proj_1.apply(cfg_embed) + cfg_embed = torch.nn.functional.silu(cfg_embed) + cfg_embed = weights.cfg_cond_proj_2.apply(cfg_embed) + embed = embed + cfg_embed + if self.sensitive_layer_dtype != self.infer_dtype: + embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype)) + else: + embed = weights.time_embedding_0.apply(embed) + embed = torch.nn.functional.silu(embed) + embed = weights.time_embedding_2.apply(embed) + embed0 = torch.nn.functional.silu(embed) + + embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) + + # text embeddings + if self.sensitive_layer_dtype != self.infer_dtype: + out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype)) + else: + out = weights.text_embedding_0.apply(context.squeeze(0)) + out = torch.nn.functional.gelu(out, approximate="tanh") + context = weights.text_embedding_2.apply(out) + if self.clean_cuda_cache: + del out + torch.cuda.empty_cache() + + if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True): + if self.task == "flf2v": + _, n, d = clip_fea.shape + clip_fea = clip_fea.view(2 * n, d) + clip_fea = clip_fea + weights.emb_pos.tensor.squeeze() + context_clip = weights.proj_0.apply(clip_fea) + if self.clean_cuda_cache: + del clip_fea + torch.cuda.empty_cache() + context_clip = weights.proj_1.apply(context_clip) + context_clip = torch.nn.functional.gelu(context_clip, approximate="none") + if self.clean_cuda_cache: + torch.cuda.empty_cache() + context_clip = weights.proj_3.apply(context_clip) + context_clip = weights.proj_4.apply(context_clip) + context = torch.concat([context_clip, context], dim=0) + + if self.clean_cuda_cache: + if self.config.get("use_image_encoder", True): + del context_clip + torch.cuda.empty_cache() + + grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w)) + return WanPreInferModuleOutput( + embed=embed, + grid_sizes=grid_sizes, + x=x.squeeze(0), + embed0=embed0.squeeze(0), + context=context, + adapter_args={"motion_vec": motion_vec}, + ) diff --git a/lightx2v/models/networks/wan/infer/self_forcing/__init__.py b/lightx2v/models/networks/wan/infer/self_forcing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py b/lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..52a43f90d7645ceb91633af802aa2b68840b8afb --- /dev/null +++ b/lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass, field +from typing import Any, Dict + +import torch + +from lightx2v.models.networks.wan.infer.module_io import GridOutput +from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer +from lightx2v.utils.envs import * +from lightx2v_platform.base.global_var import AI_DEVICE + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@dataclass +class WanSFPreInferModuleOutput: + embed: torch.Tensor + grid_sizes: GridOutput + x: torch.Tensor + embed0: torch.Tensor + seq_lens: torch.Tensor + freqs: torch.Tensor + context: torch.Tensor + conditional_dict: Dict[str, Any] = field(default_factory=dict) + + +class WanSFPreInfer(WanPreInfer): + def __init__(self, config): + super().__init__(config) + d = config["dim"] // config["num_heads"] + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ).to(AI_DEVICE) + + def time_embedding(self, weights, embed): + embed = weights.time_embedding_0.apply(embed) + embed = torch.nn.functional.silu(embed) + embed = weights.time_embedding_2.apply(embed) + + return embed + + def time_projection(self, weights, embed): + embed0 = torch.nn.functional.silu(embed) + embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) + return embed0 + + @torch.no_grad() + def infer(self, weights, inputs, kv_start=0, kv_end=0): + x = self.scheduler.latents_input + t = self.scheduler.timestep_input + + if self.scheduler.infer_condition: + context = inputs["text_encoder_output"]["context"] + else: + context = inputs["text_encoder_output"]["context_null"] + + # embeddings + x = weights.patch_embedding.apply(x.unsqueeze(0)) + grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] + x = x.flatten(2).transpose(1, 2).contiguous() + seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) + + embed_tmp = sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x) + embed = self.time_embedding(weights, embed_tmp) + embed0 = self.time_projection(weights, embed) + + # text embeddings + if self.sensitive_layer_dtype != self.infer_dtype: # False + out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype)) + else: + out = weights.text_embedding_0.apply(context.squeeze(0)) + out = torch.nn.functional.gelu(out, approximate="tanh") + context = weights.text_embedding_2.apply(out) + if self.clean_cuda_cache: + del out + torch.cuda.empty_cache() + + if self.clean_cuda_cache: + if self.config.get("use_image_encoder", True): + del context_clip + torch.cuda.empty_cache() + + grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w)) + + return WanSFPreInferModuleOutput( + embed=embed, + grid_sizes=grid_sizes, + x=x.squeeze(0), + embed0=embed0.squeeze(0), + seq_lens=seq_lens, + freqs=self.freqs, + context=context, + ) diff --git a/lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py b/lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..56fadb8da0b88efbe40bdbf14a1f65a03a2cfe70 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py @@ -0,0 +1,358 @@ +import torch + +from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer + + +def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) + freqs_i = torch.cat( + [freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).type_as(x) + + +class WanSFTransformerInfer(WanTransformerInfer): + def __init__(self, config): + super().__init__(config) + if self.config.get("cpu_offload", False): + self.device = torch.device("cpu") + else: + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + sf_config = self.config["sf_config"] + self.local_attn_size = sf_config["local_attn_size"] + self.max_attention_size = 32760 if self.local_attn_size == -1 else self.local_attn_size * 1560 + self.num_frame_per_block = sf_config["num_frame_per_block"] + self.num_transformer_blocks = sf_config["num_transformer_blocks"] + self.frame_seq_length = sf_config["frame_seq_length"] + self._initialize_kv_cache(self.device, self.dtype) + self._initialize_crossattn_cache(self.device, self.dtype) + + self.infer_func = self.infer_with_kvcache + + def get_scheduler_values(self): + pass + + def _initialize_kv_cache(self, dtype, device): + """ + Initialize a Per-GPU KV cache for the Wan model. + """ + kv_cache1 = [] + if self.local_attn_size != -1: + # Use the local attention size to compute the KV cache size + kv_cache_size = self.local_attn_size * self.frame_seq_length + else: + # Use the default KV cache size + kv_cache_size = 32760 + for _ in range(self.num_transformer_blocks): + kv_cache1.append( + { + "k": torch.zeros((kv_cache_size, 12, 128)).to(dtype).to(device), + "v": torch.zeros((kv_cache_size, 12, 128)).to(dtype).to(device), + "global_end_index": torch.tensor([0], dtype=torch.long).to(device), + "local_end_index": torch.tensor([0], dtype=torch.long).to(device), + } + ) + + self.kv_cache1_default = kv_cache1 # always store the clean cache + + def _initialize_crossattn_cache(self, dtype, device): + """ + Initialize a Per-GPU cross-attention cache for the Wan model. + """ + crossattn_cache = [] + + for _ in range(self.num_transformer_blocks): + crossattn_cache.append({"k": torch.zeros((512, 12, 128)).to(dtype).to(device), "v": torch.zeros((512, 12, 128)).to(dtype).to(device), "is_init": False}) + self.crossattn_cache_default = crossattn_cache + + def infer_with_kvcache(self, blocks, x, pre_infer_out): + self.kv_cache1 = self.kv_cache1_default + self.crossattn_cache = self.crossattn_cache_default + for block_idx in range(len(blocks)): + self.block_idx = block_idx + x = self.infer_block_witch_kvcache(blocks[block_idx], x, pre_infer_out) + return x + + def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa): + if hasattr(phase, "smooth_norm1_weight"): + norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor + norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor + else: + norm1_weight = 1 + scale_msa.squeeze() + norm1_bias = shift_msa.squeeze() + + norm1_out = phase.norm1.apply(x) + + if self.sensitive_layer_dtype != self.infer_dtype: + norm1_out = norm1_out.to(self.sensitive_layer_dtype) + + norm1_out.mul_(norm1_weight[0:1, :]).add_(norm1_bias[0:1, :]) + + if self.sensitive_layer_dtype != self.infer_dtype: # False + norm1_out = norm1_out.to(self.infer_dtype) + + s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim + + q0 = phase.self_attn_q.apply(norm1_out) + k0 = phase.self_attn_k.apply(norm1_out) + + q = phase.self_attn_norm_q.apply(q0).view(s, n, d) + k = phase.self_attn_norm_k.apply(k0).view(s, n, d) + v = phase.self_attn_v.apply(norm1_out).view(s, n, d) + + seg_index = self.scheduler.seg_index + + current_start_frame = seg_index * self.num_frame_per_block + + q = causal_rope_apply(q.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0] + k = causal_rope_apply(k.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0] + + # Assign new keys/values directly up to current_end + seg_seq_len = self.frame_seq_length * self.num_frame_per_block + local_start_index = seg_index * seg_seq_len + local_end_index = (seg_index + 1) * seg_seq_len + + self.kv_cache1[self.block_idx]["k"][local_start_index:local_end_index] = k + self.kv_cache1[self.block_idx]["v"][local_start_index:local_end_index] = v + + attn_k = self.kv_cache1[self.block_idx]["k"][max(0, local_end_index - self.max_attention_size) : local_end_index] + attn_v = self.kv_cache1[self.block_idx]["v"][max(0, local_end_index - self.max_attention_size) : local_end_index] + + k_lens = torch.empty_like(seq_lens).fill_(attn_k.size(0)) + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens) + + if self.clean_cuda_cache: + del freqs_i, norm1_out, norm1_weight, norm1_bias + torch.cuda.empty_cache() + + if self.config["seq_parallel"]: + attn_out = phase.self_attn_1_parallel.apply( + q=q, + k=attn_k, + v=attn_v, + img_qkv_len=q.shape[0], + cu_seqlens_qkv=cu_seqlens_q, + attention_module=phase.self_attn_1, + seq_p_group=self.seq_p_group, + model_cls=self.config["model_cls"], + ) + else: + attn_out = phase.self_attn_1.apply( + q=q, + k=attn_k, + v=attn_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=attn_k.size(0), + model_cls=self.config["model_cls"], + ) + + y = phase.self_attn_o.apply(attn_out) + + if self.clean_cuda_cache: + del q, k, v, attn_out + torch.cuda.empty_cache() + + return y + + def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa): + num_frames = gate_msa.shape[0] + frame_seqlen = x.shape[0] // gate_msa.shape[0] + seg_index = self.scheduler.seg_index + + x.add_((y_out.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) * gate_msa).flatten(0, 1)) + norm3_out = phase.norm3.apply(x) + + if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): + context_img = context[:257] + context = context[257:] + else: + context_img = None + + if self.sensitive_layer_dtype != self.infer_dtype: + context = context.to(self.infer_dtype) + if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): + context_img = context_img.to(self.infer_dtype) + + n, d = self.num_heads, self.head_dim + + q = phase.cross_attn_norm_q.apply(phase.cross_attn_q.apply(norm3_out)).view(-1, n, d) + + if seg_index == 0: + k = phase.cross_attn_norm_k.apply(phase.cross_attn_k.apply(context)).view(-1, n, d) + v = phase.cross_attn_v.apply(context).view(-1, n, d) + self.crossattn_cache[self.block_idx]["k"] = k + self.crossattn_cache[self.block_idx]["v"] = v + else: + k = self.crossattn_cache[self.block_idx]["k"] + v = self.crossattn_cache[self.block_idx]["v"] + + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( + q, + k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device), + ) + attn_out = phase.cross_attn_1.apply( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k.size(0), + model_cls=self.config["model_cls"], + ) + + if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True) and context_img is not None: + k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d) + v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d) + + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( + q, + k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), + ) + img_attn_out = phase.cross_attn_2.apply( + q=q, + k=k_img, + v=v_img, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k_img.size(0), + model_cls=self.config["model_cls"], + ) + attn_out.add_(img_attn_out) + + if self.clean_cuda_cache: + del k_img, v_img, img_attn_out + torch.cuda.empty_cache() + + attn_out = phase.cross_attn_o.apply(attn_out) + + if self.clean_cuda_cache: + del q, k, v, norm3_out, context, context_img + torch.cuda.empty_cache() + return x, attn_out + + def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa): + x.add_(attn_out) + + if self.clean_cuda_cache: + del attn_out + torch.cuda.empty_cache() + + num_frames = c_shift_msa.shape[0] + frame_seqlen = x.shape[0] // c_shift_msa.shape[0] + + norm2_weight = 1 + c_scale_msa + norm2_bias = c_shift_msa + + norm2_out = phase.norm2.apply(x) + norm2_out = norm2_out.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) + norm2_out.mul_(norm2_weight).add_(norm2_bias) + norm2_out = norm2_out.flatten(0, 1) + + y = phase.ffn_0.apply(norm2_out) + if self.clean_cuda_cache: + del norm2_out, x, norm2_weight, norm2_bias + torch.cuda.empty_cache() + y = torch.nn.functional.gelu(y, approximate="tanh") + if self.clean_cuda_cache: + torch.cuda.empty_cache() + y = phase.ffn_2.apply(y) + + return y + + def post_process(self, x, y, c_gate_msa, pre_infer_out=None): + num_frames = c_gate_msa.shape[0] + frame_seqlen = x.shape[0] // c_gate_msa.shape[0] + y = y.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) + x = x.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) + x.add_(y * c_gate_msa) + x = x.flatten(0, 1) + + if self.clean_cuda_cache: + del y, c_gate_msa + torch.cuda.empty_cache() + return x + + def infer_block_witch_kvcache(self, block, x, pre_infer_out): + if hasattr(block.compute_phases[0], "before_proj"): + x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process( + block.compute_phases[0].modulation, + pre_infer_out.embed0, + ) + + y_out = self.infer_self_attn_with_kvcache( + block.compute_phases[0], + pre_infer_out.grid_sizes.tensor, + x, + pre_infer_out.seq_lens, + pre_infer_out.freqs, + shift_msa, + scale_msa, + ) + + x, attn_out = self.infer_cross_attn_with_kvcache( + block.compute_phases[1], + x, + pre_infer_out.context, + y_out, + gate_msa, + ) + + y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa) + + x = self.post_process(x, y, c_gate_msa, pre_infer_out) + + if hasattr(block.compute_phases[2], "after_proj"): + pre_infer_out.adapter_output["hints"].append(block.compute_phases[2].after_proj.apply(x)) + + if self.has_post_adapter: + x = self.infer_post_adapter(block.compute_phases[3], x, pre_infer_out) + + return x + + def infer_non_blocks(self, weights, x, e): + num_frames = e.shape[0] + frame_seqlen = x.shape[0] // e.shape[0] + + x = weights.norm.apply(x) + x = x.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) + + t = self.scheduler.timestep_input + e = e.unflatten(dim=0, sizes=t.shape).unsqueeze(2) + modulation = weights.head_modulation.tensor + e = (modulation.unsqueeze(1) + e).chunk(2, dim=2) + + x.mul_(1 + e[1][0]).add_(e[0][0]) + x = x.flatten(0, 1) + x = weights.head.apply(x) + + if self.clean_cuda_cache: + del e + torch.cuda.empty_cache() + return x diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed61539312e82e8cb8b437e5a0049622faa6e759 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/transformer_infer.py @@ -0,0 +1,321 @@ +from functools import partial + +import torch + +from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer +from lightx2v.utils.envs import * + +from .triton_ops import fuse_scale_shift_kernel +from .utils import apply_wan_rope_with_chunk, apply_wan_rope_with_flashinfer, apply_wan_rope_with_torch + + +def modulate(x, scale, shift): + return x * (1 + scale.squeeze()) + shift.squeeze() + + +class WanTransformerInfer(BaseTransformerInfer): + def __init__(self, config): + self.config = config + self.task = config["task"] + self.attention_type = config.get("attention_type", "flash_attn2") + self.blocks_num = config["num_layers"] + self.phases_num = 3 + self.has_post_adapter = False + self.num_heads = config["num_heads"] + self.head_dim = config["dim"] // config["num_heads"] + self.window_size = config.get("window_size", (-1, -1)) + self.parallel_attention = None + if self.config.get("modulate_type", "triton") == "triton": + self.modulate_func = fuse_scale_shift_kernel + else: + self.modulate_func = modulate + if self.config.get("rope_type", "flashinfer") == "flashinfer": + if self.config.get("rope_chunk", False): + self.apply_rope_func = partial(apply_wan_rope_with_chunk, chunk_size=self.config.get("rope_chunk_size", 100), rope_func=apply_wan_rope_with_flashinfer) + else: + self.apply_rope_func = apply_wan_rope_with_flashinfer + else: + if self.config.get("rope_chunk", False): + self.apply_rope_func = partial(apply_wan_rope_with_chunk, chunk_size=self.config.get("rope_chunk_size", 100), rope_func=apply_wan_rope_with_torch) + else: + self.apply_rope_func = apply_wan_rope_with_torch + self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) + self.infer_dtype = GET_DTYPE() + self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() + + if self.config["seq_parallel"]: + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False) + else: + self.seq_p_group = None + self.seq_p_fp8_comm = False + self.infer_func = self.infer_without_offload + + self.cos_sin = None + + def _calculate_q_k_len(self, q, k_lens): + q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) + cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) + cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) + return cu_seqlens_q, cu_seqlens_k + + def get_scheduler_values(self): + self.cos_sin = self.scheduler.cos_sin + + @torch.no_grad() + def infer(self, weights, pre_infer_out): + self.get_scheduler_values() + x = self.infer_main_blocks(weights.blocks, pre_infer_out) + return self.infer_non_blocks(weights, x, pre_infer_out.embed) + + def infer_main_blocks(self, blocks, pre_infer_out): + x = self.infer_func(blocks, pre_infer_out.x, pre_infer_out) + return x + + def infer_non_blocks(self, weights, x, e): + if e.dim() == 2: + modulation = weights.head_modulation.tensor # 1, 2, dim + e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + elif e.dim() == 3: # For Diffustion forcing + modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim + e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + e = [ei.squeeze(1) for ei in e] + + x = weights.norm.apply(x) + + if self.sensitive_layer_dtype != self.infer_dtype: + x = x.to(self.sensitive_layer_dtype) + x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze()) + if self.sensitive_layer_dtype != self.infer_dtype: + x = x.to(self.infer_dtype) + + x = weights.head.apply(x) + + if self.clean_cuda_cache: + del e + torch.cuda.empty_cache() + return x + + def infer_without_offload(self, blocks, x, pre_infer_out): + for block_idx in range(len(blocks)): + self.block_idx = block_idx + x = self.infer_block(blocks[block_idx], x, pre_infer_out) + return x + + def infer_block(self, block, x, pre_infer_out): + if hasattr(block.compute_phases[0], "before_proj") and block.compute_phases[0].before_proj.weight is not None: + x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process( + block.compute_phases[0].modulation, + pre_infer_out.embed0, + ) + y_out = self.infer_self_attn( + block.compute_phases[0], + x, + shift_msa, + scale_msa, + ) + x, attn_out = self.infer_cross_attn(block.compute_phases[1], x, pre_infer_out.context, y_out, gate_msa) + y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa) + x = self.post_process(x, y, c_gate_msa, pre_infer_out) + if hasattr(block.compute_phases[2], "after_proj"): + pre_infer_out.adapter_args["hints"].append(block.compute_phases[2].after_proj.apply(x)) + + if self.has_post_adapter: + x = self.infer_post_adapter(block.compute_phases[3], x, pre_infer_out) + + return x + + def pre_process(self, modulation, embed0): + if embed0.dim() == 3 and embed0.shape[2] == 1: + modulation = modulation.tensor.unsqueeze(2) + embed0 = (modulation + embed0).chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0] + else: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (modulation.tensor + embed0).chunk(6, dim=1) + + if self.clean_cuda_cache: + del embed0 + torch.cuda.empty_cache() + + return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa + + def infer_self_attn(self, phase, x, shift_msa, scale_msa): + cos_sin = self.cos_sin + if hasattr(phase, "smooth_norm1_weight"): + norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor + norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor + norm1_out = phase.norm1.apply(x) + if self.sensitive_layer_dtype != self.infer_dtype: + norm1_out = norm1_out.to(self.sensitive_layer_dtype) + norm1_out.mul_(norm1_weight).add_(norm1_bias) + else: + norm1_out = phase.norm1.apply(x) + if self.sensitive_layer_dtype != self.infer_dtype: + norm1_out = norm1_out.to(self.sensitive_layer_dtype) + norm1_out = self.modulate_func(norm1_out, scale=scale_msa, shift=shift_msa).squeeze() + + if self.sensitive_layer_dtype != self.infer_dtype: + norm1_out = norm1_out.to(self.infer_dtype) + + s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim + + q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) + k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) + v = phase.self_attn_v.apply(norm1_out).view(s, n, d) + + q, k = self.apply_rope_func(q, k, cos_sin) + + img_qkv_len = q.shape[0] + cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu").to(q.device, non_blocking=True) + + if self.clean_cuda_cache: + del norm1_out, shift_msa, scale_msa + torch.cuda.empty_cache() + + if self.config["seq_parallel"]: + attn_out = phase.self_attn_1_parallel.apply( + q=q, + k=k, + v=v, + img_qkv_len=img_qkv_len, + cu_seqlens_qkv=cu_seqlens_qkv, + attention_module=phase.self_attn_1, + seq_p_group=self.seq_p_group, + use_fp8_comm=self.seq_p_fp8_comm, + model_cls=self.config["model_cls"], + ) + else: + attn_out = phase.self_attn_1.apply( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_qkv, + cu_seqlens_kv=cu_seqlens_qkv, + max_seqlen_q=img_qkv_len, + max_seqlen_kv=img_qkv_len, + model_cls=self.config["model_cls"], + ) + + y = phase.self_attn_o.apply(attn_out) + + if self.clean_cuda_cache: + del q, k, v, attn_out + torch.cuda.empty_cache() + + return y + + def infer_cross_attn(self, phase, x, context, y_out, gate_msa): + if self.sensitive_layer_dtype != self.infer_dtype: + x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze() + else: + x.add_(y_out * gate_msa.squeeze()) + + norm3_out = phase.norm3.apply(x) + if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True): + context_img = context[:257] + context = context[257:] + else: + context_img = None + + if self.sensitive_layer_dtype != self.infer_dtype: + context = context.to(self.infer_dtype) + if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True): + context_img = context_img.to(self.infer_dtype) + + n, d = self.num_heads, self.head_dim + + q = phase.cross_attn_norm_q.apply(phase.cross_attn_q.apply(norm3_out)).view(-1, n, d) + k = phase.cross_attn_norm_k.apply(phase.cross_attn_k.apply(context)).view(-1, n, d) + v = phase.cross_attn_v.apply(context).view(-1, n, d) + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( + q, + k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device), + ) + attn_out = phase.cross_attn_1.apply( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k.size(0), + model_cls=self.config["model_cls"], + ) + + if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True) and context_img is not None: + k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d) + v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d) + + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( + q, + k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), + ) + img_attn_out = phase.cross_attn_2.apply( + q=q, + k=k_img, + v=v_img, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k_img.size(0), + model_cls=self.config["model_cls"], + ) + attn_out.add_(img_attn_out) + + if self.clean_cuda_cache: + del k_img, v_img, img_attn_out + torch.cuda.empty_cache() + + attn_out = phase.cross_attn_o.apply(attn_out) + + if self.clean_cuda_cache: + del q, k, v, norm3_out, context, context_img + torch.cuda.empty_cache() + return x, attn_out + + def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa): + x.add_(attn_out) + + if self.clean_cuda_cache: + del attn_out + torch.cuda.empty_cache() + + if hasattr(phase, "smooth_norm2_weight"): + norm2_weight = (1 + c_scale_msa.squeeze()) * phase.smooth_norm2_weight.tensor + norm2_bias = c_shift_msa.squeeze() * phase.smooth_norm2_bias.tensor + norm2_out = phase.norm2.apply(x) + if self.sensitive_layer_dtype != self.infer_dtype: + norm2_out = norm2_out.to(self.sensitive_layer_dtype) + norm2_out.mul_(norm2_weight).add_(norm2_bias) + else: + norm2_out = phase.norm2.apply(x) + if self.sensitive_layer_dtype != self.infer_dtype: + norm2_out = norm2_out.to(self.sensitive_layer_dtype) + norm2_out = self.modulate_func(norm2_out, scale=c_scale_msa, shift=c_shift_msa).squeeze() + + if self.sensitive_layer_dtype != self.infer_dtype: + norm2_out = norm2_out.to(self.infer_dtype) + + y = phase.ffn_0.apply(norm2_out) + if self.clean_cuda_cache: + del norm2_out, x + torch.cuda.empty_cache() + y = torch.nn.functional.gelu(y, approximate="tanh") + if self.clean_cuda_cache: + torch.cuda.empty_cache() + y = phase.ffn_2.apply(y) + + return y + + def post_process(self, x, y, c_gate_msa, pre_infer_out=None): + if self.sensitive_layer_dtype != self.infer_dtype: + x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze() + else: + x.add_(y * c_gate_msa.squeeze()) + + if self.clean_cuda_cache: + del y, c_gate_msa + torch.cuda.empty_cache() + return x diff --git a/lightx2v/models/networks/wan/infer/triton_ops.py b/lightx2v/models/networks/wan/infer/triton_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e619b0ee462937b530507db58eb8a71c47eeac --- /dev/null +++ b/lightx2v/models/networks/wan/infer/triton_ops.py @@ -0,0 +1,902 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang + +# TODO: for temporary usage, expecting a refactor +from typing import Optional + +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from torch import Tensor + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64}, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_N": 1024}, num_warps=8), + ], + key=["inner_dim"], +) +@triton.jit +def _fused_scale_shift_4d_kernel( + output_ptr, + normalized_ptr, + scale_ptr, + shift_ptr, + rows, + inner_dim, + seq_len, + num_frames, + frame_seqlen, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N) + mask = col_offsets < inner_dim + + # Pointers for normalized and output + row_base = pid_row * inner_dim + norm_ptrs = normalized_ptr + row_base + col_offsets + out_ptrs = output_ptr + row_base + col_offsets + + # Pointers for scale and shift for 4D + b_idx = pid_row // seq_len + t_idx = pid_row % seq_len + frame_idx_in_batch = t_idx // frame_seqlen + + scale_row_idx = b_idx * num_frames + frame_idx_in_batch + scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets + shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets + + normalized = tl.load(norm_ptrs, mask=mask, other=0.0) + scale = tl.load(scale_ptrs, mask=mask, other=0.0) + shift = tl.load(shift_ptrs, mask=mask, other=0.0) + + one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype) + output = normalized * (one + scale) + shift + + tl.store(out_ptrs, output, mask=mask) + + +@triton.jit +def fuse_scale_shift_kernel_blc_opt( + x_ptr, + shift_ptr, + scale_ptr, + y_ptr, + B, + L, + C, + stride_x_b, + stride_x_l, + stride_x_c, + stride_s_b, + stride_s_l, + stride_s_c, + stride_sc_b, + stride_sc_l, + stride_sc_c, + SCALE_IS_SCALAR: tl.constexpr, + SHIFT_IS_SCALAR: tl.constexpr, + BLOCK_L: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_l = tl.program_id(0) + pid_c = tl.program_id(1) + pid_b = tl.program_id(2) + + l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_l = l_offsets < L + mask_c = c_offsets < C + mask = mask_l[:, None] & mask_c[None, :] + + x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c + x = tl.load(x_ptr + x_off, mask=mask, other=0) + + if SHIFT_IS_SCALAR: + shift_val = tl.load(shift_ptr) + shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype) + else: + s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c + shift = tl.load(shift_ptr + s_off, mask=mask, other=0) + + if SCALE_IS_SCALAR: + scale_val = tl.load(scale_ptr) + scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype) + else: + sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c + scale = tl.load(scale_ptr + sc_off, mask=mask, other=0) + + y = x * (1 + scale) + shift + tl.store(y_ptr + x_off, y, mask=mask) + + +def fuse_scale_shift_kernel( + x: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + block_l: int = 128, + block_c: int = 128, +): + # assert x.is_cuda and scale.is_cuda + assert x.is_contiguous() + if x.dim() == 2: + x = x.unsqueeze(0) + + B, L, C = x.shape + output = torch.empty_like(x) + + if scale.dim() == 4: + # scale/shift: [B, F, 1, C] + rows = B * L + x_2d = x.view(rows, C) + output_2d = output.view(rows, C) + grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa + num_frames = scale.shape[1] + assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift" + frame_seqlen = L // num_frames + + # Compact [B, F, C] without the singleton dim into [B*F, C] + scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous() + shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous() + + _fused_scale_shift_4d_kernel[grid]( + output_2d, + x_2d, + scale_reshaped, + shift_reshaped, + rows, + C, + L, + num_frames, + frame_seqlen, + ) + else: + # 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L + # 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C]) + # Also support scalar (0D or 1-element) + if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1): + scale_blc = scale.reshape(1) + elif scale.dim() == 2: + scale_blc = scale[:, None, :] + elif scale.dim() == 3: + scale_blc = scale + else: + raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D") + + if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1): + shift_blc = shift.reshape(1) + elif shift.dim() == 2: + shift_blc = shift[:, None, :] + elif shift.dim() == 3: + shift_blc = shift + else: + # broadcast later via expand if possible + shift_blc = shift + + need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1 + need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1 + + if not need_scale_scalar: + scale_exp = scale_blc.expand(B, L, C) + s_sb, s_sl, s_sc = scale_exp.stride() + else: + s_sb = s_sl = s_sc = 0 + + if not need_shift_scalar: + shift_exp = shift_blc.expand(B, L, C) + sh_sb, sh_sl, sh_sc = shift_exp.stride() + else: + sh_sb = sh_sl = sh_sc = 0 + + # If both scalars and both zero, copy fast-path + if need_scale_scalar and need_shift_scalar: + if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0): + output.copy_(x) + return output + + grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) + fuse_scale_shift_kernel_blc_opt[grid]( + x, + shift_blc if need_shift_scalar else shift_exp, + scale_blc if need_scale_scalar else scale_exp, + output, + B, + L, + C, + x.stride(0), + x.stride(1), + x.stride(2), + sh_sb, + sh_sl, + sh_sc, + s_sb, + s_sl, + s_sc, + SCALE_IS_SCALAR=need_scale_scalar, + SHIFT_IS_SCALAR=need_shift_scalar, + BLOCK_L=block_l, + BLOCK_C=block_c, + num_warps=4, + num_stages=2, + ) + return output + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2), + triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8), + ], + key=["head_size", "interleaved"], +) +@triton.jit +def _rotary_embedding_kernel( + output_ptr, + x_ptr, + cos_ptr, + sin_ptr, + num_heads, + head_size, + num_tokens, + stride_x_row, + stride_cos_row, + stride_sin_row, + interleaved: tl.constexpr, + BLOCK_HS_HALF: tl.constexpr, +): + row_idx = tl.program_id(0) + token_idx = (row_idx // num_heads) % num_tokens + + x_row_ptr = x_ptr + row_idx * stride_x_row + cos_row_ptr = cos_ptr + token_idx * stride_cos_row + sin_row_ptr = sin_ptr + token_idx * stride_sin_row + output_row_ptr = output_ptr + row_idx * stride_x_row + + # half size for x1 and x2 + head_size_half = head_size // 2 + + for block_start in range(0, head_size_half, BLOCK_HS_HALF): + offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF) + mask = offsets_half < head_size_half + + cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0) + sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0) + + offsets_x1 = 2 * offsets_half + offsets_x2 = 2 * offsets_half + 1 + + x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0) + x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0) + + x1_fp32 = x1_vals.to(tl.float32) + x2_fp32 = x2_vals.to(tl.float32) + cos_fp32 = cos_vals.to(tl.float32) + sin_fp32 = sin_vals.to(tl.float32) + o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32) + o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32) + + tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask) + tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask) + + +def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + output = torch.empty_like(x) + + if x.dim() > 3: + bsz, num_tokens, num_heads, head_size = x.shape + else: + num_tokens, num_heads, head_size = x.shape + bsz = 1 + + assert head_size % 2 == 0, "head_size must be divisible by 2" + + x_reshaped = x.view(-1, head_size) + output_reshaped = output.view(-1, head_size) + + # num_tokens per head, 1 token per block + grid = (bsz * num_tokens * num_heads,) + + if interleaved and cos.shape[-1] == head_size: + cos = cos[..., ::2].contiguous() + sin = sin[..., ::2].contiguous() + else: + cos = cos.contiguous() + sin = sin.contiguous() + + _rotary_embedding_kernel[grid]( + output_reshaped, + x_reshaped, + cos, + sin, + num_heads, + head_size, + num_tokens, + x_reshaped.stride(0), + cos.stride(0), + sin.stride(0), + interleaved, + ) + + return output + + +# RMSNorm-fp32 +def maybe_contiguous_lastdim(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def maybe_contiguous(x): + return x.contiguous() if x is not None else None + + +def triton_autotune_configs(): + if not torch.cuda.is_available(): + return [] + # Return configs with a valid warp count for the current device + configs = [] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block = 1024 + # Default to warp size 32 if not defined by device + warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block] + # return [triton.Config({}, num_warps=8)] + + +# Copied from flash-attn +@triton.autotune( + configs=triton_autotune_configs(), + key=[ + "N", + "HAS_RESIDUAL", + "STORE_RESIDUAL_OUT", + "IS_RMS_NORM", + "HAS_BIAS", + "HAS_WEIGHT", + "HAS_X1", + "HAS_W1", + "HAS_B1", + ], +) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + DROPOUT_MASK1, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + y = x_hat * w + b if HAS_BIAS else x_hat * w + else: + y = x_hat + b if HAS_BIAS else x_hat + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + out: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None: + residual_dtype = residual.dtype + if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None): + residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype) + else: + residual_out = None + y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( + x, + weight, + bias, + eps, + out, + residual=residual, + x1=x1, + weight1=weight1, + bias1=bias1, + dropout_p=dropout_p, + rowscale=rowscale, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + residual_out=residual_out, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 + + +# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema +# since we're returning a tuple of tensors +def _layer_norm_fwd_impl( + x: Tensor, + weight: Optional[Tensor], + bias: Tensor, + eps: float, + out: Tensor, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + assert out.shape == x.shape + assert out.stride(-1) == 1 + if residual_out is not None: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) + if x1 is not None: + dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask1 = None + else: + dropout_mask, dropout_mask1 = None, None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( + x, + out, + weight if weight is not None else x, # unused when HAS_WEIGHT == False + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + dropout_mask1, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + HAS_X1=x1 is not None, + HAS_W1=weight1 is not None, + HAS_B1=bias1 is not None, + ) + return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 + + +class LayerNormFn: + @staticmethod + def forward( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) + if residual is not None: + assert residual.shape == x_shape_og + residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) + # weight can be None when elementwise_affine=False for LayerNorm + if weight is not None: + weight = weight.contiguous() + bias = maybe_contiguous(bias) + weight1 = maybe_contiguous(weight1) + bias1 = maybe_contiguous(bias1) + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + out_dtype=out_dtype, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out, + ) + y = y.reshape(x_shape_og) + return y + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) + + +@triton.jit +def _norm_infer_kernel( + X, + Y, + W, + B, + stride_x_row, + stride_y_row, + M, + N, + eps, + IS_RMS_NORM: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_WEIGHT: + W += 0 + if HAS_BIAS: + B += 0 + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32) + y = x_hat * w + else: + y = x_hat + if HAS_BIAS: + b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32) + y += b + tl.store(Y + cols, y, mask=cols < N) + + +def norm_infer( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, + is_rms_norm: bool = False, + out: Optional[Tensor] = None, +): + M, N = x.shape + assert x.stride(-1) == 1 + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.shape == (N,) + assert bias.stride(-1) == 1 + if out is None: + out = torch.empty_like(x) + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + num_warps = min(max(BLOCK_N // 256, 1), 8) + _norm_infer_kernel[(M,)]( + x, + out, + weight if weight is not None else x, # dummy when HAS_WEIGHT=False + bias if bias is not None else x, # dummy when HAS_BIAS=False + x.stride(0), + out.stride(0), + M, + N, + eps, + IS_RMS_NORM=is_rms_norm, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + ) + return out + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) diff --git a/lightx2v/models/networks/wan/infer/utils.py b/lightx2v/models/networks/wan/infer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..13ca4d3a6235ad322a130f9b99a3960b6611499e --- /dev/null +++ b/lightx2v/models/networks/wan/infer/utils.py @@ -0,0 +1,239 @@ +import torch +import torch.distributed as dist + +try: + from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace +except ImportError: + apply_rope_with_cos_sin_cache_inplace = None + +from lightx2v.utils.envs import * + + +def apply_wan_rope_with_torch( + xq: torch.Tensor, + xk: torch.Tensor, + cos_sin_cache: torch.Tensor, +): + n = xq.size(1) + seq_len = cos_sin_cache.size(0) + + xq = torch.view_as_complex(xq[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2)) + xk = torch.view_as_complex(xk[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2)) + # Apply rotary embedding + xq = torch.view_as_real(xq * cos_sin_cache).flatten(2) + xk = torch.view_as_real(xk * cos_sin_cache).flatten(2) + xq = torch.cat([xq, xq[seq_len:]]) + xk = torch.cat([xk, xk[seq_len:]]) + + return xq.to(GET_DTYPE()), xk.to(GET_DTYPE()) + + +def apply_wan_rope_with_chunk( + xq: torch.Tensor, + xk: torch.Tensor, + cos_sin_cache: torch.Tensor, + chunk_size: int, + rope_func, +): + seq_len = cos_sin_cache.size(0) + x_q = torch.empty_like(xq) + x_k = torch.empty_like(xk) + + for start in range(0, seq_len, chunk_size): + end = min(start + chunk_size, seq_len) + xq_chunk = xq[start:end] + xk_chunk = xk[start:end] + cos_sin_chunk = cos_sin_cache[start:end] + xq_chunk_out, xk_chunk_out = rope_func(xq_chunk, xk_chunk, cos_sin_chunk) + x_q[start:end].copy_(xq_chunk_out, non_blocking=True) + x_k[start:end].copy_(xk_chunk_out, non_blocking=True) + del xq_chunk_out, xk_chunk_out + + target_dtype = GET_DTYPE() + if x_q.dtype != target_dtype: + x_q = x_q.to(target_dtype) + if x_k.dtype != target_dtype: + x_k = x_k.to(target_dtype) + + return x_q, x_k + + +def apply_wan_rope_with_flashinfer( + xq: torch.Tensor, + xk: torch.Tensor, + cos_sin_cache: torch.Tensor, +): + L, H, D = xq.shape + + query = xq.reshape(L, H * D).contiguous() + key = xk.reshape(L, H * D).contiguous() + + positions = torch.arange(L, device="cpu", dtype=torch.long).to(xq.device, non_blocking=True) + + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=D, + cos_sin_cache=cos_sin_cache, + is_neox=False, + ) + + xq_out = query.view(L, H, D) + xk_out = key.view(L, H, D) + return xq_out, xk_out + + +def compute_freqs(c, grid_sizes, freqs): + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + f, h, w = grid_sizes + seq_len = f * h * w + freqs_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + return freqs_i + + +def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group): + world_size = dist.get_world_size(seq_p_group) + cur_rank = dist.get_rank(seq_p_group) + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + f, h, w = grid_sizes + seq_len = f * h * w + freqs_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + freqs_i = pad_freqs(freqs_i, s * world_size) + s_per_rank = s + freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :] + return freqs_i_rank + + +def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0): + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + f, h, w = grid_sizes + seq_len = f * h * w + freqs_i = torch.cat( + [ + freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + return freqs_i + + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + + +def apply_rotary_emb(x, freqs_i): + n = x.size(1) + seq_len = freqs_i.size(0) + + x_i = torch.view_as_complex(x[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2)) + # Apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[seq_len:]]) + return x_i.to(GET_DTYPE()) + + +def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100): + n = x.size(1) + seq_len = freqs_i.size(0) + + output_chunks = [] + for start in range(0, seq_len, chunk_size): + end = min(start + chunk_size, seq_len) + x_chunk = x[start:end] + freqs_chunk = freqs_i[start:end] + + x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2)) + x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(GET_DTYPE()) + output_chunks.append(x_chunk_embedded) + del x_chunk_complex, x_chunk_embedded + torch.cuda.empty_cache() + + result = [] + for chunk in output_chunks: + result.append(chunk) + del output_chunks + torch.cuda.empty_cache() + + for start in range(seq_len, x.size(0), remaining_chunk_size): + end = min(start + remaining_chunk_size, x.size(0)) + result.append(x[start:end]) + + x_i = torch.cat(result, dim=0) + del result + torch.cuda.empty_cache() + + return x_i.to(GET_DTYPE()) + + +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)), + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float32) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + x = x.to(GET_SENSITIVE_DTYPE()) + return x + + +def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32): + """ + Args: + timesteps: torch.Tensor: generate embedding vectors at these timesteps + embedding_dim: int: dimension of the embeddings to generate + dtype: data type of the generated embeddings + + Returns: + embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + cfg_min, cfg_max = cfg_range + w = torch.round(w) + w = torch.clamp(w, min=cfg_min, max=cfg_max) + w = (w - cfg_min) / (cfg_max - cfg_min) # [0, 1] + w = w * target_range + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype).to(w.device) * -emb).to(w.device) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1).to(w.device)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb diff --git a/lightx2v/models/networks/wan/infer/vace/transformer_infer.py b/lightx2v/models/networks/wan/infer/vace/transformer_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..2350c6db41f572e5e7f7481ef1ed45d29b2e86bc --- /dev/null +++ b/lightx2v/models/networks/wan/infer/vace/transformer_infer.py @@ -0,0 +1,38 @@ +from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer +from lightx2v.utils.envs import * + + +class WanVaceTransformerInfer(WanOffloadTransformerInfer): + def __init__(self, config): + super().__init__(config) + self.vace_blocks_num = len(self.config["vace_layers"]) + self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config["vace_layers"])} + + def infer(self, weights, pre_infer_out): + self.get_scheduler_values() + pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context) + self.infer_vace_blocks(weights.vace_blocks, pre_infer_out) + x = self.infer_main_blocks(weights.blocks, pre_infer_out) + return self.infer_non_blocks(weights, x, pre_infer_out.embed) + + def vace_pre_process(self, patch_embedding, vace_context): + c = patch_embedding.apply(vace_context.unsqueeze(0).to(self.sensitive_layer_dtype)) + c = c.flatten(2).transpose(1, 2).contiguous().squeeze(0) + return c + + def infer_vace_blocks(self, vace_blocks, pre_infer_out): + pre_infer_out.adapter_args["hints"] = [] + self.infer_state = "vace" + if hasattr(self, "offload_manager"): + self.offload_manager.init_cuda_buffer(self.vace_offload_block_cuda_buffers, self.vace_offload_phase_cuda_buffers) + self.infer_func(vace_blocks, pre_infer_out.c, pre_infer_out) + self.infer_state = "base" + if hasattr(self, "offload_manager"): + self.offload_manager.init_cuda_buffer(self.offload_block_cuda_buffers, self.offload_phase_cuda_buffers) + + def post_process(self, x, y, c_gate_msa, pre_infer_out): + x = super().post_process(x, y, c_gate_msa, pre_infer_out) + if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping: + hint_idx = self.vace_blocks_mapping[self.block_idx] + x = x + pre_infer_out.adapter_args["hints"][hint_idx] * pre_infer_out.adapter_args.get("context_scale", 1.0) + return x diff --git a/lightx2v/models/networks/wan/lora_adapter.py b/lightx2v/models/networks/wan/lora_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..93a46923d9616399e27cb659f4625621dcb25c45 --- /dev/null +++ b/lightx2v/models/networks/wan/lora_adapter.py @@ -0,0 +1,131 @@ +import gc +import os + +import torch +from loguru import logger +from safetensors import safe_open + +from lightx2v.utils.envs import * + + +class WanLoraWrapper: + def __init__(self, wan_model): + self.model = wan_model + self.lora_metadata = {} + self.override_dict = {} # On CPU + + def load_lora(self, lora_path, lora_name=None): + if lora_name is None: + lora_name = os.path.basename(lora_path).split(".")[0] + + if lora_name in self.lora_metadata: + logger.info(f"LoRA {lora_name} already loaded, skipping...") + return lora_name + + self.lora_metadata[lora_name] = {"path": lora_path} + logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}") + + return lora_name + + def _load_lora_file(self, file_path): + with safe_open(file_path, framework="pt") as f: + tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()} + return tensor_dict + + def apply_lora(self, lora_name, alpha=1.0): + if lora_name not in self.lora_metadata: + logger.info(f"LoRA {lora_name} not found. Please load it first.") + + if not hasattr(self.model, "original_weight_dict"): + logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.") + return False + + lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"]) + weight_dict = self.model.original_weight_dict + self._apply_lora_weights(weight_dict, lora_weights, alpha) + self.model._apply_weights(weight_dict) + + logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") + del lora_weights + return True + + @torch.no_grad() + def _apply_lora_weights(self, weight_dict, lora_weights, alpha): + lora_pairs = {} + lora_diffs = {} + + def try_lora_pair(key, prefix, suffix_a, suffix_b, target_suffix): + if key.endswith(suffix_a): + base_name = key[len(prefix) :].replace(suffix_a, target_suffix) + pair_key = key.replace(suffix_a, suffix_b) + if pair_key in lora_weights: + lora_pairs[base_name] = (key, pair_key) + + def try_lora_diff(key, prefix, suffix, target_suffix): + if key.endswith(suffix): + base_name = key[len(prefix) :].replace(suffix, target_suffix) + lora_diffs[base_name] = key + + prefixs = [ + "", # empty prefix + "diffusion_model.", + ] + for prefix in prefixs: + for key in lora_weights.keys(): + if not key.startswith(prefix): + continue + + try_lora_pair(key, prefix, "lora_A.weight", "lora_B.weight", "weight") + try_lora_pair(key, prefix, "lora_down.weight", "lora_up.weight", "weight") + try_lora_diff(key, prefix, "diff", "weight") + try_lora_diff(key, prefix, "diff_b", "bias") + try_lora_diff(key, prefix, "diff_m", "modulation") + + applied_count = 0 + for name, param in weight_dict.items(): + if name in lora_pairs: + if name not in self.override_dict: + self.override_dict[name] = param.clone().cpu() + name_lora_A, name_lora_B = lora_pairs[name] + lora_A = lora_weights[name_lora_A].to(param.device, param.dtype) + lora_B = lora_weights[name_lora_B].to(param.device, param.dtype) + if param.shape == (lora_B.shape[0], lora_A.shape[1]): + param += torch.matmul(lora_B, lora_A) * alpha + applied_count += 1 + elif name in lora_diffs: + if name not in self.override_dict: + self.override_dict[name] = param.clone().cpu() + + name_diff = lora_diffs[name] + lora_diff = lora_weights[name_diff].to(param.device, param.dtype) + if param.shape == lora_diff.shape: + param += lora_diff * alpha + applied_count += 1 + + logger.info(f"Applied {applied_count} LoRA weight adjustments") + if applied_count == 0: + logger.info( + "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model..lora_A.weight' and 'diffusion_model..lora_B.weight'. Please verify the LoRA weight file." + ) + + @torch.no_grad() + def remove_lora(self): + logger.info(f"Removing LoRA ...") + + restored_count = 0 + for k, v in self.override_dict.items(): + self.model.original_weight_dict[k] = v.to(self.model.device) + restored_count += 1 + + logger.info(f"LoRA removed, restored {restored_count} weights") + + self.model._apply_weights(self.model.original_weight_dict) + + torch.cuda.empty_cache() + gc.collect() + + self.lora_metadata = {} + self.override_dict = {} + + def list_loaded_loras(self): + return list(self.lora_metadata.keys()) diff --git a/lightx2v/models/networks/wan/matrix_game2_model.py b/lightx2v/models/networks/wan/matrix_game2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..beedb018f51b5b8de25d8105038ef91d70ddc0c7 --- /dev/null +++ b/lightx2v/models/networks/wan/matrix_game2_model.py @@ -0,0 +1,48 @@ +import json +import os + +import torch +from safetensors import safe_open + +from lightx2v.models.networks.wan.infer.matrix_game2.pre_infer import WanMtxg2PreInfer +from lightx2v.models.networks.wan.infer.matrix_game2.transformer_infer import WanMtxg2TransformerInfer +from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer +from lightx2v.models.networks.wan.sf_model import WanSFModel +from lightx2v.models.networks.wan.weights.matrix_game2.pre_weights import WanMtxg2PreWeights +from lightx2v.models.networks.wan.weights.matrix_game2.transformer_weights import WanActionTransformerWeights +from lightx2v.utils.envs import * +from lightx2v.utils.utils import * + + +class WanSFMtxg2Model(WanSFModel): + pre_weight_class = WanMtxg2PreWeights + transformer_weight_class = WanActionTransformerWeights + + def __init__(self, model_path, config, device): + super().__init__(model_path, config, device) + + def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): + with safe_open(file_path, framework="pt", device=str(self.device)) as f: + return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()} + + def _load_ckpt(self, unified_dtype, sensitive_layer): + file_path = os.path.join(self.config["model_path"], f"{self.config['sub_model_folder']}/{self.config['sub_model_name']}") + _weight_dict = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) + weight_dict = {} + for k, v in _weight_dict.items(): + name = k[6:] + weight = v.to(torch.bfloat16).to(self.device) + weight_dict.update({name: weight}) + del _weight_dict + return weight_dict + + def _init_infer_class(self): + # update config by real model config + with open(os.path.join(self.config["model_path"], self.config["sub_model_folder"], "config.json")) as f: + model_config = json.load(f) + for k in model_config.keys(): + self.config[k] = model_config[k] + + self.pre_infer_class = WanMtxg2PreInfer + self.post_infer_class = WanPostInfer + self.transformer_infer_class = WanMtxg2TransformerInfer diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py new file mode 100644 index 0000000000000000000000000000000000000000..182dc6d1ca8ff787fbdd8d08db65bfb1a0067e3f --- /dev/null +++ b/lightx2v/models/networks/wan/model.py @@ -0,0 +1,505 @@ +import gc +import glob +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from loguru import logger +from safetensors import safe_open + +from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import ( + WanTransformerInferAdaCaching, + WanTransformerInferCustomCaching, + WanTransformerInferDualBlock, + WanTransformerInferDynamicBlock, + WanTransformerInferFirstBlock, + WanTransformerInferMagCaching, + WanTransformerInferTaylorCaching, + WanTransformerInferTeaCaching, +) +from lightx2v.models.networks.wan.infer.offload.transformer_infer import ( + WanOffloadTransformerInfer, +) +from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer +from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer +from lightx2v.models.networks.wan.infer.transformer_infer import ( + WanTransformerInfer, +) +from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights +from lightx2v.models.networks.wan.weights.transformer_weights import ( + WanTransformerWeights, +) +from lightx2v.utils.custom_compiler import CompiledMethodsMixin, compiled_method +from lightx2v.utils.envs import * +from lightx2v.utils.ggml_tensor import load_gguf_sd_ckpt +from lightx2v.utils.utils import * + + +class WanModel(CompiledMethodsMixin): + pre_weight_class = WanPreWeights + transformer_weight_class = WanTransformerWeights + + def __init__(self, model_path, config, device, model_type="wan2.1"): + super().__init__() + self.model_path = model_path + self.config = config + self.cpu_offload = self.config.get("cpu_offload", False) + self.offload_granularity = self.config.get("offload_granularity", "block") + self.model_type = model_type + self.remove_keys = [] + self.lazy_load = self.config.get("lazy_load", False) + if self.lazy_load: + self.remove_keys.extend(["blocks."]) + if self.config["seq_parallel"]: + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + + self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) + self.dit_quantized = self.config.get("dit_quantized", False) + if self.dit_quantized: + assert self.config.get("dit_quant_scheme", "Default") in [ + "Default-Force-FP32", + "fp8-vllm", + "int8-vllm", + "fp8-q8f", + "int8-q8f", + "fp8-b128-deepgemm", + "fp8-sgl", + "int8-sgl", + "int8-torchao", + "nvfp4", + "mxfp4", + "mxfp6-mxfp8", + "mxfp8", + "int8-tmo", + "gguf-Q8_0", + "gguf-Q6_K", + "gguf-Q5_K_S", + "gguf-Q5_K_M", + "gguf-Q5_0", + "gguf-Q5_1", + "gguf-Q4_K_S", + "gguf-Q4_K_M", + "gguf-Q4_0", + "gguf-Q4_1", + "gguf-Q3_K_S", + "gguf-Q3_K_M", + ] + self.device = device + self._init_infer_class() + self._init_weights() + self._init_infer() + + def _init_infer_class(self): + self.pre_infer_class = WanPreInfer + self.post_infer_class = WanPostInfer + + if self.config["feature_caching"] == "NoCaching": + self.transformer_infer_class = WanTransformerInfer if not self.cpu_offload else WanOffloadTransformerInfer + elif self.config["feature_caching"] == "Tea": + self.transformer_infer_class = WanTransformerInferTeaCaching + elif self.config["feature_caching"] == "TaylorSeer": + self.transformer_infer_class = WanTransformerInferTaylorCaching + elif self.config["feature_caching"] == "Ada": + self.transformer_infer_class = WanTransformerInferAdaCaching + elif self.config["feature_caching"] == "Custom": + self.transformer_infer_class = WanTransformerInferCustomCaching + elif self.config["feature_caching"] == "FirstBlock": + self.transformer_infer_class = WanTransformerInferFirstBlock + elif self.config["feature_caching"] == "DualBlock": + self.transformer_infer_class = WanTransformerInferDualBlock + elif self.config["feature_caching"] == "DynamicBlock": + self.transformer_infer_class = WanTransformerInferDynamicBlock + elif self.config["feature_caching"] == "Mag": + self.transformer_infer_class = WanTransformerInferMagCaching + else: + raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") + + def _should_load_weights(self): + """Determine if current rank should load weights from disk.""" + if self.config.get("device_mesh") is None: + # Single GPU mode + return True + elif dist.is_initialized(): + if self.config.get("load_from_rank0", False): + # Multi-GPU mode, only rank 0 loads + if dist.get_rank() == 0: + logger.info(f"Loading weights from {self.model_path}") + return True + else: + return True + return False + + def _should_init_empty_model(self): + if self.config.get("lora_configs") and self.config["lora_configs"]: + if self.model_type in ["wan2.1"]: + return True + if self.model_type in ["wan2.2_moe_high_noise"]: + for lora_config in self.config["lora_configs"]: + if lora_config["name"] == "high_noise_model": + return True + if self.model_type in ["wan2.2_moe_low_noise"]: + for lora_config in self.config["lora_configs"]: + if lora_config["name"] == "low_noise_model": + return True + return False + + def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): + remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] + + if self.device.type != "cpu" and dist.is_initialized(): + device = dist.get_rank() + else: + device = str(self.device) + + with safe_open(file_path, framework="pt", device=device) as f: + return { + key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) + for key in f.keys() + if not any(remove_key in key for remove_key in remove_keys) + } + + def _load_ckpt(self, unified_dtype, sensitive_layer): + if self.config.get("dit_original_ckpt", None): + safetensors_path = self.config["dit_original_ckpt"] + else: + safetensors_path = self.model_path + + if os.path.isdir(safetensors_path): + if self.lazy_load: + self.lazy_load_path = safetensors_path + non_block_file = os.path.join(safetensors_path, "non_block.safetensors") + if os.path.exists(non_block_file): + safetensors_files = [non_block_file] + else: + raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path.") + else: + safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) + else: + if self.lazy_load: + self.lazy_load_path = safetensors_path + safetensors_files = [safetensors_path] + + weight_dict = {} + for file_path in safetensors_files: + if self.config.get("adapter_model_path", None) is not None: + if self.config["adapter_model_path"] == file_path: + continue + logger.info(f"Loading weights from {file_path}") + file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) + weight_dict.update(file_weights) + + return weight_dict + + def _load_quant_ckpt(self, unified_dtype, sensitive_layer): + remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] + if self.config.get("dit_quantized_ckpt", None): + safetensors_path = self.config["dit_quantized_ckpt"] + else: + safetensors_path = self.model_path + + if "gguf" in self.config.get("dit_quant_scheme", ""): + gguf_path = "" + if os.path.isdir(safetensors_path): + gguf_type = self.config.get("dit_quant_scheme").replace("gguf-", "") + gguf_files = list(filter(lambda x: gguf_type in x, glob.glob(os.path.join(safetensors_path, "*.gguf")))) + gguf_path = gguf_files[0] + else: + gguf_path = safetensors_path + weight_dict = self._load_gguf_ckpt(gguf_path) + return weight_dict + + if os.path.isdir(safetensors_path): + if self.lazy_load: + self.lazy_load_path = safetensors_path + non_block_file = os.path.join(safetensors_path, "non_block.safetensors") + if os.path.exists(non_block_file): + safetensors_files = [non_block_file] + else: + raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path.") + else: + safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) + else: + if self.lazy_load: + self.lazy_load_path = safetensors_path + safetensors_files = [safetensors_path] + safetensors_path = os.path.dirname(safetensors_path) + + weight_dict = {} + for safetensor_path in safetensors_files: + if self.config.get("adapter_model_path", None) is not None: + if self.config["adapter_model_path"] == safetensor_path: + continue + + with safe_open(safetensor_path, framework="pt") as f: + logger.info(f"Loading weights from {safetensor_path}") + for k in f.keys(): + if any(remove_key in k for remove_key in remove_keys): + continue + if f.get_tensor(k).dtype in [ + torch.float16, + torch.bfloat16, + torch.float, + ]: + if unified_dtype or all(s not in k for s in sensitive_layer): + weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device) + else: + weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device) + else: + weight_dict[k] = f.get_tensor(k).to(self.device) + + if self.config.get("dit_quant_scheme", "Default") == "nvfp4": + calib_path = os.path.join(safetensors_path, "calib.pt") + logger.info(f"[CALIB] Loaded calibration data from: {calib_path}") + calib_data = torch.load(calib_path, map_location="cpu") + for k, v in calib_data["absmax"].items(): + weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device) + + return weight_dict + + def _load_gguf_ckpt(self, gguf_path): + state_dict = load_gguf_sd_ckpt(gguf_path, to_device=self.device) + return state_dict + + def _init_weights(self, weight_dict=None): + unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE() + # Some layers run with float32 to achieve high accuracy + sensitive_layer = { + "norm", + "embedding", + "modulation", + "time", + "img_emb.proj.0", + "img_emb.proj.4", + "before_proj", # vace + "after_proj", # vace + } + + if weight_dict is None: + is_weight_loader = self._should_load_weights() + if is_weight_loader: + if not self.dit_quantized: + # Load original weights + weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) + else: + # Load quantized weights + weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) + + if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False): + weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader) + + if hasattr(self, "adapter_weights_dict"): + weight_dict.update(self.adapter_weights_dict) + + self.original_weight_dict = weight_dict + else: + self.original_weight_dict = weight_dict + + # Initialize weight containers + self.pre_weight = self.pre_weight_class(self.config) + if self.lazy_load: + self.transformer_weights = self.transformer_weight_class(self.config, self.lazy_load_path) + else: + self.transformer_weights = self.transformer_weight_class(self.config) + if not self._should_init_empty_model(): + self._apply_weights() + + def _apply_weights(self, weight_dict=None): + if weight_dict is not None: + self.original_weight_dict = weight_dict + del weight_dict + gc.collect() + # Load weights into containers + self.pre_weight.load(self.original_weight_dict) + self.transformer_weights.load(self.original_weight_dict) + + del self.original_weight_dict + torch.cuda.empty_cache() + gc.collect() + + def _load_weights_from_rank0(self, weight_dict, is_weight_loader): + logger.info("Loading distributed weights") + global_src_rank = 0 + target_device = "cpu" if self.cpu_offload else "cuda" + + if is_weight_loader: + meta_dict = {} + for key, tensor in weight_dict.items(): + meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} + + obj_list = [meta_dict] + dist.broadcast_object_list(obj_list, src=global_src_rank) + synced_meta_dict = obj_list[0] + else: + obj_list = [None] + dist.broadcast_object_list(obj_list, src=global_src_rank) + synced_meta_dict = obj_list[0] + + distributed_weight_dict = {} + for key, meta in synced_meta_dict.items(): + distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) + + if target_device == "cuda": + dist.barrier(device_ids=[torch.cuda.current_device()]) + + for key in sorted(synced_meta_dict.keys()): + if is_weight_loader: + distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True) + + if target_device == "cpu": + if is_weight_loader: + gpu_tensor = distributed_weight_dict[key].cuda() + dist.broadcast(gpu_tensor, src=global_src_rank) + distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) + del gpu_tensor + torch.cuda.empty_cache() + else: + gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda") + dist.broadcast(gpu_tensor, src=global_src_rank) + distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) + del gpu_tensor + torch.cuda.empty_cache() + + if distributed_weight_dict[key].is_pinned(): + distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True) + else: + dist.broadcast(distributed_weight_dict[key], src=global_src_rank) + + if target_device == "cuda": + torch.cuda.synchronize() + else: + for tensor in distributed_weight_dict.values(): + if tensor.is_pinned(): + tensor.copy_(tensor, non_blocking=False) + + logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") + + return distributed_weight_dict + + def _init_infer(self): + self.pre_infer = self.pre_infer_class(self.config) + self.post_infer = self.post_infer_class(self.config) + self.transformer_infer = self.transformer_infer_class(self.config) + if hasattr(self.transformer_infer, "offload_manager"): + self._init_offload_manager() + + def _init_offload_manager(self): + self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers) + if self.lazy_load: + self.transformer_infer.offload_manager.init_cpu_buffer(self.transformer_weights.offload_block_cpu_buffers, self.transformer_weights.offload_phase_cpu_buffers) + if self.config.get("warm_up_cpu_buffers", False): + self.transformer_infer.offload_manager.warm_up_cpu_buffers(self.transformer_weights.blocks_num) + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + self.pre_infer.set_scheduler(scheduler) + self.post_infer.set_scheduler(scheduler) + self.transformer_infer.set_scheduler(scheduler) + + def to_cpu(self): + self.pre_weight.to_cpu() + self.transformer_weights.to_cpu() + + def to_cuda(self): + self.pre_weight.to_cuda() + self.transformer_weights.to_cuda() + + @torch.no_grad() + def infer(self, inputs): + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]: + self.to_cuda() + elif self.offload_granularity != "model": + self.pre_weight.to_cuda() + self.transformer_weights.non_block_weights_to_cuda() + + if self.config["enable_cfg"]: + if self.config["cfg_parallel"]: + # ==================== CFG Parallel Processing ==================== + cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p") + assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2" + cfg_p_rank = dist.get_rank(cfg_p_group) + + if cfg_p_rank == 0: + noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) + else: + noise_pred = self._infer_cond_uncond(inputs, infer_condition=False) + + noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)] + dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group) + noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0 + noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1 + else: + # ==================== CFG Processing ==================== + noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True) + noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False) + + self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) + else: + # ==================== No CFG ==================== + self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) + + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]: + self.to_cpu() + elif self.offload_granularity != "model": + self.pre_weight.to_cpu() + self.transformer_weights.non_block_weights_to_cpu() + + @compiled_method() + @torch.no_grad() + def _infer_cond_uncond(self, inputs, infer_condition=True): + self.scheduler.infer_condition = infer_condition + + pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs) + + if self.config["seq_parallel"]: + pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) + + x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out) + + if self.config["seq_parallel"]: + x = self._seq_parallel_post_process(x) + + noise_pred = self.post_infer.infer(x, pre_infer_out)[0] + + if self.clean_cuda_cache: + del x, pre_infer_out + torch.cuda.empty_cache() + + return noise_pred + + @torch.no_grad() + def _seq_parallel_pre_process(self, pre_infer_out): + x = pre_infer_out.x + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + + padding_size = (world_size - (x.shape[0] % world_size)) % world_size + if padding_size > 0: + x = F.pad(x, (0, 0, 0, padding_size)) + + pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank] + + if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v"]: + embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0 + + padding_size = (world_size - (embed.shape[0] % world_size)) % world_size + if padding_size > 0: + embed = F.pad(embed, (0, 0, 0, padding_size)) + embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) + + pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank] + pre_infer_out.embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank] + + return pre_infer_out + + @torch.no_grad() + def _seq_parallel_post_process(self, x): + world_size = dist.get_world_size(self.seq_p_group) + gathered_x = [torch.empty_like(x) for _ in range(world_size)] + dist.all_gather(gathered_x, x, group=self.seq_p_group) + combined_output = torch.cat(gathered_x, dim=0) + return combined_output diff --git a/lightx2v/models/networks/wan/sf_model.py b/lightx2v/models/networks/wan/sf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a813ef407a8bf6cdc6d32c089b66fa96b56919ea --- /dev/null +++ b/lightx2v/models/networks/wan/sf_model.py @@ -0,0 +1,53 @@ +import os + +import torch + +from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer +from lightx2v.models.networks.wan.infer.self_forcing.pre_infer import WanSFPreInfer +from lightx2v.models.networks.wan.infer.self_forcing.transformer_infer import WanSFTransformerInfer +from lightx2v.models.networks.wan.model import WanModel + + +class WanSFModel(WanModel): + def __init__(self, model_path, config, device): + super().__init__(model_path, config, device) + if config["model_cls"] not in ["wan2.1_sf_mtxg2"]: + self.to_cuda() + + def _load_ckpt(self, unified_dtype, sensitive_layer): + sf_confg = self.config["sf_config"] + file_path = os.path.join(self.config["sf_model_path"], f"checkpoints/self_forcing_{sf_confg['sf_type']}.pt") + _weight_dict = torch.load(file_path)["generator_ema"] + weight_dict = {} + for k, v in _weight_dict.items(): + name = k[6:] + weight = v.to(torch.bfloat16) + weight_dict.update({name: weight}) + del _weight_dict + return weight_dict + + def _init_infer_class(self): + self.pre_infer_class = WanSFPreInfer + self.post_infer_class = WanPostInfer + self.transformer_infer_class = WanSFTransformerInfer + + @torch.no_grad() + def infer(self, inputs): + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == 0: + self.to_cuda() + elif self.offload_granularity != "model": + self.pre_weight.to_cuda() + self.transformer_weights.non_block_weights_to_cuda() + + current_start_frame = self.scheduler.seg_index * self.scheduler.num_frame_per_block + current_end_frame = (self.scheduler.seg_index + 1) * self.scheduler.num_frame_per_block + noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) + + self.scheduler.noise_pred[:, current_start_frame:current_end_frame] = noise_pred + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: + self.to_cpu() + elif self.offload_granularity != "model": + self.pre_weight.to_cpu() + self.transformer_weights.non_block_weights_to_cpu() diff --git a/lightx2v/models/networks/wan/vace_model.py b/lightx2v/models/networks/wan/vace_model.py new file mode 100644 index 0000000000000000000000000000000000000000..452729ba71f4551b023e5b5d66447f130063bad3 --- /dev/null +++ b/lightx2v/models/networks/wan/vace_model.py @@ -0,0 +1,55 @@ +import torch + +from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer +from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer +from lightx2v.models.networks.wan.infer.vace.transformer_infer import WanVaceTransformerInfer +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights +from lightx2v.models.networks.wan.weights.vace.transformer_weights import ( + WanVaceTransformerWeights, +) +from lightx2v.utils.envs import * +from lightx2v.utils.utils import * + + +class WanVaceModel(WanModel): + pre_weight_class = WanPreWeights + transformer_weight_class = WanVaceTransformerWeights + + def __init__(self, model_path, config, device): + super().__init__(model_path, config, device) + + def _init_infer(self): + super()._init_infer() + if hasattr(self.transformer_infer, "offload_manager"): + self.transformer_infer.offload_block_cuda_buffers = self.transformer_weights.offload_block_cuda_buffers + self.transformer_infer.offload_phase_cuda_buffers = self.transformer_weights.offload_phase_cuda_buffers + self.transformer_infer.vace_offload_block_cuda_buffers = self.transformer_weights.vace_offload_block_cuda_buffers + self.transformer_infer.vace_offload_phase_cuda_buffers = self.transformer_weights.vace_offload_phase_cuda_buffers + if self.lazy_load: + self.transformer_infer.offload_block_cpu_buffers = self.transformer_weights.offload_block_cpu_buffers + self.transformer_infer.offload_phase_cpu_buffers = self.transformer_weights.offload_phase_cpu_buffers + self.transformer_infer.vace_offload_block_cpu_buffers = self.transformer_weights.vace_offload_block_cpu_buffers + self.transformer_infer.vace_offload_phase_cpu_buffers = self.transformer_weights.vace_offload_phase_cpu_buffers + + def _init_infer_class(self): + self.pre_infer_class = WanPreInfer + self.post_infer_class = WanPostInfer + self.transformer_infer_class = WanVaceTransformerInfer + + @torch.no_grad() + def _infer_cond_uncond(self, inputs, infer_condition=True): + self.scheduler.infer_condition = infer_condition + + pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs) + pre_infer_out.vace_context = inputs["image_encoder_output"]["vae_encoder_out"][0] + + x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out) + + noise_pred = self.post_infer.infer(x, pre_infer_out)[0] + + if self.clean_cuda_cache: + del x, pre_infer_out + torch.cuda.empty_cache() + + return noise_pred diff --git a/lightx2v/models/networks/wan/weights/animate/transformer_weights.py b/lightx2v/models/networks/wan/weights/animate/transformer_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..c399df6302c101557205bbe5e2cb1595fb90bc38 --- /dev/null +++ b/lightx2v/models/networks/wan/weights/animate/transformer_weights.py @@ -0,0 +1,127 @@ +import os + +from safetensors import safe_open + +from lightx2v.common.modules.weight_module import WeightModule +from lightx2v.models.networks.wan.weights.transformer_weights import ( + WanTransformerWeights, +) +from lightx2v.utils.registry_factory import ( + ATTN_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, + RMS_WEIGHT_REGISTER, +) + + +class WanAnimateTransformerWeights(WanTransformerWeights): + def __init__(self, config): + super().__init__(config) + self.adapter_blocks_num = self.blocks_num // 5 + for i in range(self.blocks_num): + if i % 5 == 0: + self.blocks[i].compute_phases.append(WanAnimateFuserBlock(self.config, i // 5, "face_adapter.fuser_blocks", self.mm_type)) + else: + self.blocks[i].compute_phases.append(WeightModule()) + self._add_animate_fuserblock_to_offload_buffers() + + def _add_animate_fuserblock_to_offload_buffers(self): + if hasattr(self, "offload_block_cuda_buffers") and self.offload_block_cuda_buffers is not None: + for i in range(self.offload_blocks_num): + self.offload_block_cuda_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cuda_buffer=True)) + if self.lazy_load: + self.offload_block_cpu_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cpu_buffer=True)) + elif hasattr(self, "offload_phase_cuda_buffers") and self.offload_phase_cuda_buffers is not None: + self.offload_phase_cuda_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cuda_buffer=True)) + if self.lazy_load: + self.offload_phase_cpu_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cpu_buffer=True)) + + +class WanAnimateFuserBlock(WeightModule): + def __init__(self, config, block_index, block_prefix, mm_type, create_cuda_buffer=False, create_cpu_buffer=False): + super().__init__() + self.config = config + self.is_post_adapter = True + lazy_load = config.get("lazy_load", False) + if lazy_load: + lazy_load_path = os.path.join( + config.dit_quantized_ckpt, + f"{block_prefix[:-1]}_{block_index}.safetensors", + ) + lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu") + else: + lazy_load_file = None + + self.add_module( + "linear1_kv", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.linear1_kv.weight", + f"{block_prefix}.{block_index}.linear1_kv.bias", + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + self.is_post_adapter, + ), + ) + + self.add_module( + "linear1_q", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.linear1_q.weight", + f"{block_prefix}.{block_index}.linear1_q.bias", + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + self.is_post_adapter, + ), + ) + self.add_module( + "linear2", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.linear2.weight", + f"{block_prefix}.{block_index}.linear2.bias", + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + self.is_post_adapter, + ), + ) + + self.add_module( + "q_norm", + RMS_WEIGHT_REGISTER["sgl-kernel"]( + f"{block_prefix}.{block_index}.q_norm.weight", + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + self.is_post_adapter, + ), + ) + + self.add_module( + "k_norm", + RMS_WEIGHT_REGISTER["sgl-kernel"]( + f"{block_prefix}.{block_index}.k_norm.weight", + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + self.is_post_adapter, + ), + ) + + self.add_module( + "pre_norm_feat", + LN_WEIGHT_REGISTER["Default"](), + ) + + self.add_module( + "pre_norm_motion", + LN_WEIGHT_REGISTER["Default"](), + ) + + self.add_module("adapter_attn", ATTN_WEIGHT_REGISTER[config["adapter_attn_type"]]()) diff --git a/lightx2v/models/networks/wan/weights/audio/transformer_weights.py b/lightx2v/models/networks/wan/weights/audio/transformer_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..097c6554fba080e0022bf616a5537399078249d0 --- /dev/null +++ b/lightx2v/models/networks/wan/weights/audio/transformer_weights.py @@ -0,0 +1,161 @@ +from lightx2v.common.modules.weight_module import WeightModule +from lightx2v.models.networks.wan.weights.transformer_weights import WanTransformerWeights +from lightx2v.utils.registry_factory import ( + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, + TENSOR_REGISTER, +) + + +class WanAudioTransformerWeights(WanTransformerWeights): + def __init__(self, config): + super().__init__(config) + for i in range(self.blocks_num): + self.blocks[i].compute_phases.append( + WanAudioAdapterCA( + i, + f"ca", + self.task, + self.mm_type, + self.config, + False, + False, + self.blocks[i].lazy_load, + self.blocks[i].lazy_load_file, + ) + ) + + self._add_audio_adapter_ca_to_offload_buffers() + + def _add_audio_adapter_ca_to_offload_buffers(self): + if hasattr(self, "offload_block_cuda_buffers") and self.offload_block_cuda_buffers is not None: + for i in range(self.offload_blocks_num): + offload_buffer = self.offload_block_cuda_buffers[i] + adapter_ca = WanAudioAdapterCA( + block_index=i, + block_prefix=f"ca", + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=True, + create_cpu_buffer=False, + lazy_load=offload_buffer.lazy_load, + lazy_load_file=offload_buffer.lazy_load_file, + ) + offload_buffer.compute_phases.append(adapter_ca) + if self.lazy_load: + offload_buffer = self.offload_block_cpu_buffers[i] + adapter_ca = WanAudioAdapterCA( + block_index=i, + block_prefix=f"ca", + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=False, + create_cpu_buffer=True, + lazy_load=offload_buffer.lazy_load, + lazy_load_file=offload_buffer.lazy_load_file, + ) + offload_buffer.compute_phases.append(adapter_ca) + + elif hasattr(self, "offload_phase_cuda_buffers") and self.offload_phase_cuda_buffers is not None: + adapter_ca = WanAudioAdapterCA( + block_index=0, + block_prefix=f"ca", + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=True, + create_cpu_buffer=False, + lazy_load=self.blocks[0].lazy_load, + lazy_load_file=self.blocks[0].lazy_load_file, + ) + self.offload_phase_cuda_buffers.append(adapter_ca) + if self.lazy_load: + adapter_ca = WanAudioAdapterCA( + block_index=0, + block_prefix=f"ca", + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=False, + create_cpu_buffer=True, + lazy_load=self.blocks[0].lazy_load, + lazy_load_file=self.blocks[0].lazy_load_file, + ) + self.offload_phase_cpu_buffers.append(adapter_ca) + + +class WanAudioAdapterCA(WeightModule): + def __init__(self, block_index, block_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + + self.add_module( + "to_q", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{block_index}.to_q.weight", + f"{block_prefix}.{block_index}.to_q.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + self.add_module( + "to_kv", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{block_index}.to_kv.weight", + f"{block_prefix}.{block_index}.to_kv.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + self.add_module( + "to_out", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{block_index}.to_out.weight", + f"{block_prefix}.{block_index}.to_out.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + self.add_module( + "norm_kv", + LN_WEIGHT_REGISTER["Default"]( + f"{block_prefix}.{block_index}.norm_kv.weight", + f"{block_prefix}.{block_index}.norm_kv.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + self.add_module( + "norm_q", + LN_WEIGHT_REGISTER["Default"](), + ) + + self.add_module( + "shift_scale_gate", + TENSOR_REGISTER["Default"]( + f"{block_prefix}.{block_index}.shift_scale_gate", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) diff --git a/lightx2v/models/networks/wan/weights/matrix_game2/pre_weights.py b/lightx2v/models/networks/wan/weights/matrix_game2/pre_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5f4aa55c66cd6df3c3d197edf829e774ee1b53 --- /dev/null +++ b/lightx2v/models/networks/wan/weights/matrix_game2/pre_weights.py @@ -0,0 +1,50 @@ +from lightx2v.common.modules.weight_module import WeightModule +from lightx2v.utils.registry_factory import ( + CONV3D_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, +) + + +class WanMtxg2PreWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.in_dim = config["in_dim"] + self.dim = config["dim"] + self.patch_size = (1, 2, 2) + self.config = config + # patch + self.add_module( + "patch_embedding", + CONV3D_WEIGHT_REGISTER["Default"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size), + ) + # time + self.add_module( + "time_embedding_0", + MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"), + ) + self.add_module( + "time_embedding_2", + MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"), + ) + self.add_module( + "time_projection_1", + MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"), + ) + # img_emb + self.add_module( + "img_emb_0", + LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5), + ) + self.add_module( + "img_emb_1", + MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"), + ) + self.add_module( + "img_emb_3", + MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"), + ) + self.add_module( + "img_emb_4", + LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5), + ) diff --git a/lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py b/lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..c3668a465fd6d3d78ec8224e2778e2b394a1e347 --- /dev/null +++ b/lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py @@ -0,0 +1,244 @@ +from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList +from lightx2v.models.networks.wan.weights.transformer_weights import ( + WanFFN, + WanSelfAttention, + WanTransformerAttentionBlock, +) +from lightx2v.utils.registry_factory import ( + ATTN_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, + RMS_WEIGHT_REGISTER, + TENSOR_REGISTER, +) + + +class WanActionTransformerWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.blocks_num = config["num_layers"] + self.task = config["task"] + self.config = config + self.mm_type = config.get("dit_quant_scheme", "Default") + if self.mm_type != "Default": + assert config.get("dit_quantized") is True + + action_blocks = config["action_config"]["blocks"] + block_list = [] + for i in range(self.blocks_num): + if i in action_blocks: + block_list.append(WanTransformerActionBlock(i, self.task, self.mm_type, self.config)) + else: + block_list.append(WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config)) + self.blocks = WeightModuleList(block_list) + self.add_module("blocks", self.blocks) + + # non blocks weights + self.register_parameter("norm", LN_WEIGHT_REGISTER["Default"]()) + self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")) + self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation")) + + def non_block_weights_to_cuda(self): + self.norm.to_cuda() + self.head.to_cuda() + self.head_modulation.to_cuda() + + def non_block_weights_to_cpu(self): + self.norm.to_cpu() + self.head.to_cpu() + self.head_modulation.to_cpu() + + +class WanTransformerActionBlock(WeightModule): + def __init__(self, block_index, task, mm_type, config, block_prefix="blocks"): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + assert not self.config.get("lazy_load", False) + self.compute_phases = WeightModuleList( + [ + WanSelfAttention(block_index, block_prefix, task, mm_type, config), + WanActionCrossAttention( + block_index, + block_prefix, + task, + mm_type, + config, + ), + WanActionModule( + block_index, + block_prefix, + task, + mm_type, + config, + ), + WanFFN( + block_index, + block_prefix, + task, + mm_type, + config, + ), + ] + ) + + self.add_module("compute_phases", self.compute_phases) + + +class WanActionModule(WeightModule): + def __init__(self, block_index, block_prefix, task, mm_type, config): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + + self.attn_rms_type = "self_forcing" + + self.add_module( + "keyboard_embed_0", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.weight", + f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.bias", + ), + ) + self.add_module( + "keyboard_embed_2", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.weight", + f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.bias", + ), + ) + + self.add_module( + "proj_keyboard", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.proj_keyboard.weight", + bias_name=None, + ), + ) + + self.add_module( + "keyboard_attn_kv", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.keyboard_attn_kv.weight", + bias_name=None, + ), + ) + + self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["cross_attn_2_type"]]()) + + self.add_module( + "mouse_attn_q", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.mouse_attn_q.weight", + bias_name=None, + ), + ) + + if self.config["mode"] != "templerun": + self.add_module( + "t_qkv", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.t_qkv.weight", + bias_name=None, + ), + ) + + self.add_module( + "proj_mouse", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.proj_mouse.weight", + bias_name=None, + ), + ) + + self.add_module( + "mouse_mlp_0", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.weight", + f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.bias", + ), + ) + self.add_module( + "mouse_mlp_2", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.weight", + f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.bias", + ), + ) + self.add_module( + "mouse_mlp_3", + LN_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.weight", + f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.bias", + eps=1e-6, + ), + ) + + +class WanActionCrossAttention(WeightModule): + def __init__(self, block_index, block_prefix, task, mm_type, config): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + + if self.config.get("sf_config", False): + self.attn_rms_type = "self_forcing" + else: + self.attn_rms_type = "sgl-kernel" + + self.add_module( + "norm3", + LN_WEIGHT_REGISTER["Default"]( + f"{block_prefix}.{self.block_index}.norm3.weight", + f"{block_prefix}.{self.block_index}.norm3.bias", + ), + ) + self.add_module( + "cross_attn_q", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.q.weight", + f"{block_prefix}.{self.block_index}.cross_attn.q.bias", + ), + ) + self.add_module( + "cross_attn_k", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.k.weight", + f"{block_prefix}.{self.block_index}.cross_attn.k.bias", + ), + ) + self.add_module( + "cross_attn_v", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.v.weight", + f"{block_prefix}.{self.block_index}.cross_attn.v.bias", + ), + ) + self.add_module( + "cross_attn_o", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.o.weight", + f"{block_prefix}.{self.block_index}.cross_attn.o.bias", + ), + ) + self.add_module( + "cross_attn_norm_q", + RMS_WEIGHT_REGISTER[self.attn_rms_type]( + f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight", + ), + ) + self.add_module( + "cross_attn_norm_k", + RMS_WEIGHT_REGISTER[self.attn_rms_type]( + f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight", + ), + ) + self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]()) diff --git a/lightx2v/models/networks/wan/weights/post_weights.py b/lightx2v/models/networks/wan/weights/post_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..1a02b1259b03bbe7e5b46cbd077b245fcc73472b --- /dev/null +++ b/lightx2v/models/networks/wan/weights/post_weights.py @@ -0,0 +1,7 @@ +from lightx2v.common.modules.weight_module import WeightModule + + +class WanPostWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.config = config diff --git a/lightx2v/models/networks/wan/weights/pre_weights.py b/lightx2v/models/networks/wan/weights/pre_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..c6172014e3796528aaa0137473f12cc33ebd37d0 --- /dev/null +++ b/lightx2v/models/networks/wan/weights/pre_weights.py @@ -0,0 +1,80 @@ +from lightx2v.common.modules.weight_module import WeightModule +from lightx2v.utils.registry_factory import ( + CONV3D_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, + TENSOR_REGISTER, +) + + +class WanPreWeights(WeightModule): + def __init__(self, config): + super().__init__() + self.in_dim = config["in_dim"] + self.dim = config["dim"] + self.patch_size = (1, 2, 2) + self.config = config + + self.add_module( + "patch_embedding", + CONV3D_WEIGHT_REGISTER["Default"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size), + ) + self.add_module( + "text_embedding_0", + MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias"), + ) + self.add_module( + "text_embedding_2", + MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias"), + ) + self.add_module( + "time_embedding_0", + MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"), + ) + self.add_module( + "time_embedding_2", + MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"), + ) + self.add_module( + "time_projection_1", + MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"), + ) + + if config["task"] in ["i2v", "flf2v", "animate", "s2v"] and config.get("use_image_encoder", True): + self.add_module( + "proj_0", + LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"), + ) + self.add_module( + "proj_1", + MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"), + ) + self.add_module( + "proj_3", + MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"), + ) + self.add_module( + "proj_4", + LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias"), + ) + + if config["model_cls"] == "wan2.1_distill" and config.get("enable_dynamic_cfg", False): + self.add_module( + "cfg_cond_proj_1", + MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_1.weight", "guidance_embedding.linear_1.bias"), + ) + self.add_module( + "cfg_cond_proj_2", + MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_2.weight", "guidance_embedding.linear_2.bias"), + ) + + if config["task"] == "flf2v" and config.get("use_image_encoder", True): + self.add_module( + "emb_pos", + TENSOR_REGISTER["Default"](f"img_emb.emb_pos"), + ) + if config["task"] == "animate": + self.add_module( + "pose_patch_embedding", + CONV3D_WEIGHT_REGISTER["Default"]("pose_patch_embedding.weight", "pose_patch_embedding.bias", stride=self.patch_size), + ) diff --git a/lightx2v/models/networks/wan/weights/transformer_weights.py b/lightx2v/models/networks/wan/weights/transformer_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..6107c8ddd1b0641de2624b7c5ad68e5b0521a3a1 --- /dev/null +++ b/lightx2v/models/networks/wan/weights/transformer_weights.py @@ -0,0 +1,578 @@ +from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList +from lightx2v.utils.registry_factory import ( + ATTN_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, + RMS_WEIGHT_REGISTER, + TENSOR_REGISTER, +) + + +class WanTransformerWeights(WeightModule): + def __init__(self, config, lazy_load_path=None): + super().__init__() + self.blocks_num = config["num_layers"] + self.task = config["task"] + self.config = config + self.mm_type = config.get("dit_quant_scheme", "Default") + if self.mm_type != "Default": + assert config.get("dit_quantized") is True + if config.get("do_mm_calib", False): + self.mm_type = "Calib" + self.lazy_load = self.config.get("lazy_load", False) + self.blocks = WeightModuleList( + [ + WanTransformerAttentionBlock( + block_index=i, + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=False, + create_cpu_buffer=False, + block_prefix="blocks", + lazy_load=self.lazy_load, + lazy_load_path=lazy_load_path, + ) + for i in range(self.blocks_num) + ] + ) + self.register_offload_buffers(config, lazy_load_path) + self.add_module("blocks", self.blocks) + + # non blocks weights + self.register_parameter("norm", LN_WEIGHT_REGISTER["Default"]()) + self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")) + self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation")) + + def register_offload_buffers(self, config, lazy_load_path): + if config["cpu_offload"]: + if config["offload_granularity"] == "block": + self.offload_blocks_num = 2 + self.offload_block_cuda_buffers = WeightModuleList( + [ + WanTransformerAttentionBlock( + block_index=i, + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=True, + create_cpu_buffer=False, + block_prefix="blocks", + lazy_load=self.lazy_load, + lazy_load_path=lazy_load_path, + ) + for i in range(self.offload_blocks_num) + ] + ) + self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers) + self.offload_phase_cuda_buffers = None + + if self.lazy_load: + self.offload_blocks_num = 2 + self.offload_block_cpu_buffers = WeightModuleList( + [ + WanTransformerAttentionBlock( + block_index=i, + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=False, + create_cpu_buffer=True, + block_prefix="blocks", + lazy_load=self.lazy_load, + lazy_load_path=lazy_load_path, + ) + for i in range(self.offload_blocks_num) + ] + ) + self.add_module("offload_block_cpu_buffers", self.offload_block_cpu_buffers) + self.offload_phase_cpu_buffers = None + + elif config["offload_granularity"] == "phase": + self.offload_phase_cuda_buffers = WanTransformerAttentionBlock( + block_index=0, + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=True, + create_cpu_buffer=False, + block_prefix="blocks", + lazy_load=self.lazy_load, + lazy_load_path=lazy_load_path, + ).compute_phases + self.add_module("offload_phase_cuda_buffers", self.offload_phase_cuda_buffers) + self.offload_block_cuda_buffers = None + if self.lazy_load: + self.offload_phase_cpu_buffers = WeightModuleList( + [ + WanTransformerAttentionBlock( + block_index=i, + task=self.task, + mm_type=self.mm_type, + config=self.config, + create_cuda_buffer=False, + create_cpu_buffer=True, + block_prefix="blocks", + lazy_load=self.lazy_load, + lazy_load_path=lazy_load_path, + ).compute_phases + for i in range(2) + ] + ) + self.add_module("offload_phase_cpu_buffers", self.offload_phase_cpu_buffers) + self.offload_block_cpu_buffers = None + + def non_block_weights_to_cuda(self): + self.norm.to_cuda() + self.head.to_cuda() + self.head_modulation.to_cuda() + + def non_block_weights_to_cpu(self): + self.norm.to_cpu() + self.head.to_cpu() + self.head_modulation.to_cpu() + + +class WanTransformerAttentionBlock(WeightModule): + def __init__( + self, + block_index, + task, + mm_type, + config, + create_cuda_buffer=False, + create_cpu_buffer=False, + block_prefix="blocks", + lazy_load=False, + lazy_load_path=None, + ): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.create_cuda_buffer = create_cuda_buffer + self.create_cpu_buffer = create_cpu_buffer + self.quant_method = config.get("quant_method", None) + + self.lazy_load = lazy_load + if self.lazy_load: + self.lazy_load_file = lazy_load_path + else: + self.lazy_load_file = None + + self.compute_phases = WeightModuleList( + [ + WanSelfAttention( + block_index, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + WanCrossAttention( + block_index, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + WanFFN( + block_index, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ] + ) + + self.add_module("compute_phases", self.compute_phases) + + +class WanSelfAttention(WeightModule): + def __init__( + self, + block_index, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + ): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + + if self.config.get("sf_config", False): + self.attn_rms_type = "self_forcing" + else: + self.attn_rms_type = "sgl-kernel" + + self.add_module( + "modulation", + TENSOR_REGISTER["Default"]( + f"{block_prefix}.{self.block_index}.modulation", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + self.add_module( + "norm1", + LN_WEIGHT_REGISTER["Default"](), + ) + + self.add_module( + "self_attn_q", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.self_attn.q.weight", + f"{block_prefix}.{self.block_index}.self_attn.q.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + self.add_module( + "self_attn_k", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.self_attn.k.weight", + f"{block_prefix}.{self.block_index}.self_attn.k.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "self_attn_v", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.self_attn.v.weight", + f"{block_prefix}.{self.block_index}.self_attn.v.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "self_attn_o", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.self_attn.o.weight", + f"{block_prefix}.{self.block_index}.self_attn.o.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "self_attn_norm_q", + RMS_WEIGHT_REGISTER[self.attn_rms_type]( + f"{block_prefix}.{self.block_index}.self_attn.norm_q.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "self_attn_norm_k", + RMS_WEIGHT_REGISTER[self.attn_rms_type]( + f"{block_prefix}.{self.block_index}.self_attn.norm_k.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + attention_weights_cls = ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]] + if self.config["self_attn_1_type"] == "svg_attn": + attention_weights_cls.prepare( + head_num=self.config["num_heads"], + head_dim=self.config["dim"] // self.config["num_heads"], + sample_mse_max_row=self.config.get("svg_sample_mse_max_row", 10000), + num_sampled_rows=self.config.get("svg_num_sampled_rows", 64), + context_length=self.config.get("svg_context_length", 0), + sparsity=self.config.get("svg_sparsity", 0.25), + ) + if self.config["self_attn_1_type"] in [ + "svg_attn", + "radial_attn", + "nbhd_attn", + "nbhd_attn_flashinfer", + ]: + attention_weights_cls.attnmap_frame_num = self.config["attnmap_frame_num"] + # nbhd_attn setting + if self.config["self_attn_1_type"] in ["nbhd_attn", "nbhd_attn_flashinfer"]: + if "nbhd_attn_setting" in self.config: + if "coefficient" in self.config["nbhd_attn_setting"]: + attention_weights_cls.coefficient = self.config["nbhd_attn_setting"]["coefficient"] + if "min_width" in self.config["nbhd_attn_setting"]: + attention_weights_cls.min_width = self.config["nbhd_attn_setting"]["min_width"] + self.add_module("self_attn_1", attention_weights_cls()) + + if self.config["seq_parallel"]: + self.add_module( + "self_attn_1_parallel", + ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")](), + ) + + if self.quant_method in ["advanced_ptq"]: + self.add_module( + "smooth_norm1_weight", + TENSOR_REGISTER["Default"]( + f"{block_prefix}.{self.block_index}.affine_norm1.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "smooth_norm1_bias", + TENSOR_REGISTER["Default"]( + f"{block_prefix}.{self.block_index}.affine_norm1.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + +class WanCrossAttention(WeightModule): + def __init__( + self, + block_index, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + ): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + + if self.config.get("sf_config", False): + self.attn_rms_type = "self_forcing" + else: + self.attn_rms_type = "sgl-kernel" + + self.add_module( + "norm3", + LN_WEIGHT_REGISTER["Default"]( + f"{block_prefix}.{self.block_index}.norm3.weight", + f"{block_prefix}.{self.block_index}.norm3.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "cross_attn_q", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.q.weight", + f"{block_prefix}.{self.block_index}.cross_attn.q.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "cross_attn_k", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.k.weight", + f"{block_prefix}.{self.block_index}.cross_attn.k.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "cross_attn_v", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.v.weight", + f"{block_prefix}.{self.block_index}.cross_attn.v.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "cross_attn_o", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.o.weight", + f"{block_prefix}.{self.block_index}.cross_attn.o.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "cross_attn_norm_q", + RMS_WEIGHT_REGISTER[self.attn_rms_type]( + f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "cross_attn_norm_k", + RMS_WEIGHT_REGISTER[self.attn_rms_type]( + f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]()) + + if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True) and self.config["model_cls"] != "wan2.1_sf_mtxg2": + self.add_module( + "cross_attn_k_img", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.k_img.weight", + f"{block_prefix}.{self.block_index}.cross_attn.k_img.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "cross_attn_v_img", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.cross_attn.v_img.weight", + f"{block_prefix}.{self.block_index}.cross_attn.v_img.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "cross_attn_norm_k_img", + RMS_WEIGHT_REGISTER[self.attn_rms_type]( + f"{block_prefix}.{self.block_index}.cross_attn.norm_k_img.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["cross_attn_2_type"]]()) + + +class WanFFN(WeightModule): + def __init__( + self, + block_index, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + ): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + + self.add_module( + "norm2", + LN_WEIGHT_REGISTER["Default"](), + ) + + self.add_module( + "ffn_0", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.ffn.0.weight", + f"{block_prefix}.{self.block_index}.ffn.0.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "ffn_2", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.ffn.2.weight", + f"{block_prefix}.{self.block_index}.ffn.2.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + if self.quant_method in ["advanced_ptq"]: + self.add_module( + "smooth_norm2_weight", + TENSOR_REGISTER["Default"]( + f"{block_prefix}.{self.block_index}.affine_norm3.weight", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + self.add_module( + "smooth_norm2_bias", + TENSOR_REGISTER["Default"]( + f"{block_prefix}.{self.block_index}.affine_norm3.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) diff --git a/lightx2v/models/networks/wan/weights/vace/transformer_weights.py b/lightx2v/models/networks/wan/weights/vace/transformer_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..dd8127e7e82391868fd061b4917439fa860dde93 --- /dev/null +++ b/lightx2v/models/networks/wan/weights/vace/transformer_weights.py @@ -0,0 +1,76 @@ +from lightx2v.common.modules.weight_module import WeightModuleList +from lightx2v.models.networks.wan.weights.transformer_weights import ( + WanTransformerAttentionBlock, + WanTransformerWeights, +) +from lightx2v.utils.registry_factory import ( + CONV3D_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, +) + + +class WanVaceTransformerWeights(WanTransformerWeights): + def __init__(self, config): + super().__init__(config) + self.patch_size = (1, 2, 2) + self.register_offload_buffers(config) + self.vace_blocks = WeightModuleList( + [WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, False, False, "vace_blocks") for i in range(len(self.config["vace_layers"]))] + ) + self.add_module("vace_blocks", self.vace_blocks) + self.add_module( + "vace_patch_embedding", + CONV3D_WEIGHT_REGISTER["Default"]("vace_patch_embedding.weight", "vace_patch_embedding.bias", stride=self.patch_size), + ) + + def register_offload_buffers(self, config): + super().register_offload_buffers(config) + if config["cpu_offload"]: + if config["offload_granularity"] == "block": + self.vace_offload_block_cuda_buffers = WeightModuleList( + [ + WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, False, "vace_blocks"), + WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, False, "vace_blocks"), + ] + ) + self.add_module("vace_offload_block_cuda_buffers", self.vace_offload_block_cuda_buffers) + self.vace_offload_phase_cuda_buffers = None + elif config["offload_granularity"] == "phase": + raise NotImplementedError + + def non_block_weights_to_cuda(self): + super().non_block_weights_to_cuda() + self.vace_patch_embedding.to_cuda() + + def non_block_weights_to_cpu(self): + super().non_block_weights_to_cpu() + self.vace_patch_embedding.to_cpu() + + +class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock): + def __init__(self, base_block_idx, block_index, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, block_prefix): + super().__init__(block_index, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, block_prefix) + if base_block_idx == 0: + self.compute_phases[0].add_module( + "before_proj", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.before_proj.weight", + f"{block_prefix}.{self.block_index}.before_proj.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) + + self.compute_phases[-1].add_module( + "after_proj", + MM_WEIGHT_REGISTER[self.mm_type]( + f"{block_prefix}.{self.block_index}.after_proj.weight", + f"{block_prefix}.{self.block_index}.after_proj.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, + ), + ) diff --git a/lightx2v/models/runners/__init__.py b/lightx2v/models/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/runners/base_runner.py b/lightx2v/models/runners/base_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..97acb7ec278f9230a06dc8347a364b279707f801 --- /dev/null +++ b/lightx2v/models/runners/base_runner.py @@ -0,0 +1,169 @@ +import os +from abc import ABC + +import torch +import torch.distributed as dist + +from lightx2v_platform.base.global_var import AI_DEVICE + + +class BaseRunner(ABC): + """Abstract base class for all Runners + + Defines interface methods that all subclasses must implement + """ + + def __init__(self, config): + self.config = config + self.vae_encoder_need_img_original = False + self.input_info = None + + def load_transformer(self): + """Load transformer model + + Returns: + Loaded transformer model instance + """ + pass + + def load_text_encoder(self): + """Load text encoder + + Returns: + Text encoder instance or list of text encoder instances + """ + pass + + def load_image_encoder(self): + """Load image encoder + + Returns: + Image encoder instance or None if not needed + """ + pass + + def load_vae(self): + """Load VAE encoder and decoder + + Returns: + Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances + """ + pass + + def run_image_encoder(self, img): + """Run image encoder + + Args: + img: Input image + + Returns: + Image encoding result + """ + pass + + def run_vae_encoder(self, img): + """Run VAE encoder + + Args: + img: Input image + + Returns: + Tuple of VAE encoding result and additional parameters + """ + pass + + def run_text_encoder(self, prompt, img): + """Run text encoder + + Args: + prompt: Input text prompt + img: Optional input image (for some models) + + Returns: + Text encoding result + """ + pass + + def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img): + """Combine encoder outputs for i2v task + + Args: + clip_encoder_out: CLIP encoder output + vae_encoder_out: VAE encoder output + text_encoder_output: Text encoder output + img: Original image + + Returns: + Combined encoder output dictionary + """ + pass + + def init_scheduler(self): + """Initialize scheduler""" + pass + + def load_vae_decoder(self): + """Load VAE decoder + + Default implementation: get decoder from load_vae method + Subclasses can override this method to provide different loading logic + + Returns: + VAE decoder instance + """ + if not hasattr(self, "vae_decoder") or self.vae_decoder is None: + _, self.vae_decoder = self.load_vae() + return self.vae_decoder + + def get_video_segment_num(self): + self.video_segment_num = 1 + + def init_run(self): + pass + + def init_run_segment(self, segment_idx): + self.segment_idx = segment_idx + + def run_segment(self, segment_idx=0): + pass + + def end_run_segment(self, segment_idx=None): + self.gen_video_final = self.gen_video + + def end_run(self): + pass + + def check_stop(self): + """Check if the stop signal is received""" + + rank, world_size = 0, 1 + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + stop_rank = int(os.getenv("WORKER_RANK", "0")) % world_size # same as worker hub target_rank + pause_rank = int(os.getenv("READER_RANK", "0")) % world_size # same as va_reader target_rank + + stopped, paused = 0, 0 + if rank == stop_rank and hasattr(self, "stop_signal") and self.stop_signal: + stopped = 1 + if rank == pause_rank and hasattr(self, "pause_signal") and self.pause_signal: + paused = 1 + + if world_size > 1: + if rank == stop_rank: + t1 = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE) + else: + t1 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE) + if rank == pause_rank: + t2 = torch.tensor([paused], dtype=torch.int32).to(device=AI_DEVICE) + else: + t2 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE) + dist.broadcast(t1, src=stop_rank) + dist.broadcast(t2, src=pause_rank) + stopped = t1.item() + paused = t2.item() + + if stopped == 1: + raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior") + if paused == 1: + raise Exception(f"find rank: {rank} pause_signal, pause running, it's an expected behavior") diff --git a/lightx2v/models/runners/default_runner.py b/lightx2v/models/runners/default_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..1850edf727cfd25574495ee6fb1a9cb69b95be92 --- /dev/null +++ b/lightx2v/models/runners/default_runner.py @@ -0,0 +1,419 @@ +import gc + +import requests +import torch +import torch.distributed as dist +import torchvision.transforms.functional as TF +from PIL import Image +from loguru import logger +from requests.exceptions import RequestException + +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import * +from lightx2v.utils.generate_task_id import generate_task_id +from lightx2v.utils.global_paras import CALIB +from lightx2v.utils.memory_profiler import peak_memory_decorator +from lightx2v.utils.profiler import * +from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image +from lightx2v_platform.base.global_var import AI_DEVICE + +from .base_runner import BaseRunner + + +class DefaultRunner(BaseRunner): + def __init__(self, config): + super().__init__(config) + self.has_prompt_enhancer = False + self.progress_callback = None + if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None: + self.has_prompt_enhancer = True + if not self.check_sub_servers("prompt_enhancer"): + self.has_prompt_enhancer = False + logger.warning("No prompt enhancer server available, disable prompt enhancer.") + if not self.has_prompt_enhancer: + self.config["use_prompt_enhancer"] = False + self.set_init_device() + self.init_scheduler() + + def init_modules(self): + logger.info("Initializing runner modules...") + if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): + self.load_model() + elif self.config.get("lazy_load", False): + assert self.config.get("cpu_offload", False) + if hasattr(self, "model"): + self.model.set_scheduler(self.scheduler) # set scheduler to model + if self.config["task"] == "i2v": + self.run_input_encoder = self._run_input_encoder_local_i2v + elif self.config["task"] == "flf2v": + self.run_input_encoder = self._run_input_encoder_local_flf2v + elif self.config["task"] == "t2v": + self.run_input_encoder = self._run_input_encoder_local_t2v + elif self.config["task"] == "vace": + self.run_input_encoder = self._run_input_encoder_local_vace + elif self.config["task"] == "animate": + self.run_input_encoder = self._run_input_encoder_local_animate + elif self.config["task"] == "s2v": + self.run_input_encoder = self._run_input_encoder_local_s2v + self.config.lock() # lock config to avoid modification + if self.config.get("compile", False) and hasattr(self.model, "comple"): + logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}") + self.model.compile(self.config.get("compile_shapes", [])) + + def set_init_device(self): + if self.config["cpu_offload"]: + self.init_device = torch.device("cpu") + else: + self.init_device = torch.device(AI_DEVICE) + + def load_vfi_model(self): + if self.config["video_frame_interpolation"].get("algo", None) == "rife": + from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper + + logger.info("Loading RIFE model...") + return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"]) + else: + raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}") + + def load_vsr_model(self): + if "video_super_resolution" in self.config: + from lightx2v.models.runners.vsr.vsr_wrapper import VSRWrapper + + logger.info("Loading VSR model...") + return VSRWrapper(self.config["video_super_resolution"]["model_path"]) + else: + return None + + @ProfilingContext4DebugL2("Load models") + def load_model(self): + self.model = self.load_transformer() + self.text_encoders = self.load_text_encoder() + self.image_encoder = self.load_image_encoder() + self.vae_encoder, self.vae_decoder = self.load_vae() + self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None + self.vsr_model = self.load_vsr_model() if "video_super_resolution" in self.config else None + + def check_sub_servers(self, task_type): + urls = self.config.get("sub_servers", {}).get(task_type, []) + available_servers = [] + for url in urls: + try: + status_url = f"{url}/v1/local/{task_type}/generate/service_status" + response = requests.get(status_url, timeout=2) + if response.status_code == 200: + available_servers.append(url) + else: + logger.warning(f"Service {url} returned status code {response.status_code}") + + except RequestException as e: + logger.warning(f"Failed to connect to {url}: {str(e)}") + continue + logger.info(f"{task_type} available servers: {available_servers}") + self.config["sub_servers"][task_type] = available_servers + return len(available_servers) > 0 + + def set_inputs(self, inputs): + self.input_info.seed = inputs.get("seed", 42) + self.input_info.prompt = inputs.get("prompt", "") + if self.config["use_prompt_enhancer"]: + self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "") + self.input_info.negative_prompt = inputs.get("negative_prompt", "") + if "image_path" in self.input_info.__dataclass_fields__: + self.input_info.image_path = inputs.get("image_path", "") + if "audio_path" in self.input_info.__dataclass_fields__: + self.input_info.audio_path = inputs.get("audio_path", "") + if "video_path" in self.input_info.__dataclass_fields__: + self.input_info.video_path = inputs.get("video_path", "") + self.input_info.save_result_path = inputs.get("save_result_path", "") + + def set_config(self, config_modify): + logger.info(f"modify config: {config_modify}") + with self.config.temporarily_unlocked(): + self.config.update(config_modify) + + def set_progress_callback(self, callback): + self.progress_callback = callback + + @peak_memory_decorator + def run_segment(self, segment_idx=0): + infer_steps = self.model.scheduler.infer_steps + + for step_index in range(infer_steps): + # only for single segment, check stop signal every step + with ProfilingContext4DebugL1( + f"Run Dit every step", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration, + metrics_labels=[step_index + 1, infer_steps], + ): + if self.video_segment_num == 1: + self.check_stop() + logger.info(f"==> step_index: {step_index + 1} / {infer_steps}") + + with ProfilingContext4DebugL1("step_pre"): + self.model.scheduler.step_pre(step_index=step_index) + + with ProfilingContext4DebugL1("🚀 infer_main"): + self.model.infer(self.inputs) + + with ProfilingContext4DebugL1("step_post"): + self.model.scheduler.step_post() + + if self.progress_callback: + current_step = segment_idx * infer_steps + step_index + 1 + total_all_steps = self.video_segment_num * infer_steps + self.progress_callback((current_step / total_all_steps) * 100, 100) + + if segment_idx is not None and segment_idx == self.video_segment_num - 1: + del self.inputs + torch.cuda.empty_cache() + + return self.model.scheduler.latents + + def run_step(self): + self.inputs = self.run_input_encoder() + if hasattr(self, "sr_version") and self.sr_version is not None is not None: + self.config_sr["is_sr_running"] = True + self.inputs_sr = self.run_input_encoder() + self.config_sr["is_sr_running"] = False + + self.run_main(total_steps=1) + + def end_run(self): + self.model.scheduler.clear() + if hasattr(self, "inputs"): + del self.inputs + self.input_info = None + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + if hasattr(self.model, "model") and len(self.model.model) == 2: # MultiModelStruct + for model in self.model.model: + if hasattr(model.transformer_infer, "offload_manager"): + del model.transformer_infer.offload_manager + torch.cuda.empty_cache() + gc.collect() + del model + else: + if hasattr(self.model.transformer_infer, "offload_manager"): + del self.model.transformer_infer.offload_manager + torch.cuda.empty_cache() + gc.collect() + del self.model + if self.config.get("do_mm_calib", False): + calib_path = os.path.join(os.getcwd(), "calib.pt") + torch.save(CALIB, calib_path) + logger.info(f"[CALIB] Saved calibration data successfully to: {calib_path}") + torch.cuda.empty_cache() + gc.collect() + + def read_image_input(self, img_path): + if isinstance(img_path, Image.Image): + img_ori = img_path + else: + img_ori = Image.open(img_path).convert("RGB") + if GET_RECORDER_MODE(): + width, height = img_ori.size + monitor_cli.lightx2v_input_image_len.observe(width * height) + img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device) + self.input_info.original_size = img_ori.size + return img, img_ori + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_i2v(self): + img, img_ori = self.read_image_input(self.input_info.image_path) + clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None + vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img) + self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info + text_encoder_output = self.run_text_encoder(self.input_info) + torch.cuda.empty_cache() + gc.collect() + return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_t2v(self): + self.input_info.latent_shape = self.get_latent_shape_with_target_hw() # Important: set latent_shape in input_info + text_encoder_output = self.run_text_encoder(self.input_info) + torch.cuda.empty_cache() + gc.collect() + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": None, + } + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_flf2v(self): + first_frame, _ = self.read_image_input(self.input_info.image_path) + last_frame, _ = self.read_image_input(self.input_info.last_frame_path) + clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None + vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame) + self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info + text_encoder_output = self.run_text_encoder(self.input_info) + torch.cuda.empty_cache() + gc.collect() + return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output) + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_vace(self): + src_video = self.input_info.src_video + src_mask = self.input_info.src_mask + src_ref_images = self.input_info.src_ref_images + src_video, src_mask, src_ref_images = self.prepare_source( + [src_video], + [src_mask], + [None if src_ref_images is None else src_ref_images.split(",")], + (self.config["target_width"], self.config["target_height"]), + ) + self.src_ref_images = src_ref_images + + vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask) + self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info + text_encoder_output = self.run_text_encoder(self.input_info) + torch.cuda.empty_cache() + gc.collect() + return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output) + + @ProfilingContext4DebugL2("Run Text Encoder") + def _run_input_encoder_local_animate(self): + text_encoder_output = self.run_text_encoder(self.input_info) + torch.cuda.empty_cache() + gc.collect() + return self.get_encoder_output_i2v(None, None, text_encoder_output, None) + + def _run_input_encoder_local_s2v(self): + pass + + def init_run(self): + self.gen_video_final = None + self.get_video_segment_num() + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.model = self.load_transformer() + self.model.set_scheduler(self.scheduler) + + self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"]) + if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: + self.inputs["image_encoder_output"]["vae_encoder_out"] = None + + if hasattr(self, "sr_version") and self.sr_version is not None: + self.lq_latents_shape = self.model.scheduler.latents.shape + self.model_sr.set_scheduler(self.scheduler_sr) + self.config_sr["is_sr_running"] = True + self.inputs_sr = self.run_input_encoder() + self.config_sr["is_sr_running"] = False + + @ProfilingContext4DebugL2("Run DiT") + def run_main(self): + self.init_run() + if self.config.get("compile", False) and hasattr(self.model, "comple"): + self.model.select_graph_for_compile(self.input_info) + for segment_idx in range(self.video_segment_num): + logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}") + with ProfilingContext4DebugL1( + f"segment end2end {segment_idx + 1}/{self.video_segment_num}", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, + metrics_labels=["DefaultRunner"], + ): + self.check_stop() + # 1. default do nothing + self.init_run_segment(segment_idx) + # 2. main inference loop + latents = self.run_segment(segment_idx) + # 3. vae decoder + if self.config.get("use_stream_vae", False): + frames = [] + for frame_segment in self.run_vae_decoder_stream(latents): + frames.append(frame_segment) + logger.info(f"frame sagment: {len(frames)} done") + self.gen_video = torch.cat(frames, dim=2) + else: + self.gen_video = self.run_vae_decoder(latents) + # 4. default do nothing + self.end_run_segment(segment_idx) + gen_video_final = self.process_images_after_vae_decoder() + self.end_run() + return gen_video_final + + @ProfilingContext4DebugL1("Run VAE Decoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"]) + def run_vae_decoder(self, latents): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.vae_decoder = self.load_vae_decoder() + images = self.vae_decoder.decode(latents.to(GET_DTYPE())) + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.vae_decoder + torch.cuda.empty_cache() + gc.collect() + return images + + @ProfilingContext4DebugL1("Run VAE Decoder Stream", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"]) + def run_vae_decoder_stream(self, latents): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.vae_decoder = self.load_vae_decoder() + + for frame_segment in self.vae_decoder.decode_stream(latents.to(GET_DTYPE())): + yield frame_segment + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.vae_decoder + torch.cuda.empty_cache() + gc.collect() + + def post_prompt_enhancer(self): + while True: + for url in self.config["sub_servers"]["prompt_enhancer"]: + response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json() + if response["service_status"] == "idle": + response = requests.post( + f"{url}/v1/local/prompt_enhancer/generate", + json={ + "task_id": generate_task_id(), + "prompt": self.config["prompt"], + }, + ) + enhanced_prompt = response.json()["output"] + logger.info(f"Enhanced prompt: {enhanced_prompt}") + return enhanced_prompt + + def process_images_after_vae_decoder(self): + self.gen_video_final = vae_to_comfyui_image(self.gen_video_final) + + if "video_frame_interpolation" in self.config: + assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None + target_fps = self.config["video_frame_interpolation"]["target_fps"] + logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}") + self.gen_video_final = self.vfi_model.interpolate_frames( + self.gen_video_final, + source_fps=self.config.get("fps", 16), + target_fps=target_fps, + ) + + if self.input_info.return_result_tensor: + return {"video": self.gen_video_final} + elif self.input_info.save_result_path is not None: + if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"): + fps = self.config["video_frame_interpolation"]["target_fps"] + else: + fps = self.config.get("fps", 16) + + if not dist.is_initialized() or dist.get_rank() == 0: + logger.info(f"🎬 Start to save video 🎬") + + save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg") + logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅") + return {"video": None} + + @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"]) + def run_pipeline(self, input_info): + if GET_RECORDER_MODE(): + monitor_cli.lightx2v_worker_request_count.inc() + self.input_info = input_info + + if self.config["use_prompt_enhancer"]: + self.input_info.prompt_enhanced = self.post_prompt_enhancer() + + self.inputs = self.run_input_encoder() + + gen_video_final = self.run_main() + + if GET_RECORDER_MODE(): + monitor_cli.lightx2v_worker_request_success.inc() + return gen_video_final diff --git a/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_distill_runner.py b/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_distill_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..95787ed729063984c37eff9bb29aee607c4704ab --- /dev/null +++ b/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_distill_runner.py @@ -0,0 +1,18 @@ +from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner +from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15SRScheduler +from lightx2v.models.schedulers.hunyuan_video.step_distill.scheduler import HunyuanVideo15StepDistillScheduler +from lightx2v.utils.registry_factory import RUNNER_REGISTER + + +@RUNNER_REGISTER("hunyuan_video_1.5_distill") +class HunyuanVideo15DistillRunner(HunyuanVideo15Runner): + def __init__(self, config): + super().__init__(config) + + def init_scheduler(self): + self.scheduler = HunyuanVideo15StepDistillScheduler(self.config) + + if self.sr_version is not None: + self.scheduler_sr = HunyuanVideo15SRScheduler(self.config_sr) + else: + self.scheduler_sr = None diff --git a/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py b/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..1c4761556f4471e68c7ddbb3276ad62cbff664a5 --- /dev/null +++ b/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py @@ -0,0 +1,550 @@ +import copy +import gc +import os + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from loguru import logger + +from lightx2v.models.input_encoders.hf.hunyuan15.byt5.model import ByT5TextEncoder +from lightx2v.models.input_encoders.hf.hunyuan15.qwen25.model import Qwen25VL_TextEncoder +from lightx2v.models.input_encoders.hf.hunyuan15.siglip.model import SiglipVisionEncoder +from lightx2v.models.networks.hunyuan_video.model import HunyuanVideo15Model +from lightx2v.models.runners.default_runner import DefaultRunner +from lightx2v.models.schedulers.hunyuan_video.feature_caching.scheduler import HunyuanVideo15SchedulerCaching +from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15SRScheduler, HunyuanVideo15Scheduler +from lightx2v.models.video_encoders.hf.hunyuanvideo15.hunyuanvideo_15_vae import HunyuanVideo15VAE +from lightx2v.models.video_encoders.hf.hunyuanvideo15.lighttae_hy15 import LightTaeHy15 +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.utils import * +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +@RUNNER_REGISTER("hunyuan_video_1.5") +class HunyuanVideo15Runner(DefaultRunner): + def __init__(self, config): + config["is_sr_running"] = False + + if "video_super_resolution" in config and "sr_version" in config["video_super_resolution"]: + self.sr_version = config["video_super_resolution"]["sr_version"] + else: + self.sr_version = None + + if self.sr_version is not None: + self.config_sr = copy.deepcopy(config) + self.config_sr["is_sr_running"] = False + self.config_sr["sample_shift"] = config["video_super_resolution"]["flow_shift"] # for SR model + self.config_sr["sample_guide_scale"] = config["video_super_resolution"]["guidance_scale"] # for SR model + self.config_sr["infer_steps"] = config["video_super_resolution"]["num_inference_steps"] + + super().__init__(config) + self.target_size_config = { + "360p": {"bucket_hw_base_size": 480, "bucket_hw_bucket_stride": 16}, + "480p": {"bucket_hw_base_size": 640, "bucket_hw_bucket_stride": 16}, + "720p": {"bucket_hw_base_size": 960, "bucket_hw_bucket_stride": 16}, + "1080p": {"bucket_hw_base_size": 1440, "bucket_hw_bucket_stride": 16}, + } + self.vision_num_semantic_tokens = 729 + self.vision_states_dim = 1152 + self.vae_cls = HunyuanVideo15VAE + self.tae_cls = LightTaeHy15 + + def init_scheduler(self): + if self.config["feature_caching"] == "NoCaching": + scheduler_class = HunyuanVideo15Scheduler + elif self.config.feature_caching in ["Mag", "Tea"]: + scheduler_class = HunyuanVideo15SchedulerCaching + else: + raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") + self.scheduler = scheduler_class(self.config) + + if self.sr_version is not None: + self.scheduler_sr = HunyuanVideo15SRScheduler(self.config_sr) + else: + self.scheduler_sr = None + + def load_text_encoder(self): + qwen25vl_offload = self.config.get("qwen25vl_cpu_offload", self.config.get("cpu_offload")) + if qwen25vl_offload: + qwen25vl_device = torch.device("cpu") + else: + qwen25vl_device = torch.device(AI_DEVICE) + + qwen25vl_quantized = self.config.get("qwen25vl_quantized", False) + qwen25vl_quant_scheme = self.config.get("qwen25vl_quant_scheme", None) + qwen25vl_quantized_ckpt = self.config.get("qwen25vl_quantized_ckpt", None) + + text_encoder_path = os.path.join(self.config["model_path"], "text_encoder/llm") + logger.info(f"Loading text encoder from {text_encoder_path}") + text_encoder = Qwen25VL_TextEncoder( + dtype=torch.float16, + device=qwen25vl_device, + checkpoint_path=text_encoder_path, + cpu_offload=qwen25vl_offload, + qwen25vl_quantized=qwen25vl_quantized, + qwen25vl_quant_scheme=qwen25vl_quant_scheme, + qwen25vl_quant_ckpt=qwen25vl_quantized_ckpt, + ) + + byt5_offload = self.config.get("byt5_cpu_offload", self.config.get("cpu_offload")) + if byt5_offload: + byt5_device = torch.device("cpu") + else: + byt5_device = torch.device(AI_DEVICE) + + byt5 = ByT5TextEncoder(config=self.config, device=byt5_device, checkpoint_path=self.config["model_path"], cpu_offload=byt5_offload) + text_encoders = [text_encoder, byt5] + return text_encoders + + def load_transformer(self): + model = HunyuanVideo15Model(self.config["model_path"], self.config, self.init_device) + if self.sr_version is not None: + self.config_sr["transformer_model_path"] = os.path.join(os.path.dirname(self.config.transformer_model_path), self.sr_version) + self.config_sr["is_sr_running"] = True + model_sr = HunyuanVideo15Model(self.config_sr["model_path"], self.config_sr, self.init_device) + self.config_sr["is_sr_running"] = False + else: + model_sr = None + + self.model_sr = model_sr + return model + + def get_latent_shape_with_target_hw(self, origin_size=None): + if origin_size is None: + width, height = self.config["aspect_ratio"].split(":") + else: + width, height = origin_size + target_size = self.config["transformer_model_name"].split("_")[0] + target_height, target_width = self.get_closest_resolution_given_original_size((int(width), int(height)), target_size) + latent_shape = [ + self.config.get("in_channels", 32), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + target_height // self.config["vae_stride"][1], + target_width // self.config["vae_stride"][2], + ] + + ori_latent_h, ori_latent_w = latent_shape[2], latent_shape[3] + if dist.is_initialized() and dist.get_world_size() > 1: + latent_h, latent_w, world_size_h, world_size_w = self._adjust_latent_for_grid_splitting(ori_latent_h, ori_latent_w, dist.get_world_size()) + latent_shape[2], latent_shape[3] = latent_h, latent_w + logger.info(f"ori latent: {ori_latent_h}x{ori_latent_w}, adjust_latent: {latent_h}x{latent_w}, grid: {world_size_h}x{world_size_w}") + else: + latent_shape[2], latent_shape[3] = ori_latent_h, ori_latent_w + world_size_h, world_size_w = None, None + + self.vae_decoder.world_size_h = world_size_h + self.vae_decoder.world_size_w = world_size_w + + self.target_height = latent_shape[2] * self.config["vae_stride"][1] + self.target_width = latent_shape[3] * self.config["vae_stride"][2] + return latent_shape + + def _adjust_latent_for_grid_splitting(self, latent_h, latent_w, world_size): + """ + Adjust latent dimensions for optimal 2D grid splitting. + Prefers balanced grids like 2x4 or 4x2 over 1x8 or 8x1. + """ + world_size_h, world_size_w = 1, 1 + if world_size <= 1: + return latent_h, latent_w, world_size_h, world_size_w + + # Define priority grids for different world sizes + priority_grids = [] + if world_size == 8: + # For 8 cards, prefer 2x4 and 4x2 over 1x8 and 8x1 + priority_grids = [(2, 4), (4, 2), (1, 8), (8, 1)] + elif world_size == 4: + priority_grids = [(2, 2), (1, 4), (4, 1)] + elif world_size == 2: + priority_grids = [(1, 2), (2, 1)] + else: + # For other sizes, try factor pairs + for h in range(1, int(np.sqrt(world_size)) + 1): + if world_size % h == 0: + w = world_size // h + priority_grids.append((h, w)) + + # Try priority grids first + for world_size_h, world_size_w in priority_grids: + if latent_h % world_size_h == 0 and latent_w % world_size_w == 0: + return latent_h, latent_w, world_size_h, world_size_w + + # If no perfect fit, find minimal padding solution + best_grid = (1, world_size) # fallback + min_total_padding = float("inf") + + for world_size_h, world_size_w in priority_grids: + # Calculate required padding + pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h + pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w + total_padding = pad_h + pad_w + + # Prefer grids with minimal total padding + if total_padding < min_total_padding: + min_total_padding = total_padding + best_grid = (world_size_h, world_size_w) + + # Apply padding + world_size_h, world_size_w = best_grid + pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h + pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w + + return latent_h + pad_h, latent_w + pad_w, world_size_h, world_size_w + + def get_sr_latent_shape_with_target_hw(self): + SizeMap = { + "480p": 640, + "720p": 960, + "1080p": 1440, + } + + sr_stride = 16 + base_size = SizeMap[self.config_sr["video_super_resolution"]["base_resolution"]] + sr_size = SizeMap[self.sr_version.split("_")[0]] + lr_video_height, lr_video_width = [x * 16 for x in self.lq_latents_shape[-2:]] + hr_bucket_map = self.build_bucket_map(lr_base_size=base_size, hr_base_size=sr_size, lr_patch_size=16, hr_patch_size=sr_stride) + target_width, target_height = hr_bucket_map((lr_video_width, lr_video_height)) + latent_shape = [ + self.config_sr.get("in_channels", 32), + (self.config_sr["target_video_length"] - 1) // self.config_sr["vae_stride"][0] + 1, + target_height // self.config_sr["vae_stride"][1], + target_width // self.config_sr["vae_stride"][2], + ] + self.target_sr_height = target_height + self.target_sr_width = target_width + return latent_shape + + def get_closest_resolution_given_original_size(self, origin_size, target_size): + bucket_hw_base_size = self.target_size_config[target_size]["bucket_hw_base_size"] + bucket_hw_bucket_stride = self.target_size_config[target_size]["bucket_hw_bucket_stride"] + + assert bucket_hw_base_size in [128, 256, 480, 512, 640, 720, 960, 1440], f"bucket_hw_base_size must be in [128, 256, 480, 512, 640, 720, 960], but got {bucket_hw_base_size}" + + crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, bucket_hw_bucket_stride) + aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) + closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) + + height = closest_size[0] + width = closest_size[1] + + return height, width + + def generate_crop_size_list(self, base_size=256, patch_size=16, max_ratio=4.0): + num_patches = round((base_size / patch_size) ** 2) + assert max_ratio >= 1.0 + crop_size_list = [] + wp, hp = num_patches, 1 + while wp > 0: + if max(wp, hp) / min(wp, hp) <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + if (hp + 1) * wp <= num_patches: + hp += 1 + else: + wp -= 1 + return crop_size_list + + def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list): + aspect_ratio = float(height) / float(width) + diff_ratios = ratios - aspect_ratio + + if aspect_ratio >= 1: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0] + else: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0] + + closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0] + closest_size = buckets[closest_ratio_id] + closest_ratio = ratios[closest_ratio_id] + + return closest_size, closest_ratio + + def run_text_encoder(self, input_info): + prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt + neg_prompt = input_info.negative_prompt + + # run qwen25vl + if self.config.get("enable_cfg", False) and self.config["cfg_parallel"]: + cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p") + cfg_p_rank = dist.get_rank(cfg_p_group) + if cfg_p_rank == 0: + context = self.text_encoders[0].infer([prompt]) + text_encoder_output = {"context": context} + else: + context_null = self.text_encoders[0].infer([neg_prompt]) + text_encoder_output = {"context_null": context_null} + else: + context = self.text_encoders[0].infer([prompt]) + context_null = self.text_encoders[0].infer([neg_prompt]) if self.config.get("enable_cfg", False) else None + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + # run byt5 + byt5_features, byt5_masks = self.text_encoders[1].infer([prompt]) + text_encoder_output.update({"byt5_features": byt5_features, "byt5_masks": byt5_masks}) + + return text_encoder_output + + def load_image_encoder(self): + siglip_offload = self.config.get("siglip_cpu_offload", self.config.get("cpu_offload")) + if siglip_offload: + siglip_device = torch.device("cpu") + else: + siglip_device = torch.device(AI_DEVICE) + image_encoder = SiglipVisionEncoder( + config=self.config, + device=siglip_device, + checkpoint_path=self.config["model_path"], + cpu_offload=siglip_offload, + ) + return image_encoder + + def load_vae_encoder(self): + # offload config + vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + + vae_config = { + "checkpoint_path": self.config["model_path"], + "device": vae_device, + "cpu_offload": vae_offload, + "dtype": GET_DTYPE(), + "parallel": self.config["parallel"], + } + if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]: + return None + else: + return self.vae_cls(**vae_config) + + def load_vae_decoder(self): + # offload config + vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + + vae_config = { + "checkpoint_path": self.config["model_path"], + "device": vae_device, + "cpu_offload": vae_offload, + "dtype": GET_DTYPE(), + "parallel": self.config["parallel"], + } + if self.config.get("use_tae", False): + tae_path = self.config["tae_path"] + vae_decoder = self.tae_cls(vae_path=tae_path, dtype=GET_DTYPE()).to(AI_DEVICE) + else: + vae_decoder = self.vae_cls(**vae_config) + return vae_decoder + + def load_vae(self): + vae_encoder = self.load_vae_encoder() + if vae_encoder is None or self.config.get("use_tae", False): + vae_decoder = self.load_vae_decoder() + else: + vae_decoder = vae_encoder + return vae_encoder, vae_decoder + + def load_vsr_model(self): + if self.sr_version: + from lightx2v.models.runners.vsr.vsr_wrapper_hy15 import SRModel3DV2, Upsampler + + upsampler_cls = SRModel3DV2 if "720p" in self.sr_version else Upsampler + upsampler_path = os.path.join(self.config["model_path"], "upsampler", self.sr_version) + logger.info("Loading VSR model from {}".format(upsampler_path)) + upsampler = upsampler_cls.from_pretrained(upsampler_path).to(self.init_device) + + return upsampler + else: + return None + + def build_bucket_map(self, lr_base_size, hr_base_size, lr_patch_size, hr_patch_size): + lr_buckets = self.generate_crop_size_list(base_size=lr_base_size, patch_size=lr_patch_size) + hr_buckets = self.generate_crop_size_list(base_size=hr_base_size, patch_size=hr_patch_size) + + lr_aspect_ratios = np.array([w / h for w, h in lr_buckets]) + hr_aspect_ratios = np.array([w / h for w, h in hr_buckets]) + + hr_bucket_map = {} + for i, (lr_w, lr_h) in enumerate(lr_buckets): + lr_ratio = lr_aspect_ratios[i] + closest_hr_ratio_id = np.abs(hr_aspect_ratios - lr_ratio).argmin() + hr_bucket_map[(lr_w, lr_h)] = hr_buckets[closest_hr_ratio_id] + + def hr_bucket_fn(lr_bucket): + if lr_bucket not in hr_bucket_map: + raise ValueError(f"LR bucket {lr_bucket} not found in bucket map") + return hr_bucket_map[lr_bucket] + + hr_bucket_fn.map = hr_bucket_map + + return hr_bucket_fn + + @ProfilingContext4DebugL1("Run SR") + def run_sr(self, lq_latents): + self.config_sr["is_sr_running"] = True + + self.model_sr.scheduler.prepare( + seed=self.input_info.seed, latent_shape=self.latent_sr_shape, lq_latents=lq_latents, upsampler=self.vsr_model, image_encoder_output=self.inputs_sr["image_encoder_output"] + ) + + total_steps = self.model_sr.scheduler.infer_steps + for step_index in range(total_steps): + with ProfilingContext4DebugL1( + f"Run SR Dit every step", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration, + metrics_labels=[step_index + 1, total_steps], + ): + logger.info(f"==> step_index: {step_index + 1} / {total_steps}") + with ProfilingContext4DebugL1("step_pre"): + self.model_sr.scheduler.step_pre(step_index=step_index) + + with ProfilingContext4DebugL1("🚀 infer_main"): + self.model_sr.infer(self.inputs_sr) + + with ProfilingContext4DebugL1("step_post"): + self.model_sr.scheduler.step_post() + + del self.inputs_sr + torch_device_module.empty_cache() + + self.config_sr["is_sr_running"] = False + return self.model_sr.scheduler.latents + + @ProfilingContext4DebugL1("Run VAE Decoder") + def run_vae_decoder(self, latents): + if self.sr_version: + latents = self.run_sr(latents) + images = super().run_vae_decoder(latents) + return images + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_t2v(self): + self.input_info.latent_shape = self.get_latent_shape_with_target_hw() # Important: set latent_shape in input_info + text_encoder_output = self.run_text_encoder(self.input_info) + + # vision_states is all zero, because we don't have any image input + siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).to(AI_DEVICE) + siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device(AI_DEVICE)) + + torch_device_module.empty_cache() + gc.collect() + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": { + "siglip_output": siglip_output, + "siglip_mask": siglip_mask, + "cond_latents": None, + }, + } + + def read_image_input(self, img_path): + if isinstance(img_path, Image.Image): + img_ori = img_path + else: + img_ori = Image.open(img_path).convert("RGB") + return img_ori + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_i2v(self): + img_ori = self.read_image_input(self.input_info.image_path) + if self.sr_version and self.config_sr["is_sr_running"]: + self.latent_sr_shape = self.get_sr_latent_shape_with_target_hw() + self.input_info.latent_shape = self.get_latent_shape_with_target_hw(origin_size=img_ori.size) # Important: set latent_shape in input_info + siglip_output, siglip_mask = self.run_image_encoder(img_ori) if self.config.get("use_image_encoder", True) else None + cond_latents = self.run_vae_encoder(img_ori) + text_encoder_output = self.run_text_encoder(self.input_info) + torch_device_module.empty_cache() + gc.collect() + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": { + "siglip_output": siglip_output, + "siglip_mask": siglip_mask, + "cond_latents": cond_latents, + }, + } + + @ProfilingContext4DebugL1( + "Run Image Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_img_encode_duration, + metrics_labels=["WanRunner"], + ) + def run_image_encoder(self, first_frame, last_frame=None): + if self.sr_version and self.config_sr["is_sr_running"]: + target_width = self.target_sr_width + target_height = self.target_sr_height + else: + target_width = self.target_width + target_height = self.target_height + + input_image_np = self.resize_and_center_crop(first_frame, target_width=target_width, target_height=target_height) + vision_states = self.image_encoder.encode_images(input_image_np).last_hidden_state.to(device=torch.device(AI_DEVICE), dtype=torch.bfloat16) + image_encoder_output = self.image_encoder.infer(vision_states) + image_encoder_mask = torch.ones((1, image_encoder_output.shape[1]), dtype=torch.bfloat16, device=torch.device(AI_DEVICE)) + return image_encoder_output, image_encoder_mask + + def resize_and_center_crop(self, image, target_width, target_height): + image = np.array(image) + if target_height == image.shape[0] and target_width == image.shape[1]: + return image + + pil_image = Image.fromarray(image) + original_width, original_height = pil_image.size + scale_factor = max(target_width / original_width, target_height / original_height) + resized_width = int(round(original_width * scale_factor)) + resized_height = int(round(original_height * scale_factor)) + resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) + left = (resized_width - target_width) / 2 + top = (resized_height - target_height) / 2 + right = (resized_width + target_width) / 2 + bottom = (resized_height + target_height) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + return np.array(cropped_image) + + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["WanRunner"], + ) + def run_vae_encoder(self, first_frame): + origin_size = first_frame.size + original_width, original_height = origin_size + + if self.sr_version and self.config_sr["is_sr_running"]: + target_width = self.target_sr_width + target_height = self.target_sr_height + else: + target_width = self.target_width + target_height = self.target_height + + scale_factor = max(target_width / original_width, self.target_height / original_height) + resize_width = int(round(original_width * scale_factor)) + resize_height = int(round(original_height * scale_factor)) + + ref_image_transform = transforms.Compose( + [ + transforms.Resize((resize_height, resize_width), interpolation=transforms.InterpolationMode.LANCZOS), + transforms.CenterCrop((target_height, target_width)), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + ref_images_pixel_values = ref_image_transform(first_frame).unsqueeze(0).unsqueeze(2).to(AI_DEVICE) + cond_latents = self.vae_encoder.encode(ref_images_pixel_values.to(GET_DTYPE())) + return cond_latents diff --git a/lightx2v/models/runners/qwen_image/qwen_image_runner.py b/lightx2v/models/runners/qwen_image/qwen_image_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc2b19b22a871fdb5e95d0d2ec5186f66339578 --- /dev/null +++ b/lightx2v/models/runners/qwen_image/qwen_image_runner.py @@ -0,0 +1,273 @@ +import gc +import math + +import torch +import torchvision.transforms.functional as TF +from PIL import Image +from loguru import logger + +from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder +from lightx2v.models.networks.qwen_image.lora_adapter import QwenImageLoraWrapper +from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel +from lightx2v.models.runners.default_runner import DefaultRunner +from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler +from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import * +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height, None + + +@RUNNER_REGISTER("qwen_image") +class QwenImageRunner(DefaultRunner): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__(self, config): + super().__init__(config) + + @ProfilingContext4DebugL2("Load models") + def load_model(self): + self.model = self.load_transformer() + self.text_encoders = self.load_text_encoder() + self.vae = self.load_vae() + + def load_transformer(self): + model = QwenImageTransformerModel(self.config) + if self.config.get("lora_configs") and self.config.lora_configs: + assert not self.config.get("dit_quantized", False) + lora_wrapper = QwenImageLoraWrapper(model) + for lora_config in self.config.lora_configs: + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + lora_name = lora_wrapper.load_lora(lora_path) + lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") + return model + + def load_text_encoder(self): + text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config) + text_encoders = [text_encoder] + return text_encoders + + def load_image_encoder(self): + pass + + def load_vae(self): + vae = AutoencoderKLQwenImageVAE(self.config) + return vae + + def init_modules(self): + logger.info("Initializing runner modules...") + if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): + self.load_model() + elif self.config.get("lazy_load", False): + assert self.config.get("cpu_offload", False) + self.run_dit = self._run_dit_local + if self.config["task"] == "t2i": + self.run_input_encoder = self._run_input_encoder_local_t2i + elif self.config["task"] == "i2i": + self.run_input_encoder = self._run_input_encoder_local_i2i + else: + assert NotImplementedError + + self.model.set_scheduler(self.scheduler) + + @ProfilingContext4DebugL2("Run DiT") + def _run_dit_local(self, total_steps=None): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.model = self.load_transformer() + self.model.scheduler.prepare(self.input_info) + latents, generator = self.run(total_steps) + return latents, generator + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_t2i(self): + prompt = self.input_info.prompt + text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt) + torch_device_module.empty_cache() + gc.collect() + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": None, + } + + def read_image_input(self, img_path): + if isinstance(img_path, Image.Image): + img_ori = img_path + else: + img_ori = Image.open(img_path).convert("RGB") + if GET_RECORDER_MODE(): + width, height = img_ori.size + monitor_cli.lightx2v_input_image_len.observe(width * height) + img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE) + self.input_info.original_size.append(img_ori.size) + return img, img_ori + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_i2i(self): + image_paths_list = self.input_info.image_path.split(",") + images_list = [] + for image_path in image_paths_list: + _, image = self.read_image_input(image_path) + images_list.append(image) + + prompt = self.input_info.prompt + text_encoder_output = self.run_text_encoder(prompt, images_list, neg_prompt=self.input_info.negative_prompt) + + image_encoder_output_list = [] + for vae_image in text_encoder_output["image_info"]["vae_image_list"]: + image_encoder_output = self.run_vae_encoder(image=vae_image) + image_encoder_output_list.append(image_encoder_output) + torch_device_module.empty_cache() + gc.collect() + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output_list, + } + + @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"]) + def run_text_encoder(self, text, image_list=None, neg_prompt=None): + if GET_RECORDER_MODE(): + monitor_cli.lightx2v_input_prompt_len.observe(len(text)) + text_encoder_output = {} + if self.config["task"] == "t2i": + prompt_embeds, prompt_embeds_mask, _ = self.text_encoders[0].infer([text]) + text_encoder_output["prompt_embeds"] = prompt_embeds + text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask + if self.config["do_true_cfg"] and neg_prompt is not None: + neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt]) + text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds + text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask + elif self.config["task"] == "i2i": + prompt_embeds, prompt_embeds_mask, image_info = self.text_encoders[0].infer([text], image_list) + text_encoder_output["prompt_embeds"] = prompt_embeds + text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask + text_encoder_output["image_info"] = image_info + if self.config["do_true_cfg"] and neg_prompt is not None: + neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt], image_list) + text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds + text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask + return text_encoder_output + + @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"]) + def run_vae_encoder(self, image): + image_latents = self.vae.encode_vae_image(image, self.input_info) + return {"image_latents": image_latents} + + def run(self, total_steps=None): + if total_steps is None: + total_steps = self.model.scheduler.infer_steps + for step_index in range(total_steps): + logger.info(f"==> step_index: {step_index + 1} / {total_steps}") + + with ProfilingContext4DebugL1("step_pre"): + self.model.scheduler.step_pre(step_index=step_index) + + with ProfilingContext4DebugL1("🚀 infer_main"): + self.model.infer(self.inputs) + + with ProfilingContext4DebugL1("step_post"): + self.model.scheduler.step_post() + + if self.progress_callback: + self.progress_callback(((step_index + 1) / total_steps) * 100, 100) + + return self.model.scheduler.latents, self.model.scheduler.generator + + def set_target_shape(self): + if not self.config["_auto_resize"]: + width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]] + else: + width, height = self.input_info.original_size[-1] + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height) + multiple_of = self.vae.vae_scale_factor * 2 + width = calculated_width // multiple_of * multiple_of + height = calculated_height // multiple_of * multiple_of + self.input_info.auto_width = width + self.input_info.auto_hight = height + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae.vae_scale_factor * 2)) + num_channels_latents = self.model.in_channels // 4 + self.input_info.target_shape = (self.config["batchsize"], 1, num_channels_latents, height, width) + + def set_img_shapes(self): + if self.config["task"] == "t2i": + width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]] + img_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * self.config["batchsize"] + elif self.config["task"] == "i2i": + img_shapes = [[(1, self.input_info.auto_hight // self.config["vae_scale_factor"] // 2, self.input_info.auto_width // self.config["vae_scale_factor"] // 2)]] + for image_height, image_width in self.inputs["text_encoder_output"]["image_info"]["vae_image_info_list"]: + img_shapes[0].append((1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2)) + + self.inputs["img_shapes"] = img_shapes + + def init_scheduler(self): + self.scheduler = QwenImageScheduler(self.config) + + def get_encoder_output_i2v(self): + pass + + def run_image_encoder(self): + pass + + @ProfilingContext4DebugL2("Load models") + def load_model(self): + self.model = self.load_transformer() + self.text_encoders = self.load_text_encoder() + self.image_encoder = self.load_image_encoder() + self.vae = self.load_vae() + self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None + + @ProfilingContext4DebugL1( + "Run VAE Decoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, + metrics_labels=["QwenImageRunner"], + ) + def run_vae_decoder(self, latents): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.vae_decoder = self.load_vae() + images = self.vae.decode(latents, self.input_info) + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.vae_decoder + torch_device_module.empty_cache() + gc.collect() + return images + + def run_pipeline(self, input_info): + self.input_info = input_info + + self.inputs = self.run_input_encoder() + self.set_target_shape() + self.set_img_shapes() + + latents, generator = self.run_dit() + images = self.run_vae_decoder(latents) + self.end_run() + + image = images[0] + image.save(f"{input_info.save_result_path}") + + del latents, generator + torch_device_module.empty_cache() + gc.collect() + + # Return (images, audio) - audio is None for default runner + return images, None diff --git a/lightx2v/models/runners/vsr/utils/TCDecoder.py b/lightx2v/models/runners/vsr/utils/TCDecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..73f2083c8b1561cdeb525953eaf569706cd57895 --- /dev/null +++ b/lightx2v/models/runners/vsr/utils/TCDecoder.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +""" +Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned) +- Encoder removed +- Transplant/widening helpers removed +- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself +""" + +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from einops import rearrange +from tqdm.auto import tqdm + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +# ---------------------------- +# Utility / building blocks +# ---------------------------- + + +class IdentityConv2d(nn.Conv2d): + """Same-shape Conv2d initialized to identity (Dirac).""" + + def __init__(self, C, kernel_size=3, bias=False): + pad = kernel_size // 2 + super().__init__(C, C, kernel_size, padding=pad, bias=bias) + with torch.no_grad(): + init.dirac_(self.weight) + if self.bias is not None: + self.bias.zero_() + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = nn.ReLU(inplace=True) + + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) + + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) + + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + + +class PixelShuffle3d(nn.Module): + def __init__(self, ff, hh, ww): + super().__init__() + self.ff = ff + self.hh = hh + self.ww = ww + + def forward(self, x): + # x: (B, C, F, H, W) + B, C, F, H, W = x.shape + if F % self.ff != 0: + first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1) + x = torch.cat([first_frame, x], dim=2) + return rearrange(x, "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w", ff=self.ff, hh=self.hh, ww=self.ww).transpose(1, 2) + + +# ---------------------------- +# Generic NTCHW graph executor (kept; used by decoder) +# ---------------------------- + + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None): + """ + Apply a sequential model with memblocks to the given input. + Args: + - model: nn.Sequential of blocks to apply + - x: input data, of dimensions NTCHW + - parallel: if True, parallelize over timesteps (fast but uses O(T) memory) + if False, each timestep will be processed sequentially (slow but uses O(1) memory) + - show_progress_bar: if True, enables tqdm progressbar display + + Returns NTCHW tensor of output data. + """ + assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor" + N, T, C, H, W = x.shape + if parallel: + x = x.reshape(N * T, C, H, W) + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + NT, C, H, W = x.shape + T = NT // N + _x = x.reshape(N, T, C, H, W) + mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + NT, C, H, W = x.shape + T = NT // N + x = x.view(N, T, C, H, W) + else: + out = [] + work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))] + progress_bar = tqdm(range(T), disable=not show_progress_bar) + while work_queue: + xt, i = work_queue.pop(0) + if i == 0: + progress_bar.update(1) + if i == len(model): + out.append(xt) + else: + b = model[i] + if isinstance(b, MemBlock): + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt + else: + xt_new = b(xt, mem[i]) + mem[i].copy_(xt) + work_queue.insert(0, TWorkItem(xt_new, i + 1)) + elif isinstance(b, TPool): + if mem[i] is None: + mem[i] = [] + mem[i].append(xt) + if len(mem[i]) > b.stride: + raise ValueError("TPool internal state invalid.") + elif len(mem[i]) == b.stride: + N_, C_, H_, W_ = xt.shape + xt = b(torch.cat(mem[i], 1).view(N_ * b.stride, C_, H_, W_)) + mem[i] = [] + work_queue.insert(0, TWorkItem(xt, i + 1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C_, H_, W_ = xt.shape + for xt_next in reversed(xt.view(N, b.stride * C_, H_, W_).chunk(b.stride, 1)): + work_queue.insert(0, TWorkItem(xt_next, i + 1)) + else: + xt = b(xt) + work_queue.insert(0, TWorkItem(xt, i + 1)) + progress_bar.close() + x = torch.stack(out, 1) + return x, mem + + +# ---------------------------- +# Decoder-only TAEHV +# ---------------------------- + + +class TAEHV(nn.Module): + image_channels = 3 + + def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), channels=[256, 128, 64, 64], latent_channels=16): + """Initialize TAEHV (decoder-only) with built-in deepening after every ReLU. + Deepening config: how_many_each=1, k=3 (fixed as requested). + """ + super().__init__() + self.latent_channels = latent_channels + n_f = channels + self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 + + # Build the decoder "skeleton" + base_decoder = nn.Sequential( + Clamp(), + conv(self.latent_channels, n_f[0]), + nn.ReLU(inplace=True), + MemBlock(n_f[0], n_f[0]), + MemBlock(n_f[0], n_f[0]), + MemBlock(n_f[0], n_f[0]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), + TGrow(n_f[0], 1), + conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1]), + MemBlock(n_f[1], n_f[1]), + MemBlock(n_f[1], n_f[1]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), + TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), + conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2]), + MemBlock(n_f[2], n_f[2]), + MemBlock(n_f[2], n_f[2]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), + TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), + conv(n_f[2], n_f[3], bias=False), + nn.ReLU(inplace=True), + conv(n_f[3], TAEHV.image_channels), + ) + + # Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU + self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3) + + self.pixel_shuffle = PixelShuffle3d(4, 8, 8) + + if checkpoint_path is not None: + missing_keys = self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)), strict=False) + print("missing_keys", missing_keys) + + # Initialize decoder mem state + self.mem = [None] * len(self.decoder) + + @staticmethod + def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential: + """Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU).""" + new_layers = [] + for b in decoder: + new_layers.append(b) + if isinstance(b, nn.ReLU): + # Deduce channel count from preceding layer + C = None + if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d): + C = new_layers[-2].out_channels + elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock): + C = new_layers[-2].conv[-1].out_channels + if C is not None: + for _ in range(how_many_each): + new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False)) + new_layers.append(nn.ReLU(inplace=True)) + return nn.Sequential(*new_layers) + + def patch_tgrow_layers(self, sd): + """Patch TGrow layers to use a smaller kernel if needed (decoder-only).""" + new_sd = self.state_dict() + for i, layer in enumerate(self.decoder): + if isinstance(layer, TGrow): + key = f"decoder.{i}.conv.weight" + if key in sd and sd[key].shape[0] > new_sd[key].shape[0]: + sd[key] = sd[key][-new_sd[key].shape[0] :] + return sd + + def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None): + """Decode a sequence of frames from latents. + x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1]. + """ + trim_flag = self.mem[-8] is None # keeps original relative check + + if cond is not None: + x = torch.cat([self.pixel_shuffle(cond), x], dim=2) + + x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem) + + if trim_flag: + return x[:, self.frames_to_trim :] + return x + + def forward(self, *args, **kwargs): + raise NotImplementedError("Decoder-only model: call decode_video(...) instead.") + + def clean_mem(self): + self.mem = [None] * len(self.decoder) + + +class DotDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + + +class TAEW2_1DiffusersWrapper(nn.Module): + def __init__(self, pretrained_path=None, channels=[256, 128, 64, 64]): + super().__init__() + self.dtype = torch.bfloat16 + self.device = "cuda" + self.taehv = TAEHV(pretrained_path, channels=channels).to(self.dtype) + self.temperal_downsample = [True, True, False] # [sic] + self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16)) + + def decode(self, latents, return_dict=None): + n, c, t, h, w = latents.shape + return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),) + + def stream_decode_with_cond(self, latents, tiled=False, cond=None): + n, c, t, h, w = latents.shape + return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1) + + def clean_mem(self): + self.taehv.clean_mem() + + +# ---------------------------- +# Simplified builder (no small, no transplant, no post-hoc deepening) +# ---------------------------- + + +def build_tcdecoder(new_channels=[512, 256, 128, 128], device="cuda", dtype=torch.bfloat16, new_latent_channels=None): + """ + 构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。 + - 不创建 small / 不做移植 + - base_ckpt_path 参数保留但不使用(接口兼容) + + 返回:big (单个模型) + """ + if new_latent_channels is not None: + big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train() + else: + big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train() + + big.clean_mem() + return big diff --git a/lightx2v/models/runners/vsr/utils/utils.py b/lightx2v/models/runners/vsr/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3451a87185854f514ac7ca8bd6967f333c8ccc39 --- /dev/null +++ b/lightx2v/models/runners/vsr/utils/utils.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +CACHE_T = 2 + + +class RMS_norm(nn.Module): + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + # print(cache_x.shape, x.shape) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + # print('cache!') + x = F.pad(x, padding, mode="replicate") # mode='replicate' + # print(x[0,0,:,0,0]) + + return super().forward(x) + + +class PixelShuffle3d(nn.Module): + def __init__(self, ff, hh, ww): + super().__init__() + self.ff = ff + self.hh = hh + self.ww = ww + + def forward(self, x): + # x: (B, C, F, H, W) + return rearrange(x, "b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w", ff=self.ff, hh=self.hh, ww=self.ww) + + +class Buffer_LQ4x_Proj(nn.Module): + def __init__(self, in_dim, out_dim, layer_num=30): + super().__init__() + self.ff = 1 + self.hh = 16 + self.ww = 16 + self.hidden_dim1 = 2048 + self.hidden_dim2 = 3072 + self.layer_num = layer_num + + self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) + + self.conv1 = CausalConv3d(in_dim * self.ff * self.hh * self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w + self.norm1 = RMS_norm(self.hidden_dim1, images=False) + self.act1 = nn.SiLU() + + self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w + self.norm2 = RMS_norm(self.hidden_dim2, images=False) + self.act2 = nn.SiLU() + + self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]) + + self.clip_idx = 0 + + def forward(self, video): + self.clear_cache() + # x: (B, C, F, H, W) + + t = video.shape[2] + iter_ = 1 + (t - 1) // 4 + first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video = torch.cat([first_frame, video], dim=2) + # print(video.shape) + + out_x = [] + for i in range(iter_): + x = self.pixel_shuffle(video[:, :, i * 4 : (i + 1) * 4, :, :]) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache["conv1"] = cache1_x + x = self.conv1(x, self.cache["conv1"]) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache["conv2"] = cache2_x + if i == 0: + continue + x = self.conv2(x, self.cache["conv2"]) + x = self.norm2(x) + x = self.act2(x) + out_x.append(x) + out_x = torch.cat(out_x, dim=2) + # print(out_x.shape) + out_x = rearrange(out_x, "b c f h w -> b (f h w) c") + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + return outputs + + def clear_cache(self): + self.cache = {} + self.cache["conv1"] = None + self.cache["conv2"] = None + self.clip_idx = 0 + + def stream_forward(self, video_clip): + if self.clip_idx == 0: + # self.clear_cache() + first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video_clip = torch.cat([first_frame, video_clip], dim=2) + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache["conv1"] = cache1_x + x = self.conv1(x, self.cache["conv1"]) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache["conv2"] = cache2_x + self.clip_idx += 1 + return None + else: + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache["conv1"] = cache1_x + x = self.conv1(x, self.cache["conv1"]) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache["conv2"] = cache2_x + x = self.conv2(x, self.cache["conv2"]) + x = self.norm2(x) + x = self.act2(x) + out_x = rearrange(x, "b c f h w -> b (f h w) c") + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + self.clip_idx += 1 + return outputs diff --git a/lightx2v/models/runners/vsr/vsr_wrapper.py b/lightx2v/models/runners/vsr/vsr_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ade1b7677b42269083db23e71ec1016d9d1aaf60 --- /dev/null +++ b/lightx2v/models/runners/vsr/vsr_wrapper.py @@ -0,0 +1,162 @@ +import os +from typing import Optional + +import torch +from torch.nn import functional as F + +from lightx2v.utils.profiler import * + +try: + from diffsynth import FlashVSRTinyPipeline, ModelManager +except ImportError: + ModelManager = None + FlashVSRTinyPipeline = None + + +from .utils.TCDecoder import build_tcdecoder +from .utils.utils import Buffer_LQ4x_Proj + + +def largest_8n1_leq(n): # 8n+1 + return 0 if n < 1 else ((n - 1) // 8) * 8 + 1 + + +def compute_scaled_and_target_dims(w0: int, h0: int, scale: float = 4.0, multiple: int = 128): + if w0 <= 0 or h0 <= 0: + raise ValueError("Invalid original size") + if scale <= 0: + raise ValueError("scale must be > 0") + + sW = int(round(w0 * scale)) + sH = int(round(h0 * scale)) + + tW = (sW // multiple) * multiple + tH = (sH // multiple) * multiple + + if tW == 0 or tH == 0: + raise ValueError(f"Scaled size too small ({sW}x{sH}) for multiple={multiple}. Increase scale (got {scale}).") + + return sW, sH, tW, tH + + +def prepare_input_tensor(input_tensor, scale: float = 2.0, dtype=torch.bfloat16, device="cuda"): + """ + 视频预处理: [T,H,W,3] -> [1,C,F,H,W] + 1. GPU 上完成插值 + 中心裁剪 + 2. 自动 pad 帧数到 8n-3 + """ + + input_tensor = input_tensor.to(device=device, dtype=torch.float32) # [T,H,W,3] + total, h0, w0, _ = input_tensor.shape + + # 计算缩放与目标分辨率 + sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=128) + print(f"Scaled (x{scale:.2f}): {sW}x{sH} -> Target: {tW}x{tH}") + + # Pad 帧数到 8n-3 + idx = list(range(total)) + [total - 1] * 4 + F_target = largest_8n1_leq(len(idx)) + if F_target == 0: + raise RuntimeError(f"Not enough frames after padding. Got {len(idx)}.") + idx = idx[:F_target] + print(f"Target Frames (8n-3): {F_target - 4}") + + # 取帧并转为 tensor 格式 [B,C,H,W] + frames = input_tensor[idx] # [F,H,W,3] + frames = frames.permute(0, 3, 1, 2) * 2.0 - 1.0 # [F,3,H,W] -> [-1,1] + + # 上采样 (Bilinear) + frames = F.interpolate(frames, scale_factor=scale, mode="bicubic", align_corners=False) + _, _, sH, sW = frames.shape + + # 中心裁剪 + left = (sW - tW) // 2 + top = (sH - tH) // 2 + frames = frames[:, :, top : top + tH, left : left + tW] + + # 输出 [1, C, F, H, W] + vid = frames.permute(1, 0, 2, 3).unsqueeze(0).to(dtype) + return vid, tH, tW, F_target + + +def init_pipeline(model_path): + # print(torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device())) + mm = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + mm.load_models( + [ + model_path + "/diffusion_pytorch_model_streaming_dmd.safetensors", + ] + ) + pipe = FlashVSRTinyPipeline.from_model_manager(mm, device="cuda") + pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to("cuda", dtype=torch.bfloat16) + LQ_proj_in_path = model_path + "/LQ_proj_in.ckpt" + if os.path.exists(LQ_proj_in_path): + pipe.denoising_model().LQ_proj_in.load_state_dict(torch.load(LQ_proj_in_path, map_location="cpu"), strict=True) + pipe.denoising_model().LQ_proj_in.to("cuda") + + multi_scale_channels = [512, 256, 128, 128] + pipe.TCDecoder = build_tcdecoder(new_channels=multi_scale_channels, new_latent_channels=16 + 768) + mis = pipe.TCDecoder.load_state_dict(torch.load(model_path + "/TCDecoder.ckpt"), strict=False) + # print(mis) + + pipe.to("cuda") + pipe.enable_vram_management(num_persistent_param_in_dit=None) + pipe.init_cross_kv() + pipe.load_models_to_device(["dit", "vae"]) + return pipe + + +class VSRWrapper: + def __init__(self, model_path, device: Optional[torch.device] = None): + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Setup torch for optimal performance + torch.set_grad_enabled(False) + if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # Load model + self.dtype, self.device = torch.bfloat16, "cuda" + self.sparse_ratio = 2.0 # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable. + with ProfilingContext4DebugL2("Load VSR model"): + self.pipe = init_pipeline(model_path) + self._warm_up() + + def _warm_up(self): + dummy = torch.zeros((25, 384, 640, 3), dtype=torch.float32, device=self.device) + _ = self.super_resolve_frames(dummy, seed=0, scale=2.0) + torch.cuda.synchronize() + del dummy + + @ProfilingContext4DebugL2("VSR video") + def super_resolve_frames( + self, + video: torch.Tensor, # [T,H,W,C] + seed: float = 0.0, + scale: float = 2.0, + ) -> torch.Tensor: + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + LQ, th, tw, F = prepare_input_tensor(video, scale=scale, dtype=self.dtype, device=self.device) + + video = self.pipe( + prompt="", + negative_prompt="", + cfg_scale=1.0, + num_inference_steps=1, + seed=seed, + LQ_video=LQ, + num_frames=F, + height=th, + width=tw, + is_full_block=False, + if_buffer=True, + topk_ratio=self.sparse_ratio * 768 * 1280 / (th * tw), + kv_ratio=3.0, + local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results. + color_fix=True, + ) + video = (video + 1.0) / 2.0 # 将 [-1,1] 映射到 [0,1] + video = video.permute(1, 2, 3, 0).clamp(0.0, 1.0) # [C,T,H,W] -> [T,H,W,C] + return video diff --git a/lightx2v/models/runners/vsr/vsr_wrapper_hy15.py b/lightx2v/models/runners/vsr/vsr_wrapper_hy15.py new file mode 100644 index 0000000000000000000000000000000000000000..68e0ef55d11d72210614b1b24cde54244848e71c --- /dev/null +++ b/lightx2v/models/runners/vsr/vsr_wrapper_hy15.py @@ -0,0 +1,152 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from einops import rearrange +from torch import Tensor + +from lightx2v.models.video_encoders.hf.hunyuanvideo15.hunyuanvideo_15_vae import ( + CausalConv3d, + RMS_norm, + ResnetBlock, + forward_with_checkpointing, + swish, +) + + +class UpsamplerType(Enum): + LEARNED = "learned" + FIXED = "fixed" + NONE = "none" + LEARNED_FIXED = "learned_fixed" + + +@dataclass +class UpsamplerConfig: + load_from: str + enable: bool = False + hidden_channels: int = 128 + num_blocks: int = 16 + model_type: UpsamplerType = UpsamplerType.NONE + version: str = "720p" + + +class SRResidualCausalBlock3D(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.block = nn.Sequential( + CausalConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + CausalConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + CausalConv3d(channels, channels, kernel_size=3), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + + +class SRModel3DV2(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int | None = None, + num_blocks: int = 6, + global_residual: bool = False, + ): + super().__init__() + if hidden_channels is None: + hidden_channels = 64 + self.in_conv = CausalConv3d(in_channels, hidden_channels, kernel_size=3) + self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)]) + self.out_conv = CausalConv3d(hidden_channels, out_channels, kernel_size=3) + self.global_residual = bool(global_residual) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + y = self.in_conv(x) + for blk in self.blocks: + y = blk(y) + y = self.out_conv(y) + if self.global_residual and (y.shape == residual.shape): + y = y + residual + return y + + +class Upsampler(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: tuple[int, ...], + num_res_blocks: int = 2, + is_residual: bool = False, + ): + super().__init__() + self.num_res_blocks = num_res_blocks + self.block_out_channels = block_out_channels + self.z_channels = z_channels + + # assert block_out_channels[0] % z_channels == 0 + block_in = block_out_channels[0] + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3) + + self.up = nn.ModuleList() + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + + self.up.append(up) + + self.norm_out = RMS_norm(block_in, images=False) + self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3) + + self.gradient_checkpointing = False + self.is_residual = is_residual + + def forward(self, z: Tensor, target_shape: Sequence[int] = None) -> Tensor: + """ + Args: + z: (B, C, T, H, W) + target_shape: (H, W) + """ + use_checkpointing = bool(self.training and self.gradient_checkpointing) + if target_shape is not None and z.shape[-2:] != target_shape: + bsz = z.shape[0] + z = rearrange(z, "b c f h w -> (b f) c h w") + z = F.interpolate(z, size=target_shape, mode="bilinear", align_corners=False) + z = rearrange(z, "(b f) c h w -> b c f h w", b=bsz) + + # z to block_in + repeats = self.block_out_channels[0] // (self.z_channels) + h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # upsampling + for i_level in range(len(self.block_out_channels)): + for i_block in range(self.num_res_blocks + 1): + h = forward_with_checkpointing( + self.up[i_level].block[i_block], + h, + use_checkpointing=use_checkpointing, + ) + if hasattr(self.up[i_level], "upsample"): + h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h diff --git a/lightx2v/models/runners/wan/__init__.py b/lightx2v/models/runners/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/runners/wan/wan_animate_runner.py b/lightx2v/models/runners/wan/wan_animate_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..014a848732d5cc490e1920fd765ae95f5e644e85 --- /dev/null +++ b/lightx2v/models/runners/wan/wan_animate_runner.py @@ -0,0 +1,418 @@ +import gc +from copy import deepcopy + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from loguru import logger + +try: + from decord import VideoReader +except ImportError: + VideoReader = None + logger.info("If you want to run animate model, please install decord.") + + +from lightx2v.models.input_encoders.hf.animate.face_encoder import FaceEncoder +from lightx2v.models.input_encoders.hf.animate.motion_encoder import Generator +from lightx2v.models.networks.wan.animate_model import WanAnimateModel +from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper +from lightx2v.models.runners.wan.wan_runner import WanRunner +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import * +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.utils import load_weights, remove_substrings_from_keys +from lightx2v_platform.base.global_var import AI_DEVICE + + +@RUNNER_REGISTER("wan2.2_animate") +class WanAnimateRunner(WanRunner): + def __init__(self, config): + super().__init__(config) + assert self.config["task"] == "animate" + + def inputs_padding(self, array, target_len): + idx = 0 + flip = False + target_array = [] + while len(target_array) < target_len: + target_array.append(deepcopy(array[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == len(array) - 1: + flip = not flip + return target_array[:target_len] + + def get_valid_len(self, real_len, clip_len=81, overlap=1): + real_clip_len = clip_len - overlap + last_clip_num = (real_len - overlap) % real_clip_len + if last_clip_num == 0: + extra = 0 + else: + extra = real_clip_len - last_clip_num + target_len = real_len + extra + return target_len + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t - 1) * 4 + 1, lat_h, lat_w, dtype=GET_DTYPE(), device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def padding_resize( + self, + img_ori, + height=512, + width=512, + padding_color=(0, 0, 0), + interpolation=cv2.INTER_LINEAR, + ): + ori_height = img_ori.shape[0] + ori_width = img_ori.shape[1] + channel = img_ori.shape[2] + + img_pad = np.zeros((height, width, channel)) + if channel == 1: + img_pad[:, :, 0] = padding_color[0] + else: + img_pad[:, :, 0] = padding_color[0] + img_pad[:, :, 1] = padding_color[1] + img_pad[:, :, 2] = padding_color[2] + + if (ori_height / ori_width) > (height / width): + new_width = int(height / ori_height * ori_width) + img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation) + padding = int((width - new_width) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[:, padding : padding + new_width, :] = img + else: + new_height = int(width / ori_width * ori_height) + img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation) + padding = int((height - new_height) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[padding : padding + new_height, :, :] = img + + img_pad = np.uint8(img_pad) + + return img_pad + + def prepare_source(self, src_pose_path, src_face_path, src_ref_path): + pose_video_reader = VideoReader(src_pose_path) + pose_len = len(pose_video_reader) + pose_idxs = list(range(pose_len)) + cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy() + + face_video_reader = VideoReader(src_face_path) + face_len = len(face_video_reader) + face_idxs = list(range(face_len)) + face_images = face_video_reader.get_batch(face_idxs).asnumpy() + height, width = cond_images[0].shape[:2] + refer_images = cv2.imread(src_ref_path)[..., ::-1] + refer_images = self.padding_resize(refer_images, height=height, width=width) + return cond_images, face_images, refer_images + + def prepare_source_for_replace(self, src_bg_path, src_mask_path): + bg_video_reader = VideoReader(src_bg_path) + bg_len = len(bg_video_reader) + bg_idxs = list(range(bg_len)) + bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy() + + mask_video_reader = VideoReader(src_mask_path) + mask_len = len(mask_video_reader) + mask_idxs = list(range(mask_len)) + mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy() + mask_images = mask_images[:, :, :, 0] / 255 + return bg_images, mask_images + + @ProfilingContext4DebugL2("Run Image Encoders") + def run_image_encoders( + self, + conditioning_pixel_values, + refer_t_pixel_values, + bg_pixel_values, + mask_pixel_values, + face_pixel_values, + ): + clip_encoder_out = self.run_image_encoder(self.refer_pixel_values) + vae_encoder_out, pose_latents = self.run_vae_encoder( + conditioning_pixel_values, + refer_t_pixel_values, + bg_pixel_values, + mask_pixel_values, + ) + return {"image_encoder_output": {"clip_encoder_out": clip_encoder_out, "vae_encoder_out": vae_encoder_out, "pose_latents": pose_latents, "face_pixel_values": face_pixel_values}} + + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["WanAnimateRunner"], + ) + def run_vae_encoder( + self, + conditioning_pixel_values, + refer_t_pixel_values, + bg_pixel_values, + mask_pixel_values, + ): + H, W = self.refer_pixel_values.shape[-2], self.refer_pixel_values.shape[-1] + pose_latents = self.vae_encoder.encode(conditioning_pixel_values.unsqueeze(0)) # c t h w + ref_latents = self.vae_encoder.encode(self.refer_pixel_values.unsqueeze(1).unsqueeze(0)) # c t h w + + mask_ref = self.get_i2v_mask(1, self.latent_h, self.latent_w, 1) + y_ref = torch.concat([mask_ref, ref_latents]) + + if self.mask_reft_len > 0: + if self.config["replace_flag"]: + y_reft = self.vae_encoder.encode( + torch.concat( + [ + refer_t_pixel_values.unsqueeze(2)[0, :, : self.mask_reft_len], + bg_pixel_values[:, self.mask_reft_len :], + ], + dim=1, + ) + .to(AI_DEVICE) + .unsqueeze(0) + ) + mask_pixel_values = 1 - mask_pixel_values + mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3) + mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest") + mask_pixel_values = mask_pixel_values[:, 0, :, :] + + msk_reft = self.get_i2v_mask( + self.latent_t, + self.latent_h, + self.latent_w, + self.mask_reft_len, + mask_pixel_values=mask_pixel_values.unsqueeze(0), + ) + else: + y_reft = self.vae_encoder.encode( + torch.concat( + [ + torch.nn.functional.interpolate( + refer_t_pixel_values.unsqueeze(2)[0, :, : self.mask_reft_len].cpu(), + size=(H, W), + mode="bicubic", + ), + torch.zeros(3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE()), + ], + dim=1, + ) + .to(AI_DEVICE) + .unsqueeze(0) + ) + msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len) + else: + if self.config["replace_flag"]: + mask_pixel_values = 1 - mask_pixel_values + mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3) + mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest") + mask_pixel_values = mask_pixel_values[:, 0, :, :] + y_reft = self.vae_encoder.encode(bg_pixel_values.unsqueeze(0)) + msk_reft = self.get_i2v_mask( + self.latent_t, + self.latent_h, + self.latent_w, + self.mask_reft_len, + mask_pixel_values=mask_pixel_values.unsqueeze(0), + ) + else: + y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device="cuda")) + msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len) + + y_reft = torch.concat([msk_reft, y_reft]) + y = torch.concat([y_ref, y_reft], dim=1) + + return y, pose_latents + + def prepare_input(self): + src_pose_path = self.input_info.src_pose_path + src_face_path = self.input_info.src_face_path + src_ref_path = self.input_info.src_ref_images + self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path) + self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device="cuda").permute(2, 0, 1) # chw + self.latent_t = self.config["target_video_length"] // self.config["vae_stride"][0] + 1 + self.latent_h = self.refer_pixel_values.shape[-2] // self.config["vae_stride"][1] + self.latent_w = self.refer_pixel_values.shape[-1] // self.config["vae_stride"][2] + self.input_info.latent_shape = [self.config.get("num_channels_latents", 16), self.latent_t + 1, self.latent_h, self.latent_w] + self.real_frame_len = len(self.cond_images) + target_len = self.get_valid_len( + self.real_frame_len, + self.config["target_video_length"], + overlap=self.config["refert_num"] if "refert_num" in self.config else 1, + ) + logger.info("real frames: {} target frames: {}".format(self.real_frame_len, target_len)) + self.cond_images = self.inputs_padding(self.cond_images, target_len) + self.face_images = self.inputs_padding(self.face_images, target_len) + + if self.config["replace_flag"] if "replace_flag" in self.config else False: + src_bg_path = self.input_info.src_bg_path + src_mask_path = self.input_info.src_mask_path + self.bg_images, self.mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path) + self.bg_images = self.inputs_padding(self.bg_images, target_len) + self.mask_images = self.inputs_padding(self.mask_images, target_len) + + def get_video_segment_num(self): + total_frames = len(self.cond_images) + self.move_frames = self.config["target_video_length"] - self.config["refert_num"] + if total_frames <= self.config["target_video_length"]: + self.video_segment_num = 1 + else: + self.video_segment_num = 1 + (total_frames - self.config["target_video_length"] + self.move_frames - 1) // self.move_frames + + def init_run(self): + self.all_out_frames = [] + self.prepare_input() + super().init_run() + + @ProfilingContext4DebugL1( + "Run VAE Decoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, + metrics_labels=["WanAnimateRunner"], + ) + def run_vae_decoder(self, latents): + if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): + self.vae_decoder = self.load_vae_decoder() + images = self.vae_decoder.decode(latents[:, 1:].to(GET_DTYPE())) + if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): + del self.vae_decoder + torch.cuda.empty_cache() + gc.collect() + return images + + @ProfilingContext4DebugL1( + "Init run segment", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_init_run_segment_duration, + metrics_labels=["WanAnimateRunner"], + ) + def init_run_segment(self, segment_idx): + start = segment_idx * self.move_frames + end = start + self.config["target_video_length"] + if start == 0: + self.mask_reft_len = 0 + else: + self.mask_reft_len = self.config["refert_num"] + + conditioning_pixel_values = torch.tensor( + np.stack(self.cond_images[start:end]) / 127.5 - 1, + device="cuda", + dtype=GET_DTYPE(), + ).permute(3, 0, 1, 2) # c t h w + + face_pixel_values = torch.tensor( + np.stack(self.face_images[start:end]) / 127.5 - 1, + device="cuda", + dtype=GET_DTYPE(), + ).permute(0, 3, 1, 2) # thwc->tchw + + if start == 0: + height, width = self.refer_images.shape[:2] + refer_t_pixel_values = torch.zeros( + 3, + self.config["refert_num"], + height, + width, + device="cuda", + dtype=GET_DTYPE(), + ) # c t h w + else: + refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().to(AI_DEVICE) # c t h w + + bg_pixel_values, mask_pixel_values = None, None + if self.config["replace_flag"] if "replace_flag" in self.config else False: + bg_pixel_values = torch.tensor( + np.stack(self.bg_images[start:end]) / 127.5 - 1, + device="cuda", + dtype=GET_DTYPE(), + ).permute(3, 0, 1, 2) # c t h w, + + mask_pixel_values = torch.tensor( + np.stack(self.mask_images[start:end])[:, :, :, None], + device="cuda", + dtype=GET_DTYPE(), + ).permute(3, 0, 1, 2) # c t h w, + + self.inputs.update( + self.run_image_encoders( + conditioning_pixel_values, + refer_t_pixel_values, + bg_pixel_values, + mask_pixel_values, + face_pixel_values, + ) + ) + + if start != 0: + self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape) + + def end_run_segment(self, segment_idx): + if segment_idx != 0: + self.gen_video = self.gen_video[:, :, self.config["refert_num"] :] + self.all_out_frames.append(self.gen_video.cpu()) + + def process_images_after_vae_decoder(self): + self.gen_video_final = torch.cat(self.all_out_frames, dim=2)[:, :, : self.real_frame_len] + del self.all_out_frames + gc.collect() + super().process_images_after_vae_decoder() + + @ProfilingContext4DebugL1( + "Run Image Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_img_encode_duration, + metrics_labels=["WanAnimateRunner"], + ) + def run_image_encoder(self, img): # CHW + if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): + self.image_encoder = self.load_image_encoder() + clip_encoder_out = self.image_encoder.visual([img.unsqueeze(0)]).squeeze(0).to(GET_DTYPE()) + if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): + del self.image_encoder + torch.cuda.empty_cache() + gc.collect() + return clip_encoder_out + + def load_transformer(self): + model = WanAnimateModel( + self.config["model_path"], + self.config, + self.init_device, + ) + + if self.config.get("lora_configs") and self.config.lora_configs: + assert not self.config.get("dit_quantized", False) + lora_wrapper = WanLoraWrapper(model) + for lora_config in self.config.lora_configs: + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + lora_name = lora_wrapper.load_lora(lora_path) + lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") + + motion_encoder, face_encoder = self.load_encoders() + model.set_animate_encoders(motion_encoder, face_encoder) + return model + + def load_encoders(self): + motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).to(AI_DEVICE) + face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).to(AI_DEVICE) + motion_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["motion_encoder"]), "motion_encoder.") + face_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["face_encoder"]), "face_encoder.") + motion_encoder.load_state_dict(motion_weight_dict) + face_encoder.load_state_dict(face_weight_dict) + return motion_encoder, face_encoder diff --git a/lightx2v/models/runners/wan/wan_audio_runner.py b/lightx2v/models/runners/wan/wan_audio_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..4032e1da6d38dde6fe9e7fb9cd39453bd84b41a7 --- /dev/null +++ b/lightx2v/models/runners/wan/wan_audio_runner.py @@ -0,0 +1,923 @@ +import gc +import io +import json +import os +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchaudio as ta +import torchvision.transforms.functional as TF +from PIL import Image, ImageCms, ImageOps +from einops import rearrange +from loguru import logger +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize + +from lightx2v.deploy.common.va_controller import VAController +from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter +from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel +from lightx2v.models.networks.wan.audio_model import WanAudioModel +from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper +from lightx2v.models.runners.wan.wan_runner import WanRunner +from lightx2v.models.schedulers.wan.audio.scheduler import EulerScheduler +from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import * +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.utils import find_torch_model_path, load_weights, vae_to_comfyui_image_inplace +from lightx2v_platform.base.global_var import AI_DEVICE + +warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio") +warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io") + + +def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size): + assert sp_size > 0 and (sp_size & (sp_size - 1)) == 0, "sp_size must be a power of 2" + + h_ratio, w_ratio = 1, 1 + while sp_size != 1: + sp_size //= 2 + if patched_h % 2 == 0: + patched_h //= 2 + h_ratio *= 2 + elif patched_w % 2 == 0: + patched_w //= 2 + w_ratio *= 2 + else: + if patched_h > patched_w: + patched_h //= 2 + h_ratio *= 2 + else: + patched_w //= 2 + w_ratio *= 2 + return patched_h * h_ratio, patched_w * w_ratio + + +def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w): + tgt_ar = tgt_h / tgt_w + ori_ar = ori_h / ori_w + if abs(ori_ar - tgt_ar) < 0.01: + return 0, ori_h, 0, ori_w + if ori_ar > tgt_ar: + crop_h = int(tgt_ar * ori_w) + y0 = (ori_h - crop_h) // 2 + y1 = y0 + crop_h + return y0, y1, 0, ori_w + else: + crop_w = int(ori_h / tgt_ar) + x0 = (ori_w - crop_w) // 2 + x1 = x0 + crop_w + return 0, ori_h, x0, x1 + + +def isotropic_crop_resize(frames: torch.Tensor, size: tuple): + """ + frames: (C, H, W) or (T, C, H, W) or (N, C, H, W) + size: (H, W) + """ + original_shape = frames.shape + + if len(frames.shape) == 3: + frames = frames.unsqueeze(0) + elif len(frames.shape) == 4 and frames.shape[0] > 1: + pass + + ori_h, ori_w = frames.shape[2:] + h, w = size + y0, y1, x0, x1 = get_crop_bbox(ori_h, ori_w, h, w) + cropped_frames = frames[:, :, y0:y1, x0:x1] + resized_frames = resize(cropped_frames, [h, w], InterpolationMode.BICUBIC, antialias=True) + + if len(original_shape) == 3: + resized_frames = resized_frames.squeeze(0) + + return resized_frames + + +def fixed_shape_resize(img, target_height, target_width): + orig_height, orig_width = img.shape[-2:] + + target_ratio = target_height / target_width + orig_ratio = orig_height / orig_width + + if orig_ratio > target_ratio: + crop_width = orig_width + crop_height = int(crop_width * target_ratio) + else: + crop_height = orig_height + crop_width = int(crop_height / target_ratio) + + cropped_img = TF.center_crop(img, [crop_height, crop_width]) + + resized_img = TF.resize(cropped_img, [target_height, target_width], antialias=True) + + h, w = resized_img.shape[-2:] + return resized_img, h, w + + +def resize_image(img, resize_mode="adaptive", bucket_shape=None, fixed_area=None, fixed_shape=None): + assert resize_mode in ["adaptive", "keep_ratio_fixed_area", "fixed_min_area", "fixed_max_area", "fixed_shape", "fixed_min_side"] + + if resize_mode == "fixed_shape": + assert fixed_shape is not None + logger.info(f"[wan_audio] fixed_shape_resize fixed_height: {fixed_shape[0]}, fixed_width: {fixed_shape[1]}") + return fixed_shape_resize(img, fixed_shape[0], fixed_shape[1]) + + if bucket_shape is not None: + """ + "adaptive_shape": { + "0.667": [[480, 832], [544, 960], [720, 1280]], + "1.500": [[832, 480], [960, 544], [1280, 720]], + "1.000": [[480, 480], [576, 576], [704, 704], [960, 960]] + } + """ + bucket_config = {} + for ratio, resolutions in bucket_shape.items(): + bucket_config[float(ratio)] = np.array(resolutions, dtype=np.int64) + # logger.info(f"[wan_audio] use custom bucket_shape: {bucket_config}") + else: + bucket_config = { + 0.667: np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64), + 1.500: np.array([[832, 480], [960, 544], [1280, 720]], dtype=np.int64), + 1.000: np.array([[480, 480], [576, 576], [704, 704], [960, 960]], dtype=np.int64), + } + # logger.info(f"[wan_audio] use default bucket_shape: {bucket_config}") + + ori_height = img.shape[-2] + ori_weight = img.shape[-1] + ori_ratio = ori_height / ori_weight + + if resize_mode == "adaptive": + aspect_ratios = np.array(np.array(list(bucket_config.keys()))) + closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio)) + closet_ratio = aspect_ratios[closet_aspect_idx] + if ori_ratio < 1.0: + target_h, target_w = 480, 832 + elif ori_ratio == 1.0: + target_h, target_w = 480, 480 + else: + target_h, target_w = 832, 480 + for resolution in bucket_config[closet_ratio]: + if ori_height * ori_weight >= resolution[0] * resolution[1]: + target_h, target_w = resolution + elif resize_mode == "keep_ratio_fixed_area": + area_in_pixels = 480 * 832 + if fixed_area == "480p": + area_in_pixels = 480 * 832 + elif fixed_area == "720p": + area_in_pixels = 720 * 1280 + else: + area_in_pixels = 480 * 832 + target_h = round(np.sqrt(area_in_pixels * ori_ratio)) + target_w = round(np.sqrt(area_in_pixels / ori_ratio)) + elif resize_mode == "fixed_min_area": + aspect_ratios = np.array(np.array(list(bucket_config.keys()))) + closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio)) + closet_ratio = aspect_ratios[closet_aspect_idx] + target_h, target_w = bucket_config[closet_ratio][0] + elif resize_mode == "fixed_min_side": + min_side = 720 + if fixed_area == "720p": + min_side = 720 + elif fixed_area == "480p": + min_side = 480 + else: + logger.warning(f"[wan_audio] fixed_area is not '480p' or '720p', using default 480p: {fixed_area}") + min_side = 480 + if ori_ratio < 1.0: + target_h = min_side + target_w = round(target_h / ori_ratio) + else: + target_w = min_side + target_h = round(target_w * ori_ratio) + elif resize_mode == "fixed_max_area": + aspect_ratios = np.array(np.array(list(bucket_config.keys()))) + closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio)) + closet_ratio = aspect_ratios[closet_aspect_idx] + target_h, target_w = bucket_config[closet_ratio][-1] + + cropped_img = isotropic_crop_resize(img, (target_h, target_w)) + logger.info(f"[wan_audio] resize_image: {img.shape} -> {cropped_img.shape}, resize_mode: {resize_mode}, target_h: {target_h}, target_w: {target_w}") + return cropped_img, target_h, target_w + + +@dataclass +class AudioSegment: + """Data class for audio segment information""" + + audio_array: torch.Tensor + start_frame: int + end_frame: int + + +class FramePreprocessorTorchVersion: + """Handles frame preprocessing including noise and masking""" + + def __init__(self, noise_mean: float = -3.0, noise_std: float = 0.5, mask_rate: float = 0.1): + self.noise_mean = noise_mean + self.noise_std = noise_std + self.mask_rate = mask_rate + + def add_noise(self, frames: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor: + """Add noise to frames""" + + device = frames.device + shape = frames.shape + bs = 1 if len(shape) == 4 else shape[0] + + # Generate sigma values on the same device + sigma = torch.normal(mean=self.noise_mean, std=self.noise_std, size=(bs,), device=device, generator=generator) + sigma = torch.exp(sigma) + + for _ in range(1, len(shape)): + sigma = sigma.unsqueeze(-1) + + # Generate noise on the same device + noise = torch.randn(*shape, device=device, generator=generator) * sigma + return frames + noise + + def add_mask(self, frames: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor: + """Add mask to frames""" + + device = frames.device + h, w = frames.shape[-2:] + + # Generate mask on the same device + mask = torch.rand(h, w, device=device, generator=generator) > self.mask_rate + return frames * mask + + def process_prev_frames(self, frames: torch.Tensor) -> torch.Tensor: + """Process previous frames with noise and masking""" + frames = self.add_noise(frames, torch.Generator(device=frames.device)) + frames = self.add_mask(frames, torch.Generator(device=frames.device)) + return frames + + +class AudioProcessor: + """Handles audio loading and segmentation""" + + def __init__(self, audio_sr: int = 16000, target_fps: int = 16): + self.audio_sr = audio_sr + self.target_fps = target_fps + self.audio_frame_rate = audio_sr // target_fps + + def load_audio(self, audio_path: str): + audio_array, ori_sr = ta.load(audio_path) + audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=self.audio_sr) + return audio_array + + def load_multi_person_audio(self, audio_paths: List[str]): + audio_arrays = [] + max_len = 0 + + for audio_path in audio_paths: + audio_array = self.load_audio(audio_path) + audio_arrays.append(audio_array) + max_len = max(max_len, audio_array.numel()) + + num_files = len(audio_arrays) + padded = torch.zeros(num_files, max_len, dtype=torch.float32) + + for i, arr in enumerate(audio_arrays): + length = arr.numel() + padded[i, :length] = arr + + return padded + + def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]: + """Calculate audio range for given frame range""" + return round(start_frame * self.audio_frame_rate), round(end_frame * self.audio_frame_rate) + + def segment_audio(self, audio_array: torch.Tensor, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]: + """ + Segment audio based on frame requirements + audio_array is (N, T) tensor + """ + segments = [] + segments_idx = self.init_segments_idx(expected_frames, max_num_frames, prev_frame_length) + + audio_start, audio_end = self.get_audio_range(0, expected_frames) + audio_array_ori = audio_array[:, audio_start:audio_end] + + for idx, (start_idx, end_idx) in enumerate(segments_idx): + audio_start, audio_end = self.get_audio_range(start_idx, end_idx) + audio_array = audio_array_ori[:, audio_start:audio_end] + + if idx < len(segments_idx) - 1: + end_idx = segments_idx[idx + 1][0] + else: # for last segments + if audio_array.shape[1] < audio_end - audio_start: + padding_len = audio_end - audio_start - audio_array.shape[1] + audio_array = F.pad(audio_array, (0, padding_len)) + # Adjust end_idx to account for the frames added by padding + end_idx = end_idx - padding_len // self.audio_frame_rate + + segments.append(AudioSegment(audio_array, start_idx, end_idx)) + del audio_array, audio_array_ori + return segments + + def init_segments_idx(self, total_frame: int, clip_frame: int = 81, overlap_frame: int = 5) -> list[tuple[int, int, int]]: + """Initialize segment indices with overlap""" + start_end_list = [] + min_frame = clip_frame + for start in range(0, total_frame, clip_frame - overlap_frame): + is_last = start + clip_frame >= total_frame + end = min(start + clip_frame, total_frame) + if end - start < min_frame: + end = start + min_frame + if ((end - start) - 1) % 4 != 0: + end = start + (((end - start) - 1) // 4) * 4 + 1 + start_end_list.append((start, end)) + if is_last: + break + return start_end_list + + +def load_image(image: Union[str, Image.Image], to_rgb: bool = True) -> Image.Image: + _image = image + if isinstance(image, str): + if os.path.isfile(image): + _image = Image.open(image) + else: + raise ValueError(f"Incorrect path. {image} is not a valid path.") + # orientation transpose + _image = ImageOps.exif_transpose(_image) + # convert color space to sRGB + icc_profile = _image.info.get("icc_profile") + if icc_profile: + srgb_profile = ImageCms.createProfile("sRGB") + input_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc_profile)) + _image = ImageCms.profileToProfile(_image, input_profile, srgb_profile) + # convert to "RGB" + if to_rgb: + _image = _image.convert("RGB") + + return _image + + +@RUNNER_REGISTER("seko_talk") +class WanAudioRunner(WanRunner): # type:ignore + def __init__(self, config): + super().__init__(config) + self.prev_frame_length = self.config.get("prev_frame_length", 5) + self.frame_preprocessor = FramePreprocessorTorchVersion() + + def init_scheduler(self): + """Initialize consistency model scheduler""" + self.scheduler = EulerScheduler(self.config) + + def read_audio_input(self, audio_path): + """Read audio input - handles both single and multi-person scenarios""" + audio_sr = self.config.get("audio_sr", 16000) + target_fps = self.config.get("target_fps", 16) + self._audio_processor = AudioProcessor(audio_sr, target_fps) + + if not isinstance(audio_path, str): + return [], 0, None, 0 + + # Get audio files from person objects or legacy format + audio_files, mask_files = self.get_audio_files_from_audio_path(audio_path) + + # Load audio based on single or multi-person mode + if len(audio_files) == 1: + audio_array = self._audio_processor.load_audio(audio_files[0]) + audio_array = audio_array.unsqueeze(0) # Add batch dimension for consistency + else: + audio_array = self._audio_processor.load_multi_person_audio(audio_files) + + video_duration = self.config.get("video_duration", 5) + audio_len = int(audio_array.shape[1] / audio_sr * target_fps) + if GET_RECORDER_MODE(): + monitor_cli.lightx2v_input_audio_len.observe(audio_len) + + expected_frames = min(max(1, int(video_duration * target_fps)), audio_len) + if expected_frames < int(video_duration * target_fps): + logger.warning(f"Input video duration is greater than actual audio duration, using audio duration instead: audio_duration={audio_len / target_fps}, video_duration={video_duration}") + + # Segment audio + audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81), self.prev_frame_length) + + # Mask latent for multi-person s2v + if mask_files is not None: + mask_latents = [self.process_single_mask(mask_file) for mask_file in mask_files] + mask_latents = torch.cat(mask_latents, dim=0) + else: + mask_latents = None + + return audio_segments, expected_frames, mask_latents, len(audio_files) + + def get_audio_files_from_audio_path(self, audio_path): + if os.path.isdir(audio_path): + audio_files = [] + mask_files = [] + logger.info(f"audio_path is a directory, loading config.json from {audio_path}") + audio_config_path = os.path.join(audio_path, "config.json") + assert os.path.exists(audio_config_path), "config.json not found in audio_path" + with open(audio_config_path, "r") as f: + audio_config = json.load(f) + for talk_object in audio_config["talk_objects"]: + audio_files.append(os.path.join(audio_path, talk_object["audio"])) + mask_files.append(os.path.join(audio_path, talk_object["mask"])) + else: + logger.info(f"audio_path is a file without mask: {audio_path}") + audio_files = [audio_path] + mask_files = None + + return audio_files, mask_files + + def process_single_mask(self, mask_file): + mask_img = load_image(mask_file) + mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE) + + if mask_img.shape[1] == 3: # If it is an RGB three-channel image + mask_img = mask_img[:, :1] # Only take the first channel + + mask_img, h, w = resize_image( + mask_img, + resize_mode=self.config.get("resize_mode", "adaptive"), + bucket_shape=self.config.get("bucket_shape", None), + fixed_area=self.config.get("fixed_area", None), + fixed_shape=self.config.get("fixed_shape", None), + ) + + mask_latent = torch.nn.functional.interpolate( + mask_img, # (1, 1, H, W) + size=(h // 16, w // 16), + mode="bicubic", + ) + + mask_latent = (mask_latent > 0).to(torch.int8) + return mask_latent + + def read_image_input(self, img_path): + if isinstance(img_path, Image.Image): + ref_img = img_path + else: + ref_img = load_image(img_path) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE) + + ref_img, h, w = resize_image( + ref_img, + resize_mode=self.config.get("resize_mode", "adaptive"), + bucket_shape=self.config.get("bucket_shape", None), + fixed_area=self.config.get("fixed_area", None), + fixed_shape=self.config.get("fixed_shape", None), + ) + logger.info(f"[wan_audio] resize_image target_h: {h}, target_w: {w}") + patched_h = h // self.config["vae_stride"][1] // self.config["patch_size"][1] + patched_w = w // self.config["vae_stride"][2] // self.config["patch_size"][2] + + patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1) + + latent_h = patched_h * self.config["patch_size"][1] + latent_w = patched_w * self.config["patch_size"][2] + + latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) + target_shape = [latent_h * self.config["vae_stride"][1], latent_w * self.config["vae_stride"][2]] + + logger.info(f"[wan_audio] target_h: {target_shape[0]}, target_w: {target_shape[1]}, latent_h: {latent_h}, latent_w: {latent_w}") + + ref_img = torch.nn.functional.interpolate(ref_img, size=(target_shape[0], target_shape[1]), mode="bicubic") + return ref_img, latent_shape, target_shape + + @ProfilingContext4DebugL1( + "Run Image Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_img_encode_duration, + metrics_labels=["WanAudioRunner"], + ) + def run_image_encoder(self, first_frame, last_frame=None): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.image_encoder = self.load_image_encoder() + clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.image_encoder + torch.cuda.empty_cache() + gc.collect() + return clip_encoder_out + + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["WanAudioRunner"], + ) + def run_vae_encoder(self, img): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.vae_encoder = self.load_vae_encoder() + + img = rearrange(img, "1 C H W -> 1 C 1 H W") + vae_encoder_out = self.vae_encoder.encode(img.to(GET_DTYPE())) + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.vae_encoder + torch.cuda.empty_cache() + gc.collect() + return vae_encoder_out + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_s2v(self): + img, latent_shape, target_shape = self.read_image_input(self.input_info.image_path) + if self.config.get("f2v_process", False): + self.ref_img = img + self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info + self.input_info.target_shape = target_shape # Important: set target_shape in input_info + clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None + vae_encode_out = self.run_vae_encoder(img) + + audio_segments, expected_frames, person_mask_latens, audio_num = self.read_audio_input(self.input_info.audio_path) + self.input_info.audio_num = audio_num + self.input_info.with_mask = person_mask_latens is not None + text_encoder_output = self.run_text_encoder(self.input_info) + torch.cuda.empty_cache() + gc.collect() + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": { + "clip_encoder_out": clip_encoder_out, + "vae_encoder_out": vae_encode_out, + }, + "audio_segments": audio_segments, + "expected_frames": expected_frames, + "person_mask_latens": person_mask_latens, + } + + def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]: + """Prepare previous latents for conditioning""" + dtype = GET_DTYPE() + + tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1] + prev_frames = torch.zeros((1, 3, self.config["target_video_length"], tgt_h, tgt_w), device=AI_DEVICE) + + if prev_video is not None: + # Extract and process last frames + last_frames = prev_video[:, :, -prev_frame_length:].clone().to(AI_DEVICE) + if self.config["model_cls"] != "wan2.2_audio" and not self.config.get("f2v_process", False): + last_frames = self.frame_preprocessor.process_prev_frames(last_frames) + prev_frames[:, :, :prev_frame_length] = last_frames + prev_len = (prev_frame_length - 1) // 4 + 1 + else: + prev_len = 0 + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.vae_encoder = self.load_vae_encoder() + + _, nframe, height, width = self.model.scheduler.latents.shape + with ProfilingContext4DebugL1( + "vae_encoder in init run segment", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_pre_latent_duration, + metrics_labels=["WanAudioRunner"], + ): + if self.config["model_cls"] == "wan2.2_audio": + if prev_video is not None: + prev_latents = self.vae_encoder.encode(prev_frames.to(dtype)) + else: + prev_latents = None + prev_mask = self.model.scheduler.mask + else: + prev_latents = self.vae_encoder.encode(prev_frames.to(dtype)) + + frames_n = (nframe - 1) * 4 + 1 + prev_mask = torch.ones((1, frames_n, height, width), device=AI_DEVICE, dtype=dtype) + prev_frame_len = max((prev_len - 1) * 4 + 1, 0) + prev_mask[:, prev_frame_len:] = 0 + prev_mask = self._wan_mask_rearrange(prev_mask) + + if prev_latents is not None: + if prev_latents.shape[-2:] != (height, width): + logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={tgt_h}, tgt_w={tgt_w}") + prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False) + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.vae_encoder + torch.cuda.empty_cache() + gc.collect() + + return {"prev_latents": prev_latents, "prev_mask": prev_mask, "prev_len": prev_len} + + def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor: + """Rearrange mask for WAN model""" + if mask.ndim == 3: + mask = mask[None] + assert mask.ndim == 4 + _, t, h, w = mask.shape + assert t == ((t - 1) // 4 * 4 + 1) + mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) + mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1) + mask = mask.view(mask.shape[1] // 4, 4, h, w) + return mask.transpose(0, 1).contiguous() + + def get_video_segment_num(self): + self.video_segment_num = len(self.inputs["audio_segments"]) + + def init_run(self): + super().init_run() + self.scheduler.set_audio_adapter(self.audio_adapter) + if self.config.get("f2v_process", False): + self.prev_video = self.ref_img.unsqueeze(2) + else: + self.prev_video = None + if self.input_info.return_result_tensor: + self.gen_video_final = torch.zeros((self.inputs["expected_frames"], self.input_info.target_shape[0], self.input_info.target_shape[1], 3), dtype=torch.float32, device="cpu") + self.cut_audio_final = torch.zeros((self.inputs["expected_frames"] * self._audio_processor.audio_frame_rate), dtype=torch.float32, device="cpu") + else: + self.gen_video_final = None + self.cut_audio_final = None + + @ProfilingContext4DebugL1( + "Init run segment", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_init_run_segment_duration, + metrics_labels=["WanAudioRunner"], + ) + def init_run_segment(self, segment_idx, audio_array=None): + self.segment_idx = segment_idx + if audio_array is not None: + end_idx = audio_array.shape[0] // self._audio_processor.audio_frame_rate - self.prev_frame_length + audio_tensor = torch.Tensor(audio_array).float().unsqueeze(0) + self.segment = AudioSegment(audio_tensor, 0, end_idx) + else: + self.segment = self.inputs["audio_segments"][segment_idx] + + self.input_info.seed = self.input_info.seed + segment_idx + torch.manual_seed(self.input_info.seed) + # logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}") + + if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and not hasattr(self, "audio_encoder"): + self.audio_encoder = self.load_audio_encoder() + + features_list = [] + for i in range(self.segment.audio_array.shape[0]): + feat = self.audio_encoder.infer(self.segment.audio_array[i]) + feat = self.audio_adapter.forward_audio_proj(feat, self.model.scheduler.latents.shape[1]) + features_list.append(feat.squeeze(0)) + audio_features = torch.stack(features_list, dim=0) + + self.inputs["audio_encoder_output"] = audio_features + self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=self.prev_frame_length) + + # Reset scheduler for non-first segments + if segment_idx > 0: + self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape, self.inputs["previmg_encoder_output"]) + + @ProfilingContext4DebugL1( + "End run segment", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_end_run_segment_duration, + metrics_labels=["WanAudioRunner"], + ) + def end_run_segment(self, segment_idx): + self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float) + useful_length = self.segment.end_frame - self.segment.start_frame + video_seg = self.gen_video[:, :, :useful_length].cpu() + audio_seg = self.segment.audio_array[:, : useful_length * self._audio_processor.audio_frame_rate] + audio_seg = audio_seg.sum(dim=0) # Multiple audio tracks, mixed into one track + video_seg = vae_to_comfyui_image_inplace(video_seg) + + # [Warning] Need check whether video segment interpolation works... + if "video_frame_interpolation" in self.config and self.vfi_model is not None: + target_fps = self.config["video_frame_interpolation"]["target_fps"] + logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}") + video_seg = self.vfi_model.interpolate_frames( + video_seg, + source_fps=self.config.get("fps", 16), + target_fps=target_fps, + ) + + if "video_super_resolution" in self.config and self.vsr_model is not None: + # logger.info(f"Applying video super resolution with scale {self.config['video_super_resolution']['scale']}") + video_seg = self.vsr_model.super_resolve_frames( + video_seg, + seed=self.config["video_super_resolution"]["seed"], + scale=self.config["video_super_resolution"]["scale"], + ) + + if self.va_controller.recorder is not None: + self.va_controller.pub_livestream(video_seg, audio_seg, self.gen_video[:, :, :useful_length]) + elif self.input_info.return_result_tensor: + self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg) + self.cut_audio_final[self.segment.start_frame * self._audio_processor.audio_frame_rate : self.segment.end_frame * self._audio_processor.audio_frame_rate].copy_(audio_seg) + + # Update prev_video for next iteration + self.prev_video = self.gen_video + + del video_seg, audio_seg + torch.cuda.empty_cache() + + @ProfilingContext4DebugL1( + "End run segment stream", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_end_run_segment_duration, + metrics_labels=["WanAudioRunner"], + ) + def end_run_segment_stream(self, latents): + valid_length = self.segment.end_frame - self.segment.start_frame + frame_segments = [] + frame_idx = 0 + + # frame_segment: 1*C*1*H*W, 1*C*4*H*W, 1*C*4*H*W, ... + for origin_seg in self.run_vae_decoder_stream(latents): + origin_seg = torch.clamp(origin_seg, -1, 1).to(torch.float) + valid_T = min(valid_length - frame_idx, origin_seg.shape[2]) + + video_seg = vae_to_comfyui_image_inplace(origin_seg[:, :, :valid_T].cpu()) + audio_start = frame_idx * self._audio_processor.audio_frame_rate + audio_end = (frame_idx + valid_T) * self._audio_processor.audio_frame_rate + audio_seg = self.segment.audio_array[:, audio_start:audio_end].sum(dim=0) + + if self.va_controller.recorder is not None: + self.va_controller.pub_livestream(video_seg, audio_seg, origin_seg[:, :, :valid_T]) + + frame_segments.append(origin_seg) + frame_idx += valid_T + del video_seg, audio_seg + + # Update prev_video for next iteration + self.prev_video = torch.cat(frame_segments, dim=2) + torch.cuda.empty_cache() + + def run_main(self): + try: + self.va_controller = None + self.va_controller = VAController(self) + logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}") + + # fixed audio segments inputs + if self.va_controller.reader is None: + return super().run_main() + + self.va_controller.start() + self.init_run() + if self.config.get("compile", False) and hasattr(self.model, "comple"): + self.model.select_graph_for_compile(self.input_info) + # steam audio input, video segment num is unlimited + self.video_segment_num = 1000000 + segment_idx = 0 + fail_count, max_fail_count = 0, 10 + self.va_controller.before_control() + + while True: + with ProfilingContext4DebugL1(f"stream segment get audio segment {segment_idx}"): + control = self.va_controller.next_control() + if control.action == "immediate": + self.prev_video = control.data + elif control.action == "wait": + time.sleep(0.01) + continue + + audio_array = self.va_controller.reader.get_audio_segment() + if audio_array is None: + fail_count += 1 + logger.warning(f"Failed to get audio chunk {fail_count} times") + if fail_count > max_fail_count: + raise Exception(f"Failed to get audio chunk {fail_count} times, stop reader") + continue + + with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"): + try: + # reset pause signal + self.pause_signal = False + self.init_run_segment(segment_idx, audio_array) + self.check_stop() + latents = self.run_segment(segment_idx) + self.check_stop() + if self.config.get("use_stream_vae", False): + self.end_run_segment_stream(latents) + else: + self.gen_video = self.run_vae_decoder(latents) + self.check_stop() + self.end_run_segment(segment_idx) + segment_idx += 1 + fail_count = 0 + except Exception as e: + if "pause_signal, pause running" in str(e): + logger.warning(f"model infer audio pause: {e}, should continue") + else: + raise + finally: + if hasattr(self.model, "inputs"): + self.end_run() + if self.va_controller is not None: + self.va_controller.clear() + self.va_controller = None + + @ProfilingContext4DebugL1("Process after vae decoder") + def process_images_after_vae_decoder(self): + if self.input_info.return_result_tensor: + audio_waveform = self.cut_audio_final.unsqueeze(0).unsqueeze(0) + comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr} + return {"video": self.gen_video_final, "audio": comfyui_audio} + return {"video": None, "audio": None} + + def load_transformer(self): + """Load transformer with LoRA support""" + base_model = WanAudioModel(self.config["model_path"], self.config, self.init_device) + if self.config.get("lora_configs") and self.config["lora_configs"]: + assert not self.config.get("dit_quantized", False) + lora_wrapper = WanLoraWrapper(base_model) + for lora_config in self.config["lora_configs"]: + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + lora_name = lora_wrapper.load_lora(lora_path) + lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") + + return base_model + + def load_audio_encoder(self): + audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large")) + audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False)) + model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload) + return model + + def load_audio_adapter(self): + audio_adapter_offload = self.config.get("audio_adapter_cpu_offload", self.config.get("cpu_offload", False)) + if audio_adapter_offload: + device = torch.device("cpu") + else: + device = torch.device(AI_DEVICE) + audio_adapter = AudioAdapter( + attention_head_dim=self.config["dim"] // self.config["num_heads"], + num_attention_heads=self.config["num_heads"], + base_num_layers=self.config["num_layers"], + interval=1, + audio_feature_dim=1024, + time_freq_dim=256, + projection_transformer_layers=4, + mlp_dims=(1024, 1024, 32 * 1024), + quantized=self.config.get("adapter_quantized", False), + quant_scheme=self.config.get("adapter_quant_scheme", None), + cpu_offload=audio_adapter_offload, + ) + + audio_adapter.to(device) + load_from_rank0 = self.config.get("load_from_rank0", False) + weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=audio_adapter_offload, remove_key="ca", load_from_rank0=load_from_rank0) + audio_adapter.load_state_dict(weights_dict, strict=False) + return audio_adapter.to(dtype=GET_DTYPE()) + + def load_model(self): + super().load_model() + with ProfilingContext4DebugL2("Load audio encoder and adapter"): + self.audio_encoder = self.load_audio_encoder() + self.audio_adapter = self.load_audio_adapter() + + def get_latent_shape_with_lat_hw(self, latent_h, latent_w): + latent_shape = [ + self.config.get("num_channels_latents", 16), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + return latent_shape + + +@RUNNER_REGISTER("wan2.2_audio") +class Wan22AudioRunner(WanAudioRunner): + def __init__(self, config): + super().__init__(config) + + def load_vae_decoder(self): + # offload config + vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + vae_config = { + "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"), + "device": vae_device, + "cpu_offload": vae_offload, + "offload_cache": self.config.get("vae_offload_cache", False), + } + vae_decoder = Wan2_2_VAE(**vae_config) + return vae_decoder + + def load_vae_encoder(self): + # offload config + vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + vae_config = { + "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"), + "device": vae_device, + "cpu_offload": vae_offload, + "offload_cache": self.config.get("vae_offload_cache", False), + } + if self.config.task not in ["i2v", "s2v"]: + return None + else: + return Wan2_2_VAE(**vae_config) + + def load_vae(self): + vae_encoder = self.load_vae_encoder() + vae_decoder = self.load_vae_decoder() + return vae_encoder, vae_decoder diff --git a/lightx2v/models/runners/wan/wan_distill_runner.py b/lightx2v/models/runners/wan/wan_distill_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..1f55754d0bf1bb4391f4ab98fd0998ad3606eae2 --- /dev/null +++ b/lightx2v/models/runners/wan/wan_distill_runner.py @@ -0,0 +1,200 @@ +import os + +from loguru import logger + +from lightx2v.models.networks.wan.distill_model import WanDistillModel +from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner +from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER + + +@RUNNER_REGISTER("wan2.1_distill") +class WanDistillRunner(WanRunner): + def __init__(self, config): + super().__init__(config) + + def load_transformer(self): + if self.config.get("lora_configs") and self.config["lora_configs"]: + model = WanModel( + self.config["model_path"], + self.config, + self.init_device, + ) + lora_wrapper = WanLoraWrapper(model) + for lora_config in self.config["lora_configs"]: + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + lora_name = lora_wrapper.load_lora(lora_path) + lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") + else: + model = WanDistillModel(self.config["model_path"], self.config, self.init_device) + return model + + def init_scheduler(self): + if self.config["feature_caching"] == "NoCaching": + self.scheduler = WanStepDistillScheduler(self.config) + else: + raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") + + +class MultiDistillModelStruct(MultiModelStruct): + def __init__(self, model_list, config, boundary_step_index=2): + self.model = model_list # [high_noise_model, low_noise_model] + assert len(self.model) == 2, "MultiModelStruct only supports 2 models now." + self.config = config + self.boundary_step_index = boundary_step_index + self.cur_model_index = -1 + logger.info(f"boundary step index: {self.boundary_step_index}") + + @ProfilingContext4DebugL2("Swtich models in infer_main costs") + def get_current_model_index(self): + if self.scheduler.step_index < self.boundary_step_index: + logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}") + # self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0] + if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model": + if self.cur_model_index == -1: + self.to_cuda(model_index=0) + elif self.cur_model_index == 1: # 1 -> 0 + self.offload_cpu(model_index=1) + self.to_cuda(model_index=0) + self.cur_model_index = 0 + else: + logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}") + # self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1] + if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model": + if self.cur_model_index == -1: + self.to_cuda(model_index=1) + elif self.cur_model_index == 0: # 0 -> 1 + self.offload_cpu(model_index=0) + self.to_cuda(model_index=1) + self.cur_model_index = 1 + + def infer(self, inputs): + self.get_current_model_index() + if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): + self.model[self.cur_model_index].infer(inputs) + else: + if self.model[self.cur_model_index] is not None: + self.model[self.cur_model_index].infer(inputs) + else: + if self.cur_model_index == 0: + high_noise_model = WanDistillModel( + self.high_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_high_noise", + ) + high_noise_model.set_scheduler(self.scheduler) + self.model[0] = high_noise_model + self.model[0].infer(inputs) + elif self.cur_model_index == 1: + low_noise_model = WanDistillModel( + self.low_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_low_noise", + ) + low_noise_model.set_scheduler(self.scheduler) + self.model[1] = low_noise_model + self.model[1].infer(inputs) + + +@RUNNER_REGISTER("wan2.2_moe_distill") +class Wan22MoeDistillRunner(WanDistillRunner): + def __init__(self, config): + super().__init__(config) + if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None): + self.high_noise_model_path = self.config["high_noise_quantized_ckpt"] + elif self.config.get("high_noise_original_ckpt", None): + self.high_noise_model_path = self.config["high_noise_original_ckpt"] + else: + self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model") + if not os.path.isdir(self.high_noise_model_path): + self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model") + if not os.path.isdir(self.high_noise_model_path): + raise FileNotFoundError(f"High Noise Model does not find") + + if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None): + self.low_noise_model_path = self.config["low_noise_quantized_ckpt"] + elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None): + self.low_noise_model_path = self.config["low_noise_original_ckpt"] + else: + self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model") + if not os.path.isdir(self.low_noise_model_path): + self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model") + if not os.path.isdir(self.high_noise_model_path): + raise FileNotFoundError(f"Low Noise Model does not find") + + def load_transformer(self): + if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): + use_high_lora, use_low_lora = False, False + if self.config.get("lora_configs") and self.config["lora_configs"]: + for lora_config in self.config["lora_configs"]: + if lora_config.get("name", "") == "high_noise_model": + use_high_lora = True + elif lora_config.get("name", "") == "low_noise_model": + use_low_lora = True + + if use_high_lora: + high_noise_model = WanModel( + self.high_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_high_noise", + ) + high_lora_wrapper = WanLoraWrapper(high_noise_model) + for lora_config in self.config["lora_configs"]: + if lora_config.get("name", "") == "high_noise_model": + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + lora_name = high_lora_wrapper.load_lora(lora_path) + high_lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}") + else: + high_noise_model = WanDistillModel( + self.high_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_high_noise", + ) + + if use_low_lora: + low_noise_model = WanModel( + self.low_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_low_noise", + ) + low_lora_wrapper = WanLoraWrapper(low_noise_model) + for lora_config in self.config["lora_configs"]: + if lora_config.get("name", "") == "low_noise_model": + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + lora_name = low_lora_wrapper.load_lora(lora_path) + low_lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}") + else: + low_noise_model = WanDistillModel( + self.low_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_low_noise", + ) + + return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"]) + else: + model_struct = MultiDistillModelStruct([None, None], self.config, self.config["boundary_step_index"]) + model_struct.low_noise_model_path = self.low_noise_model_path + model_struct.high_noise_model_path = self.high_noise_model_path + model_struct.init_device = self.init_device + return model_struct + + def init_scheduler(self): + if self.config["feature_caching"] == "NoCaching": + self.scheduler = Wan22StepDistillScheduler(self.config) + else: + raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") diff --git a/lightx2v/models/runners/wan/wan_matrix_game2_runner.py b/lightx2v/models/runners/wan/wan_matrix_game2_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..9aed80f8d9951f879ec04ee960691058ff63f9bd --- /dev/null +++ b/lightx2v/models/runners/wan/wan_matrix_game2_runner.py @@ -0,0 +1,327 @@ +import os + +import torch +from diffusers.utils.loading_utils import load_image +from torchvision.transforms import v2 + +from lightx2v.models.input_encoders.hf.wan.matrix_game2.clip import CLIPModel +from lightx2v.models.input_encoders.hf.wan.matrix_game2.conditions import Bench_actions_gta_drive, Bench_actions_templerun, Bench_actions_universal +from lightx2v.models.networks.wan.matrix_game2_model import WanSFMtxg2Model +from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner +from lightx2v.models.video_encoders.hf.wan.vae_sf import WanMtxg2VAE +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import * +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + + +class VAEWrapper: + def __init__(self, vae): + self.vae = vae + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + else: + return getattr(self.vae, name) + + def encode(self, x): + raise NotImplementedError + + def decode(self, latents): + return NotImplementedError + + +class WanxVAEWrapper(VAEWrapper): + def __init__(self, vae, clip): + self.vae = vae + self.vae.requires_grad_(False) + self.vae.eval() + self.clip = clip + if clip is not None: + self.clip.requires_grad_(False) + self.clip.eval() + + def encode(self, x, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + x = self.vae.encode(x, device=device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) # already scaled + return x # torch.stack(x, dim=0) + + def clip_img(self, x): + x = self.clip(x) + return x + + def decode(self, latents, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + videos = self.vae.decode(latents, device=device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return videos # self.vae.decode(videos, dim=0) # already scaled + + def to(self, device, dtype): + # 移动 vae 到指定设备 + self.vae = self.vae.to(device, dtype) + + # 如果 clip 存在,也移动到指定设备 + if self.clip is not None: + self.clip = self.clip.to(device, dtype) + + return self + + +def get_wanx_vae_wrapper(model_path, weight_dtype): + vae = WanMtxg2VAE(pretrained_path=os.path.join(model_path, "Wan2.1_VAE.pth")).to(weight_dtype) + clip = CLIPModel(checkpoint_path=os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), tokenizer_path=os.path.join(model_path, "xlm-roberta-large")) + return WanxVAEWrapper(vae, clip) + + +def get_current_action(mode="universal"): + CAM_VALUE = 0.1 + if mode == "universal": + print() + print("-" * 30) + print("PRESS [I, K, J, L, U] FOR CAMERA TRANSFORM\n (I: up, K: down, J: left, L: right, U: no move)") + print("PRESS [W, S, A, D, Q] FOR MOVEMENT\n (W: forward, S: back, A: left, D: right, Q: no move)") + print("-" * 30) + CAMERA_VALUE_MAP = {"i": [CAM_VALUE, 0], "k": [-CAM_VALUE, 0], "j": [0, -CAM_VALUE], "l": [0, CAM_VALUE], "u": [0, 0]} + KEYBOARD_IDX = {"w": [1, 0, 0, 0], "s": [0, 1, 0, 0], "a": [0, 0, 1, 0], "d": [0, 0, 0, 1], "q": [0, 0, 0, 0]} + flag = 0 + while flag != 1: + try: + idx_mouse = input("Please input the mouse action (e.g. `U`):\n").strip().lower() + idx_keyboard = input("Please input the keyboard action (e.g. `W`):\n").strip().lower() + if idx_mouse in CAMERA_VALUE_MAP.keys() and idx_keyboard in KEYBOARD_IDX.keys(): + flag = 1 + except Exception as e: + pass + mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse]).to(AI_DEVICE) + keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).to(AI_DEVICE) + elif mode == "gta_drive": + print() + print("-" * 30) + print("PRESS [W, S, A, D, Q] FOR MOVEMENT\n (W: forward, S: back, A: left, D: right, Q: no move)") + print("-" * 30) + CAMERA_VALUE_MAP = {"a": [0, -CAM_VALUE], "d": [0, CAM_VALUE], "q": [0, 0]} + KEYBOARD_IDX = {"w": [1, 0], "s": [0, 1], "q": [0, 0]} + flag = 0 + while flag != 1: + try: + indexes = input("Please input the actions (split with ` `):\n(e.g. `W` for forward, `W A` for forward and left)\n").strip().lower().split(" ") + idx_mouse = [] + idx_keyboard = [] + for i in indexes: + if i in CAMERA_VALUE_MAP.keys(): + idx_mouse += [i] + elif i in KEYBOARD_IDX.keys(): + idx_keyboard += [i] + if len(idx_mouse) == 0: + idx_mouse += ["q"] + if len(idx_keyboard) == 0: + idx_keyboard += ["q"] + assert idx_mouse in [["a"], ["d"], ["q"]] and idx_keyboard in [["q"], ["w"], ["s"]] + flag = 1 + except Exception as e: + pass + mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse[0]]).to(AI_DEVICE) + keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard[0]]).to(AI_DEVICE) + elif mode == "templerun": + print() + print("-" * 30) + print("PRESS [W, S, A, D, Z, C, Q] FOR ACTIONS\n (W: jump, S: slide, A: left side, D: right side, Z: turn left, C: turn right, Q: no move)") + print("-" * 30) + KEYBOARD_IDX = { + "w": [0, 1, 0, 0, 0, 0, 0], + "s": [0, 0, 1, 0, 0, 0, 0], + "a": [0, 0, 0, 0, 0, 1, 0], + "d": [0, 0, 0, 0, 0, 0, 1], + "z": [0, 0, 0, 1, 0, 0, 0], + "c": [0, 0, 0, 0, 1, 0, 0], + "q": [1, 0, 0, 0, 0, 0, 0], + } + flag = 0 + while flag != 1: + try: + idx_keyboard = input("Please input the action: \n(e.g. `W` for forward, `Z` for turning left)\n").strip().lower() + if idx_keyboard in KEYBOARD_IDX.keys(): + flag = 1 + except Exception as e: + pass + keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).to(AI_DEVICE) + + if mode != "templerun": + return {"mouse": mouse_cond, "keyboard": keyboard_cond} + return {"keyboard": keyboard_cond} + + +@RUNNER_REGISTER("wan2.1_sf_mtxg2") +class WanSFMtxg2Runner(WanSFRunner): + def __init__(self, config): + super().__init__(config) + self.frame_process = v2.Compose( + [ + v2.Resize(size=(352, 640), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + self.device = torch.device("cuda") + self.weight_dtype = torch.bfloat16 + + def load_text_encoder(self): + from lightx2v.models.input_encoders.hf.wan.matrix_game2.conditions import MatrixGame2_Bench + + return MatrixGame2_Bench() + + def load_image_encoder(self): + wrapper = get_wanx_vae_wrapper(self.config["model_path"], torch.float16) + wrapper.requires_grad_(False) + wrapper.eval() + return wrapper.to(self.device, self.weight_dtype) + + def _resizecrop(self, image, th, tw): + w, h = image.size + if h / w > th / tw: + new_w = int(w) + new_h = int(new_w * th / tw) + else: + new_h = int(h) + new_w = int(new_h * tw / th) + left = (w - new_w) / 2 + top = (h - new_h) / 2 + right = (w + new_w) / 2 + bottom = (h + new_h) / 2 + image = image.crop((left, top, right, bottom)) + return image + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_i2v(self): + # image + image = load_image(self.input_info.image_path) + image = self._resizecrop(image, 352, 640) + image = self.frame_process(image)[None, :, None, :, :].to(dtype=self.weight_dtype, device=self.device) + padding_video = torch.zeros_like(image).repeat(1, 1, 4 * (self.config["num_output_frames"] - 1), 1, 1) + img_cond = torch.concat([image, padding_video], dim=2) + tiler_kwargs = {"tiled": True, "tile_size": [44, 80], "tile_stride": [23, 38]} + img_cond = self.image_encoder.encode(img_cond, device=self.device, **tiler_kwargs).to(self.device) + mask_cond = torch.ones_like(img_cond) + mask_cond[:, :, 1:] = 0 + cond_concat = torch.cat([mask_cond[:, :4], img_cond], dim=1) + visual_context = self.image_encoder.clip.encode_video(image) + image_encoder_output = {"cond_concat": cond_concat.to(device=self.device, dtype=self.weight_dtype), "visual_context": visual_context.to(device=self.device, dtype=self.weight_dtype)} + + # text + text_encoder_output = {} + num_frames = (self.config["num_output_frames"] - 1) * 4 + 1 + if self.config["mode"] == "universal": + cond_data = Bench_actions_universal(num_frames) + mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype) + text_encoder_output["mouse_cond"] = mouse_condition + elif self.config["mode"] == "gta_drive": + cond_data = Bench_actions_gta_drive(num_frames) + mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype) + text_encoder_output["mouse_cond"] = mouse_condition + else: + cond_data = Bench_actions_templerun(num_frames) + keyboard_condition = cond_data["keyboard_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype) + text_encoder_output["keyboard_cond"] = keyboard_condition + + # set shape + self.input_info.latent_shape = [16, self.config["num_output_frames"], 44, 80] + + return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output} + + def load_transformer(self): + model = WanSFMtxg2Model( + self.config["model_path"], + self.config, + self.init_device, + ) + return model + + def init_run_segment(self, segment_idx): + self.segment_idx = segment_idx + + if self.config["streaming"]: + self.inputs["current_actions"] = get_current_action(mode=self.config["mode"]) + + @ProfilingContext4DebugL2("Run DiT") + def run_main(self): + self.init_run() + if self.config.get("compile", False): + self.model.select_graph_for_compile(self.input_info) + + stop = "" + while stop != "n": + for segment_idx in range(self.video_segment_num): + logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}") + with ProfilingContext4DebugL1( + f"segment end2end {segment_idx + 1}/{self.video_segment_num}", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, + metrics_labels=["DefaultRunner"], + ): + self.check_stop() + # 1. default do nothing + self.init_run_segment(segment_idx) + # 2. main inference loop + latents = self.run_segment(segment_idx=segment_idx) + # 3. vae decoder + self.gen_video = self.run_vae_decoder(latents) + # 4. default do nothing + self.end_run_segment(segment_idx) + + # 5. stop or not + if self.config["streaming"]: + stop = input("Press `n` to stop generation: ").strip().lower() + if stop == "n": + break + stop = "n" + + gen_video_final = self.process_images_after_vae_decoder() + self.end_run() + return gen_video_final + + @ProfilingContext4DebugL2("Run DiT") + def run_main_live(self, total_steps=None): + try: + self.init_video_recorder() + logger.info(f"init video_recorder: {self.video_recorder}") + rank, world_size = self.get_rank_and_world_size() + if rank == world_size - 1: + assert self.video_recorder is not None, "video_recorder is required for stream audio input for rank 2" + self.video_recorder.start(self.width, self.height) + if world_size > 1: + dist.barrier() + self.init_run() + if self.config.get("compile", False): + self.model.select_graph_for_compile(self.input_info) + + stop = "" + while stop != "n": + for segment_idx in range(self.video_segment_num): + logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}") + with ProfilingContext4DebugL1( + f"segment end2end {segment_idx + 1}/{self.video_segment_num}", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, + metrics_labels=["DefaultRunner"], + ): + self.check_stop() + # 1. default do nothing + self.init_run_segment(segment_idx) + # 2. main inference loop + latents = self.run_segment(segment_idx=segment_idx) + # 3. vae decoder + self.gen_video = self.run_vae_decoder(latents) + # 4. default do nothing + self.end_run_segment(segment_idx) + + # 5. stop or not + if self.config["streaming"]: + stop = input("Press `n` to stop generation: ").strip().lower() + if stop == "n": + break + stop = "n" + finally: + if hasattr(self.model, "inputs"): + self.end_run() + if self.video_recorder: + self.video_recorder.stop() + self.video_recorder = None diff --git a/lightx2v/models/runners/wan/wan_runner.py b/lightx2v/models/runners/wan/wan_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..490b4f04d84a5e485c7fe3cd9c611ce6bc1511d2 --- /dev/null +++ b/lightx2v/models/runners/wan/wan_runner.py @@ -0,0 +1,638 @@ +import gc +import os + +import numpy as np +import torch +import torch.distributed as dist +import torchvision.transforms.functional as TF +from PIL import Image +from loguru import logger + +from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel +from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel +from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.runners.default_runner import DefaultRunner +from lightx2v.models.schedulers.wan.changing_resolution.scheduler import ( + WanScheduler4ChangingResolutionInterface, +) +from lightx2v.models.schedulers.wan.feature_caching.scheduler import ( + WanSchedulerCaching, + WanSchedulerTaylorCaching, +) +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.models.video_encoders.hf.wan.vae import WanVAE +from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE +from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import * +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.utils import * +from lightx2v_platform.base.global_var import AI_DEVICE + + +@RUNNER_REGISTER("wan2.1") +class WanRunner(DefaultRunner): + def __init__(self, config): + super().__init__(config) + self.vae_cls = WanVAE + self.tiny_vae_cls = WanVAE_tiny + self.vae_name = config.get("vae_name", "Wan2.1_VAE.pth") + self.tiny_vae_name = "taew2_1.pth" + + def load_transformer(self): + model = WanModel( + self.config["model_path"], + self.config, + self.init_device, + ) + if self.config.get("lora_configs") and self.config.lora_configs: + assert not self.config.get("dit_quantized", False) + lora_wrapper = WanLoraWrapper(model) + for lora_config in self.config.lora_configs: + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + lora_name = lora_wrapper.load_lora(lora_path) + lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") + return model + + def load_image_encoder(self): + image_encoder = None + if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True): + # offload config + clip_offload = self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)) + if clip_offload: + clip_device = torch.device("cpu") + else: + clip_device = torch.device(AI_DEVICE) + # quant_config + clip_quantized = self.config.get("clip_quantized", False) + if clip_quantized: + clip_quant_scheme = self.config.get("clip_quant_scheme", None) + assert clip_quant_scheme is not None + tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0] + clip_model_name = f"models_clip_open-clip-xlm-roberta-large-vit-huge-14-{tmp_clip_quant_scheme}.pth" + clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name) + clip_original_ckpt = None + else: + clip_quantized_ckpt = None + clip_quant_scheme = None + clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" + clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name) + + image_encoder = CLIPModel( + dtype=torch.float16, + device=clip_device, + checkpoint_path=clip_original_ckpt, + clip_quantized=clip_quantized, + clip_quantized_ckpt=clip_quantized_ckpt, + quant_scheme=clip_quant_scheme, + cpu_offload=clip_offload, + use_31_block=self.config.get("use_31_block", True), + load_from_rank0=self.config.get("load_from_rank0", False), + ) + + return image_encoder + + def load_text_encoder(self): + # offload config + t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload")) + if t5_offload: + t5_device = torch.device("cpu") + else: + t5_device = torch.device(AI_DEVICE) + tokenizer_path = os.path.join(self.config["model_path"], "google/umt5-xxl") + # quant_config + t5_quantized = self.config.get("t5_quantized", False) + if t5_quantized: + t5_quant_scheme = self.config.get("t5_quant_scheme", None) + assert t5_quant_scheme is not None + tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0] + t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth" + t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name) + t5_original_ckpt = None + else: + t5_quant_scheme = None + t5_quantized_ckpt = None + t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth" + t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name) + + text_encoder = T5EncoderModel( + text_len=self.config["text_len"], + dtype=torch.bfloat16, + device=t5_device, + checkpoint_path=t5_original_ckpt, + tokenizer_path=tokenizer_path, + shard_fn=None, + cpu_offload=t5_offload, + t5_quantized=t5_quantized, + t5_quantized_ckpt=t5_quantized_ckpt, + quant_scheme=t5_quant_scheme, + load_from_rank0=self.config.get("load_from_rank0", False), + ) + text_encoders = [text_encoder] + return text_encoders + + def load_vae_encoder(self): + # offload config + vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + + vae_config = { + "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name), + "device": vae_device, + "parallel": self.config["parallel"], + "use_tiling": self.config.get("use_tiling_vae", False), + "cpu_offload": vae_offload, + "dtype": GET_DTYPE(), + "load_from_rank0": self.config.get("load_from_rank0", False), + "use_lightvae": self.config.get("use_lightvae", False), + } + if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]: + return None + else: + return self.vae_cls(**vae_config) + + def load_vae_decoder(self): + # offload config + vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload")) + if vae_offload: + vae_device = torch.device("cpu") + else: + vae_device = torch.device(AI_DEVICE) + + vae_config = { + "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name), + "device": vae_device, + "parallel": self.config["parallel"], + "use_tiling": self.config.get("use_tiling_vae", False), + "cpu_offload": vae_offload, + "use_lightvae": self.config.get("use_lightvae", False), + "dtype": GET_DTYPE(), + "load_from_rank0": self.config.get("load_from_rank0", False), + } + if self.config.get("use_tae", False): + tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name) + vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to(AI_DEVICE) + else: + vae_decoder = self.vae_cls(**vae_config) + return vae_decoder + + def load_vae(self): + vae_encoder = self.load_vae_encoder() + if vae_encoder is None or self.config.get("use_tae", False): + vae_decoder = self.load_vae_decoder() + else: + vae_decoder = vae_encoder + return vae_encoder, vae_decoder + + def init_scheduler(self): + if self.config["feature_caching"] == "NoCaching": + scheduler_class = WanScheduler + elif self.config["feature_caching"] == "TaylorSeer": + scheduler_class = WanSchedulerTaylorCaching + elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock", "Mag"]: + scheduler_class = WanSchedulerCaching + else: + raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") + + if self.config.get("changing_resolution", False): + self.scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config) + else: + self.scheduler = scheduler_class(self.config) + + @ProfilingContext4DebugL1( + "Run Text Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_text_encode_duration, + metrics_labels=["WanRunner"], + ) + def run_text_encoder(self, input_info): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.text_encoders = self.load_text_encoder() + + prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt + if GET_RECORDER_MODE(): + monitor_cli.lightx2v_input_prompt_len.observe(len(prompt)) + neg_prompt = input_info.negative_prompt + + if self.config.get("enable_cfg", False) and self.config["cfg_parallel"]: + cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p") + cfg_p_rank = dist.get_rank(cfg_p_group) + if cfg_p_rank == 0: + context = self.text_encoders[0].infer([prompt]) + context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context]) + text_encoder_output = {"context": context} + else: + context_null = self.text_encoders[0].infer([neg_prompt]) + context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null]) + text_encoder_output = {"context_null": context_null} + else: + context = self.text_encoders[0].infer([prompt]) + context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context]) + if self.config.get("enable_cfg", False): + context_null = self.text_encoders[0].infer([neg_prompt]) + context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null]) + else: + context_null = None + text_encoder_output = { + "context": context, + "context_null": context_null, + } + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.text_encoders[0] + torch.cuda.empty_cache() + gc.collect() + + return text_encoder_output + + @ProfilingContext4DebugL1( + "Run Image Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_img_encode_duration, + metrics_labels=["WanRunner"], + ) + def run_image_encoder(self, first_frame, last_frame=None): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.image_encoder = self.load_image_encoder() + if last_frame is None: + clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE()) + else: + clip_encoder_out = self.image_encoder.visual([first_frame, last_frame]).squeeze(0).to(GET_DTYPE()) + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.image_encoder + torch.cuda.empty_cache() + gc.collect() + return clip_encoder_out + + def _adjust_latent_for_grid_splitting(self, latent_h, latent_w, world_size): + """ + Adjust latent dimensions for optimal 2D grid splitting. + Prefers balanced grids like 2x4 or 4x2 over 1x8 or 8x1. + """ + world_size_h, world_size_w = 1, 1 + if world_size <= 1: + return latent_h, latent_w, world_size_h, world_size_w + + # Define priority grids for different world sizes + priority_grids = [] + if world_size == 8: + # For 8 cards, prefer 2x4 and 4x2 over 1x8 and 8x1 + priority_grids = [(2, 4), (4, 2), (1, 8), (8, 1)] + elif world_size == 4: + priority_grids = [(2, 2), (1, 4), (4, 1)] + elif world_size == 2: + priority_grids = [(1, 2), (2, 1)] + else: + # For other sizes, try factor pairs + for h in range(1, int(np.sqrt(world_size)) + 1): + if world_size % h == 0: + w = world_size // h + priority_grids.append((h, w)) + + # Try priority grids first + for world_size_h, world_size_w in priority_grids: + if latent_h % world_size_h == 0 and latent_w % world_size_w == 0: + return latent_h, latent_w, world_size_h, world_size_w + + # If no perfect fit, find minimal padding solution + best_grid = (1, world_size) # fallback + min_total_padding = float("inf") + + for world_size_h, world_size_w in priority_grids: + # Calculate required padding + pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h + pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w + total_padding = pad_h + pad_w + + # Prefer grids with minimal total padding + if total_padding < min_total_padding: + min_total_padding = total_padding + best_grid = (world_size_h, world_size_w) + + # Apply padding + world_size_h, world_size_w = best_grid + pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h + pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w + + return latent_h + pad_h, latent_w + pad_w, world_size_h, world_size_w + + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["WanRunner"], + ) + def run_vae_encoder(self, first_frame, last_frame=None): + h, w = first_frame.shape[2:] + aspect_ratio = h / w + max_area = self.config["target_height"] * self.config["target_width"] + + # Calculate initial latent dimensions + ori_latent_h = round(np.sqrt(max_area * aspect_ratio) // self.config["vae_stride"][1] // self.config["patch_size"][1] * self.config["patch_size"][1]) + ori_latent_w = round(np.sqrt(max_area / aspect_ratio) // self.config["vae_stride"][2] // self.config["patch_size"][2] * self.config["patch_size"][2]) + + # Adjust latent dimensions for optimal 2D grid splitting when using distributed processing + if dist.is_initialized() and dist.get_world_size() > 1: + latent_h, latent_w, world_size_h, world_size_w = self._adjust_latent_for_grid_splitting(ori_latent_h, ori_latent_w, dist.get_world_size()) + logger.info(f"ori latent: {ori_latent_h}x{ori_latent_w}, adjust_latent: {latent_h}x{latent_w}, grid: {world_size_h}x{world_size_w}") + else: + latent_h, latent_w = ori_latent_h, ori_latent_w + world_size_h, world_size_w = None, None + + latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) # Important: latent_shape is used to set the input_info + + if self.config.get("changing_resolution", False): + assert last_frame is None + vae_encode_out_list = [] + for i in range(len(self.config["resolution_rate"])): + latent_h_tmp, latent_w_tmp = ( + int(latent_h * self.config["resolution_rate"][i]) // 2 * 2, + int(latent_w * self.config["resolution_rate"][i]) // 2 * 2, + ) + vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h_tmp, latent_w_tmp, world_size_h=world_size_h, world_size_w=world_size_w)) + vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h, latent_w, world_size_h=world_size_h, world_size_w=world_size_w)) + return vae_encode_out_list, latent_shape + else: + if last_frame is not None: + first_frame_size = first_frame.shape[2:] + last_frame_size = last_frame.shape[2:] + if first_frame_size != last_frame_size: + last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1]) + last_frame_size = [ + round(last_frame_size[0] * last_frame_resize_ratio), + round(last_frame_size[1] * last_frame_resize_ratio), + ] + last_frame = TF.center_crop(last_frame, last_frame_size) + vae_encoder_out = self.get_vae_encoder_output(first_frame, latent_h, latent_w, last_frame, world_size_h=world_size_h, world_size_w=world_size_w) + return vae_encoder_out, latent_shape + + def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None, world_size_h=None, world_size_w=None): + h = lat_h * self.config["vae_stride"][1] + w = lat_w * self.config["vae_stride"][2] + msk = torch.ones( + 1, + self.config["target_video_length"], + lat_h, + lat_w, + device=torch.device(AI_DEVICE), + ) + if last_frame is not None: + msk[:, 1:-1] = 0 + else: + msk[:, 1:] = 0 + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.vae_encoder = self.load_vae_encoder() + + if last_frame is not None: + vae_input = torch.concat( + [ + torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, self.config["target_video_length"] - 2, h, w), + torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + ], + dim=1, + ).to(AI_DEVICE) + else: + vae_input = torch.concat( + [ + torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, self.config["target_video_length"] - 1, h, w), + ], + dim=1, + ).to(AI_DEVICE) + + vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()), world_size_h=world_size_h, world_size_w=world_size_w) + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.vae_encoder + torch.cuda.empty_cache() + gc.collect() + vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE()) + return vae_encoder_out + + def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img=None): + image_encoder_output = { + "clip_encoder_out": clip_encoder_out, + "vae_encoder_out": vae_encoder_out, + } + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output, + } + + def get_latent_shape_with_lat_hw(self, latent_h, latent_w): + latent_shape = [ + self.config.get("num_channels_latents", 16), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + latent_h, + latent_w, + ] + return latent_shape + + def get_latent_shape_with_target_hw(self): + latent_shape = [ + self.config.get("num_channels_latents", 16), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + int(self.config["target_height"]) // self.config["vae_stride"][1], + int(self.config["target_width"]) // self.config["vae_stride"][2], + ] + return latent_shape + + +class MultiModelStruct: + def __init__(self, model_list, config, boundary=0.875, num_train_timesteps=1000): + self.model = model_list # [high_noise_model, low_noise_model] + assert len(self.model) == 2, "MultiModelStruct only supports 2 models now." + self.config = config + self.boundary = boundary + self.boundary_timestep = self.boundary * num_train_timesteps + self.cur_model_index = -1 + logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}") + + @property + def device(self): + return self.model[self.cur_model_index].device + + def set_scheduler(self, shared_scheduler): + self.scheduler = shared_scheduler + for model in self.model: + if model is not None: + model.set_scheduler(shared_scheduler) + + def infer(self, inputs): + self.get_current_model_index() + if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): + self.model[self.cur_model_index].infer(inputs) + else: + if self.model[self.cur_model_index] is not None: + self.model[self.cur_model_index].infer(inputs) + else: + if self.cur_model_index == 0: + high_noise_model = WanModel( + self.high_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_high_noise", + ) + high_noise_model.set_scheduler(self.scheduler) + self.model[0] = high_noise_model + self.model[0].infer(inputs) + elif self.cur_model_index == 1: + low_noise_model = WanModel( + self.low_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_low_noise", + ) + low_noise_model.set_scheduler(self.scheduler) + self.model[1] = low_noise_model + self.model[1].infer(inputs) + + @ProfilingContext4DebugL2("Swtich models in infer_main costs") + def get_current_model_index(self): + if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep: + logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}") + self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0] + if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model": + if self.cur_model_index == -1: + self.to_cuda(model_index=0) + elif self.cur_model_index == 1: # 1 -> 0 + self.offload_cpu(model_index=1) + self.to_cuda(model_index=0) + self.cur_model_index = 0 + else: + logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}") + self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1] + if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model": + if self.cur_model_index == -1: + self.to_cuda(model_index=1) + elif self.cur_model_index == 0: # 0 -> 1 + self.offload_cpu(model_index=0) + self.to_cuda(model_index=1) + self.cur_model_index = 1 + + def offload_cpu(self, model_index): + self.model[model_index].to_cpu() + + def to_cuda(self, model_index): + self.model[model_index].to_cuda() + + +@RUNNER_REGISTER("wan2.2_moe") +class Wan22MoeRunner(WanRunner): + def __init__(self, config): + super().__init__(config) + self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model") + if not os.path.isdir(self.high_noise_model_path): + self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model") + if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None): + self.high_noise_model_path = self.config["high_noise_quantized_ckpt"] + elif self.config.get("high_noise_original_ckpt", None): + self.high_noise_model_path = self.config["high_noise_original_ckpt"] + + self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model") + if not os.path.isdir(self.low_noise_model_path): + self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model") + if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None): + self.low_noise_model_path = self.config["low_noise_quantized_ckpt"] + elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None): + self.low_noise_model_path = self.config["low_noise_original_ckpt"] + + def load_transformer(self): + # encoder -> high_noise_model -> low_noise_model -> vae -> video_output + if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): + high_noise_model = WanModel( + self.high_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_high_noise", + ) + low_noise_model = WanModel( + self.low_noise_model_path, + self.config, + self.init_device, + model_type="wan2.2_moe_low_noise", + ) + + if self.config.get("lora_configs") and self.config["lora_configs"]: + assert not self.config.get("dit_quantized", False) + + for lora_config in self.config["lora_configs"]: + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + base_name = os.path.basename(lora_path) + if base_name.startswith("high"): + lora_wrapper = WanLoraWrapper(high_noise_model) + lora_name = lora_wrapper.load_lora(lora_path) + lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") + elif base_name.startswith("low"): + lora_wrapper = WanLoraWrapper(low_noise_model) + lora_name = lora_wrapper.load_lora(lora_path) + lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") + else: + raise ValueError(f"Unsupported LoRA path: {lora_path}") + + return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"]) + else: + model_struct = MultiModelStruct([None, None], self.config, self.config["boundary"]) + model_struct.low_noise_model_path = self.low_noise_model_path + model_struct.high_noise_model_path = self.high_noise_model_path + model_struct.init_device = self.init_device + return model_struct + + +@RUNNER_REGISTER("wan2.2") +class Wan22DenseRunner(WanRunner): + def __init__(self, config): + super().__init__(config) + self.vae_encoder_need_img_original = True + self.vae_cls = Wan2_2_VAE + self.tiny_vae_cls = Wan2_2_VAE_tiny + self.vae_name = "Wan2.2_VAE.pth" + self.tiny_vae_name = "taew2_2.pth" + + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["Wan22DenseRunner"], + ) + def run_vae_encoder(self, img): + max_area = self.config.target_height * self.config.target_width + ih, iw = img.height, img.width + dh, dw = self.config.patch_size[1] * self.config.vae_stride[1], self.config.patch_size[2] * self.config.vae_stride[2] + ow, oh = best_output_size(iw, ih, dw, dh, max_area) + + scale = max(ow / iw, oh / ih) + img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS) + + # center-crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + assert img.width == ow and img.height == oh + + # to tensor + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(AI_DEVICE).unsqueeze(1) + vae_encoder_out = self.get_vae_encoder_output(img) + latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1] + latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) + return vae_encoder_out, latent_shape + + def get_vae_encoder_output(self, img): + z = self.vae_encoder.encode(img.unsqueeze(0).to(GET_DTYPE())) + return z diff --git a/lightx2v/models/runners/wan/wan_sf_runner.py b/lightx2v/models/runners/wan/wan_sf_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..48c3d8331524f23e6f3733f7a0eac5b4187c9ada --- /dev/null +++ b/lightx2v/models/runners/wan/wan_sf_runner.py @@ -0,0 +1,175 @@ +import gc + +import torch +from loguru import logger + +from lightx2v.deploy.common.video_recorder import VideoRecorder +from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper +from lightx2v.models.networks.wan.sf_model import WanSFModel +from lightx2v.models.runners.wan.wan_runner import WanRunner +from lightx2v.models.schedulers.wan.self_forcing.scheduler import WanSFScheduler +from lightx2v.models.video_encoders.hf.wan.vae_sf import WanSFVAE +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import * +from lightx2v.utils.memory_profiler import peak_memory_decorator +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.utils import vae_to_comfyui_image_inplace + + +@RUNNER_REGISTER("wan2.1_sf") +class WanSFRunner(WanRunner): + def __init__(self, config): + super().__init__(config) + self.vae_cls = WanSFVAE + self.is_live = config.get("is_live", False) + if self.is_live: + self.width = self.config["target_width"] + self.height = self.config["target_height"] + self.run_main = self.run_main_live + + def load_transformer(self): + model = WanSFModel( + self.config, + self.config, + self.init_device, + ) + if self.config.get("lora_configs") and self.config.lora_configs: + assert not self.config.get("dit_quantized", False) + lora_wrapper = WanLoraWrapper(model) + for lora_config in self.config.lora_configs: + lora_path = lora_config["path"] + strength = lora_config.get("strength", 1.0) + lora_name = lora_wrapper.load_lora(lora_path) + lora_wrapper.apply_lora(lora_name, strength) + logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") + return model + + def init_scheduler(self): + self.scheduler = WanSFScheduler(self.config) + + def set_target_shape(self): + self.num_output_frames = 21 + self.config.target_shape = [16, self.num_output_frames, 60, 104] + + def get_video_segment_num(self): + self.video_segment_num = self.scheduler.num_blocks + + @ProfilingContext4DebugL1("Run VAE Decoder") + def run_vae_decoder(self, latents): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.vae_decoder = self.load_vae_decoder() + images = self.vae_decoder.decode(latents.to(GET_DTYPE()), use_cache=True) + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.vae_decoder + torch.cuda.empty_cache() + gc.collect() + return images + + def init_run(self): + super().init_run() + + @peak_memory_decorator + def run_segment(self, segment_idx=0): + infer_steps = self.model.scheduler.infer_steps + for step_index in range(infer_steps): + # only for single segment, check stop signal every step + if self.video_segment_num == 1: + self.check_stop() + logger.info(f"==> step_index: {step_index + 1} / {infer_steps}") + + with ProfilingContext4DebugL1("step_pre"): + self.model.scheduler.step_pre(seg_index=segment_idx, step_index=step_index, is_rerun=False) + + with ProfilingContext4DebugL1("🚀 infer_main"): + self.model.infer(self.inputs) + + with ProfilingContext4DebugL1("step_post"): + self.model.scheduler.step_post() + + if self.progress_callback: + current_step = segment_idx * infer_steps + step_index + 1 + total_all_steps = self.video_segment_num * infer_steps + self.progress_callback((current_step / total_all_steps) * 100, 100) + + return self.model.scheduler.stream_output + + def get_rank_and_world_size(self): + rank = 0 + world_size = 1 + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + return rank, world_size + + def init_video_recorder(self): + output_video_path = self.input_info.save_result_path + self.video_recorder = None + if isinstance(output_video_path, dict): + output_video_path = output_video_path["data"] + logger.info(f"init video_recorder with output_video_path: {output_video_path}") + rank, world_size = self.get_rank_and_world_size() + if output_video_path and rank == world_size - 1: + record_fps = self.config.get("target_fps", 16) + audio_sr = self.config.get("audio_sr", 16000) + if "video_frame_interpolation" in self.config and self.vfi_model is not None: + record_fps = self.config["video_frame_interpolation"]["target_fps"] + + self.video_recorder = VideoRecorder( + livestream_url=output_video_path, + fps=record_fps, + ) + + @ProfilingContext4DebugL1("End run segment") + def end_run_segment(self, segment_idx=None): + with ProfilingContext4DebugL1("step_pre_in_rerun"): + self.model.scheduler.step_pre(seg_index=segment_idx, step_index=self.model.scheduler.infer_steps - 1, is_rerun=True) + with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"): + self.model.infer(self.inputs) + + self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video + if self.is_live: + if self.video_recorder: + stream_video = vae_to_comfyui_image_inplace(self.gen_video) + self.video_recorder.pub_video(stream_video) + + torch.cuda.empty_cache() + + @ProfilingContext4DebugL2("Run DiT") + def run_main_live(self, total_steps=None): + try: + self.init_video_recorder() + logger.info(f"init video_recorder: {self.video_recorder}") + rank, world_size = self.get_rank_and_world_size() + if rank == world_size - 1: + assert self.video_recorder is not None, "video_recorder is required for stream audio input for rank 2" + self.video_recorder.start(self.width, self.height) + if world_size > 1: + dist.barrier() + self.init_run() + if self.config.get("compile", False): + self.model.select_graph_for_compile(self.input_info) + + for segment_idx in range(self.video_segment_num): + logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}") + with ProfilingContext4DebugL1( + f"segment end2end {segment_idx + 1}/{self.video_segment_num}", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, + metrics_labels=["DefaultRunner"], + ): + self.check_stop() + # 1. default do nothing + self.init_run_segment(segment_idx) + # 2. main inference loop + latents = self.run_segment(segment_idx) + # 3. vae decoder + self.gen_video = self.run_vae_decoder(latents) + # 4. default do nothing + self.end_run_segment(segment_idx) + finally: + if hasattr(self.model, "inputs"): + self.end_run() + if self.video_recorder: + self.video_recorder.stop() + self.video_recorder = None diff --git a/lightx2v/models/runners/wan/wan_vace_runner.py b/lightx2v/models/runners/wan/wan_vace_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..2f5803680b09d505566f3a874e67a7caa7b2f800 --- /dev/null +++ b/lightx2v/models/runners/wan/wan_vace_runner.py @@ -0,0 +1,192 @@ +import gc + +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from PIL import Image + +from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProcessor +from lightx2v.models.networks.wan.vace_model import WanVaceModel +from lightx2v.models.runners.wan.wan_runner import WanRunner +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import * +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER + + +@RUNNER_REGISTER("wan2.1_vace") +class WanVaceRunner(WanRunner): + def __init__(self, config): + super().__init__(config) + assert self.config["task"] == "vace" + self.vid_proc = VaceVideoProcessor( + downsample=tuple([x * y for x, y in zip(self.config["vae_stride"], self.config["patch_size"])]), + min_area=720 * 1280, + max_area=720 * 1280, + min_fps=self.config["fps"] if "fps" in self.config else 16, + max_fps=self.config["fps"] if "fps" in self.config else 16, + zero_start=True, + seq_len=75600, + keep_last=True, + ) + + def load_transformer(self): + model = WanVaceModel( + self.config["model_path"], + self.config, + self.init_device, + ) + return model + + def prepare_source(self, src_video, src_mask, src_ref_images, image_size, device=torch.device("cuda")): + area = image_size[0] * image_size[1] + self.vid_proc.set_area(area) + if area == 720 * 1280: + self.vid_proc.set_seq_len(75600) + elif area == 480 * 832: + self.vid_proc.set_seq_len(32760) + else: + raise NotImplementedError(f"image_size {image_size} is not supported") + + image_size = (image_size[1], image_size[0]) + image_sizes = [] + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + src_video[i] = torch.zeros((3, self.config["target_video_length"], image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(src_video[i].shape[2:]) + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = Image.open(ref_img).convert("RGB") + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode="bilinear", align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top : top + new_height, left : left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["WanVaceRunner"], + ) + def run_vae_encoder(self, frames, ref_images, masks): + if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): + self.vae_encoder = self.load_vae_encoder() + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = [self.vae_encoder.encode(frame.unsqueeze(0).to(GET_DTYPE())) for frame in frames] + else: + masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = [self.vae_encoder.encode(inact.unsqueeze(0).to(GET_DTYPE())) for inact in inactive] + reactive = [self.vae_encoder.encode(react.unsqueeze(0).to(GET_DTYPE())) for react in reactive] + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs] + else: + ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs] + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + self.latent_shape = list(cat_latents[0].shape) + if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): + del self.vae_encoder + torch.cuda.empty_cache() + gc.collect() + return self.get_vae_encoder_output(cat_latents, masks, ref_images), self.set_input_info_latent_shape() + + def get_vae_encoder_output(self, cat_latents, masks, ref_images): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // self.config["vae_stride"][0]) + height = 2 * (int(height) // (self.config["vae_stride"][1] * 2)) + width = 2 * (int(width) // (self.config["vae_stride"][2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view(depth, height, self.config["vae_stride"][1], width, self.config["vae_stride"][1]) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape(self.config["vae_stride"][1] * self.config["vae_stride"][2], depth, height, width) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact").squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(cat_latents, result_masks)] + + def set_input_info_latent_shape(self): + latent_shape = self.latent_shape + latent_shape[0] = int(latent_shape[0] / 2) + return latent_shape + + @ProfilingContext4DebugL1( + "Run VAE Decoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, + metrics_labels=["WanVaceRunner"], + ) + def run_vae_decoder(self, latents): + if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): + self.vae_decoder = self.load_vae_decoder() + + if self.src_ref_images is not None: + assert len(self.src_ref_images) == 1 + refs = self.src_ref_images[0] + if refs is not None: + latents = latents[:, len(refs) :, :, :] + + images = self.vae_decoder.decode(latents.to(GET_DTYPE())) + + if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): + del self.vae_decoder + torch.cuda.empty_cache() + gc.collect() + + return images diff --git a/lightx2v/models/schedulers/__init__.py b/lightx2v/models/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/schedulers/hunyuan_video/__init__.py b/lightx2v/models/schedulers/hunyuan_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/schedulers/hunyuan_video/feature_caching/__init__.py b/lightx2v/models/schedulers/hunyuan_video/feature_caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/schedulers/hunyuan_video/feature_caching/scheduler.py b/lightx2v/models/schedulers/hunyuan_video/feature_caching/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5156b8fe268fa29f9594ba9cb98c5c01739d9c --- /dev/null +++ b/lightx2v/models/schedulers/hunyuan_video/feature_caching/scheduler.py @@ -0,0 +1,9 @@ +from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15Scheduler + + +class HunyuanVideo15SchedulerCaching(HunyuanVideo15Scheduler): + def __init__(self, config): + super().__init__(config) + + def clear(self): + self.transformer_infer.clear() diff --git a/lightx2v/models/schedulers/hunyuan_video/posemb_layers.py b/lightx2v/models/schedulers/hunyuan_video/posemb_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..948c0a6e07863c678dd0805f74d302123c05d7de --- /dev/null +++ b/lightx2v/models/schedulers/hunyuan_video/posemb_layers.py @@ -0,0 +1,283 @@ +from typing import List, Tuple, Union + +import torch + +from lightx2v_platform.base.global_var import AI_DEVICE + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 + + +def reshape_for_broadcast( + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + x: torch.Tensor, +): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Notes: + When using FlashMHAModified, head_first should be False. + When using Attention, head_first should be True. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + cos, sin = reshape_for_broadcast(freqs_cis, xq) # [S, D] + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + return xq_out, xk_out + + +def rotate_half_force_bf16(x): + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb_force_bf16( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + cos, sin = reshape_for_broadcast(freqs_cis, xq) # [S, D] + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = xq * cos + rotate_half_force_bf16(xq) * sin + xk_out = xk * cos + rotate_half_force_bf16(xk) * sin + return xq_out, xk_out + + +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + theta_rescale_factor: Union[float, List[float]] = 1.0, + interpolation_factor: Union[float, List[float]] = 1.0, + **kwds, +): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] + + if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + **kwds, + ) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + **kwds, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + freqs = torch.outer(pos * interpolation_factor, freqs).to(AI_DEVICE) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis diff --git a/lightx2v/models/schedulers/hunyuan_video/scheduler.py b/lightx2v/models/schedulers/hunyuan_video/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9cbff6e1c2a56003a018a5146e732446e2fbe6 --- /dev/null +++ b/lightx2v/models/schedulers/hunyuan_video/scheduler.py @@ -0,0 +1,216 @@ +import torch +import torch.distributed as dist +from einops import rearrange +from torch.nn import functional as F + +from lightx2v.models.schedulers.scheduler import BaseScheduler +from lightx2v_platform.base.global_var import AI_DEVICE + +from .posemb_layers import get_nd_rotary_pos_embed + + +class HunyuanVideo15Scheduler(BaseScheduler): + def __init__(self, config): + super().__init__(config) + self.reverse = True + self.num_train_timesteps = 1000 + self.sample_shift = self.config["sample_shift"] + self.reorg_token = False + self.keep_latents_dtype_in_scheduler = True + self.sample_guide_scale = self.config["sample_guide_scale"] + if self.config["seq_parallel"]: + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + + def prepare(self, seed, latent_shape, image_encoder_output=None): + self.prepare_latents(seed, latent_shape, dtype=torch.bfloat16) + self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift) + self.multitask_mask = self.get_task_mask(self.config["task"], latent_shape[-3]) + self.cond_latents_concat, self.mask_concat = self._prepare_cond_latents_and_mask(self.config["task"], image_encoder_output["cond_latents"], self.latents, self.multitask_mask, self.reorg_token) + self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3])) + + def prepare_latents(self, seed, latent_shape, dtype=torch.bfloat16): + self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed) + self.latents = torch.randn( + 1, + latent_shape[0], + latent_shape[1], + latent_shape[2], + latent_shape[3], + dtype=dtype, + device=AI_DEVICE, + generator=self.generator, + ) + + def set_timesteps(self, num_inference_steps, device, shift): + sigmas = torch.linspace(1, 0, num_inference_steps + 1) + + # Apply timestep shift + if shift != 1.0: + sigmas = self.sd3_time_shift(sigmas, shift) + + if not self.reverse: + sigmas = 1 - sigmas + + self.sigmas = sigmas + self.timesteps = (sigmas[:-1] * self.num_train_timesteps).to(dtype=torch.float32, device=device) + + def sd3_time_shift(self, t: torch.Tensor, shift): + return (shift * t) / (1 + (shift - 1) * t) + + def get_task_mask(self, task_type, latent_target_length): + if task_type == "t2v": + mask = torch.zeros(latent_target_length) + elif task_type == "i2v": + mask = torch.zeros(latent_target_length) + mask[0] = 1.0 + else: + raise ValueError(f"{task_type} is not supported !") + return mask + + def _prepare_cond_latents_and_mask(self, task_type, cond_latents, latents, multitask_mask, reorg_token): + """ + Prepare multitask mask training logic. + + Args: + task_type: Type of task ("i2v" or "t2v") + cond_latents: Conditional latents tensor + latents: Main latents tensor + multitask_mask: Multitask mask tensor + reorg_token: Whether to reorganize tokens + + Returns: + tuple: (latents_concat, mask_concat) - may contain None values + """ + latents_concat = None + mask_concat = None + + if cond_latents is not None and task_type == "i2v": + latents_concat = cond_latents.repeat(1, 1, latents.shape[2], 1, 1) + latents_concat[:, :, 1:, :, :] = 0.0 + else: + if reorg_token: + latents_concat = torch.zeros(latents.shape[0], latents.shape[1] // 2, latents.shape[2], latents.shape[3], latents.shape[4]).to(latents.device) + else: + latents_concat = torch.zeros(latents.shape[0], latents.shape[1], latents.shape[2], latents.shape[3], latents.shape[4]).to(latents.device) + + mask_zeros = torch.zeros(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4]) + mask_ones = torch.ones(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4]) + mask_concat = self.merge_tensor_by_mask(mask_zeros.cpu(), mask_ones.cpu(), mask=multitask_mask.cpu(), dim=2).to(device=latents.device) + + return latents_concat, mask_concat + + def merge_tensor_by_mask(self, tensor_1, tensor_2, mask, dim): + assert tensor_1.shape == tensor_2.shape + # Mask is a 0/1 vector. Choose tensor_2 when the value is 1; otherwise, tensor_1 + masked_indices = torch.nonzero(mask).squeeze(1) + tmp = tensor_1.clone() + if dim == 0: + tmp[masked_indices] = tensor_2[masked_indices] + elif dim == 1: + tmp[:, masked_indices] = tensor_2[:, masked_indices] + elif dim == 2: + tmp[:, :, masked_indices] = tensor_2[:, :, masked_indices] + return tmp + + def step_post(self): + model_output = self.noise_pred.to(torch.float32) + sample = self.latents.to(torch.float32) + dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] + self.latents = sample + model_output * dt + + def prepare_cos_sin(self, rope_sizes): + target_ndim = 3 + head_dim = self.config["hidden_size"] // self.config["heads_num"] + rope_dim_list = self.config["rope_dim_list"] + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed(rope_dim_list, rope_sizes, theta=self.config["rope_theta"], use_real=True, theta_rescale_factor=1, device=AI_DEVICE) + cos_half = freqs_cos[:, ::2].contiguous() + sin_half = freqs_sin[:, ::2].contiguous() + cos_sin = torch.cat([cos_half, sin_half], dim=-1) + if self.seq_p_group is not None: + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + seqlen = cos_sin.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + cos_sin = F.pad(cos_sin, (0, 0, 0, padding_size)) + cos_sin = torch.chunk(cos_sin, world_size, dim=0)[cur_rank] + return cos_sin + + +class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler): + def __init__(self, config): + super().__init__(config) + self.noise_scale = 0.7 + + def prepare(self, seed, latent_shape, lq_latents, upsampler, image_encoder_output=None): + dtype = lq_latents.dtype + self.prepare_latents(seed, latent_shape, lq_latents, dtype=dtype) + self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift) + self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3])) + + tgt_shape = latent_shape[-2:] + bsz = lq_latents.shape[0] + lq_latents = rearrange(lq_latents, "b c f h w -> (b f) c h w") + lq_latents = F.interpolate(lq_latents, size=tgt_shape, mode="bilinear", align_corners=False) + lq_latents = rearrange(lq_latents, "(b f) c h w -> b c f h w", b=bsz) + + lq_latents = upsampler(lq_latents.to(dtype=torch.float32, device=device)) + lq_latents = lq_latents.to(dtype=dtype) + + lq_latents = self.add_noise_to_lq(lq_latents, self.noise_scale) + + condition = self.get_condition(lq_latents, image_encoder_output["cond_latents"], self.config["task"]) + c = lq_latents.shape[1] + + zero_condition = condition.clone() + zero_condition[:, c + 1 : 2 * c + 1] = torch.zeros_like(lq_latents) + zero_condition[:, 2 * c + 1] = 0 + + self.condition = condition + self.zero_condition = zero_condition + + def prepare_latents(self, seed, latent_shape, lq_latents, dtype=torch.bfloat16): + self.generator = torch.Generator(device=lq_latents.device).manual_seed(seed) + self.latents = torch.randn( + 1, + latent_shape[0], + latent_shape[1], + latent_shape[2], + latent_shape[3], + dtype=dtype, + device=lq_latents.device, + generator=self.generator, + ) + + def get_condition(self, lq_latents, img_cond, task): + """ + latents: shape (b c f h w) + """ + b, c, f, h, w = self.latents.shape + cond = torch.zeros([b, c * 2 + 2, f, h, w], device=lq_latents.device, dtype=lq_latents.dtype) + + cond[:, c + 1 : 2 * c + 1] = lq_latents + cond[:, 2 * c + 1] = 1 + if "t2v" in task: + return cond + elif "i2v" in task: + cond[:, :c, :1] = img_cond + cond[:, c + 1, 0] = 1 + return cond + else: + raise ValueError(f"Unsupported task: {task}") + + def add_noise_to_lq(self, lq_latents, strength=0.7): + def expand_dims(tensor: torch.Tensor, ndim: int): + shape = tensor.shape + (1,) * (ndim - tensor.ndim) + return tensor.reshape(shape) + + noise = torch.randn_like(lq_latents) + timestep = torch.tensor([1000.0], device=lq_latents.device) * strength + t = expand_dims(timestep, lq_latents.ndim) + return (1 - t / 1000.0) * lq_latents + (t / 1000.0) * noise diff --git a/lightx2v/models/schedulers/hunyuan_video/step_distill/scheduler.py b/lightx2v/models/schedulers/hunyuan_video/step_distill/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f0815739dcca1db4b465e1c112c73a5a5b23ab65 --- /dev/null +++ b/lightx2v/models/schedulers/hunyuan_video/step_distill/scheduler.py @@ -0,0 +1,33 @@ +import torch + +from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15Scheduler + + +class HunyuanVideo15StepDistillScheduler(HunyuanVideo15Scheduler): + def __init__(self, config): + super().__init__(config) + self.denoising_step_list = config["denoising_step_list"] + self.infer_steps = len(self.denoising_step_list) + + self.num_train_timesteps = 1000 + self.sigma_max = 1.0 + self.sigma_min = 0.0 + + def set_timesteps(self, num_inference_steps, device, shift): + sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) + self.sigmas = torch.linspace(sigma_start, self.sigma_min, self.num_train_timesteps + 1)[:-1] + self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas) + self.timesteps = self.sigmas * self.num_train_timesteps + + self.denoising_step_index = [self.num_train_timesteps - x for x in self.denoising_step_list] + self.timesteps = self.timesteps[self.denoising_step_index].to(device) + self.sigmas = self.sigmas[self.denoising_step_index].to("cpu") + + def step_post(self): + flow_pred = self.noise_pred.to(torch.float32) + sigma = self.sigmas[self.step_index].item() + noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred + if self.step_index < self.infer_steps - 1: + sigma_n = self.sigmas[self.step_index + 1].item() + noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n + self.latents = noisy_image_or_video.to(self.latents.dtype) diff --git a/lightx2v/models/schedulers/qwen_image/scheduler.py b/lightx2v/models/schedulers/qwen_image/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7f8670ab60b06456089a9bb60b1bffba5dffd3 --- /dev/null +++ b/lightx2v/models/schedulers/qwen_image/scheduler.py @@ -0,0 +1,235 @@ +import inspect +import json +import os +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + +from lightx2v.models.schedulers.scheduler import BaseScheduler +from lightx2v_platform.base.global_var import AI_DEVICE + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.") + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional[Union[str, "torch.device"]] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + if isinstance(device, str): + device = torch.device(device) + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + print( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slightly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + + +class QwenImageScheduler(BaseScheduler): + def __init__(self, config): + super().__init__(config) + self.config = config + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler")) + with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f: + self.scheduler_config = json.load(f) + self.dtype = torch.bfloat16 + self.guidance_scale = 1.0 + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) + + return latent_image_ids.to(device=device, dtype=dtype) + + def prepare_latents(self, input_info): + shape = input_info.target_shape + width, height = shape[-1], shape[-2] + + latents = randn_tensor(shape, generator=self.generator, device=AI_DEVICE, dtype=self.dtype) + latents = self._pack_latents(latents, self.config["batchsize"], self.config["num_channels_latents"], height, width) + latent_image_ids = self._prepare_latent_image_ids(self.config["batchsize"], height // 2, width // 2, AI_DEVICE, self.dtype) + + self.latents = latents + self.latent_image_ids = latent_image_ids + self.noise_pred = None + + def set_timesteps(self): + sigmas = np.linspace(1.0, 1 / self.config["infer_steps"], self.config["infer_steps"]) + image_seq_len = self.latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler_config.get("base_image_seq_len", 256), + self.scheduler_config.get("max_image_seq_len", 4096), + self.scheduler_config.get("base_shift", 0.5), + self.scheduler_config.get("max_shift", 1.15), + ) + num_inference_steps = self.config["infer_steps"] + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + AI_DEVICE, + sigmas=sigmas, + mu=mu, + ) + + self.timesteps = timesteps + self.infer_steps = num_inference_steps + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + self.num_warmup_steps = num_warmup_steps + + def prepare_guidance(self): + # handle guidance + if self.config["guidance_embeds"]: + guidance = torch.full([1], self.guidance_scale, device=AI_DEVICE, dtype=torch.float32) + guidance = guidance.expand(self.latents.shape[0]) + else: + guidance = None + self.guidance = guidance + + def prepare(self, input_info): + if self.config["task"] == "i2i": + self.generator = torch.Generator().manual_seed(input_info.seed) + elif self.config["task"] == "t2i": + self.generator = torch.Generator(device=AI_DEVICE).manual_seed(input_info.seed) + self.prepare_latents(input_info) + self.prepare_guidance() + self.set_timesteps() + + def step_post(self): + # compute the previous noisy sample x_t -> x_t-1 + t = self.timesteps[self.step_index] + latents = self.scheduler.step(self.noise_pred, t, self.latents, return_dict=False)[0] + self.latents = latents diff --git a/lightx2v/models/schedulers/scheduler.py b/lightx2v/models/schedulers/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a14403d0d2e962c8406cea0064de6f2b26db4ce8 --- /dev/null +++ b/lightx2v/models/schedulers/scheduler.py @@ -0,0 +1,22 @@ +from lightx2v.utils.envs import * + + +class BaseScheduler: + def __init__(self, config): + self.config = config + self.latents = None + self.step_index = 0 + self.infer_steps = config["infer_steps"] + self.caching_records = [True] * config["infer_steps"] + self.flag_df = False + self.transformer_infer = None + self.infer_condition = True # cfg status + self.keep_latents_dtype_in_scheduler = False + + def step_pre(self, step_index): + self.step_index = step_index + if GET_DTYPE() == GET_SENSITIVE_DTYPE() and not self.keep_latents_dtype_in_scheduler: + self.latents = self.latents.to(GET_DTYPE()) + + def clear(self): + pass diff --git a/lightx2v/models/schedulers/wan/audio/scheduler.py b/lightx2v/models/schedulers/wan/audio/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a9901b1a8a3defe661f615594ea1fb206330f5b5 --- /dev/null +++ b/lightx2v/models/schedulers/wan/audio/scheduler.py @@ -0,0 +1,127 @@ +import math + +import numpy as np +import torch +from loguru import logger + +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.utils.envs import * +from lightx2v.utils.utils import masks_like +from lightx2v_platform.base.global_var import AI_DEVICE + + +class EulerScheduler(WanScheduler): + def __init__(self, config): + super().__init__(config) + d = config["dim"] // config["num_heads"] + self.rope_t_dim = d // 2 - 2 * (d // 6) + + if self.config["parallel"]: + self.sp_size = self.config["parallel"].get("seq_p_size", 1) + else: + self.sp_size = 1 + + if self.config["model_cls"] == "wan2.2_audio": + self.prev_latents = None + self.prev_len = 0 + + def set_audio_adapter(self, audio_adapter): + self.audio_adapter = audio_adapter + + def step_pre(self, step_index): + super().step_pre(step_index) + if self.audio_adapter.cpu_offload: + self.audio_adapter.time_embedding.to("cuda") + self.audio_adapter_t_emb = self.audio_adapter.time_embedding(self.timestep_input).unflatten(1, (3, -1)) + if self.audio_adapter.cpu_offload: + self.audio_adapter.time_embedding.to("cpu") + + if self.config["model_cls"] == "wan2.2_audio": + _, lat_f, lat_h, lat_w = self.latents.shape + F = (lat_f - 1) * self.config["vae_stride"][0] + 1 + per_latent_token_len = lat_h * lat_w // (self.config["patch_size"][1] * self.config["patch_size"][2]) + max_seq_len = ((F - 1) // self.config["vae_stride"][0] + 1) * per_latent_token_len + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + temp_ts = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten() + self.timestep_input = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * self.timestep_input]).unsqueeze(0) + + self.timestep_input = torch.cat( + [ + self.timestep_input, + torch.zeros( + (1, per_latent_token_len), # padding for reference frame latent + dtype=self.timestep_input.dtype, + device=self.timestep_input.device, + ), + ], + dim=1, + ) + + def prepare_latents(self, seed, latent_shape, dtype=torch.float32): + self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed) + self.latents = torch.randn( + latent_shape[0], + latent_shape[1], + latent_shape[2], + latent_shape[3], + dtype=dtype, + device=AI_DEVICE, + generator=self.generator, + ) + if self.config["model_cls"] == "wan2.2_audio": + self.mask = masks_like(self.latents, zero=True, prev_len=self.prev_len) + if self.prev_latents is not None: + self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents + + def prepare(self, seed, latent_shape, image_encoder_output=None): + self.prepare_latents(seed, latent_shape, dtype=torch.float32) + timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32) + + self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=AI_DEVICE) + self.timesteps_ori = self.timesteps.clone() + + self.sigmas = self.timesteps_ori / self.num_train_timesteps + self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas) + + self.timesteps = self.sigmas * self.num_train_timesteps + + self.freqs[latent_shape[1] // self.patch_size[0] :, : self.rope_t_dim] = 0 + self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0] + 1, latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2])) + + def step_post(self): + model_output = self.noise_pred.to(torch.float32) + sample = self.latents.to(torch.float32) + sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype) + sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype) + x_t_next = sample + (sigma_next - sigma) * model_output + self.latents = x_t_next + if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None: + self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents + + def reset(self, seed, latent_shape, image_encoder_output=None): + if self.config["model_cls"] == "wan2.2_audio": + self.prev_latents = image_encoder_output["prev_latents"] + self.prev_len = image_encoder_output["prev_len"] + self.prepare_latents(seed, latent_shape, dtype=torch.float32) + + def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim): + if in_tensor.ndim > tgt_n_dim: + logger.warning(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned") + return in_tensor + if in_tensor.ndim < tgt_n_dim: + in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)] + return in_tensor + + +class ConsistencyModelScheduler(EulerScheduler): + def step_post(self): + model_output = self.noise_pred.to(torch.float32) + sample = self.latents.to(torch.float32) + sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype) + sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype) + x0 = sample - model_output * sigma + x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator) + self.latents = x_t_next + if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None: + self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents diff --git a/lightx2v/models/schedulers/wan/changing_resolution/scheduler.py b/lightx2v/models/schedulers/wan/changing_resolution/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..87c6e6ce88c183fa15b11dee459123e2e3c447ad --- /dev/null +++ b/lightx2v/models/schedulers/wan/changing_resolution/scheduler.py @@ -0,0 +1,94 @@ +import torch + +from lightx2v_platform.base.global_var import AI_DEVICE + + +class WanScheduler4ChangingResolutionInterface: + def __new__(cls, father_scheduler, config): + class NewClass(WanScheduler4ChangingResolution, father_scheduler): + def __init__(self, config): + father_scheduler.__init__(self, config) + WanScheduler4ChangingResolution.__init__(self, config) + + return NewClass(config) + + +class WanScheduler4ChangingResolution: + def __init__(self, config): + if "resolution_rate" not in config: + config["resolution_rate"] = [0.75] + if "changing_resolution_steps" not in config: + config["changing_resolution_steps"] = [config.infer_steps // 2] + assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"]) + + def prepare_latents(self, seed, latent_shape, dtype=torch.float32): + self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed) + self.latents_list = [] + for i in range(len(self.config["resolution_rate"])): + self.latents_list.append( + torch.randn( + latent_shape[0], + latent_shape[1], + int(latent_shape[2] * self.config["resolution_rate"][i]) // 2 * 2, + int(latent_shape[3] * self.config["resolution_rate"][i]) // 2 * 2, + dtype=dtype, + device=AI_DEVICE, + generator=self.generator, + ) + ) + + # add original resolution latents + self.latents_list.append( + torch.randn( + latent_shape[0], + latent_shape[1], + latent_shape[2], + latent_shape[3], + dtype=dtype, + device=AI_DEVICE, + generator=self.generator, + ) + ) + + # set initial latents + self.latents = self.latents_list[0] + self.changing_resolution_index = 0 + + def step_post(self): + if self.step_index + 1 in self.config["changing_resolution_steps"]: + self.step_post_upsample() + self.changing_resolution_index += 1 + else: + super().step_post() + + def step_post_upsample(self): + # 1. denoised sample to clean noise + model_output = self.noise_pred.to(torch.float32) + sample = self.latents.to(torch.float32) + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + denoised_sample = x0_pred.to(sample.dtype) + + # 2. upsample clean noise to target shape + denoised_sample_5d = denoised_sample.unsqueeze(0) # (C,T,H,W) -> (1,C,T,H,W) + + shape_to_upsampled = self.latents_list[self.changing_resolution_index + 1].shape[1:] + clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=shape_to_upsampled, mode="trilinear") + clean_noise = clean_noise.squeeze(0) # (1,C,T,H,W) -> (C,T,H,W) + + # 3. add noise to clean noise + noisy_sample = self.add_noise(clean_noise, self.latents_list[self.changing_resolution_index + 1], self.timesteps[self.step_index + 1]) + + # 4. update latents + self.latents = noisy_sample + + # self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed + + # 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪 + self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift + self.changing_resolution_index + 1) + + def add_noise(self, original_samples, noise, timesteps): + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples diff --git a/lightx2v/models/schedulers/wan/feature_caching/scheduler.py b/lightx2v/models/schedulers/wan/feature_caching/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..c306ab87efb56b5b57069efe0019d9deb66a68aa --- /dev/null +++ b/lightx2v/models/schedulers/wan/feature_caching/scheduler.py @@ -0,0 +1,18 @@ +from lightx2v.models.schedulers.wan.scheduler import WanScheduler + + +class WanSchedulerCaching(WanScheduler): + def __init__(self, config): + super().__init__(config) + + def clear(self): + self.transformer_infer.clear() + + +class WanSchedulerTaylorCaching(WanSchedulerCaching): + def __init__(self, config): + super().__init__(config) + + pattern = [True, False, False, False] + self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps] + self.caching_records_2 = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps] diff --git a/lightx2v/models/schedulers/wan/scheduler.py b/lightx2v/models/schedulers/wan/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ef9f4eea85aaa5d4c7035ddd30247ae48f7fa760 --- /dev/null +++ b/lightx2v/models/schedulers/wan/scheduler.py @@ -0,0 +1,434 @@ +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch.nn import functional as F + +from lightx2v.models.schedulers.scheduler import BaseScheduler +from lightx2v.utils.utils import masks_like +from lightx2v_platform.base.global_var import AI_DEVICE + + +class WanScheduler(BaseScheduler): + def __init__(self, config): + super().__init__(config) + self.infer_steps = self.config["infer_steps"] + self.target_video_length = self.config["target_video_length"] + self.sample_shift = self.config["sample_shift"] + if self.config["seq_parallel"]: + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + self.patch_size = (1, 2, 2) + self.shift = 1 + self.num_train_timesteps = 1000 + self.disable_corrector = [] + self.solver_order = 2 + self.noise_pred = None + self.sample_guide_scale = self.config["sample_guide_scale"] + self.caching_records_2 = [True] * self.config["infer_steps"] + self.head_size = self.config["dim"] // self.config["num_heads"] + self.freqs = torch.cat( + [ + self.rope_params(1024, self.head_size - 4 * (self.head_size // 6)), + self.rope_params(1024, 2 * (self.head_size // 6)), + self.rope_params(1024, 2 * (self.head_size // 6)), + ], + dim=1, + ).to(torch.device(AI_DEVICE)) + + def rope_params(self, max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)), + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def prepare(self, seed, latent_shape, image_encoder_output=None): + if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: + self.vae_encoder_out = image_encoder_output["vae_encoder_out"] + + self.prepare_latents(seed, latent_shape, dtype=torch.float32) + + alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * self.num_train_timesteps + + self.model_outputs = [None] * self.solver_order + self.timestep_list = [None] * self.solver_order + self.last_sample = None + + self.sigmas = self.sigmas.to("cpu") + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift) + + self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0], latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2])) + + def prepare_cos_sin(self, grid_sizes): + c = self.head_size // 2 + freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + f, h, w = grid_sizes + seq_len = f * h * w + cos_sin = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ) + if self.config.get("rope_type", "flashinfer") == "flashinfer": + cos_sin = cos_sin.reshape(seq_len, -1) + # Extract cos and sin parts separately and concatenate + cos_half = cos_sin.real.contiguous() + sin_half = cos_sin.imag.contiguous() + cos_sin = torch.cat([cos_half, sin_half], dim=-1) + if self.seq_p_group is not None: + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + seqlen = cos_sin.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + cos_sin = F.pad(cos_sin, (0, 0, 0, padding_size)) + cos_sin = torch.chunk(cos_sin, world_size, dim=0)[cur_rank] + else: + cos_sin = cos_sin.reshape(seq_len, 1, -1) + if self.seq_p_group is not None: + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + seqlen = cos_sin.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + cos_sin = F.pad(cos_sin, (0, 0, 0, 0, 0, padding_size)) + cos_sin = torch.chunk(cos_sin, world_size, dim=0)[cur_rank] + return cos_sin + + def prepare_latents(self, seed, latent_shape, dtype=torch.float32): + self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed) + self.latents = torch.randn( + latent_shape[0], + latent_shape[1], + latent_shape[2], + latent_shape[3], + dtype=dtype, + device=AI_DEVICE, + generator=self.generator, + ) + if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: + self.mask = masks_like(self.latents, zero=True) + self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents + + def set_timesteps( + self, + infer_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[:-1] + + if shift is None: + shift = self.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + sigma_last = 0 + + timesteps = sigmas * self.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + assert len(self.timesteps) == self.infer_steps + self.model_outputs = [ + None, + ] * self.solver_order + self.lower_order_nums = 0 + self.last_sample = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + return x0_pred + + def reset(self, seed, latent_shape, step_index=None): + if step_index is not None: + self.step_index = step_index + self.model_outputs = [None] * self.solver_order + self.timestep_list = [None] * self.solver_order + self.last_sample = None + self.noise_pred = None + self.this_order = None + self.lower_order_nums = 0 + self.prepare_latents(seed, latent_shape, dtype=torch.float32) + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyward argument") + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + B_h = torch.expm1(hh) + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyward argument") + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + B_h = torch.expm1(hh) + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step_pre(self, step_index): + super().step_pre(step_index) + self.timestep_input = torch.stack([self.timesteps[self.step_index]]) + if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: + self.timestep_input = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten() + + def step_post(self): + model_output = self.noise_pred.to(torch.float32) + timestep = self.timesteps[self.step_index] + sample = self.latents.to(torch.float32) + + use_corrector = self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + this_order = min(self.solver_order, len(self.timesteps) - self.step_index) + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.solver_order: + self.lower_order_nums += 1 + + self.latents = prev_sample + if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: + self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents diff --git a/lightx2v/models/schedulers/wan/self_forcing/scheduler.py b/lightx2v/models/schedulers/wan/self_forcing/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..540116fe3d29397e45a5b2fc7338692a99b31c17 --- /dev/null +++ b/lightx2v/models/schedulers/wan/self_forcing/scheduler.py @@ -0,0 +1,105 @@ +import torch + +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v.utils.envs import * +from lightx2v_platform.base.global_var import AI_DEVICE + + +class WanSFScheduler(WanScheduler): + def __init__(self, config): + super().__init__(config) + self.dtype = torch.bfloat16 + self.num_frame_per_block = self.config["sf_config"]["num_frame_per_block"] + self.num_output_frames = self.config["sf_config"]["num_output_frames"] + self.num_blocks = self.num_output_frames // self.num_frame_per_block + self.denoising_step_list = self.config["sf_config"]["denoising_step_list"] + self.infer_steps = len(self.denoising_step_list) + self.all_num_frames = [self.num_frame_per_block] * self.num_blocks + self.num_input_frames = 0 + self.denoising_strength = 1.0 + self.sigma_max = 1.0 + self.sigma_min = 0 + self.sf_shift = self.config["sf_config"]["shift"] + self.inverse_timesteps = False + self.extra_one_step = True + self.reverse_sigmas = False + self.num_inference_steps = self.config["sf_config"]["num_inference_steps"] + self.context_noise = 0 + + def prepare(self, seed, latent_shape, image_encoder_output=None): + self.latents = torch.randn(latent_shape, device=AI_DEVICE, dtype=self.dtype) + + timesteps = [] + for frame_block_idx, current_num_frames in enumerate(self.all_num_frames): + frame_steps = [] + + for step_index, current_timestep in enumerate(self.denoising_step_list): + timestep = torch.ones([self.num_frame_per_block], device=AI_DEVICE, dtype=torch.int64) * current_timestep + frame_steps.append(timestep) + + timesteps.append(frame_steps) + self.timesteps = timesteps + + self.noise_pred = torch.zeros(latent_shape, device=AI_DEVICE, dtype=self.dtype) + + sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * self.denoising_strength + if self.extra_one_step: + self.sigmas_sf = torch.linspace(sigma_start, self.sigma_min, self.num_inference_steps + 1)[:-1] + else: + self.sigmas_sf = torch.linspace(sigma_start, self.sigma_min, self.num_inference_steps) + if self.inverse_timesteps: + self.sigmas_sf = torch.flip(self.sigmas_sf, dims=[0]) + self.sigmas_sf = self.sf_shift * self.sigmas_sf / (1 + (self.sf_shift - 1) * self.sigmas_sf) + if self.reverse_sigmas: + self.sigmas_sf = 1 - self.sigmas_sf + self.sigmas_sf = self.sigmas_sf.to(AI_DEVICE) + + self.timesteps_sf = self.sigmas_sf * self.num_train_timesteps + self.timesteps_sf = self.timesteps_sf.to(AI_DEVICE) + + self.stream_output = None + + def step_pre(self, seg_index, step_index, is_rerun=False): + self.step_index = step_index + self.seg_index = seg_index + + if not GET_DTYPE() == GET_SENSITIVE_DTYPE(): + self.latents = self.latents.to(GET_DTYPE()) + + if not is_rerun: + self.timestep_input = torch.stack([self.timesteps[self.seg_index][self.step_index]]) + self.latents_input = self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)] + else: + # rerun with timestep zero to update KV cache using clean context + self.timestep_input = torch.ones_like(torch.stack([self.timesteps[self.seg_index][self.step_index]])) * self.context_noise + self.latents_input = self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)] + + def step_post(self): + # convert model outputs + current_start_frame = self.seg_index * self.num_frame_per_block + current_end_frame = (self.seg_index + 1) * self.num_frame_per_block + + flow_pred = self.noise_pred[:, current_start_frame:current_end_frame].transpose(0, 1) + xt = self.latents_input.transpose(0, 1) + timestep = self.timestep_input.squeeze(0) + + original_dtype = flow_pred.dtype + + flow_pred, xt, sigmas, timesteps = map(lambda x: x.double().to(flow_pred.device), [flow_pred, xt, self.sigmas_sf, self.timesteps_sf]) + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + x0_pred = x0_pred.to(original_dtype) + + # add noise + if self.step_index < self.infer_steps - 1: + timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=AI_DEVICE, dtype=torch.long) + timestep_id_next = torch.argmin((self.timesteps_sf.unsqueeze(0) - timestep_next.unsqueeze(1)).abs(), dim=1) + sigma_next = self.sigmas_sf[timestep_id_next].reshape(-1, 1, 1, 1) + noise_next = torch.randn_like(x0_pred) + sample_next = (1 - sigma_next) * x0_pred + sigma_next * noise_next + sample_next = sample_next.type_as(noise_next) + self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)] = sample_next.transpose(0, 1) + else: + self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)] = x0_pred.transpose(0, 1) + self.stream_output = x0_pred.transpose(0, 1) diff --git a/lightx2v/models/schedulers/wan/step_distill/scheduler.py b/lightx2v/models/schedulers/wan/step_distill/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..e59317422fc4bab5e37486a42d7ef1c6f84416bc --- /dev/null +++ b/lightx2v/models/schedulers/wan/step_distill/scheduler.py @@ -0,0 +1,76 @@ +import math +from typing import Union + +import torch + +from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v_platform.base.global_var import AI_DEVICE + + +class WanStepDistillScheduler(WanScheduler): + def __init__(self, config): + super().__init__(config) + self.denoising_step_list = config["denoising_step_list"] + self.infer_steps = len(self.denoising_step_list) + self.sample_shift = self.config["sample_shift"] + + self.num_train_timesteps = 1000 + self.sigma_max = 1.0 + self.sigma_min = 0.0 + + def prepare(self, seed, latent_shape, image_encoder_output=None): + self.prepare_latents(seed, latent_shape, dtype=torch.float32) + self.set_denoising_timesteps(device=AI_DEVICE) + self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0], latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2])) + + def set_denoising_timesteps(self, device: Union[str, torch.device] = None): + sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) + self.sigmas = torch.linspace(sigma_start, self.sigma_min, self.num_train_timesteps + 1)[:-1] + self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas) + self.timesteps = self.sigmas * self.num_train_timesteps + + self.denoising_step_index = [self.num_train_timesteps - x for x in self.denoising_step_list] + self.timesteps = self.timesteps[self.denoising_step_index].to(device) + self.sigmas = self.sigmas[self.denoising_step_index].to("cpu") + + def reset(self, seed, latent_shape, step_index=None): + self.prepare_latents(seed, latent_shape, dtype=torch.float32) + + def add_noise(self, original_samples, noise, sigma): + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def step_post(self): + flow_pred = self.noise_pred.to(torch.float32) + sigma = self.sigmas[self.step_index].item() + noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred + if self.step_index < self.infer_steps - 1: + sigma_n = self.sigmas[self.step_index + 1].item() + noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n + self.latents = noisy_image_or_video.to(self.latents.dtype) + + +class Wan22StepDistillScheduler(WanStepDistillScheduler): + def __init__(self, config): + super().__init__(config) + self.boundary_step_index = config["boundary_step_index"] + + def set_denoising_timesteps(self, device: Union[str, torch.device] = None): + super().set_denoising_timesteps(device) + self.sigma_bound = self.sigmas[self.boundary_step_index].item() + + def calculate_alpha_beta_high(self, sigma): + alpha = (1 - sigma) / (1 - self.sigma_bound) + beta = math.sqrt(sigma**2 - (alpha * self.sigma_bound) ** 2) + return alpha, beta + + def step_post(self): + flow_pred = self.noise_pred.to(torch.float32) + sigma = self.sigmas[self.step_index].item() + noisy_image_or_video = self.latents.to(torch.float32) - flow_pred * sigma + # self.latent: x_t + if self.step_index < self.infer_steps - 1: + sigma_n = self.sigmas[self.step_index + 1].item() + noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n + + self.latents = noisy_image_or_video.to(self.latents.dtype) diff --git a/lightx2v/models/vfi/rife/model/loss.py b/lightx2v/models/vfi/rife/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..46607c2cb6861702e347ba91aeb5eede913a202f --- /dev/null +++ b/lightx2v/models/vfi/rife/model/loss.py @@ -0,0 +1,130 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class EPE(nn.Module): + def __init__(self): + super(EPE, self).__init__() + + def forward(self, flow, gt, loss_mask): + loss_map = (flow - gt.detach()) ** 2 + loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 + return loss_map * loss_mask + + +class Ternary(nn.Module): + def __init__(self): + super(Ternary, self).__init__() + patch_size = 7 + out_channels = patch_size * patch_size + self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) + self.w = np.transpose(self.w, (3, 2, 0, 1)) + self.w = torch.tensor(self.w).float().to(device) + + def transform(self, img): + patches = F.conv2d(img, self.w, padding=3, bias=None) + transf = patches - img + transf_norm = transf / torch.sqrt(0.81 + transf**2) + return transf_norm + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + def hamming(self, t1, t2): + dist = (t1 - t2) ** 2 + dist_norm = torch.mean(dist / (0.1 + dist), 1, True) + return dist_norm + + def valid_mask(self, t, padding): + n, _, h, w = t.size() + inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) + mask = F.pad(inner, [padding] * 4) + return mask + + def forward(self, img0, img1): + img0 = self.transform(self.rgb2gray(img0)) + img1 = self.transform(self.rgb2gray(img1)) + return self.hamming(img0, img1) * self.valid_mask(img0, 1) + + +class SOBEL(nn.Module): + def __init__(self): + super(SOBEL, self).__init__() + self.kernelX = torch.tensor( + [ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1], + ] + ).float() + self.kernelY = self.kernelX.clone().T + self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) + self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) + + def forward(self, pred, gt): + N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] + img_stack = torch.cat([pred.reshape(N * C, 1, H, W), gt.reshape(N * C, 1, H, W)], 0) + sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) + sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) + pred_X, gt_X = sobel_stack_x[: N * C], sobel_stack_x[N * C :] + pred_Y, gt_Y = sobel_stack_y[: N * C], sobel_stack_y[N * C :] + + L1X, L1Y = torch.abs(pred_X - gt_X), torch.abs(pred_Y - gt_Y) + loss = L1X + L1Y + return loss + + +class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): + c = len(data_mean) + super(MeanShift, self).__init__(c, c, kernel_size=1) + std = torch.Tensor(data_std) + self.weight.data = torch.eye(c).view(c, c, 1, 1) + if norm: + self.weight.data.div_(std.view(c, 1, 1, 1)) + self.bias.data = -1 * data_range * torch.Tensor(data_mean) + self.bias.data.div_(std) + else: + self.weight.data.mul_(std.view(c, 1, 1, 1)) + self.bias.data = data_range * torch.Tensor(data_mean) + self.requires_grad = False + + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, rank=0): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + pretrained = True + self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features + self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X, Y, indices=None): + X = self.normalize(X) + Y = self.normalize(Y) + indices = [2, 7, 12, 21, 30] + weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5] + k = 0 + loss = 0 + for i in range(indices[-1]): + X = self.vgg_pretrained_features[i](X) + Y = self.vgg_pretrained_features[i](Y) + if (i + 1) in indices: + loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1 + k += 1 + return loss + + +if __name__ == "__main__": + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal(0, 1, (3, 3, 256, 256))).float().to(device) + ternary_loss = Ternary() + print(ternary_loss(img0, img1).shape) diff --git a/lightx2v/models/vfi/rife/model/pytorch_msssim/__init__.py b/lightx2v/models/vfi/rife/model/pytorch_msssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcd17450723e513f2ebd92de7e683d91ee19575 --- /dev/null +++ b/lightx2v/models/vfi/rife/model/pytorch_msssim/__init__.py @@ -0,0 +1,204 @@ +from math import exp + +import numpy as np +import torch +import torch.nn.functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + + +def create_window_3d(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()) + _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) + return window + + +def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, channel, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window(real_size, channel=channel).to(img1.device) + + # mu1 = F.conv2d(img1, window, padding=padd, groups=channel) + # mu2 = F.conv2d(img2, window, padding=padd, groups=channel) + mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel) + mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq + sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, _, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window_3d(real_size, channel=1).to(img1.device) + # Channel is set to 1 since we consider color images as volumetric images + + img1 = img1.unsqueeze(1) + img2 = img2.unsqueeze(1) + + mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) + mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq + sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq + sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): + device = img1.device + weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) + levels = weights.size()[0] + mssim = [] + mcs = [] + for _ in range(levels): + sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) + mssim.append(sim) + mcs.append(cs) + + img1 = F.avg_pool2d(img1, (2, 2)) + img2 = F.avg_pool2d(img2, (2, 2)) + + mssim = torch.stack(mssim) + mcs = torch.stack(mcs) + + # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) + if normalize: + mssim = (mssim + 1) / 2 + mcs = (mcs + 1) / 2 + + pow1 = mcs**weights + pow2 = mssim**weights + # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ + output = torch.prod(pow1[:-1] * pow2[-1]) + return output + + +# Classes to re-use window +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, val_range=None): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.val_range = val_range + + # Assume 3 channel for SSIM + self.channel = 3 + self.window = create_window(window_size, channel=self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) + self.window = window + self.channel = channel + + _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) + dssim = (1 - _ssim) / 2 + return dssim + + +class MSSSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, channel=3): + super(MSSSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = channel + + def forward(self, img1, img2): + return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) diff --git a/lightx2v/models/vfi/rife/model/warplayer.py b/lightx2v/models/vfi/rife/model/warplayer.py new file mode 100644 index 0000000000000000000000000000000000000000..096d48fedea6a7ba5e9a5bab1fe5732f64b4d707 --- /dev/null +++ b/lightx2v/models/vfi/rife/model/warplayer.py @@ -0,0 +1,17 @@ +import torch + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True) diff --git a/lightx2v/models/vfi/rife/rife_comfyui_wrapper.py b/lightx2v/models/vfi/rife/rife_comfyui_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..0911c725947e6b5c7964ca0853b583d780f4687f --- /dev/null +++ b/lightx2v/models/vfi/rife/rife_comfyui_wrapper.py @@ -0,0 +1,138 @@ +import os +from typing import List, Optional, Tuple + +import torch +from torch.nn import functional as F + +from lightx2v.utils.profiler import * + + +class RIFEWrapper: + """Wrapper for RIFE model to work with ComfyUI Image tensors""" + + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + def __init__(self, model_path, device: Optional[torch.device] = None): + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Setup torch for optimal performance + torch.set_grad_enabled(False) + if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # Load model + from .train_log.RIFE_HDv3 import Model + + self.model = Model() + with ProfilingContext4DebugL2("Load RIFE model"): + self.model.load_model(model_path, -1) + self.model.eval() + self.model.device() + + @ProfilingContext4DebugL2("Interpolate frames") + def interpolate_frames( + self, + images: torch.Tensor, + source_fps: float, + target_fps: float, + scale: float = 1.0, + ) -> torch.Tensor: + """ + Interpolate frames from source FPS to target FPS + + Args: + images: ComfyUI Image tensor [N, H, W, C] in range [0, 1] + source_fps: Source frame rate + target_fps: Target frame rate + scale: Scale factor for processing + + Returns: + Interpolated ComfyUI Image tensor [M, H, W, C] in range [0, 1] + """ + # Validate input + assert images.dim() == 4 and images.shape[-1] == 3, "Input must be [N, H, W, C] with C=3" + + if source_fps == target_fps: + return images + + total_source_frames = images.shape[0] + height, width = images.shape[1:3] + + # Calculate padding for model + tmp = max(128, int(128 / scale)) + ph = ((height - 1) // tmp + 1) * tmp + pw = ((width - 1) // tmp + 1) * tmp + padding = (0, pw - width, 0, ph - height) + + # Calculate target frame positions + frame_positions = self._calculate_target_frame_positions(source_fps, target_fps, total_source_frames) + + # Prepare output tensor + output_frames = [] + + for source_idx1, source_idx2, interp_factor in frame_positions: + if interp_factor == 0.0 or source_idx1 == source_idx2: + # No interpolation needed, use the source frame directly + output_frames.append(images[source_idx1]) + else: + # Get frames to interpolate + frame1 = images[source_idx1] + frame2 = images[source_idx2] + + # Convert ComfyUI format [H, W, C] to RIFE format [1, C, H, W] + # Also convert from [0, 1] to [0, 1] (already in correct range) + I0 = frame1.permute(2, 0, 1).unsqueeze(0).to(self.device) + I1 = frame2.permute(2, 0, 1).unsqueeze(0).to(self.device) + + # Pad images + I0 = F.pad(I0, padding) + I1 = F.pad(I1, padding) + + # Perform interpolation + with torch.no_grad(): + interpolated = self.model.inference(I0, I1, timestep=interp_factor, scale=scale) + + # Convert back to ComfyUI format [H, W, C] + # Crop to original size and permute dimensions + interpolated_frame = interpolated[0, :, :height, :width].permute(1, 2, 0).cpu() + output_frames.append(interpolated_frame) + + # Stack all frames + return torch.stack(output_frames, dim=0) + + def _calculate_target_frame_positions(self, source_fps: float, target_fps: float, total_source_frames: int) -> List[Tuple[int, int, float]]: + """ + Calculate which frames need to be generated for the target frame rate. + + Returns: + List of (source_frame_index1, source_frame_index2, interpolation_factor) tuples + """ + frame_positions = [] + + # Calculate the time duration of the video + duration = (total_source_frames - 1) / source_fps + + # Calculate number of target frames + total_target_frames = int(duration * target_fps) + 1 + + for target_idx in range(total_target_frames): + # Calculate the time position of this target frame + target_time = target_idx / target_fps + + # Calculate the corresponding position in source frames + source_position = target_time * source_fps + + # Find the two source frames to interpolate between + source_idx1 = int(source_position) + source_idx2 = min(source_idx1 + 1, total_source_frames - 1) + + # Calculate interpolation factor (0 means use frame1, 1 means use frame2) + if source_idx1 == source_idx2: + interpolation_factor = 0.0 + else: + interpolation_factor = source_position - source_idx1 + + frame_positions.append((source_idx1, source_idx2, interpolation_factor)) + + return frame_positions diff --git a/lightx2v/models/vfi/rife/train_log/IFNet_HDv3.py b/lightx2v/models/vfi/rife/train_log/IFNet_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..099bc4a1180d8d35cda26db04c71788f456b6320 --- /dev/null +++ b/lightx2v/models/vfi/rife/train_log/IFNet_HDv3.py @@ -0,0 +1,214 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..model.warplayer import warp + +# from train_log.refine import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.LeakyReLU(0.2, True), + ) + + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True), + ) + + +class Head(nn.Module): + def __init__(self): + super(Head, self).__init__() + self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) + self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x, feat=False): + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + if feat: + return [x0, x1, x2, x3] + return x3 + + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + feat = tmp[:, 5:] + return flow, mask, feat + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7 + 8, c=192) + self.block1 = IFBlock(8 + 4 + 8 + 8, c=128) + self.block2 = IFBlock(8 + 4 + 8 + 8, c=96) + self.block3 = IFBlock(8 + 4 + 8 + 8, c=64) + self.block4 = IFBlock(8 + 4 + 8 + 8, c=32) + self.encode = Head() + + # not used during inference + """ + self.teacher = IFBlock(8+4+8+3+8, c=64) + self.caltime = nn.Sequential( + nn.Conv2d(16+9, 8, 3, 2, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, 64, 3, 2, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(64, 64, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(64, 64, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(64, 1, 3, 1, 1), + nn.Sigmoid() + ) + """ + + def forward( + self, + x, + timestep=0.5, + scale_list=[8, 4, 2, 1], + training=False, + fastmode=True, + ensemble=False, + ): + if not training: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + loss_cons = 0 + block = [self.block0, self.block1, self.block2, self.block3, self.block4] + for i in range(5): + if flow is None: + flow, mask, feat = block[i]( + torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), + None, + scale=scale_list[i], + ) + if ensemble: + print("warning: ensemble is not supported since RIFEv4.21") + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0, feat = block[i]( + torch.cat( + ( + warped_img0[:, :3], + warped_img1[:, :3], + wf0, + wf1, + timestep, + mask, + feat, + ), + 1, + ), + flow, + scale=scale_list[i], + ) + if ensemble: + print("warning: ensemble is not supported since RIFEv4.21") + else: + mask = m0 + flow = flow + fd + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask = torch.sigmoid(mask) + merged[4] = warped_img0 * mask + warped_img1 * (1 - mask) + if not fastmode: + print("contextnet is removed") + """ + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[4] = torch.clamp(merged[4] + res, 0, 1) + """ + return flow_list, mask_list[4], merged diff --git a/lightx2v/models/vfi/rife/train_log/RIFE_HDv3.py b/lightx2v/models/vfi/rife/train_log/RIFE_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..29d5caa8077de7600c661c4e0b9bf9b958770d3d --- /dev/null +++ b/lightx2v/models/vfi/rife/train_log/RIFE_HDv3.py @@ -0,0 +1,85 @@ +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW + +from ..model.loss import * +from .IFNet_HDv3 import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.device() + self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) + self.epe = EPE() + self.version = 4.25 + # self.vgg = VGGPerceptualLoss().to(device) + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def load_model(self, path, rank=0): + def convert(param): + if rank == -1: + return {k.replace("module.", ""): v for k, v in param.items() if "module." in k} + else: + return param + + if rank <= 0: + if torch.cuda.is_available(): + self.flownet.load_state_dict(convert(torch.load(path)), False) + else: + self.flownet.load_state_dict( + convert(torch.load(path, map_location="cpu")), + False, + ) + + def save_model(self, path, rank=0): + if rank == 0: + torch.save(self.flownet.state_dict(), "{}/flownet.pkl".format(path)) + + def inference(self, img0, img1, timestep=0.5, scale=1.0): + imgs = torch.cat((img0, img1), 1) + scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] + flow, mask, merged = self.flownet(imgs, timestep, scale_list) + return merged[-1] + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group["lr"] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if training: + self.train() + else: + self.eval() + scale = [16, 8, 4, 2, 1] + flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) + loss_l1 = (merged[-1] - gt).abs().mean() + loss_smooth = self.sobel(flow[-1], flow[-1] * 0).mean() + # loss_vgg = self.vgg(merged[-1], gt) + if training: + self.optimG.zero_grad() + loss_G = loss_l1 + loss_cons + loss_smooth * 0.1 + loss_G.backward() + self.optimG.step() + else: + flow_teacher = flow[2] + return merged[-1], { + "mask": mask, + "flow": flow[-1][:, :2], + "loss_l1": loss_l1, + "loss_cons": loss_cons, + "loss_smooth": loss_smooth, + } diff --git a/lightx2v/models/vfi/rife/train_log/refine.py b/lightx2v/models/vfi/rife/train_log/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..190e1d41aed46300c02d947d1d96f90c27119cfc --- /dev/null +++ b/lightx2v/models/vfi/rife/train_log/refine.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..model.warplayer import warp + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.LeakyReLU(0.2, True), + ) + + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=4, + stride=2, + padding=1, + bias=True, + ), + nn.LeakyReLU(0.2, True), + ) + + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +c = 16 + + +class Contextnet(nn.Module): + def __init__(self): + super(Contextnet, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x, flow): + x = self.conv1(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f2 = warp(x, flow) + x = self.conv3(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f3 = warp(x, flow) + x = self.conv4(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f4 = warp(x, flow) + return [f1, f2, f3, f4] + + +class Unet(nn.Module): + def __init__(self): + super(Unet, self).__init__() + self.down0 = Conv2(17, 2 * c) + self.down1 = Conv2(4 * c, 4 * c) + self.down2 = Conv2(8 * c, 8 * c) + self.down3 = Conv2(16 * c, 16 * c) + self.up0 = deconv(32 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = nn.Conv2d(c, 3, 3, 1, 1) + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/lightx2v/models/video_encoders/__init__.py b/lightx2v/models/video_encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/video_encoders/hf/__init__.py b/lightx2v/models/video_encoders/hf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/video_encoders/hf/hunyuanvideo15/__init__.py b/lightx2v/models/video_encoders/hf/hunyuanvideo15/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py b/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..52f013e69ed6d9bd981f0e00ac868c2a55b7f400 --- /dev/null +++ b/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py @@ -0,0 +1,910 @@ +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.autoencoders.vae import BaseOutput, DiagonalGaussianDistribution +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange +from torch import Tensor, nn + +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +@dataclass +class DecoderOutput(BaseOutput): + sample: torch.FloatTensor + posterior: Optional[DiagonalGaussianDistribution] = None + + +def swish(x: Tensor) -> Tensor: + """Applies the swish activation function.""" + return x * torch.sigmoid(x) + + +def forward_with_checkpointing(module, *inputs, use_checkpointing=False): + """Forward with optional gradient checkpointing.""" + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if use_checkpointing: + return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False) + else: + return module(*inputs) + + +# Optimized implementation of CogVideoXSafeConv3d +# https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py#L38 +class PatchCausalConv3d(nn.Conv3d): + r"""Causal Conv3d with efficient patch processing for large tensors.""" + + def find_split_indices(self, seq_len, part_num): + ideal_interval = seq_len / part_num + possible_indices = list(range(0, seq_len, self.stride[0])) + selected_indices = [] + + for i in range(1, part_num): + closest = min(possible_indices, key=lambda x: abs(x - round(i * ideal_interval))) + if closest not in selected_indices: + selected_indices.append(closest) + + merged_indices = [] + prev_idx = 0 + for idx in selected_indices: + if idx - prev_idx >= self.kernel_size[0]: + merged_indices.append(idx) + prev_idx = idx + + return merged_indices + + def forward(self, input): + T = input.shape[2] # input: NCTHW + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + if T > self.kernel_size[0] and memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + split_indices = self.find_split_indices(T, part_num) + input_chunks = torch.tensor_split(input, split_indices, dim=2) if len(split_indices) > 0 else [input] + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) for i in range(1, len(input_chunks))] + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super().forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super().forward(input) + + +class RMS_norm(nn.Module): + """Root Mean Square Layer Normalization for Channel-First or Last""" + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Conv3d(nn.Conv3d): + """Perform Conv3d on patches with memory-efficient symmetric padding.""" + + def forward(self, input): + B, C, T, H, W = input.shape + memory_count = (C * T * H * W) * 2 / 1024**3 + n_split = math.ceil(memory_count / 2) + if memory_count > 2 and input.shape[-3] % n_split == 0: + chunks = torch.chunk(input, chunks=n_split, dim=-3) + padded_chunks = [] + for i in range(len(chunks)): + if self.padding[0] > 0: + padded_chunk = F.pad( + chunks[i], + (0, 0, 0, 0, self.padding[0], self.padding[0]), + mode="constant" if self.padding_mode == "zeros" else self.padding_mode, + value=0, + ) + if i > 0: + padded_chunk[:, :, : self.padding[0]] = chunks[i - 1][:, :, -self.padding[0] :] + if i < len(chunks) - 1: + padded_chunk[:, :, -self.padding[0] :] = chunks[i + 1][:, :, : self.padding[0]] + else: + padded_chunk = chunks[i] + padded_chunks.append(padded_chunk) + padding_bak = self.padding + self.padding = (0, self.padding[1], self.padding[2]) + outputs = [] + for chunk in padded_chunks: + outputs.append(super().forward(chunk)) + self.padding = padding_bak + return torch.cat(outputs, dim=-3) + else: + return super().forward(input) + + +class CausalConv3d(nn.Module): + """Causal Conv3d with configurable padding for temporal axis.""" + + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + pad_mode="replicate", + disable_causal=False, + enable_patch_conv=False, + **kwargs, + ): + super().__init__() + + self.pad_mode = pad_mode + if disable_causal: + padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2) + else: + padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T + self.time_causal_padding = padding + + if enable_patch_conv: + self.conv = PatchCausalConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + else: + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): + """Prepare a causal attention mask for 3D videos. + + Args: + n_frame (int): Number of frames (temporal length). + n_hw (int): Product of height and width. + dtype: Desired mask dtype. + device: Device for the mask. + batch_size (int, optional): If set, expands for batch. + + Returns: + torch.Tensor: Causal attention mask. + """ + seq_len = n_frame * n_hw + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // n_hw + mask[i, : (i_frame + 1) * n_hw] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class AttnBlock(nn.Module): + """Self-attention block for 3D video tensors.""" + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = RMS_norm(in_channels, images=False) + + self.q = Conv3d(in_channels, in_channels, kernel_size=1) + self.k = Conv3d(in_channels, in_channels, kernel_size=1) + self.v = Conv3d(in_channels, in_channels, kernel_size=1) + self.proj_out = Conv3d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, f, h, w = q.shape + q = rearrange(q, "b c f h w -> b 1 (f h w) c").contiguous() + k = rearrange(k, "b c f h w -> b 1 (f h w) c").contiguous() + v = rearrange(v, "b c f h w -> b 1 (f h w) c").contiguous() + attention_mask = prepare_causal_attention_mask(f, h * w, h_.dtype, h_.device, batch_size=b) + h_ = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask.unsqueeze(1)) + + return rearrange(h_, "b 1 (f h w) c -> b c f h w", f=f, h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + """ResNet-style block for 3D video tensors.""" + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = RMS_norm(in_channels, images=False) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3) + + self.norm2 = RMS_norm(out_channels, images=False) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3) + if self.in_channels != self.out_channels: + self.nin_shortcut = Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2 + assert out_channels % factor == 0 + self.conv = CausalConv3d(in_channels, out_channels // factor, kernel_size=3) + self.add_temporal_downsample = add_temporal_downsample + self.group_size = factor * in_channels // out_channels + + def forward(self, x: Tensor): + r1 = 2 if self.add_temporal_downsample else 1 + h = self.conv(x) + if self.add_temporal_downsample: + h_first = h[:, :, :1, :, :] + h_first = rearrange(h_first, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2) + h_first = torch.cat([h_first, h_first], dim=1) + h_next = h[:, :, 1:, :, :] + h_next = rearrange(h_next, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = rearrange(x_first, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2) + B, C, T, H, W = x_first.shape + x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2) + + x_next = x[:, :, 1:, :, :] + x_next = rearrange(x_next, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2) + B, C, T, H, W = x_next.shape + x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + shortcut = torch.cat([x_first, x_next], dim=2) + else: + h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2) + shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2) + B, C, T, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + + return h + shortcut + + +class Upsample(nn.Module): + """Hierarchical upsampling with temporal/ spatial support.""" + + def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2 + self.conv = CausalConv3d(in_channels, out_channels * factor, kernel_size=3) + self.add_temporal_upsample = add_temporal_upsample + self.repeats = factor * out_channels // in_channels + + def forward(self, x: Tensor): + r1 = 2 if self.add_temporal_upsample else 1 + h = self.conv(x) + if self.add_temporal_upsample: + h_first = h[:, :, :1, :, :] + h_first = rearrange(h_first, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2) + h_first = h_first[:, : h_first.shape[1] // 2] + h_next = h[:, :, 1:, :, :] + h_next = rearrange(h_next, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = rearrange(x_first, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2) + x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1) + + x_next = x[:, :, 1:, :, :] + x_next = rearrange(x_next, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2) + x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = torch.cat([x_first, x_next], dim=2) + + else: + h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2) + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = rearrange(shortcut, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2) + return h + shortcut + + +class Encoder(nn.Module): + """Hierarchical video encoder with temporal and spatial factorization.""" + + def __init__( + self, + in_channels: int, + z_channels: int, + block_out_channels: Tuple[int, ...], + num_res_blocks: int, + ffactor_spatial: int, + ffactor_temporal: int, + downsample_match_channel: bool = True, + ): + super().__init__() + assert block_out_channels[-1] % (2 * z_channels) == 0 + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + # downsampling + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3) + + self.down = nn.ModuleList() + block_in = block_out_channels[0] + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + + add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial)) + add_temporal_downsample = add_spatial_downsample and bool(i_level >= np.log2(ffactor_spatial // ffactor_temporal)) + if add_spatial_downsample or add_temporal_downsample: + assert i_level < len(block_out_channels) - 1 + block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in + down.downsample = Downsample(block_in, block_out, add_temporal_downsample) + block_in = block_out + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = RMS_norm(block_in, images=False) + self.conv_out = CausalConv3d(block_in, 2 * z_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through the encoder.""" + use_checkpointing = bool(self.training and self.gradient_checkpointing) + + # downsampling + h = self.conv_in(x) + for i_level in range(len(self.block_out_channels)): + for i_block in range(self.num_res_blocks): + h = forward_with_checkpointing(self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing) + if hasattr(self.down[i_level], "downsample"): + h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing) + + # middle + h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing) + h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing) + h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing) + + # end + group_size = self.block_out_channels[-1] // (2 * self.z_channels) + shortcut = rearrange(h, "b (c r) f h w -> b c r f h w", r=group_size).mean(dim=2) + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + h += shortcut + return h + + +class Decoder(nn.Module): + """Hierarchical video decoder with upsampling factories.""" + + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: Tuple[int, ...], + num_res_blocks: int, + ffactor_spatial: int, + ffactor_temporal: int, + upsample_match_channel: bool = True, + ): + super().__init__() + assert block_out_channels[0] % z_channels == 0 + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + block_in = block_out_channels[0] + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + + add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial)) + add_temporal_upsample = bool(i_level < np.log2(ffactor_temporal)) + if add_spatial_upsample or add_temporal_upsample: + assert i_level < len(block_out_channels) - 1 + block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in + up.upsample = Upsample(block_in, block_out, add_temporal_upsample) + block_in = block_out + self.up.append(up) + + # end + self.norm_out = RMS_norm(block_in, images=False) + self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, z: Tensor) -> Tensor: + """Forward pass through the decoder.""" + use_checkpointing = bool(self.training and self.gradient_checkpointing) + + # z to block_in + repeats = self.block_out_channels[0] // (self.z_channels) + h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # middle + h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing) + h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing) + h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing) + + # upsampling + for i_level in range(len(self.block_out_channels)): + for i_block in range(self.num_res_blocks + 1): + h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing) + if hasattr(self.up[i_level], "upsample"): + h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class AutoencoderKLConv3D(ModelMixin, ConfigMixin): + """KL regularized 3D Conv VAE with advanced tiling and slicing strategies.""" + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int, + out_channels: int, + latent_channels: int, + block_out_channels: Tuple[int, ...], + layers_per_block: int, + ffactor_spatial: int, + ffactor_temporal: int, + sample_size: int, + sample_tsize: int, + scaling_factor: float = None, + shift_factor: Optional[float] = None, + downsample_match_channel: bool = True, + upsample_match_channel: bool = True, + spatial_compression_ratio: int = 16, + time_compression_ratio: int = 4, + ): + super().__init__() + self.ffactor_spatial = ffactor_spatial + self.ffactor_temporal = ffactor_temporal + self.scaling_factor = scaling_factor + self.shift_factor = shift_factor + + self.encoder = Encoder( + in_channels=in_channels, + z_channels=latent_channels, + block_out_channels=block_out_channels, + num_res_blocks=layers_per_block, + ffactor_spatial=ffactor_spatial, + ffactor_temporal=ffactor_temporal, + downsample_match_channel=downsample_match_channel, + ) + self.decoder = Decoder( + z_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + num_res_blocks=layers_per_block, + ffactor_spatial=ffactor_spatial, + ffactor_temporal=ffactor_temporal, + upsample_match_channel=upsample_match_channel, + ) + + self.use_slicing = False + self.use_spatial_tiling = False + self.use_temporal_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = sample_size + self.tile_latent_min_size = sample_size // ffactor_spatial + self.tile_sample_min_tsize = sample_tsize + self.tile_latent_min_tsize = sample_tsize // ffactor_temporal + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + """Enable or disable gradient checkpointing on encoder and decoder.""" + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_temporal_tiling(self, use_tiling: bool = True): + self.use_temporal_tiling = use_tiling + + def disable_temporal_tiling(self): + self.enable_temporal_tiling(False) + + def enable_spatial_tiling(self, use_tiling: bool = True): + self.use_spatial_tiling = use_tiling + + def disable_spatial_tiling(self): + self.enable_spatial_tiling(False) + + def enable_tiling(self, use_tiling: bool = True): + self.enable_spatial_tiling(use_tiling) + + def disable_tiling(self): + self.disable_spatial_tiling() + + def enable_slicing(self): + self.use_slicing = True + + def disable_slicing(self): + self.use_slicing = False + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int): + """Blend tensor b horizontally into a at blend_extent region.""" + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int): + """Blend tensor b vertically into a at blend_extent region.""" + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int): + """Blend tensor b temporally into a at blend_extent region.""" + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent) + return b + + def spatial_tiled_encode(self, x: torch.Tensor): + """Tiled spatial encoding for large inputs via overlapping.""" + B, C, T, H, W = x.shape + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + rows = [] + for i in range(0, H, overlap_size): + row = [] + for j in range(0, W, overlap_size): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + moments = torch.cat(result_rows, dim=-2) + return moments + + def temporal_tiled_encode(self, x: torch.Tensor): + """Tiled temporal encoding for large video sequences.""" + B, C, T, H, W = x.shape + overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_latent_min_tsize - blend_extent + + row = [] + for i in range(0, T, overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :] + if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size): + tile = self.spatial_tiled_encode(tile) + else: + tile = self.encoder(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + else: + result_row.append(tile[:, :, : t_limit + 1, :, :]) + moments = torch.cat(result_row, dim=-3) + return moments + + def spatial_tiled_decode(self, z: torch.Tensor): + """Tiled spatial decoding for large latent maps.""" + B, C, T, H, W = z.shape + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + rows = [] + for i in range(0, H, overlap_size): + row = [] + for j in range(0, W, overlap_size): + tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=-2) + return dec + + def temporal_tiled_decode(self, z: torch.Tensor): + """Tiled temporal decoding for long sequence latents.""" + B, C, T, H, W = z.shape + overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_sample_min_tsize - blend_extent + assert 0 < overlap_size < self.tile_latent_min_tsize + + row = [] + for i in range(0, T, overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :] + if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size): + decoded = self.spatial_tiled_decode(tile) + else: + decoded = self.decoder(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + else: + result_row.append(tile[:, :, : t_limit + 1, :, :]) + dec = torch.cat(result_row, dim=-3) + return dec + + @torch.no_grad() + def encode(self, x: Tensor, return_dict: bool = True): + if self.cpu_offload: + self.encoder = self.encoder.to(AI_DEVICE) + + def _encode(x): + if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize: + return self.temporal_tiled_encode(x) + if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.spatial_tiled_encode(x) + return self.encoder(x) + + assert len(x.shape) == 5 # (B, C, T, H, W) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [_encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = _encode(x) + posterior = DiagonalGaussianDistribution(h) + if self.cpu_offload: + self.encoder = self.encoder.to("cpu") + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + @torch.no_grad() + def decode(self, z: Tensor, return_dict: bool = True, generator=None): + if self.cpu_offload: + self.decoder = self.decoder.to(AI_DEVICE) + + def _decode(z): + if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize: + return self.temporal_tiled_decode(z) + if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.spatial_tiled_decode(z) + return self.decoder(z) + + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [_decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = _decode(z) + if self.cpu_offload: + self.decoder = self.decoder.to("cpu") + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + @torch.no_grad() + def forward(self, sample: torch.Tensor, sample_posterior: bool = False, return_posterior: bool = True, return_dict: bool = True): + """Forward autoencoder pass. Returns both reconstruction and optionally the posterior.""" + posterior = self.encode(sample).latent_dist + z = posterior.sample() if sample_posterior else posterior.mode() + dec = self.decode(z).sample + return DecoderOutput(sample=dec, posterior=posterior) if return_dict else (dec, posterior) + + +class HunyuanVideo15VAE: + def __init__(self, checkpoint_path=None, dtype=torch.float16, device="cuda", cpu_offload=False, parallel=False): + self.vae = AutoencoderKLConv3D.from_pretrained(os.path.join(checkpoint_path, "vae")).to(dtype).to(device) + self.vae.cpu_offload = cpu_offload + self.parallel = parallel + self.world_size_h, self.world_size_w = None, None + + @torch.no_grad() + def encode(self, x): + return self.vae.encode(x).latent_dist.mode() * self.vae.config.scaling_factor + + @torch.no_grad() + def decode(self, z): + z = z / self.vae.config.scaling_factor + + self.vae.enable_tiling() + if self.parallel and self.world_size_h is not None and self.world_size_w is not None: + video_frames = self.decode_dist_2d(z, self.world_size_h, self.world_size_w) + self.world_size_h, self.world_size_w = None, None + else: + video_frames = self.vae.decode(z, return_dict=False)[0] + self.vae.disable_tiling() + return video_frames + + @torch.no_grad() + def decode_dist_2d(self, z, world_size_h, world_size_w): + cur_rank = dist.get_rank() + cur_rank_h = cur_rank // world_size_w + cur_rank_w = cur_rank % world_size_w + total_h = z.shape[3] + total_w = z.shape[4] + + chunk_h = total_h // world_size_h + chunk_w = total_w // world_size_w + + padding_size = 1 + + # Calculate H dimension slice + if cur_rank_h == 0: + h_start = 0 + h_end = chunk_h + 2 * padding_size + elif cur_rank_h == world_size_h - 1: + h_start = total_h - (chunk_h + 2 * padding_size) + h_end = total_h + else: + h_start = cur_rank_h * chunk_h - padding_size + h_end = (cur_rank_h + 1) * chunk_h + padding_size + + # Calculate W dimension slice + if cur_rank_w == 0: + w_start = 0 + w_end = chunk_w + 2 * padding_size + elif cur_rank_w == world_size_w - 1: + w_start = total_w - (chunk_w + 2 * padding_size) + w_end = total_w + else: + w_start = cur_rank_w * chunk_w - padding_size + w_end = (cur_rank_w + 1) * chunk_w + padding_size + + # Extract the latent chunk for this process + zs_chunk = z[:, :, :, h_start:h_end, w_start:w_end].contiguous() + + # Decode the chunk + images_chunk = self.vae.decode(zs_chunk, return_dict=False)[0] + + # Remove padding from decoded chunk + spatial_ratio = 16 + if cur_rank_h == 0: + decoded_h_start = 0 + decoded_h_end = chunk_h * spatial_ratio + elif cur_rank_h == world_size_h - 1: + decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio + decoded_h_end = images_chunk.shape[3] + else: + decoded_h_start = padding_size * spatial_ratio + decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio + + if cur_rank_w == 0: + decoded_w_start = 0 + decoded_w_end = chunk_w * spatial_ratio + elif cur_rank_w == world_size_w - 1: + decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio + decoded_w_end = images_chunk.shape[4] + else: + decoded_w_start = padding_size * spatial_ratio + decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio + + images_chunk = images_chunk[:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end].contiguous() + + # Gather all chunks + total_processes = world_size_h * world_size_w + full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)] + + dist.all_gather(full_images, images_chunk) + + self.device_synchronize() + + # Reconstruct the full image tensor + image_rows = [] + for h_idx in range(world_size_h): + image_cols = [] + for w_idx in range(world_size_w): + process_idx = h_idx * world_size_w + w_idx + image_cols.append(full_images[process_idx]) + image_rows.append(torch.cat(image_cols, dim=4)) + + images = torch.cat(image_rows, dim=3) + + return images + + def device_synchronize( + self, + ): + torch_device_module.synchronize() + + +if __name__ == "__main__": + vae = HunyuanVideo15VAE(checkpoint_path="/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5", dtype=torch.float16, device="cuda") + z = torch.randn(1, 32, 31, 30, 53, dtype=torch.float16, device="cuda") + video_frames = vae.decode(z) + print(video_frames.shape) diff --git a/lightx2v/models/video_encoders/hf/hunyuanvideo15/lighttae_hy15.py b/lightx2v/models/video_encoders/hf/hunyuanvideo15/lighttae_hy15.py new file mode 100644 index 0000000000000000000000000000000000000000..74b319eaf5df8d8fb2c8d481bdc1493f09b61cd2 --- /dev/null +++ b/lightx2v/models/video_encoders/hf/hunyuanvideo15/lighttae_hy15.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +from lightx2v.models.video_encoders.hf.tae import TAEHV + + +class LightTaeHy15(nn.Module): + def __init__(self, vae_path="lighttae_hy1_5.pth", dtype=torch.bfloat16): + super().__init__() + self.dtype = dtype + self.taehv = TAEHV(vae_path, model_type="hy15", latent_channels=32, patch_size=2).to(self.dtype) + self.scaling_factor = 1.03682 + + @torch.no_grad() + def decode(self, latents, parallel=True, show_progress_bar=True, skip_trim=False): + latents = latents / self.scaling_factor + return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel, show_progress_bar).transpose(1, 2) diff --git a/lightx2v/models/video_encoders/hf/qwen_image/__init__.py b/lightx2v/models/video_encoders/hf/qwen_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/video_encoders/hf/qwen_image/vae.py b/lightx2v/models/video_encoders/hf/qwen_image/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..92595a89026e683b8f667be358cf9be076fd2930 --- /dev/null +++ b/lightx2v/models/video_encoders/hf/qwen_image/vae.py @@ -0,0 +1,140 @@ +import gc +import json +import os +from typing import Optional + +import torch + +from lightx2v_platform.base.global_var import AI_DEVICE + +try: + from diffusers import AutoencoderKLQwenImage + from diffusers.image_processor import VaeImageProcessor +except ImportError: + AutoencoderKLQwenImage = None + VaeImageProcessor = None + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class AutoencoderKLQwenImageVAE: + def __init__(self, config): + self.config = config + + self.cpu_offload = config.get("cpu_offload", False) + if self.cpu_offload: + self.device = torch.device("cpu") + else: + self.device = torch.device(AI_DEVICE) + self.dtype = torch.bfloat16 + self.latent_channels = config["vae_z_dim"] + self.load() + + def load(self): + self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(self.config["model_path"], "vae")).to(self.device).to(self.dtype) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2) + with open(os.path.join(self.config["model_path"], "vae", "config.json"), "r") as f: + vae_config = json.load(f) + self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8 + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batchsize, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batchsize, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batchsize, channels // (2 * 2), 1, height, width) + + return latents + + @torch.no_grad() + def decode(self, latents, input_info): + if self.cpu_offload: + self.model.to(torch.device("cuda")) + if self.config["task"] == "t2i": + width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]] + elif self.config["task"] == "i2i": + width, height = input_info.auto_width, input_info.auto_hight + latents = self._unpack_latents(latents, height, width, self.config["vae_scale_factor"]) + latents = latents.to(self.dtype) + latents_mean = torch.tensor(self.config["vae_latents_mean"]).view(1, self.config["vae_z_dim"], 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / torch.tensor(self.config["vae_latents_std"]).view(1, self.config["vae_z_dim"], 1, 1, 1).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + images = self.model.decode(latents, return_dict=False)[0][:, :, 0] + images = self.image_processor.postprocess(images, output_type="pil") + if self.cpu_offload: + self.model.to(torch.device("cpu")) + torch.cuda.empty_cache() + gc.collect() + return images + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batchsize, num_channels_latents, height, width): + latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4) + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [retrieve_latents(self.model.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(image.shape[0])] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.model.encode(image), generator=generator, sample_mode="argmax") + latents_mean = torch.tensor(self.model.config["latents_mean"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_std = torch.tensor(self.model.config["latents_std"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + @torch.no_grad() + def encode_vae_image(self, image, input_info): + if self.config["task"] == "i2i": + self.generator = torch.Generator().manual_seed(input_info.seed) + elif self.config["task"] == "t2i": + self.generator = torch.Generator(device="cuda").manual_seed(input_info.seed) + + if self.cpu_offload: + self.model.to(torch.device("cuda")) + + num_channels_latents = self.config["transformer_in_channels"] // 4 + image = image.to(self.model.device).to(self.dtype) + + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=self.generator) + else: + image_latents = image + if self.config["batchsize"] > image_latents.shape[0] and self.config["batchsize"] % image_latents.shape[0] == 0: + # expand init_latents for batchsize + additional_image_per_prompt = self.config["batchsize"] // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif self.config["batchsize"] > image_latents.shape[0] and self.config["batchsize"] % image_latents.shape[0] != 0: + raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {self.config['batchsize']} text prompts.") + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents(image_latents, self.config["batchsize"], num_channels_latents, image_latent_height, image_latent_width) + + if self.cpu_offload: + self.model.to(torch.device("cpu")) + torch.cuda.empty_cache() + gc.collect() + return image_latents diff --git a/lightx2v/models/video_encoders/hf/tae.py b/lightx2v/models/video_encoders/hf/tae.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf63e2ddfae61750f624125ae3eac2c296495cb --- /dev/null +++ b/lightx2v/models/video_encoders/hf/tae.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +""" +Tiny AutoEncoder for Hunyuan Video +(DNN for encoding / decoding videos to Hunyuan Video's latent space) +""" + +import os +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import load_file +from tqdm.auto import tqdm + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out, act_func): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = act_func + + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) + + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) + + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): + """ + Apply a sequential model with memblocks to the given input. + Args: + - model: nn.Sequential of blocks to apply + - x: input data, of dimensions NTCHW + - parallel: if True, parallelize over timesteps (fast but uses O(T) memory) + if False, each timestep will be processed sequentially (slow but uses O(1) memory) + - show_progress_bar: if True, enables tqdm progressbar display + + Returns NTCHW tensor of output data. + """ + assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor" + N, T, C, H, W = x.shape + if parallel: + x = x.reshape(N * T, C, H, W) + # parallel over input timesteps, iterate over blocks + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + NT, C, H, W = x.shape + T = NT // N + _x = x.reshape(N, T, C, H, W) + mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + NT, C, H, W = x.shape + T = NT // N + x = x.view(N, T, C, H, W) + else: + # TODO(oboerbohan): at least on macos this still gradually uses more memory during decode... + # need to fix :( + out = [] + # iterate over input timesteps and also iterate over blocks. + # because of the cursed TPool/TGrow blocks, this is not a nested loop, + # it's actually a ***graph traversal*** problem! so let's make a queue + work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))] + # in addition to manually managing our queue, we also need to manually manage our progressbar. + # we'll update it for every source node that we consume. + progress_bar = tqdm(range(T), disable=not show_progress_bar) + # we'll also need a separate addressable memory per node as well + mem = [None] * len(model) + while work_queue: + xt, i = work_queue.pop(0) + if i == 0: + # new source node consumed + progress_bar.update(1) + if i == len(model): + # reached end of the graph, append result to output list + out.append(xt) + else: + # fetch the block to process + b = model[i] + if isinstance(b, MemBlock): + # mem blocks are simple since we're visiting the graph in causal order + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt + else: + xt_new = b(xt, mem[i]) + mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though + # add successor to work queue + work_queue.insert(0, TWorkItem(xt_new, i + 1)) + elif isinstance(b, TPool): + # pool blocks are miserable + if mem[i] is None: + mem[i] = [] # pool memory is itself a queue of inputs to pool + mem[i].append(xt) + if len(mem[i]) > b.stride: + # pool mem is in invalid state, we should have pooled before this + raise ValueError("???") + elif len(mem[i]) < b.stride: + # pool mem is not yet full, go back to processing the work queue + pass + else: + # pool mem is ready, run the pool block + N, C, H, W = xt.shape + xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W)) + # reset the pool mem + mem[i] = [] + # add successor to work queue + work_queue.insert(0, TWorkItem(xt, i + 1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C, H, W = xt.shape + # each tgrow has multiple successor nodes + for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)): + # add successor to work queue + work_queue.insert(0, TWorkItem(xt_next, i + 1)) + else: + # normal block with no funny business + xt = b(xt) + # add successor to work queue + work_queue.insert(0, TWorkItem(xt, i + 1)) + progress_bar.close() + x = torch.stack(out, 1) + return x + + +class TAEHV(nn.Module): + def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), patch_size=1, latent_channels=16, model_type="wan21"): + """Initialize pretrained TAEHV from the given checkpoint. + + Arg: + checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1. + decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview. + decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview. + patch_size: input/output pixelshuffle patch-size for this model. + latent_channels: number of latent channels (z dim) for this model. + """ + super().__init__() + self.patch_size = patch_size + self.latent_channels = latent_channels + self.image_channels = 3 + self.is_cogvideox = checkpoint_path is not None and "taecvx" in checkpoint_path + # if checkpoint_path is not None and "taew2_2" in checkpoint_path: + # self.patch_size, self.latent_channels = 2, 48 + self.model_type = model_type + if model_type == "wan22": + self.patch_size, self.latent_channels = 2, 48 + if model_type == "hy15": + act_func = nn.LeakyReLU(0.2, inplace=True) + else: + act_func = nn.ReLU(inplace=True) + + self.encoder = nn.Sequential( + conv(self.image_channels * self.patch_size**2, 64), + act_func, + TPool(64, 2), + conv(64, 64, stride=2, bias=False), + MemBlock(64, 64, act_func), + MemBlock(64, 64, act_func), + MemBlock(64, 64, act_func), + TPool(64, 2), + conv(64, 64, stride=2, bias=False), + MemBlock(64, 64, act_func), + MemBlock(64, 64, act_func), + MemBlock(64, 64, act_func), + TPool(64, 1), + conv(64, 64, stride=2, bias=False), + MemBlock(64, 64, act_func), + MemBlock(64, 64, act_func), + MemBlock(64, 64, act_func), + conv(64, self.latent_channels), + ) + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( + Clamp(), + conv(self.latent_channels, n_f[0]), + act_func, + MemBlock(n_f[0], n_f[0], act_func), + MemBlock(n_f[0], n_f[0], act_func), + MemBlock(n_f[0], n_f[0], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), + TGrow(n_f[0], 1), + conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), + MemBlock(n_f[1], n_f[1], act_func), + MemBlock(n_f[1], n_f[1], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), + TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), + conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), + MemBlock(n_f[2], n_f[2], act_func), + MemBlock(n_f[2], n_f[2], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), + TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), + conv(n_f[2], n_f[3], bias=False), + act_func, + conv(n_f[3], self.image_channels * self.patch_size**2), + ) + if checkpoint_path is not None: + ext = os.path.splitext(checkpoint_path)[1].lower() + + if ext == ".pth": + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + elif ext == ".safetensors": + state_dict = load_file(checkpoint_path, device="cpu") + else: + raise ValueError(f"Unsupported checkpoint format: {ext}. Supported formats: .pth, .safetensors") + + self.load_state_dict(self.patch_tgrow_layers(state_dict)) + + def patch_tgrow_layers(self, sd): + """Patch TGrow layers to use a smaller kernel if needed. + + Args: + sd: state dict to patch + """ + new_sd = self.state_dict() + for i, layer in enumerate(self.decoder): + if isinstance(layer, TGrow): + key = f"decoder.{i}.conv.weight" + if sd[key].shape[0] > new_sd[key].shape[0]: + # take the last-timestep output channels + sd[key] = sd[key][-new_sd[key].shape[0] :] + return sd + + def encode_video(self, x, parallel=True, show_progress_bar=True): + """Encode a sequence of frames. + + Args: + x: input NTCHW RGB (C=3) tensor with values in [0, 1]. + parallel: if True, all frames will be processed at once. + (this is faster but may require more memory). + if False, frames will be processed sequentially. + Returns NTCHW latent tensor with ~Gaussian values. + """ + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) + if x.shape[1] % 4 != 0: + # pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar) + + def decode_video(self, x, parallel=True, show_progress_bar=True): + """Decode a sequence of frames. + + Args: + x: input NTCHW latent (C=12) tensor with ~Gaussian values. + parallel: if True, all frames will be processed at once. + (this is faster but may require more memory). + if False, frames will be processed sequentially. + Returns NTCHW RGB tensor with ~[0, 1] values. + """ + skip_trim = self.is_cogvideox and x.shape[1] % 2 == 0 + x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar) + if self.model_type == "hy15": + x = x.clamp_(-1, 1) + else: + x = x.clamp_(0, 1) + if self.patch_size > 1: + x = F.pixel_shuffle(x, self.patch_size) + if skip_trim: + # skip trimming for cogvideox to make frame counts match. + # this still doesn't have correct temporal alignment for certain frame counts + # (cogvideox seems to pad at the start?), but for multiple-of-4 it's fine. + return x + return x[:, self.frames_to_trim :] diff --git a/lightx2v/models/video_encoders/hf/vid_recon.py b/lightx2v/models/video_encoders/hf/vid_recon.py new file mode 100644 index 0000000000000000000000000000000000000000..972d3b27e6a716b6396eb4609cfcba6647b270ab --- /dev/null +++ b/lightx2v/models/video_encoders/hf/vid_recon.py @@ -0,0 +1,94 @@ +import argparse + +import cv2 +import torch +from loguru import logger + +from lightx2v.models.video_encoders.hf.wan.vae import WanVAE +from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE +from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny + + +class VideoTensorReader: + def __init__(self, video_file_path): + self.cap = cv2.VideoCapture(video_file_path) + assert self.cap.isOpened(), f"Could not load {video_file_path}" + self.fps = self.cap.get(cv2.CAP_PROP_FPS) + + def __iter__(self): + return self + + def __next__(self): + ret, frame = self.cap.read() + if not ret: + self.cap.release() + raise StopIteration # End of video or error + return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW + + +class VideoTensorWriter: + def __init__(self, video_file_path, width_height, fps=30): + self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, width_height) + assert self.writer.isOpened(), f"Could not create writer for {video_file_path}" + + def write(self, frame_tensor): + assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??" + self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC + + def __del__(self): + if hasattr(self, "writer"): + self.writer.release() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Encode and decode videos using the TaeHV model for reconstruction") + parser.add_argument("video_paths", nargs="+", help="Paths to input video files (multiple allowed)") + parser.add_argument("--checkpoint", "-c", help=f"Path to the model checkpoint file") + parser.add_argument("--device", "-d", default="cuda", help=f'Computing device (e.g., "cuda", "mps", "cpu"; default: auto-detect available device)') + parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float32"], help="Data type for model computation (default: bfloat16)") + parser.add_argument("--model_type", choices=["taew2_1", "taew2_2", "vaew2_1", "vaew2_2"], required=True, help="Type of the model to use (choices: taew2_1, taew2_2)") + parser.add_argument("--use_lightvae", default=False, action="store_true") + + args = parser.parse_args() + if args.use_lightvae: + assert args.model_type in ["vaew2_1"] + + if args.device: + dev = torch.device(args.device) + else: + dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + + dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} + model_map = {"taew2_1": WanVAE_tiny, "taew2_2": Wan2_2_VAE_tiny, "vaew2_1": WanVAE, "vaew2_2": Wan2_2_VAE} + + dtype = dtype_map[args.dtype] + + model_args = {"vae_path": args.checkpoint, "dtype": dtype, "device": dev} + if args.model_type in "vaew2_1": + model_args.update({"use_lightvae": args.use_lightvae}) + + model = model_map[args.model_type](**model_args) + if args.model_type.startswith("tae"): + model = model_map[args.model_type](**model_args).to(dev) + + # Process each input video + for idx, video_path in enumerate(args.video_paths): + logger.info(f"Processing video {video_path}...") + # Read video + video_in = VideoTensorReader(video_path) + video = torch.stack(list(video_in), 0)[None] # Add batch dimension + vid_dev = video.to(dev, dtype).div_(255.0) # Normalize to [0,1] + # Encode + vid_enc = model.encode_video(vid_dev) + if isinstance(vid_enc, tuple): + vid_enc = vid_enc[0] + # Decode + vid_dec = model.decode_video(vid_enc) + # Save reconstructed video + video_out_path = f"{video_path}.reconstructed_{idx}.mp4" + frame_size = (vid_dec.shape[-1], vid_dec.shape[-2]) + fps = int(round(video_in.fps)) + video_out = VideoTensorWriter(video_out_path, frame_size, fps) + for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]: + video_out.write(frame) + logger.info(f" Reconstructed video saved to {video_out_path}") diff --git a/lightx2v/models/video_encoders/hf/wan/__init__.py b/lightx2v/models/video_encoders/hf/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/models/video_encoders/hf/wan/vae.py b/lightx2v/models/video_encoders/hf/wan/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..6b870af6357b9924a65107b7141524878ebdc085 --- /dev/null +++ b/lightx2v/models/video_encoders/hf/wan/vae.py @@ -0,0 +1,1457 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from loguru import logger + +from lightx2v.utils.utils import load_weights +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +__all__ = [ + "WanVAE", +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x) + + +class Resample(nn.Module): + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat( + [torch.zeros_like(cache_x).to(cache_x.device), cache_x], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +class Encoder3d(nn.Module): + def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, pruning_rate=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + dims = [int(d * (1 - pruning_rate)) for d in dims] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0, pruning_rate=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + dims = [int(d * (1 - pruning_rate)) for d in dims] + + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, pruning_rate=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + pruning_rate, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + pruning_rate, + ) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def blend_v(self, a, b, blend_extent): + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a, b, blend_extent): + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x, scale): + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + mu, log_var = self.conv1(tile).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + + time.append(mu) + + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z, scale): + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.conv2(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + return dec + + def encode(self, x, scale, return_mu=False): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + + self.clear_cache() + if return_mu: + return mu, log_var + else: + return mu + + def decode(self, z, scale): + self.clear_cache() + + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + else: + out_ = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + out = torch.cat([out, out_], 2) + + self.clear_cache() + return out + + def decode_stream(self, z, scale): + self.clear_cache() + + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + out = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + yield out + + def cached_decode(self, z, scale): + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False, scale=[0, 1]): + mu, log_var = self.encode(imgs, scale, return_mu=True) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std), mu, log_var + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def encode_video(self, x, scale=[0, 1]): + assert x.ndim == 5 # NTCHW + assert x.shape[2] % 3 == 0 + x = x.transpose(1, 2) + y = x.mul(2).sub_(1) + y, mu, log_var = self.sample(y, scale=scale) + return y.transpose(1, 2).to(x), mu, log_var + + def decode_video(self, x, scale=[0, 1]): + assert x.ndim == 5 # NTCHW + assert x.shape[2] % self.z_dim == 0 + x = x.transpose(1, 2) + # B, C, T, H, W + y = x + y = self.decode(y, scale).clamp_(-1, 1) + y = y.mul_(0.5).add_(0.5).clamp_(0, 1) # NCTHW + return y.transpose(1, 2).to(x) + + +def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, pruning_rate=0.0, **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0, + pruning_rate=pruning_rate, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + # load checkpoint + weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0) + for k in weights_dict.keys(): + if weights_dict[k].dtype != dtype: + weights_dict[k] = weights_dict[k].to(dtype) + model.load_state_dict(weights_dict, assign=True) + + return model + + +class WanVAE: + def __init__( + self, + z_dim=16, + vae_path="cache/vae_step_411000.pth", + dtype=torch.float, + device="cuda", + parallel=False, + use_tiling=False, + cpu_offload=False, + use_2d_split=True, + load_from_rank0=False, + use_lightvae=False, + ): + self.dtype = dtype + self.device = device + self.parallel = parallel + self.use_tiling = use_tiling + self.cpu_offload = cpu_offload + self.use_2d_split = use_2d_split + if use_lightvae: + pruning_rate = 0.75 # 0.75 + else: + pruning_rate = 0.0 + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = torch.tensor(mean, dtype=dtype, device=AI_DEVICE) + self.inv_std = 1.0 / torch.tensor(std, dtype=dtype, device=AI_DEVICE) + self.scale = [self.mean, self.inv_std] + + # (height, width, world_size) -> (world_size_h, world_size_w) + self.grid_table = { + # world_size = 2 + (60, 104, 2): (1, 2), + (68, 120, 2): (1, 2), + (90, 160, 2): (1, 2), + (60, 60, 2): (1, 2), + (72, 72, 2): (1, 2), + (88, 88, 2): (1, 2), + (120, 120, 2): (1, 2), + (104, 60, 2): (2, 1), + (120, 68, 2): (2, 1), + (160, 90, 2): (2, 1), + # world_size = 4 + (60, 104, 4): (2, 2), + (68, 120, 4): (2, 2), + (90, 160, 4): (2, 2), + (60, 60, 4): (2, 2), + (72, 72, 4): (2, 2), + (88, 88, 4): (2, 2), + (120, 120, 4): (2, 2), + (104, 60, 4): (2, 2), + (120, 68, 4): (2, 2), + (160, 90, 4): (2, 2), + # world_size = 8 + (60, 104, 8): (2, 4), + (68, 120, 8): (2, 4), + (90, 160, 8): (2, 4), + (60, 60, 8): (2, 4), + (72, 72, 8): (2, 4), + (88, 88, 8): (2, 4), + (120, 120, 8): (2, 4), + (104, 60, 8): (4, 2), + (120, 68, 8): (4, 2), + (160, 90, 8): (4, 2), + } + + # init model + self.model = ( + _video_vae(pretrained_path=vae_path, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0, pruning_rate=pruning_rate) + .eval() + .requires_grad_(False) + .to(device) + .to(dtype) + ) + + def _calculate_2d_grid(self, latent_height, latent_width, world_size): + if (latent_height, latent_width, world_size) in self.grid_table: + best_h, best_w = self.grid_table[(latent_height, latent_width, world_size)] + # logger.info(f"Vae using cached 2D grid: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent") + return best_h, best_w + + best_h, best_w = 1, world_size + min_aspect_diff = float("inf") + + for h in range(1, world_size + 1): + if world_size % h == 0: + w = world_size // h + if latent_height % h == 0 and latent_width % w == 0: + # Calculate how close this grid is to square + aspect_diff = abs((latent_height / h) - (latent_width / w)) + if aspect_diff < min_aspect_diff: + min_aspect_diff = aspect_diff + best_h, best_w = h, w + # logger.info(f"Vae using 2D grid & Update cache: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent") + self.grid_table[(latent_height, latent_width, world_size)] = (best_h, best_w) + return best_h, best_w + + def current_device(self): + return next(self.model.parameters()).device + + def to_cpu(self): + self.model.encoder = self.model.encoder.to("cpu") + self.model.decoder = self.model.decoder.to("cpu") + self.model = self.model.to("cpu") + self.mean = self.mean.cpu() + self.inv_std = self.inv_std.cpu() + self.scale = [self.mean, self.inv_std] + + def to_cuda(self): + self.model.encoder = self.model.encoder.to(AI_DEVICE) + self.model.decoder = self.model.decoder.to(AI_DEVICE) + self.model = self.model.to(AI_DEVICE) + self.mean = self.mean.to(AI_DEVICE) + self.inv_std = self.inv_std.to(AI_DEVICE) + self.scale = [self.mean, self.inv_std] + + def encode_dist(self, video, world_size, cur_rank, split_dim): + spatial_ratio = 8 + + if split_dim == 3: + total_latent_len = video.shape[3] // spatial_ratio + elif split_dim == 4: + total_latent_len = video.shape[4] // spatial_ratio + else: + raise ValueError(f"Unsupported split_dim: {split_dim}") + + splited_chunk_len = total_latent_len // world_size + padding_size = 1 + + video_chunk_len = splited_chunk_len * spatial_ratio + video_padding_len = padding_size * spatial_ratio + + if cur_rank == 0: + if split_dim == 3: + video_chunk = video[:, :, :, : video_chunk_len + 2 * video_padding_len, :].contiguous() + elif split_dim == 4: + video_chunk = video[:, :, :, :, : video_chunk_len + 2 * video_padding_len].contiguous() + elif cur_rank == world_size - 1: + if split_dim == 3: + video_chunk = video[:, :, :, -(video_chunk_len + 2 * video_padding_len) :, :].contiguous() + elif split_dim == 4: + video_chunk = video[:, :, :, :, -(video_chunk_len + 2 * video_padding_len) :].contiguous() + else: + start_idx = cur_rank * video_chunk_len - video_padding_len + end_idx = (cur_rank + 1) * video_chunk_len + video_padding_len + if split_dim == 3: + video_chunk = video[:, :, :, start_idx:end_idx, :].contiguous() + elif split_dim == 4: + video_chunk = video[:, :, :, :, start_idx:end_idx].contiguous() + + if self.use_tiling: + encoded_chunk = self.model.tiled_encode(video_chunk, self.scale) + else: + encoded_chunk = self.model.encode(video_chunk, self.scale) + + if cur_rank == 0: + if split_dim == 3: + encoded_chunk = encoded_chunk[:, :, :, :splited_chunk_len, :].contiguous() + elif split_dim == 4: + encoded_chunk = encoded_chunk[:, :, :, :, :splited_chunk_len].contiguous() + elif cur_rank == world_size - 1: + if split_dim == 3: + encoded_chunk = encoded_chunk[:, :, :, -splited_chunk_len:, :].contiguous() + elif split_dim == 4: + encoded_chunk = encoded_chunk[:, :, :, :, -splited_chunk_len:].contiguous() + else: + if split_dim == 3: + encoded_chunk = encoded_chunk[:, :, :, padding_size:-padding_size, :].contiguous() + elif split_dim == 4: + encoded_chunk = encoded_chunk[:, :, :, :, padding_size:-padding_size].contiguous() + + full_encoded = [torch.empty_like(encoded_chunk) for _ in range(world_size)] + dist.all_gather(full_encoded, encoded_chunk) + + self.device_synchronize() + + encoded = torch.cat(full_encoded, dim=split_dim) + + return encoded.squeeze(0) + + def encode_dist_2d(self, video, world_size_h, world_size_w, cur_rank_h, cur_rank_w): + spatial_ratio = 8 + + # Calculate chunk sizes for both dimensions + total_latent_h = video.shape[3] // spatial_ratio + total_latent_w = video.shape[4] // spatial_ratio + + chunk_h = total_latent_h // world_size_h + chunk_w = total_latent_w // world_size_w + + padding_size = 1 + video_chunk_h = chunk_h * spatial_ratio + video_chunk_w = chunk_w * spatial_ratio + video_padding_h = padding_size * spatial_ratio + video_padding_w = padding_size * spatial_ratio + + # Calculate H dimension slice + if cur_rank_h == 0: + h_start = 0 + h_end = video_chunk_h + 2 * video_padding_h + elif cur_rank_h == world_size_h - 1: + h_start = video.shape[3] - (video_chunk_h + 2 * video_padding_h) + h_end = video.shape[3] + else: + h_start = cur_rank_h * video_chunk_h - video_padding_h + h_end = (cur_rank_h + 1) * video_chunk_h + video_padding_h + + # Calculate W dimension slice + if cur_rank_w == 0: + w_start = 0 + w_end = video_chunk_w + 2 * video_padding_w + elif cur_rank_w == world_size_w - 1: + w_start = video.shape[4] - (video_chunk_w + 2 * video_padding_w) + w_end = video.shape[4] + else: + w_start = cur_rank_w * video_chunk_w - video_padding_w + w_end = (cur_rank_w + 1) * video_chunk_w + video_padding_w + + # Extract the video chunk for this process + video_chunk = video[:, :, :, h_start:h_end, w_start:w_end].contiguous() + + # Encode the chunk + if self.use_tiling: + encoded_chunk = self.model.tiled_encode(video_chunk, self.scale) + else: + encoded_chunk = self.model.encode(video_chunk, self.scale) + + # Remove padding from encoded chunk + if cur_rank_h == 0: + encoded_h_start = 0 + encoded_h_end = chunk_h + elif cur_rank_h == world_size_h - 1: + encoded_h_start = encoded_chunk.shape[3] - chunk_h + encoded_h_end = encoded_chunk.shape[3] + else: + encoded_h_start = padding_size + encoded_h_end = encoded_chunk.shape[3] - padding_size + + if cur_rank_w == 0: + encoded_w_start = 0 + encoded_w_end = chunk_w + elif cur_rank_w == world_size_w - 1: + encoded_w_start = encoded_chunk.shape[4] - chunk_w + encoded_w_end = encoded_chunk.shape[4] + else: + encoded_w_start = padding_size + encoded_w_end = encoded_chunk.shape[4] - padding_size + + encoded_chunk = encoded_chunk[:, :, :, encoded_h_start:encoded_h_end, encoded_w_start:encoded_w_end].contiguous() + + # Gather all chunks + total_processes = world_size_h * world_size_w + full_encoded = [torch.empty_like(encoded_chunk) for _ in range(total_processes)] + + dist.all_gather(full_encoded, encoded_chunk) + + self.device_synchronize() + + # Reconstruct the full encoded tensor + encoded_rows = [] + for h_idx in range(world_size_h): + encoded_cols = [] + for w_idx in range(world_size_w): + process_idx = h_idx * world_size_w + w_idx + encoded_cols.append(full_encoded[process_idx]) + encoded_rows.append(torch.cat(encoded_cols, dim=4)) + + encoded = torch.cat(encoded_rows, dim=3) + + return encoded.squeeze(0) + + def encode(self, video, world_size_h=None, world_size_w=None): + """ + video: one video with shape [1, C, T, H, W]. + """ + if self.cpu_offload: + self.to_cuda() + + if self.parallel: + world_size = dist.get_world_size() + cur_rank = dist.get_rank() + height, width = video.shape[3], video.shape[4] + + if self.use_2d_split: + if world_size_h is None or world_size_w is None: + world_size_h, world_size_w = self._calculate_2d_grid(height // 8, width // 8, world_size) + cur_rank_h = cur_rank // world_size_w + cur_rank_w = cur_rank % world_size_w + out = self.encode_dist_2d(video, world_size_h, world_size_w, cur_rank_h, cur_rank_w) + else: + # Original 1D splitting logic + if width % world_size == 0: + out = self.encode_dist(video, world_size, cur_rank, split_dim=4) + elif height % world_size == 0: + out = self.encode_dist(video, world_size, cur_rank, split_dim=3) + else: + logger.info("Fall back to naive encode mode") + if self.use_tiling: + out = self.model.tiled_encode(video, self.scale).squeeze(0) + else: + out = self.model.encode(video, self.scale).squeeze(0) + else: + if self.use_tiling: + out = self.model.tiled_encode(video, self.scale).squeeze(0) + else: + out = self.model.encode(video, self.scale).squeeze(0) + + if self.cpu_offload: + self.to_cpu() + return out + + def decode_dist(self, zs, world_size, cur_rank, split_dim): + splited_total_len = zs.shape[split_dim] + splited_chunk_len = splited_total_len // world_size + padding_size = 1 + + if cur_rank == 0: + if split_dim == 2: + zs = zs[:, :, : splited_chunk_len + 2 * padding_size, :].contiguous() + elif split_dim == 3: + zs = zs[:, :, :, : splited_chunk_len + 2 * padding_size].contiguous() + elif cur_rank == world_size - 1: + if split_dim == 2: + zs = zs[:, :, -(splited_chunk_len + 2 * padding_size) :, :].contiguous() + elif split_dim == 3: + zs = zs[:, :, :, -(splited_chunk_len + 2 * padding_size) :].contiguous() + else: + if split_dim == 2: + zs = zs[:, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size, :].contiguous() + elif split_dim == 3: + zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous() + + decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode + images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1) + + if cur_rank == 0: + if split_dim == 2: + images = images[:, :, :, : splited_chunk_len * 8, :].contiguous() + elif split_dim == 3: + images = images[:, :, :, :, : splited_chunk_len * 8].contiguous() + elif cur_rank == world_size - 1: + if split_dim == 2: + images = images[:, :, :, -splited_chunk_len * 8 :, :].contiguous() + elif split_dim == 3: + images = images[:, :, :, :, -splited_chunk_len * 8 :].contiguous() + else: + if split_dim == 2: + images = images[:, :, :, 8 * padding_size : -8 * padding_size, :].contiguous() + elif split_dim == 3: + images = images[:, :, :, :, 8 * padding_size : -8 * padding_size].contiguous() + + full_images = [torch.empty_like(images) for _ in range(world_size)] + dist.all_gather(full_images, images) + + self.device_synchronize() + + images = torch.cat(full_images, dim=split_dim + 1) + + return images + + def decode_dist_2d(self, zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w): + total_h = zs.shape[2] + total_w = zs.shape[3] + + chunk_h = total_h // world_size_h + chunk_w = total_w // world_size_w + + padding_size = 2 + + # Calculate H dimension slice + if cur_rank_h == 0: + h_start = 0 + h_end = chunk_h + 2 * padding_size + elif cur_rank_h == world_size_h - 1: + h_start = total_h - (chunk_h + 2 * padding_size) + h_end = total_h + else: + h_start = cur_rank_h * chunk_h - padding_size + h_end = (cur_rank_h + 1) * chunk_h + padding_size + + # Calculate W dimension slice + if cur_rank_w == 0: + w_start = 0 + w_end = chunk_w + 2 * padding_size + elif cur_rank_w == world_size_w - 1: + w_start = total_w - (chunk_w + 2 * padding_size) + w_end = total_w + else: + w_start = cur_rank_w * chunk_w - padding_size + w_end = (cur_rank_w + 1) * chunk_w + padding_size + + # Extract the latent chunk for this process + zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous() + + # Decode the chunk + decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode + images_chunk = decode_func(zs_chunk.unsqueeze(0), self.scale).clamp_(-1, 1) + + # Remove padding from decoded chunk + spatial_ratio = 8 + if cur_rank_h == 0: + decoded_h_start = 0 + decoded_h_end = chunk_h * spatial_ratio + elif cur_rank_h == world_size_h - 1: + decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio + decoded_h_end = images_chunk.shape[3] + else: + decoded_h_start = padding_size * spatial_ratio + decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio + + if cur_rank_w == 0: + decoded_w_start = 0 + decoded_w_end = chunk_w * spatial_ratio + elif cur_rank_w == world_size_w - 1: + decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio + decoded_w_end = images_chunk.shape[4] + else: + decoded_w_start = padding_size * spatial_ratio + decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio + + images_chunk = images_chunk[:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end].contiguous() + + # Gather all chunks + total_processes = world_size_h * world_size_w + full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)] + + dist.all_gather(full_images, images_chunk) + + self.device_synchronize() + + # Reconstruct the full image tensor + image_rows = [] + for h_idx in range(world_size_h): + image_cols = [] + for w_idx in range(world_size_w): + process_idx = h_idx * world_size_w + w_idx + image_cols.append(full_images[process_idx]) + image_rows.append(torch.cat(image_cols, dim=4)) + + images = torch.cat(image_rows, dim=3) + + return images + + def decode_dist_2d_stream(self, zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w): + total_h = zs.shape[2] + total_w = zs.shape[3] + + chunk_h = total_h // world_size_h + chunk_w = total_w // world_size_w + + padding_size = 2 + + # Calculate H dimension slice + if cur_rank_h == 0: + h_start = 0 + h_end = chunk_h + 2 * padding_size + elif cur_rank_h == world_size_h - 1: + h_start = total_h - (chunk_h + 2 * padding_size) + h_end = total_h + else: + h_start = cur_rank_h * chunk_h - padding_size + h_end = (cur_rank_h + 1) * chunk_h + padding_size + + # Calculate W dimension slice + if cur_rank_w == 0: + w_start = 0 + w_end = chunk_w + 2 * padding_size + elif cur_rank_w == world_size_w - 1: + w_start = total_w - (chunk_w + 2 * padding_size) + w_end = total_w + else: + w_start = cur_rank_w * chunk_w - padding_size + w_end = (cur_rank_w + 1) * chunk_w + padding_size + + # Extract the latent chunk for this process + zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous() + + for image in self.model.decode_stream(zs_chunk.unsqueeze(0), self.scale): + images_chunk = image.clamp_(-1, 1) + # Remove padding from decoded chunk + spatial_ratio = 8 + if cur_rank_h == 0: + decoded_h_start = 0 + decoded_h_end = chunk_h * spatial_ratio + elif cur_rank_h == world_size_h - 1: + decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio + decoded_h_end = images_chunk.shape[3] + else: + decoded_h_start = padding_size * spatial_ratio + decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio + + if cur_rank_w == 0: + decoded_w_start = 0 + decoded_w_end = chunk_w * spatial_ratio + elif cur_rank_w == world_size_w - 1: + decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio + decoded_w_end = images_chunk.shape[4] + else: + decoded_w_start = padding_size * spatial_ratio + decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio + + images_chunk = images_chunk[:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end].contiguous() + + # Gather all chunks + total_processes = world_size_h * world_size_w + full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)] + + dist.all_gather(full_images, images_chunk) + + self.device_synchronize() + + # Reconstruct the full image tensor + image_rows = [] + for h_idx in range(world_size_h): + image_cols = [] + for w_idx in range(world_size_w): + process_idx = h_idx * world_size_w + w_idx + image_cols.append(full_images[process_idx]) + image_rows.append(torch.cat(image_cols, dim=4)) + + images = torch.cat(image_rows, dim=3) + + yield images + + def decode(self, zs): + if self.cpu_offload: + self.to_cuda() + + if self.parallel: + world_size = dist.get_world_size() + cur_rank = dist.get_rank() + latent_height, latent_width = zs.shape[2], zs.shape[3] + + if self.use_2d_split: + world_size_h, world_size_w = self._calculate_2d_grid(latent_height, latent_width, world_size) + cur_rank_h = cur_rank // world_size_w + cur_rank_w = cur_rank % world_size_w + images = self.decode_dist_2d(zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w) + else: + # Original 1D splitting logic + if latent_width % world_size == 0: + images = self.decode_dist(zs, world_size, cur_rank, split_dim=3) + elif latent_height % world_size == 0: + images = self.decode_dist(zs, world_size, cur_rank, split_dim=2) + else: + logger.info("Fall back to naive decode mode") + images = self.model.decode(zs.unsqueeze(0), self.scale).clamp_(-1, 1) + else: + decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode + images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1) + + if self.cpu_offload: + images = images.cpu() + self.to_cpu() + + return images + + def decode_stream(self, zs): + if self.cpu_offload: + self.to_cuda() + + if self.parallel: + world_size = dist.get_world_size() + cur_rank = dist.get_rank() + latent_height, latent_width = zs.shape[2], zs.shape[3] + + world_size_h, world_size_w = self._calculate_2d_grid(latent_height, latent_width, world_size) + cur_rank_h = cur_rank // world_size_w + cur_rank_w = cur_rank % world_size_w + for images in self.decode_dist_2d_stream(zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w): + yield images + else: + for image in self.model.decode_stream(zs.unsqueeze(0), self.scale): + yield image.clamp_(-1, 1) + + if self.cpu_offload: + self.to_cpu() + + def encode_video(self, vid): + return self.model.encode_video(vid) + + def decode_video(self, vid_enc): + return self.model.decode_video(vid_enc) + + def device_synchronize( + self, + ): + torch_device_module.synchronize() diff --git a/lightx2v/models/video_encoders/hf/wan/vae_2_2.py b/lightx2v/models/video_encoders/hf/wan/vae_2_2.py new file mode 100644 index 0000000000000000000000000000000000000000..d345857407cf34d322605c1bfcba0f866c8e5f85 --- /dev/null +++ b/lightx2v/models/video_encoders/hf/wan/vae_2_2.py @@ -0,0 +1,1045 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from lightx2v.utils.utils import load_weights +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + +__all__ = [ + "Wan2_2_VAE", +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + # nn.Conv2d(dim, dim//2, 3, padding=1) + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat( + [torch.zeros_like(cache_x).to(cache_x.device), cache_x], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +def patchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 6: + x = x.squeeze(0) + + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + ) + ) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False + upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1, + ) + ) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, offload_cache=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + idx = feat_idx[0] + x = layer(x, feat_cache, feat_idx) + if offload_cache: + for _idx in range(idx, feat_idx[0]): + if isinstance(feat_cache[_idx], torch.Tensor): + feat_cache[_idx] = feat_cache[_idx].cpu() + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + idx = feat_idx[0] + x = layer(x, feat_cache, feat_idx, first_chunk) + if offload_cache: + for _idx in range(idx, feat_idx[0]): + if isinstance(feat_cache[_idx], torch.Tensor): + feat_cache[_idx] = feat_cache[_idx].cpu() + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x.cpu() if offload_cache else cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + def __init__( + self, + dim=160, + dec_dim=256, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dec_dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + ) + + def forward(self, x, scale=[0, 1]): + mu = self.encode(x, scale) + x_recon = self.decode(mu, scale) + return x_recon, mu + + def encode(self, x, scale, return_mu=False): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + if return_mu: + return mu, log_var + else: + return mu + + def decode(self, z, scale, offload_cache=False): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True, offload_cache=offload_cache) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, offload_cache=offload_cache) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False, scale=[0, 1]): + mu, log_var = self.encode(imgs, scale, return_mu=True) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std), mu, log_var + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def encode_video(self, x, scale=[0, 1]): + assert x.ndim == 5 # NTCHW + assert x.shape[2] % 3 == 0 + x = x.transpose(1, 2) + y = x.mul(2).sub_(1) + y, mu, log_var = self.sample(y, scale=scale) + return y.transpose(1, 2).to(x), mu, log_var + + def decode_video(self, x, scale=[0, 1]): + assert x.ndim == 5 # NTCHW + assert x.shape[2] % self.z_dim == 0 + x = x.transpose(1, 2) + # B, C, T, H, W + y = x + y = self.decode(y, scale).clamp_(-1, 1) + y = y.mul_(0.5).add_(0.5).clamp_(0, 1) # NCTHW + return y.transpose(1, 2).to(x) + + +def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, dtype=torch.float32, load_from_rank0=False, **kwargs): + # params + cfg = dict( + dim=dim, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f"loading {pretrained_path}") + weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0) + for k in weights_dict.keys(): + if weights_dict[k].dtype != dtype: + weights_dict[k] = weights_dict[k].to(dtype) + model.load_state_dict(weights_dict, assign=True) + + return model + + +class Wan2_2_VAE: + def __init__( + self, + z_dim=48, + c_dim=160, + vae_path=None, + dim_mult=[1, 2, 4, 4], + temperal_downsample=[False, True, True], + dtype=torch.float, + device="cuda", + cpu_offload=False, + offload_cache=False, + load_from_rank0=False, + **kwargs, + ): + self.dtype = dtype + self.device = device + self.cpu_offload = cpu_offload + self.offload_cache = offload_cache + + self.mean = torch.tensor( + [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, + ], + dtype=dtype, + device=AI_DEVICE, + ) + self.std = torch.tensor( + [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, + ], + dtype=dtype, + device=AI_DEVICE, + ) + self.inv_std = 1.0 / self.std + self.scale = [self.mean, self.inv_std] + # init model + self.model = ( + _video_vae( + pretrained_path=vae_path, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0 + ) + .eval() + .requires_grad_(False) + .to(device) + .to(dtype) + ) + + def to_cpu(self): + self.model.encoder = self.model.encoder.to("cpu") + self.model.decoder = self.model.decoder.to("cpu") + self.model = self.model.to("cpu") + self.mean = self.mean.cpu() + self.inv_std = self.inv_std.cpu() + self.scale = [self.mean, self.inv_std] + + def to_cuda(self): + self.model.encoder = self.model.encoder.to(AI_DEVICE) + self.model.decoder = self.model.decoder.to(AI_DEVICE) + self.model = self.model.to(AI_DEVICE) + self.mean = self.mean.to(AI_DEVICE) + self.inv_std = self.inv_std.to(AI_DEVICE) + self.scale = [self.mean, self.inv_std] + + def encode(self, video): + if self.cpu_offload: + self.to_cuda() + out = self.model.encode(video, self.scale).float().squeeze(0) + if self.cpu_offload: + self.to_cpu() + return out + + def decode(self, zs): + if self.cpu_offload: + self.to_cuda() + images = self.model.decode(zs.unsqueeze(0), self.scale, offload_cache=self.offload_cache if self.cpu_offload else False).float().clamp_(-1, 1) + if self.cpu_offload: + images = images.cpu().float() + self.to_cpu() + return images + + def encode_video(self, vid): + return self.model.encode_video(vid) + + def decode_video(self, vid_enc): + return self.model.decode_video(vid_enc) diff --git a/lightx2v/models/video_encoders/hf/wan/vae_sf.py b/lightx2v/models/video_encoders/hf/wan/vae_sf.py new file mode 100644 index 0000000000000000000000000000000000000000..f480e9f336426970791e54765bfa71286a4f12ce --- /dev/null +++ b/lightx2v/models/video_encoders/hf/wan/vae_sf.py @@ -0,0 +1,348 @@ +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from lightx2v.models.video_encoders.hf.wan.vae import WanVAE_, _video_vae +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +class WanSFVAE: + def __init__( + self, + z_dim=16, + vae_path="cache/vae_step_411000.pth", + dtype=torch.float, + device="cuda", + parallel=False, + use_tiling=False, + cpu_offload=False, + use_2d_split=True, + load_from_rank0=False, + **kwargs, + ): + self.dtype = dtype + self.device = device + self.parallel = parallel + self.use_tiling = use_tiling + self.cpu_offload = cpu_offload + self.use_2d_split = use_2d_split + + mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921] + std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160] + self.mean = torch.tensor(mean, dtype=torch.float32) + self.std = torch.tensor(std, dtype=torch.float32) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae(pretrained_path=vae_path, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype) + self.model.clear_cache() + self.upsampling_factor = 8 + + def to_cpu(self): + self.model.encoder = self.model.encoder.to("cpu") + self.model.decoder = self.model.decoder.to("cpu") + self.model = self.model.to("cpu") + self.mean = self.mean.cpu() + self.inv_std = self.inv_std.cpu() + self.scale = [self.mean, self.inv_std] + + def to_cuda(self): + self.model.encoder = self.model.encoder.to(AI_DEVICE) + self.model.decoder = self.model.decoder.to(AI_DEVICE) + self.model = self.model.to(AI_DEVICE) + self.mean = self.mean.to(AI_DEVICE) + self.inv_std = self.inv_std.to(AI_DEVICE) + self.scale = [self.mean, self.inv_std] + + def decode(self, latent: torch.Tensor, use_cache: bool = False) -> torch.Tensor: + # from [batch_size, num_frames, num_channels, height, width] + # to [batch_size, num_channels, num_frames, height, width] + latent = latent.transpose(0, 1).unsqueeze(0) + zs = latent.permute(0, 2, 1, 3, 4) + if use_cache: + assert latent.shape[0] == 1, "Batch size must be 1 when using cache" + + device, dtype = latent.device, latent.dtype + scale = [self.mean.to(device=device, dtype=dtype), 1.0 / self.std.to(device=device, dtype=dtype)] + + if use_cache: + decode_function = self.model.cached_decode + else: + decode_function = self.model.decode + + output = [] + for u in zs: + output.append(decode_function(u.unsqueeze(0), scale).float().clamp_(-1, 1).squeeze(0)) + output = torch.stack(output, dim=0) + # from [batch_size, num_channels, num_frames, height, width] + # to [batch_size, num_frames, num_channels, height, width] + output = output.permute(0, 2, 1, 3, 4).squeeze(0) + return output + + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if h - stride_h >= 0 and h - stride_h + size_h >= H: + continue + for w in range(0, W, stride_w): + if w - stride_w >= 0 and w - stride_w + size_w >= W: + continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + for h, h_, w, w_ in tasks: # tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h : target_h + hidden_states_batch.shape[3], + target_w : target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h : target_h + hidden_states_batch.shape[3], + target_w : target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + return values + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x + + def encode(self, videos, device, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) + if tiled: + tile_size = (tile_size[0] * 8, tile_size[1] * 8) + tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + hidden_states = torch.stack(hidden_states) + return hidden_states + + +class WanMtxg2VAE(nn.Module): + def __init__(self, pretrained_path=None, z_dim=16): + super().__init__() + mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921] + std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + # init model + self.model = ( + WanVAE_( + dim=96, + z_dim=z_dim, + num_res_blocks=2, + dim_mult=[1, 2, 4, 4], + temperal_downsample=[False, True, True], + dropout=0.0, + pruning_rate=0.0, + ) + .eval() + .requires_grad_(False) + ) + if pretrained_path is not None: + self.model.load_state_dict(torch.load(pretrained_path, map_location="cpu"), assign=True) + self.upsampling_factor = 8 + + def to(self, *args, **kwargs): + self.mean = self.mean.to(*args, **kwargs) + self.std = self.std.to(*args, **kwargs) + self.scale = [self.mean, 1.0 / self.std] + self.model = self.model.to(*args, **kwargs) + return self + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + def tiled_decode(self, hidden_states, device, tile_size, tile_stride): + _, _, T, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if h - stride_h >= 0 and h - stride_h + size_h >= H: + continue + for w in range(0, W, stride_w): + if w - stride_w >= 0 and w - stride_w + size_w >= W: + continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" # TODO + computation_device = device + + out_T = T * 4 - 3 + weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + + for h, h_, w, w_ in tasks: # tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_h = h * self.upsampling_factor + target_w = w * self.upsampling_factor + values[ + :, + :, + :, + target_h : target_h + hidden_states_batch.shape[3], + target_w : target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h : target_h + hidden_states_batch.shape[3], + target_w : target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.clamp_(-1, 1) + return values + + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if h - stride_h >= 0 and h - stride_h + size_h >= H: + continue + for w in range(0, W, stride_w): + if w - stride_w >= 0 and w - stride_w + size_w >= W: + continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + + for h, h_, w, w_ in tasks: # tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h : target_h + hidden_states_batch.shape[3], + target_w : target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h : target_h + hidden_states_batch.shape[3], + target_w : target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + return values + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + video = self.model.decode(hidden_state, self.scale) + return video.clamp_(-1, 1) + + def encode(self, videos, device, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + # videos: torch.Size([1, 3, 597, 352, 640]), device='cuda:0', dtype=torch.bfloat16 + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) # torch.Size([1, 3, 597, 352, 640]) torch.bfloat16 device(type='cpu') + if tiled: # True + tile_size = (tile_size[0] * 8, tile_size[1] * 8) + tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + hidden_states = torch.stack(hidden_states) + return hidden_states + + def decode(self, hidden_states, device, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + videos = torch.stack(videos) + return videos diff --git a/lightx2v/models/video_encoders/hf/wan/vae_tiny.py b/lightx2v/models/video_encoders/hf/wan/vae_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..d67dd50beed29f73b3f96571d7b42e2f97f1247e --- /dev/null +++ b/lightx2v/models/video_encoders/hf/wan/vae_tiny.py @@ -0,0 +1,216 @@ +import torch +import torch.nn as nn + +from lightx2v.models.video_encoders.hf.tae import TAEHV +from lightx2v.utils.memory_profiler import peak_memory_decorator + + +class DotDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + + +class WanVAE_tiny(nn.Module): + def __init__(self, vae_path="taew2_1.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False): + super().__init__() + self.dtype = dtype + self.device = torch.device("cuda") + self.taehv = TAEHV(vae_path).to(self.dtype) + self.temperal_downsample = [True, True, False] + self.need_scaled = need_scaled + + if self.need_scaled: + self.latents_mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + + self.latents_std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + + self.z_dim = 16 + + @peak_memory_decorator + @torch.no_grad() + def decode(self, latents): + latents = latents.unsqueeze(0) + + if self.need_scaled: + latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + # low-memory, set parallel=True for faster + higher memory + return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel=False).transpose(1, 2).mul_(2).sub_(1) + + @torch.no_grad() + def encode_video(self, vid): + return self.taehv.encode_video(vid) + + @torch.no_grad() + def decode_video(self, vid_enc): + return self.taehv.decode_video(vid_enc) + + +class Wan2_2_VAE_tiny(nn.Module): + def __init__(self, vae_path="taew2_2.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False): + super().__init__() + self.dtype = dtype + self.device = torch.device("cuda") + self.taehv = TAEHV(vae_path, model_type="wan22").to(self.dtype) + self.need_scaled = need_scaled + if self.need_scaled: + self.latents_mean = [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, + ] + + self.latents_std = [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, + ] + + self.z_dim = 48 + + @peak_memory_decorator + @torch.no_grad() + def decode(self, latents): + latents = latents.unsqueeze(0) + + if self.need_scaled: + latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + # low-memory, set parallel=True for faster + higher memory + return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel=False).transpose(1, 2).mul_(2).sub_(1) + + @torch.no_grad() + def encode_video(self, vid): + return self.taehv.encode_video(vid) + + @torch.no_grad() + def decode_video(self, vid_enc): + return self.taehv.decode_video(vid_enc) diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f90371b6618f60caf0503656894557de5dbb99 --- /dev/null +++ b/lightx2v/pipeline.py @@ -0,0 +1,360 @@ +# please do not set envs in this file, it will be imported by the __init__.py file +# os.environ["TOKENIZERS_PARALLELISM"] = "false" +# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +# os.environ["DTYPE"] = "BF16" +# os.environ["SENSITIVE_LAYER_DTYPE"] = "None" +# os.environ["PROFILING_DEBUG_LEVEL"] = "2" + +import json + +import torch +import torch.distributed as dist +from loguru import logger + +from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 +from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401 +# from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner # noqa: F401 +from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401 +from lightx2v.utils.input_info import set_input_info +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.set_config import print_config, set_config, set_parallel_config +from lightx2v.utils.utils import seed_all + + +def dict_like(cls): + cls.__getitem__ = lambda self, key: getattr(self, key) + cls.__setitem__ = lambda self, key, value: setattr(self, key, value) + cls.__delitem__ = lambda self, key: delattr(self, key) + cls.__contains__ = lambda self, key: hasattr(self, key) + + def update(self, *args, **kwargs): + for arg in args: + if isinstance(arg, dict): + items = arg.items() + else: + items = arg + for k, v in items: + setattr(self, k, v) + for k, v in kwargs.items(): + setattr(self, k, v) + + def get(self, key, default=None): + return getattr(self, key, default) + + cls.get = get + cls.update = update + + return cls + + +@dict_like +class LightX2VPipeline: + def __init__( + self, + task, + model_path, + model_cls, + sf_model_path=None, + dit_original_ckpt=None, + low_noise_original_ckpt=None, + high_noise_original_ckpt=None, + transformer_model_name=None, + ): + self.task = task + self.model_path = model_path + self.model_cls = model_cls + self.sf_model_path = sf_model_path + self.dit_original_ckpt = dit_original_ckpt + self.low_noise_original_ckpt = low_noise_original_ckpt + self.high_noise_original_ckpt = high_noise_original_ckpt + self.transformer_model_name = transformer_model_name + + if self.model_cls in [ + "wan2.1", + "wan2.1_distill", + "wan2.1_vace", + "wan2.1_sf", + "wan2.1_sf_mtxg2", + "seko_talk", + "wan2.2_moe", + "wan2.2_moe_audio", + "wan2.2_audio", + "wan2.2_moe_distill", + "wan2.2_animate", + ]: + self.vae_stride = (4, 8, 8) + if self.model_cls.startswith("wan2.2_moe"): + self.use_image_encoder = False + elif self.model_cls in ["wan2.2"]: + self.vae_stride = (4, 16, 16) + self.num_channels_latents = 48 + elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: + self.vae_stride = (4, 16, 16) + self.num_channels_latents = 32 + + def create_generator( + self, + attn_mode="flash_attn2", + infer_steps=50, + num_frames=81, + height=480, + width=832, + guidance_scale=5.0, + sample_shift=5.0, + fps=16, + aspect_ratio="16:9", + boundary=0.900, + boundary_step_index=2, + denoising_step_list=[1000, 750, 500, 250], + config_json=None, + rope_type="torch", + ): + if config_json is not None: + self.set_infer_config_json(config_json) + else: + self.set_infer_config( + attn_mode, + rope_type, + infer_steps, + num_frames, + height, + width, + guidance_scale, + sample_shift, + fps, + aspect_ratio, + boundary, + boundary_step_index, + denoising_step_list, + ) + + config = set_config(self) + print_config(config) + self.runner = self._init_runner(config) + logger.info(f"Initializing {self.model_cls} runner for {self.task} task...") + logger.info(f"Model path: {self.model_path}") + logger.info("LightGenerator initialized successfully!") + + def set_infer_config( + self, + attn_mode, + rope_type, + infer_steps, + num_frames, + height, + width, + guidance_scale, + sample_shift, + fps, + aspect_ratio, + boundary, + boundary_step_index, + denoising_step_list, + ): + self.infer_steps = infer_steps + self.target_width = width + self.target_height = height + self.target_video_length = num_frames + self.sample_guide_scale = guidance_scale + self.sample_shift = sample_shift + if self.sample_guide_scale == 1: + self.enable_cfg = False + else: + self.enable_cfg = True + self.rope_type = rope_type + self.fps = fps + self.aspect_ratio = aspect_ratio + self.boundary = boundary + self.boundary_step_index = boundary_step_index + self.denoising_step_list = denoising_step_list + if self.model_cls.startswith("wan"): + self.self_attn_1_type = attn_mode + self.cross_attn_1_type = attn_mode + self.cross_attn_2_type = attn_mode + elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: + self.attn_type = attn_mode + + def set_infer_config_json(self, config_json): + logger.info(f"Loading infer config from {config_json}") + with open(config_json, "r") as f: + config_json = json.load(f) + self.update(config_json) + + def enable_lightvae( + self, + use_lightvae=False, + use_tae=False, + vae_path=None, + tae_path=None, + ): + self.use_lightvae = use_lightvae + self.use_tae = use_tae + self.vae_path = vae_path + self.tae_path = tae_path + if self.use_tae and self.model_cls.startswith("wan") and "lighttae" in tae_path: + self.need_scaled = True + + def enable_quantize( + self, + dit_quantized=False, + text_encoder_quantized=False, + image_encoder_quantized=False, + dit_quantized_ckpt=None, + low_noise_quantized_ckpt=None, + high_noise_quantized_ckpt=None, + text_encoder_quantized_ckpt=False, + image_encoder_quantized_ckpt=False, + quant_scheme="fp8-sgl", + ): + self.dit_quantized = dit_quantized + self.dit_quant_scheme = quant_scheme + self.dit_quantized_ckpt = dit_quantized_ckpt + self.low_noise_quantized_ckpt = low_noise_quantized_ckpt + self.high_noise_quantized_ckpt = high_noise_quantized_ckpt + + if self.model_cls.startswith("wan"): + self.t5_quant_scheme = quant_scheme + self.t5_quantized = text_encoder_quantized + self.t5_quantized_ckpt = text_encoder_quantized_ckpt + self.clip_quant_scheme = quant_scheme + self.clip_quantized = image_encoder_quantized + self.clip_quantized_ckpt = image_encoder_quantized_ckpt + elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: + self.qwen25vl_quantized = text_encoder_quantized + self.qwen25vl_quantized_ckpt = text_encoder_quantized_ckpt + self.qwen25vl_quant_scheme = quant_scheme + + def enable_offload( + self, + cpu_offload=False, + offload_granularity="block", + text_encoder_offload=False, + image_encoder_offload=False, + vae_offload=False, + ): + self.cpu_offload = cpu_offload + self.offload_granularity = offload_granularity + self.vae_offload = vae_offload + if self.model_cls in [ + "wan2.1", + "wan2.1_distill", + "wan2.1_vace", + "wan2.1_sf", + "wan2.1_sf_mtxg2", + "seko_talk", + "wan2.2_moe", + "wan2.2", + "wan2.2_moe_audio", + "wan2.2_audio", + "wan2.2_moe_distill", + "wan2.2_animate", + ]: + self.t5_cpu_offload = text_encoder_offload + self.clip_encoder_offload = image_encoder_offload + + elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: + self.qwen25vl_cpu_offload = text_encoder_offload + self.siglip_cpu_offload = image_encoder_offload + self.byt5_cpu_offload = image_encoder_offload + + def enable_compile( + self, + ): + self.compile = True + self.compile_shapes = [ + [480, 832], + [544, 960], + [720, 1280], + [832, 480], + [960, 544], + [1280, 720], + [480, 480], + [576, 576], + [704, 704], + [960, 960], + ] + + def enable_lora(self, lora_configs): + self.lora_configs = lora_configs + + def enable_cache( + self, + cache_method="Tea", + coefficients=[], + teacache_thresh=0.15, + use_ret_steps=False, + magcache_calibration=False, + magcache_K=6, + magcache_thresh=0.24, + magcache_retention_ratio=0.2, + magcache_ratios=[], + ): + self.feature_caching = cache_method + if cache_method == "Tea": + self.coefficients = coefficients + self.teacache_thresh = teacache_thresh + self.use_ret_steps = use_ret_steps + elif cache_method == "Mag": + self.magcache_calibration = magcache_calibration + self.magcache_K = magcache_K + self.magcache_thresh = magcache_thresh + self.magcache_retention_ratio = magcache_retention_ratio + self.magcache_ratios = magcache_ratios + + def enable_parallel(self, cfg_p_size=1, seq_p_size=1, seq_p_attn_type="ulysses"): + self._init_parallel() + self.parallel = { + "cfg_p_size": cfg_p_size, + "seq_p_size": seq_p_size, + "seq_p_attn_type": seq_p_attn_type, + } + set_parallel_config(self) + + @torch.no_grad() + def generate( + self, + seed, + prompt, + negative_prompt, + save_result_path, + image_path=None, + last_frame_path=None, + audio_path=None, + src_ref_images=None, + src_video=None, + src_mask=None, + return_result_tensor=False, + ): + # Run inference (following LightX2V pattern) + self.seed = seed + + self.image_path = image_path + self.last_frame_path = last_frame_path + self.audio_path = audio_path + self.src_ref_images = src_ref_images + self.src_video = src_video + self.src_mask = src_mask + self.prompt = prompt + self.negative_prompt = negative_prompt + self.save_result_path = save_result_path + self.return_result_tensor = return_result_tensor + seed_all(self.seed) + input_info = set_input_info(self) + self.runner.run_pipeline(input_info) + logger.info("Video generated successfully!") + logger.info(f"Video Saved in {save_result_path}") + + def _init_runner(self, config): + torch.set_grad_enabled(False) + runner = RUNNER_REGISTER[config["model_cls"]](config) + runner.init_modules() + return runner + + def _init_parallel(self): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank()) diff --git a/lightx2v/server/README.md b/lightx2v/server/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7fa9aceaf3ac2775de7223d2b516ce5abafeaf67 --- /dev/null +++ b/lightx2v/server/README.md @@ -0,0 +1,438 @@ +# LightX2V Server + +## Overview + +The LightX2V server is a distributed video/image generation service built with FastAPI that processes image-to-video and text-to-image tasks using a multi-process architecture with GPU support. It implements a sophisticated task queue system with distributed inference capabilities for high-throughput generation workloads. + +## Directory Structure + +``` +server/ +├── __init__.py +├── __main__.py # Entry point +├── main.py # Server startup +├── config.py # Configuration +├── task_manager.py # Task management +├── schema.py # Data models (VideoTaskRequest, ImageTaskRequest) +├── api/ +│ ├── __init__.py +│ ├── router.py # Main router aggregation +│ ├── deps.py # Dependency injection container +│ ├── server.py # ApiServer class +│ ├── files.py # /v1/files/* +│ ├── service_routes.py # /v1/service/* +│ └── tasks/ +│ ├── __init__.py +│ ├── common.py # Common task operations +│ ├── video.py # POST /v1/tasks/video +│ └── image.py # POST /v1/tasks/image +├── services/ +│ ├── __init__.py +│ ├── file_service.py # File service (unified download) +│ ├── distributed_utils.py # Distributed manager +│ ├── inference/ +│ │ ├── __init__.py +│ │ ├── worker.py # TorchrunInferenceWorker +│ │ └── service.py # DistributedInferenceService +│ └── generation/ +│ ├── __init__.py +│ ├── base.py # Base generation service +│ ├── video.py # VideoGenerationService +│ └── image.py # ImageGenerationService +├── media/ +│ ├── __init__.py +│ ├── base.py # MediaHandler base class +│ ├── image.py # ImageHandler +│ └── audio.py # AudioHandler +└── metrics/ # Prometheus metrics +``` + +## Architecture + +### System Architecture + +```mermaid +flowchart TB + Client[Client] -->|Send API Request| Router[FastAPI Router] + + subgraph API Layer + Router --> TaskRoutes[Task APIs] + Router --> FileRoutes[File APIs] + Router --> ServiceRoutes[Service Status APIs] + + TaskRoutes --> CreateVideoTask["POST /v1/tasks/video - Create Video Task"] + TaskRoutes --> CreateImageTask["POST /v1/tasks/image - Create Image Task"] + TaskRoutes --> CreateVideoTaskForm["POST /v1/tasks/video/form - Form Create Video"] + TaskRoutes --> CreateImageTaskForm["POST /v1/tasks/image/form - Form Create Image"] + TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"] + TaskRoutes --> GetTaskStatus["GET /v1/tasks/{id}/status - Get Status"] + TaskRoutes --> GetTaskResult["GET /v1/tasks/{id}/result - Get Result"] + TaskRoutes --> StopTask["DELETE /v1/tasks/{id} - Stop Task"] + + FileRoutes --> DownloadFile["GET /v1/files/download/{path} - Download File"] + + ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"] + ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"] + end + + subgraph Task Management + TaskManager[Task Manager] + TaskQueue[Task Queue] + TaskStatus[Task Status] + TaskResult[Task Result] + + CreateVideoTask --> TaskManager + CreateImageTask --> TaskManager + TaskManager --> TaskQueue + TaskManager --> TaskStatus + TaskManager --> TaskResult + end + + subgraph File Service + FileService[File Service] + DownloadMedia[Download Media] + SaveFile[Save File] + GetOutputPath[Get Output Path] + + FileService --> DownloadMedia + FileService --> SaveFile + FileService --> GetOutputPath + end + + subgraph Media Handlers + MediaHandler[MediaHandler Base] + ImageHandler[ImageHandler] + AudioHandler[AudioHandler] + + MediaHandler --> ImageHandler + MediaHandler --> AudioHandler + end + + subgraph Processing Thread + ProcessingThread[Processing Thread] + NextTask[Get Next Task] + ProcessTask[Process Single Task] + + ProcessingThread --> NextTask + ProcessingThread --> ProcessTask + end + + subgraph Generation Services + VideoService[VideoGenerationService] + ImageService[ImageGenerationService] + BaseService[BaseGenerationService] + + BaseService --> VideoService + BaseService --> ImageService + end + + subgraph Distributed Inference Service + InferenceService[DistributedInferenceService] + SubmitTask[Submit Task] + Worker[TorchrunInferenceWorker] + ProcessRequest[Process Request] + RunPipeline[Run Inference Pipeline] + + InferenceService --> SubmitTask + SubmitTask --> Worker + Worker --> ProcessRequest + ProcessRequest --> RunPipeline + end + + TaskQueue --> ProcessingThread + ProcessTask --> VideoService + ProcessTask --> ImageService + VideoService --> InferenceService + ImageService --> InferenceService + GetTaskResult --> FileService + DownloadFile --> FileService + VideoService --> FileService + ImageService --> FileService + FileService --> MediaHandler +``` + +## Task Processing Flow + +```mermaid +sequenceDiagram + participant C as Client + participant API as API Server + participant TM as TaskManager + participant PT as Processing Thread + participant GS as GenerationService
(Video/Image) + participant FS as FileService + participant DIS as DistributedInferenceService + participant TIW0 as TorchrunInferenceWorker
(Rank 0) + participant TIW1 as TorchrunInferenceWorker
(Rank 1..N) + + C->>API: POST /v1/tasks/video
or /v1/tasks/image + API->>TM: create_task() + TM->>TM: Generate task_id + TM->>TM: Add to queue
(status: PENDING) + API->>PT: ensure_processing_thread() + API-->>C: TaskResponse
(task_id, status: pending) + + Note over PT: Processing Loop + PT->>TM: get_next_pending_task() + TM-->>PT: task_id + + PT->>TM: acquire_processing_lock() + PT->>TM: start_task()
(status: PROCESSING) + + PT->>PT: Select service by task type + PT->>GS: generate_with_stop_event() + + alt Image is URL + GS->>FS: download_media(url, "image") + FS->>FS: HTTP download
with retry + FS-->>GS: image_path + else Image is Base64 + GS->>GS: save_base64_image() + GS-->>GS: image_path + else Image is local path + GS->>GS: use existing path + end + + alt Audio is URL (Video only) + GS->>FS: download_media(url, "audio") + FS->>FS: HTTP download
with retry + FS-->>GS: audio_path + else Audio is Base64 + GS->>GS: save_base64_audio() + GS-->>GS: audio_path + end + + GS->>DIS: submit_task_async(task_data) + DIS->>TIW0: process_request(task_data) + + Note over TIW0,TIW1: Torchrun-based Distributed Processing + TIW0->>TIW0: Check if processing + TIW0->>TIW0: Set processing = True + + alt Multi-GPU Mode (world_size > 1) + TIW0->>TIW1: broadcast_task_data()
(via DistributedManager) + Note over TIW1: worker_loop() listens for broadcasts + TIW1->>TIW1: Receive task_data + end + + par Parallel Inference across all ranks + TIW0->>TIW0: runner.set_inputs(task_data) + TIW0->>TIW0: runner.run_pipeline() + and + Note over TIW1: If world_size > 1 + TIW1->>TIW1: runner.set_inputs(task_data) + TIW1->>TIW1: runner.run_pipeline() + end + + Note over TIW0,TIW1: Synchronization + alt Multi-GPU Mode + TIW0->>TIW1: barrier() for sync + TIW1->>TIW0: barrier() response + end + + TIW0->>TIW0: Set processing = False + TIW0->>DIS: Return result (only rank 0) + TIW1->>TIW1: Return None (non-rank 0) + + DIS-->>GS: TaskResponse + GS-->>PT: TaskResponse + + PT->>TM: complete_task()
(status: COMPLETED) + PT->>TM: release_processing_lock() + + Note over C: Client Polling + C->>API: GET /v1/tasks/{task_id}/status + API->>TM: get_task_status() + TM-->>API: status info + API-->>C: Task Status + + C->>API: GET /v1/tasks/{task_id}/result + API->>TM: get_task_status() + API->>FS: stream_file_response() + FS-->>API: Video/Image Stream + API-->>C: Output File +``` + +## Task States + +```mermaid +stateDiagram-v2 + [*] --> PENDING: create_task() + PENDING --> PROCESSING: start_task() + PROCESSING --> COMPLETED: complete_task() + PROCESSING --> FAILED: fail_task() + PENDING --> CANCELLED: cancel_task() + PROCESSING --> CANCELLED: cancel_task() + COMPLETED --> [*] + FAILED --> [*] + CANCELLED --> [*] +``` + +## API Endpoints + +### Task APIs + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/v1/tasks/video` | POST | Create video generation task | +| `/v1/tasks/video/form` | POST | Create video task with form data | +| `/v1/tasks/image` | POST | Create image generation task | +| `/v1/tasks/image/form` | POST | Create image task with form data | +| `/v1/tasks` | GET | List all tasks | +| `/v1/tasks/queue/status` | GET | Get queue status | +| `/v1/tasks/{task_id}/status` | GET | Get task status | +| `/v1/tasks/{task_id}/result` | GET | Get task result (stream) | +| `/v1/tasks/{task_id}` | DELETE | Cancel task | +| `/v1/tasks/all/running` | DELETE | Cancel all running tasks | + +### File APIs + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/v1/files/download/{path}` | GET | Download output file | + +### Service APIs + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/v1/service/status` | GET | Get service status | +| `/v1/service/metadata` | GET | Get service metadata | + +## Request Models + +### VideoTaskRequest + +```python +class VideoTaskRequest(BaseTaskRequest): + num_fragments: int = 1 + target_video_length: int = 81 + audio_path: str = "" + video_duration: int = 5 + talk_objects: Optional[list[TalkObject]] = None +``` + +### ImageTaskRequest + +```python +class ImageTaskRequest(BaseTaskRequest): + aspect_ratio: str = "16:9" +``` + +### BaseTaskRequest (Common Fields) + +```python +class BaseTaskRequest(BaseModel): + task_id: str # auto-generated + prompt: str = "" + use_prompt_enhancer: bool = False + negative_prompt: str = "" + image_path: str = "" # URL, base64, or local path + save_result_path: str = "" + infer_steps: int = 5 + seed: int # auto-generated +``` + +## Configuration + +### Environment Variables + +see `lightx2v/server/config.py` + +### Command Line Arguments + +```bash +# Single GPU +python -m lightx2v.server \ + --model_path /path/to/model \ + --model_cls wan2.1_distill \ + --task i2v \ + --host 0.0.0.0 \ + --port 8000 \ + --config_json /path/to/xxx_config.json +``` + +```bash +# Multi-GPU with torchrun +torchrun --nproc_per_node=2 -m lightx2v.server \ + --model_path /path/to/model \ + --model_cls wan2.1_distill \ + --task i2v \ + --host 0.0.0.0 \ + --port 8000 \ + --config_json /path/to/xxx_dist_config.json +``` + +## Key Features + +### 1. Distributed Processing + +- **Multi-process architecture** for GPU parallelization +- **Master-worker pattern** with rank 0 as coordinator +- **PyTorch distributed** backend (NCCL for GPU, Gloo for CPU) +- **Automatic GPU allocation** across processes +- **Task broadcasting** with chunked pickle serialization + +### 2. Task Queue Management + +- **Thread-safe** task queue with locks +- **Sequential processing** with single processing thread +- **Configurable queue limits** with overflow protection +- **Task prioritization** (FIFO) +- **Automatic cleanup** of old completed tasks +- **Cancellation support** for pending and running tasks + +### 3. File Management + +- **Multiple input formats**: URL, base64, file upload +- **HTTP downloads** with exponential backoff retry +- **Streaming responses** for large video files +- **Cache management** with automatic cleanup +- **File validation** and format detection +- **Unified media handling** via MediaHandler pattern + +### 4. Separate Video/Image Endpoints + +- **Dedicated endpoints** for video and image generation +- **Type-specific request models** (VideoTaskRequest, ImageTaskRequest) +- **Automatic service routing** based on task type +- **Backward compatible** with legacy `/v1/tasks` endpoint + +## Performance Considerations + +1. **Single Task Processing**: Tasks are processed sequentially to manage GPU memory effectively +2. **Multi-GPU Support**: Distributes inference across available GPUs for parallelization +3. **Connection Pooling**: Reuses HTTP connections to reduce overhead +4. **Streaming Responses**: Large files are streamed to avoid memory issues +5. **Queue Management**: Automatic task cleanup prevents memory leaks +6. **Process Isolation**: Distributed workers run in separate processes for stability + +## Monitoring and Debugging + +### Logging + +The server uses `loguru` for structured logging. Logs include: + +- Request/response details +- Task lifecycle events +- Worker process status +- Error traces with context + +### Health Checks + +- `/v1/service/status` - Overall service health +- `/v1/tasks/queue/status` - Queue status and processing state +- Process monitoring via system tools (htop, nvidia-smi) + +### Common Issues + +1. **GPU Out of Memory**: Reduce `nproc_per_node` or adjust model batch size +2. **Task Timeout**: Increase `LIGHTX2V_TASK_TIMEOUT` for longer videos +3. **Queue Full**: Increase `LIGHTX2V_MAX_QUEUE_SIZE` or add rate limiting + +## Security Considerations + +1. **Input Validation**: All inputs validated with Pydantic schemas + +## License + +See the main project LICENSE file for licensing information. diff --git a/lightx2v/server/__init__.py b/lightx2v/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/server/__main__.py b/lightx2v/server/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd9f620b68226e042b115f1075ce0bffdeeb902d --- /dev/null +++ b/lightx2v/server/__main__.py @@ -0,0 +1,28 @@ +import argparse + +from .main import run_server + + +def main(): + parser = argparse.ArgumentParser(description="LightX2V Server") + + parser.add_argument("--model_path", type=str, required=True, help="Path to model") + parser.add_argument("--model_cls", type=str, required=True, help="Model class name") + + parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host") + parser.add_argument("--port", type=int, default=8000, help="Server port") + + args, unknown = parser.parse_known_args() + + for i in range(0, len(unknown), 2): + if unknown[i].startswith("--"): + key = unknown[i][2:] + if i + 1 < len(unknown) and not unknown[i + 1].startswith("--"): + value = unknown[i + 1] + setattr(args, key, value) + + run_server(args) + + +if __name__ == "__main__": + main() diff --git a/lightx2v/server/api/__init__.py b/lightx2v/server/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27afd2929740066ceba6d365e04bc177553447f9 --- /dev/null +++ b/lightx2v/server/api/__init__.py @@ -0,0 +1,7 @@ +from .router import create_api_router +from .server import ApiServer + +__all__ = [ + "create_api_router", + "ApiServer", +] diff --git a/lightx2v/server/api/deps.py b/lightx2v/server/api/deps.py new file mode 100644 index 0000000000000000000000000000000000000000..34068ef62f9ada00e7a211fd2e67be697e0797ac --- /dev/null +++ b/lightx2v/server/api/deps.py @@ -0,0 +1,54 @@ +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +import httpx +from loguru import logger + +from ..services import DistributedInferenceService, FileService, ImageGenerationService, VideoGenerationService + + +class ServiceContainer: + _instance: Optional["ServiceContainer"] = None + + def __init__(self): + self.file_service: Optional[FileService] = None + self.inference_service: Optional[DistributedInferenceService] = None + self.video_service: Optional[VideoGenerationService] = None + self.image_service: Optional[ImageGenerationService] = None + self.max_queue_size: int = 10 + + @classmethod + def get_instance(cls) -> "ServiceContainer": + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def initialize(self, cache_dir: Path, inference_service: DistributedInferenceService, max_queue_size: int = 10): + self.file_service = FileService(cache_dir) + self.inference_service = inference_service + self.video_service = VideoGenerationService(self.file_service, inference_service) + self.image_service = ImageGenerationService(self.file_service, inference_service) + self.max_queue_size = max_queue_size + + +def get_services() -> ServiceContainer: + return ServiceContainer.get_instance() + + +async def validate_url_async(url: str) -> bool: + if not url or not url.startswith("http"): + return True + + try: + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + return False + + timeout = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=5.0) + async with httpx.AsyncClient(verify=False, timeout=timeout) as client: + response = await client.head(url, follow_redirects=True) + return response.status_code < 400 + except Exception as e: + logger.warning(f"URL validation failed for {url}: {str(e)}") + return False diff --git a/lightx2v/server/api/files.py b/lightx2v/server/api/files.py new file mode 100644 index 0000000000000000000000000000000000000000..2e378802b892415576d56d9919b0093379918eb9 --- /dev/null +++ b/lightx2v/server/api/files.py @@ -0,0 +1,69 @@ +from pathlib import Path + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from loguru import logger + +from .deps import get_services + +router = APIRouter() + + +def _stream_file_response(file_path: Path, filename: str | None = None) -> StreamingResponse: + services = get_services() + assert services.file_service is not None, "File service is not initialized" + + try: + resolved_path = file_path.resolve() + + if not str(resolved_path).startswith(str(services.file_service.output_video_dir.resolve())): + raise HTTPException(status_code=403, detail="Access to this file is not allowed") + + if not resolved_path.exists() or not resolved_path.is_file(): + raise HTTPException(status_code=404, detail=f"File not found: {file_path}") + + file_size = resolved_path.stat().st_size + actual_filename = filename or resolved_path.name + + mime_type = "application/octet-stream" + if actual_filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")): + mime_type = "video/mp4" + elif actual_filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")): + mime_type = "image/jpeg" + + headers = { + "Content-Disposition": f'attachment; filename="{actual_filename}"', + "Content-Length": str(file_size), + "Accept-Ranges": "bytes", + } + + def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024): + with open(file_path, "rb") as file: + while chunk := file.read(chunk_size): + yield chunk + + return StreamingResponse( + file_stream_generator(str(resolved_path)), + media_type=mime_type, + headers=headers, + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error occurred while processing file stream response: {e}") + raise HTTPException(status_code=500, detail="File transfer failed") + + +@router.get("/download/{file_path:path}") +async def download_file(file_path: str): + services = get_services() + assert services.file_service is not None, "File service is not initialized" + + try: + full_path = services.file_service.output_video_dir / file_path + return _stream_file_response(full_path) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error occurred while processing file download request: {e}") + raise HTTPException(status_code=500, detail="File download failed") diff --git a/lightx2v/server/api/router.py b/lightx2v/server/api/router.py new file mode 100644 index 0000000000000000000000000000000000000000..01e0b82367120911fca0a24f8da5b847ccb36629 --- /dev/null +++ b/lightx2v/server/api/router.py @@ -0,0 +1,25 @@ +from fastapi import APIRouter + +from .files import router as files_router +from .service_routes import router as service_router +from .tasks import common_router, image_router, video_router + + +def create_api_router() -> APIRouter: + api_router = APIRouter() + + tasks_router = APIRouter(prefix="/v1/tasks", tags=["tasks"]) + tasks_router.include_router(common_router) + tasks_router.include_router(video_router, prefix="/video", tags=["video"]) + tasks_router.include_router(image_router, prefix="/image", tags=["image"]) + + # backward compatibility : POST /v1/tasks default to video task + from .tasks.video import create_video_task + + tasks_router.post("/", response_model_exclude_unset=True, deprecated=True)(create_video_task) + + api_router.include_router(tasks_router) + api_router.include_router(files_router, prefix="/v1/files", tags=["files"]) + api_router.include_router(service_router, prefix="/v1/service", tags=["service"]) + + return api_router diff --git a/lightx2v/server/api/server.py b/lightx2v/server/api/server.py new file mode 100644 index 0000000000000000000000000000000000000000..55bb1d116ac85be374d4411a5c7d420437142443 --- /dev/null +++ b/lightx2v/server/api/server.py @@ -0,0 +1,125 @@ +import asyncio +import threading +import time +from pathlib import Path +from typing import Any, Optional + +from fastapi import FastAPI +from loguru import logger +from starlette.responses import RedirectResponse + +from ..services import DistributedInferenceService +from ..task_manager import TaskStatus, task_manager +from .deps import ServiceContainer, get_services +from .router import create_api_router + + +class ApiServer: + def __init__(self, max_queue_size: int = 10, app: Optional[FastAPI] = None): + self.app = app or FastAPI(title="LightX2V API", version="1.0.0") + self.max_queue_size = max_queue_size + + self.processing_thread = None + self.stop_processing = threading.Event() + + self._setup_routes() + + def _setup_routes(self): + @self.app.get("/") + def redirect_to_docs(): + return RedirectResponse(url="/docs") + + api_router = create_api_router() + self.app.include_router(api_router) + + def _ensure_processing_thread_running(self): + if self.processing_thread is None or not self.processing_thread.is_alive(): + self.stop_processing.clear() + self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True) + self.processing_thread.start() + logger.info("Started task processing thread") + + def _task_processing_loop(self): + logger.info("Task processing loop started") + + asyncio.set_event_loop(asyncio.new_event_loop()) + loop = asyncio.get_event_loop() + + while not self.stop_processing.is_set(): + task_id = task_manager.get_next_pending_task() + + if task_id is None: + time.sleep(1) + continue + + task_info = task_manager.get_task(task_id) + if task_info and task_info.status == TaskStatus.PENDING: + logger.info(f"Processing task {task_id}") + loop.run_until_complete(self._process_single_task(task_info)) + + loop.close() + logger.info("Task processing loop stopped") + + async def _process_single_task(self, task_info: Any): + services = get_services() + + task_id = task_info.task_id + message = task_info.message + + lock_acquired = task_manager.acquire_processing_lock(task_id, timeout=1) + if not lock_acquired: + logger.error(f"Task {task_id} failed to acquire processing lock") + task_manager.fail_task(task_id, "Failed to acquire processing lock") + return + + try: + task_manager.start_task(task_id) + + if task_info.stop_event.is_set(): + logger.info(f"Task {task_id} cancelled before processing") + task_manager.fail_task(task_id, "Task cancelled") + return + + from ..schema import ImageTaskRequest + + if isinstance(message, ImageTaskRequest): + generation_service = services.image_service + else: + generation_service = services.video_service + + result = await generation_service.generate_with_stop_event(message, task_info.stop_event) + + if result: + task_manager.complete_task(task_id, result.save_result_path) + logger.info(f"Task {task_id} completed successfully") + else: + if task_info.stop_event.is_set(): + task_manager.fail_task(task_id, "Task cancelled during processing") + logger.info(f"Task {task_id} cancelled during processing") + else: + task_manager.fail_task(task_id, "Generation failed") + logger.error(f"Task {task_id} generation failed") + + except Exception as e: + logger.exception(f"Task {task_id} processing failed: {str(e)}") + task_manager.fail_task(task_id, str(e)) + finally: + if lock_acquired: + task_manager.release_processing_lock(task_id) + + def initialize_services(self, cache_dir: Path, inference_service: DistributedInferenceService): + container = ServiceContainer.get_instance() + container.initialize(cache_dir, inference_service, self.max_queue_size) + self._ensure_processing_thread_running() + + async def cleanup(self): + self.stop_processing.set() + if self.processing_thread and self.processing_thread.is_alive(): + self.processing_thread.join(timeout=5) + + services = get_services() + if services.file_service: + await services.file_service.cleanup() + + def get_app(self) -> FastAPI: + return self.app diff --git a/lightx2v/server/api/service_routes.py b/lightx2v/server/api/service_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..296700b0dbc1626017f1b9b4af880f422731777f --- /dev/null +++ b/lightx2v/server/api/service_routes.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter + +from ..task_manager import task_manager +from .deps import get_services + +router = APIRouter() + + +@router.get("/status") +async def get_service_status(): + return task_manager.get_service_status() + + +@router.get("/metadata") +async def get_service_metadata(): + services = get_services() + assert services.inference_service is not None, "Inference service is not initialized" + return services.inference_service.server_metadata() diff --git a/lightx2v/server/api/tasks/__init__.py b/lightx2v/server/api/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e20f9071670138df9eb1a5f2fddbd8ec8b0f4d53 --- /dev/null +++ b/lightx2v/server/api/tasks/__init__.py @@ -0,0 +1,9 @@ +from .common import router as common_router +from .image import router as image_router +from .video import router as video_router + +__all__ = [ + "common_router", + "video_router", + "image_router", +] diff --git a/lightx2v/server/api/tasks/common.py b/lightx2v/server/api/tasks/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3aac65439143499002deac50606f44601f1993be --- /dev/null +++ b/lightx2v/server/api/tasks/common.py @@ -0,0 +1,147 @@ +import gc +from pathlib import Path + +import torch +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from loguru import logger + +from ...schema import StopTaskResponse +from ...task_manager import TaskStatus, task_manager +from ..deps import get_services + +router = APIRouter() + + +def _stream_file_response(file_path: Path, filename: str | None = None) -> StreamingResponse: + services = get_services() + assert services.file_service is not None, "File service is not initialized" + + try: + resolved_path = file_path.resolve() + + if not str(resolved_path).startswith(str(services.file_service.output_video_dir.resolve())): + raise HTTPException(status_code=403, detail="Access to this file is not allowed") + + if not resolved_path.exists() or not resolved_path.is_file(): + raise HTTPException(status_code=404, detail=f"File not found: {file_path}") + + file_size = resolved_path.stat().st_size + actual_filename = filename or resolved_path.name + + mime_type = "application/octet-stream" + if actual_filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")): + mime_type = "video/mp4" + elif actual_filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")): + mime_type = "image/jpeg" + + headers = { + "Content-Disposition": f'attachment; filename="{actual_filename}"', + "Content-Length": str(file_size), + "Accept-Ranges": "bytes", + } + + def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024): + with open(file_path, "rb") as file: + while chunk := file.read(chunk_size): + yield chunk + + return StreamingResponse( + file_stream_generator(str(resolved_path)), + media_type=mime_type, + headers=headers, + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error occurred while processing file stream response: {e}") + raise HTTPException(status_code=500, detail="File transfer failed") + + +@router.get("/") +async def list_tasks(): + return task_manager.get_all_tasks() + + +@router.get("/queue/status") +async def get_queue_status(): + services = get_services() + service_status = task_manager.get_service_status() + return { + "is_processing": task_manager.is_processing(), + "current_task": service_status.get("current_task"), + "pending_count": task_manager.get_pending_task_count(), + "active_count": task_manager.get_active_task_count(), + "queue_size": services.max_queue_size, + "queue_available": services.max_queue_size - task_manager.get_active_task_count(), + } + + +@router.get("/{task_id}/status") +async def get_task_status(task_id: str): + status = task_manager.get_task_status(task_id) + if not status: + raise HTTPException(status_code=404, detail="Task not found") + return status + + +@router.get("/{task_id}/result") +async def get_task_result(task_id: str): + services = get_services() + assert services.video_service is not None, "Video service is not initialized" + assert services.file_service is not None, "File service is not initialized" + + try: + task_status = task_manager.get_task_status(task_id) + + if not task_status: + raise HTTPException(status_code=404, detail="Task not found") + + if task_status.get("status") != TaskStatus.COMPLETED.value: + raise HTTPException(status_code=404, detail="Task not completed") + + save_result_path = task_status.get("save_result_path") + if not save_result_path: + raise HTTPException(status_code=404, detail="Task result file does not exist") + + full_path = Path(save_result_path) + if not full_path.is_absolute(): + full_path = services.file_service.output_video_dir / save_result_path + + return _stream_file_response(full_path) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error occurred while getting task result: {e}") + raise HTTPException(status_code=500, detail="Failed to get task result") + + +@router.delete("/{task_id}", response_model=StopTaskResponse) +async def stop_task(task_id: str): + try: + if task_manager.cancel_task(task_id): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info(f"Task {task_id} stopped successfully.") + return StopTaskResponse(stop_status="success", reason="Task stopped successfully.") + else: + return StopTaskResponse(stop_status="do_nothing", reason="Task not found or already completed.") + except Exception as e: + logger.error(f"Error occurred while stopping task {task_id}: {str(e)}") + return StopTaskResponse(stop_status="error", reason=str(e)) + + +@router.delete("/all/running", response_model=StopTaskResponse) +async def stop_all_running_tasks(): + try: + task_manager.cancel_all_tasks() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("All tasks stopped successfully.") + return StopTaskResponse(stop_status="success", reason="All tasks stopped successfully.") + except Exception as e: + logger.error(f"Error occurred while stopping all tasks: {str(e)}") + return StopTaskResponse(stop_status="error", reason=str(e)) diff --git a/lightx2v/server/api/tasks/image.py b/lightx2v/server/api/tasks/image.py new file mode 100644 index 0000000000000000000000000000000000000000..791d8feb8091833d444e05a7ca8ebdd0322a49cf --- /dev/null +++ b/lightx2v/server/api/tasks/image.py @@ -0,0 +1,97 @@ +import asyncio +import uuid +from pathlib import Path + +from fastapi import APIRouter, File, Form, HTTPException, UploadFile +from loguru import logger + +from ...schema import ImageTaskRequest, TaskResponse +from ...task_manager import task_manager +from ..deps import get_services, validate_url_async + +router = APIRouter() + + +def _write_file_sync(file_path: Path, content: bytes) -> None: + with open(file_path, "wb") as buffer: + buffer.write(content) + + +@router.post("/", response_model=TaskResponse) +async def create_image_task(message: ImageTaskRequest): + try: + if hasattr(message, "image_path") and message.image_path and message.image_path.startswith("http"): + if not await validate_url_async(message.image_path): + raise HTTPException(status_code=400, detail=f"Image URL is not accessible: {message.image_path}") + + task_id = task_manager.create_task(message) + message.task_id = task_id + + return TaskResponse( + task_id=task_id, + task_status="pending", + save_result_path=message.save_result_path, + ) + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) + except Exception as e: + logger.error(f"Failed to create image task: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/form", response_model=TaskResponse) +async def create_image_task_form( + image_file: UploadFile = File(None), + prompt: str = Form(default=""), + save_result_path: str = Form(default=""), + use_prompt_enhancer: bool = Form(default=False), + negative_prompt: str = Form(default=""), + infer_steps: int = Form(default=5), + seed: int = Form(default=42), + aspect_ratio: str = Form(default="16:9"), +): + services = get_services() + assert services.file_service is not None, "File service is not initialized" + + async def save_file_async(file: UploadFile, target_dir: Path) -> str: + if not file or not file.filename: + return "" + + file_extension = Path(file.filename).suffix + unique_filename = f"{uuid.uuid4()}{file_extension}" + file_path = target_dir / unique_filename + + content = await file.read() + await asyncio.to_thread(_write_file_sync, file_path, content) + + return str(file_path) + + image_path = "" + if image_file and image_file.filename: + image_path = await save_file_async(image_file, services.file_service.input_image_dir) + + message = ImageTaskRequest( + prompt=prompt, + use_prompt_enhancer=use_prompt_enhancer, + negative_prompt=negative_prompt, + image_path=image_path, + save_result_path=save_result_path, + infer_steps=infer_steps, + seed=seed, + aspect_ratio=aspect_ratio, + ) + + try: + task_id = task_manager.create_task(message) + message.task_id = task_id + + return TaskResponse( + task_id=task_id, + task_status="pending", + save_result_path=message.save_result_path, + ) + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) + except Exception as e: + logger.error(f"Failed to create image form task: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/lightx2v/server/api/tasks/video.py b/lightx2v/server/api/tasks/video.py new file mode 100644 index 0000000000000000000000000000000000000000..f8f1797cdf6fd13e9829ffb5f7029c998f05f563 --- /dev/null +++ b/lightx2v/server/api/tasks/video.py @@ -0,0 +1,109 @@ +import asyncio +import uuid +from pathlib import Path + +from fastapi import APIRouter, File, Form, HTTPException, UploadFile +from loguru import logger + +from ...schema import TaskResponse, VideoTaskRequest +from ...task_manager import task_manager +from ..deps import get_services, validate_url_async + +router = APIRouter() + + +def _write_file_sync(file_path: Path, content: bytes) -> None: + with open(file_path, "wb") as buffer: + buffer.write(content) + + +@router.post("/", response_model=TaskResponse) +async def create_video_task(message: VideoTaskRequest): + try: + if hasattr(message, "image_path") and message.image_path and message.image_path.startswith("http"): + if not await validate_url_async(message.image_path): + raise HTTPException(status_code=400, detail=f"Image URL is not accessible: {message.image_path}") + + task_id = task_manager.create_task(message) + message.task_id = task_id + + return TaskResponse( + task_id=task_id, + task_status="pending", + save_result_path=message.save_result_path, + ) + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) + except Exception as e: + logger.error(f"Failed to create video task: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/form", response_model=TaskResponse) +async def create_video_task_form( + image_file: UploadFile = File(...), + prompt: str = Form(default=""), + save_result_path: str = Form(default=""), + use_prompt_enhancer: bool = Form(default=False), + negative_prompt: str = Form(default=""), + num_fragments: int = Form(default=1), + infer_steps: int = Form(default=5), + target_video_length: int = Form(default=81), + seed: int = Form(default=42), + audio_file: UploadFile = File(None), + video_duration: int = Form(default=5), + target_fps: int = Form(default=16), +): + services = get_services() + assert services.file_service is not None, "File service is not initialized" + + async def save_file_async(file: UploadFile, target_dir: Path) -> str: + if not file or not file.filename: + return "" + + file_extension = Path(file.filename).suffix + unique_filename = f"{uuid.uuid4()}{file_extension}" + file_path = target_dir / unique_filename + + content = await file.read() + await asyncio.to_thread(_write_file_sync, file_path, content) + + return str(file_path) + + image_path = "" + if image_file and image_file.filename: + image_path = await save_file_async(image_file, services.file_service.input_image_dir) + + audio_path = "" + if audio_file and audio_file.filename: + audio_path = await save_file_async(audio_file, services.file_service.input_audio_dir) + + message = VideoTaskRequest( + prompt=prompt, + use_prompt_enhancer=use_prompt_enhancer, + negative_prompt=negative_prompt, + image_path=image_path, + num_fragments=num_fragments, + save_result_path=save_result_path, + infer_steps=infer_steps, + target_video_length=target_video_length, + seed=seed, + audio_path=audio_path, + video_duration=video_duration, + target_fps=target_fps, + ) + + try: + task_id = task_manager.create_task(message) + message.task_id = task_id + + return TaskResponse( + task_id=task_id, + task_status="pending", + save_result_path=message.save_result_path, + ) + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) + except Exception as e: + logger.error(f"Failed to create video form task: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/lightx2v/server/config.py b/lightx2v/server/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a1abb6190aaf5ca64746a98c0256d724abf59d4e --- /dev/null +++ b/lightx2v/server/config.py @@ -0,0 +1,63 @@ +import os +from dataclasses import dataclass +from pathlib import Path + +from loguru import logger + + +@dataclass +class ServerConfig: + host: str = "0.0.0.0" + port: int = 8000 + max_queue_size: int = 10 + + task_timeout: int = 300 + task_history_limit: int = 1000 + + http_timeout: int = 30 + http_max_retries: int = 3 + + cache_dir: str = str(Path(__file__).parent.parent / "server_cache") + max_upload_size: int = 500 * 1024 * 1024 # 500MB + + @classmethod + def from_env(cls) -> "ServerConfig": + config = cls() + + if env_host := os.environ.get("LIGHTX2V_HOST"): + config.host = env_host + + if env_port := os.environ.get("LIGHTX2V_PORT"): + try: + config.port = int(env_port) + except ValueError: + logger.warning(f"Invalid port in environment: {env_port}") + + if env_queue_size := os.environ.get("LIGHTX2V_MAX_QUEUE_SIZE"): + try: + config.max_queue_size = int(env_queue_size) + except ValueError: + logger.warning(f"Invalid max queue size: {env_queue_size}") + + # MASTER_ADDR is now managed by torchrun, no need to set manually + + if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"): + config.cache_dir = env_cache_dir + + return config + + def validate(self) -> bool: + valid = True + + if self.max_queue_size <= 0: + logger.error("max_queue_size must be positive") + valid = False + + if self.task_timeout <= 0: + logger.error("task_timeout must be positive") + valid = False + + return valid + + +server_config = ServerConfig.from_env() diff --git a/lightx2v/server/main.py b/lightx2v/server/main.py new file mode 100644 index 0000000000000000000000000000000000000000..790ef6bc8ba155d6820e50bd7a054ce9a827c25a --- /dev/null +++ b/lightx2v/server/main.py @@ -0,0 +1,59 @@ +import os +import sys +from pathlib import Path + +import uvicorn +from loguru import logger + +from .api import ApiServer +from .config import server_config +from .services import DistributedInferenceService + + +def run_server(args): + inference_service = None + try: + rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + + logger.info(f"Starting LightX2V server (Rank {rank}/{world_size})...") + + if hasattr(args, "host") and args.host: + server_config.host = args.host + if hasattr(args, "port") and args.port: + server_config.port = args.port + + if not server_config.validate(): + raise RuntimeError("Invalid server configuration") + + inference_service = DistributedInferenceService() + if not inference_service.start_distributed_inference(args): + raise RuntimeError("Failed to start distributed inference service") + logger.info(f"Rank {rank}: Inference service started successfully") + + if rank == 0: + cache_dir = Path(server_config.cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + + api_server = ApiServer(max_queue_size=server_config.max_queue_size) + api_server.initialize_services(cache_dir, inference_service) + + app = api_server.get_app() + + logger.info(f"Starting FastAPI server on {server_config.host}:{server_config.port}") + uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info") + else: + logger.info(f"Rank {rank}: Starting worker loop") + import asyncio + + asyncio.run(inference_service.run_worker_loop()) + + except KeyboardInterrupt: + logger.info(f"Server rank {rank} interrupted by user") + if inference_service: + inference_service.stop_distributed_inference() + except Exception as e: + logger.error(f"Server rank {rank} failed: {e}") + if inference_service: + inference_service.stop_distributed_inference() + sys.exit(1) diff --git a/lightx2v/server/media/__init__.py b/lightx2v/server/media/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bed40c9531d692f4c34e0631a1c6f2164eaa82b2 --- /dev/null +++ b/lightx2v/server/media/__init__.py @@ -0,0 +1,13 @@ +from .audio import AudioHandler, is_base64_audio, save_base64_audio +from .base import MediaHandler +from .image import ImageHandler, is_base64_image, save_base64_image + +__all__ = [ + "MediaHandler", + "ImageHandler", + "AudioHandler", + "is_base64_image", + "save_base64_image", + "is_base64_audio", + "save_base64_audio", +] diff --git a/lightx2v/server/media/audio.py b/lightx2v/server/media/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..e72a1c8b259f5b147be5029b7d3240c966a6740e --- /dev/null +++ b/lightx2v/server/media/audio.py @@ -0,0 +1,74 @@ +from typing import Dict + +from .base import MediaHandler + + +class AudioHandler(MediaHandler): + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def get_media_signatures(self) -> Dict[bytes, str]: + return { + b"ID3": "mp3", + b"\xff\xfb": "mp3", + b"\xff\xf3": "mp3", + b"\xff\xf2": "mp3", + b"OggS": "ogg", + b"fLaC": "flac", + } + + def get_data_url_prefix(self) -> str: + return "data:audio/" + + def get_data_url_pattern(self) -> str: + return r"data:audio/(\w+);base64,(.+)" + + def get_default_extension(self) -> str: + return "mp3" + + def is_base64(self, data: str) -> bool: + if data.startswith(self.get_data_url_prefix()): + return True + + try: + import base64 + + if len(data) % 4 == 0: + base64.b64decode(data, validate=True) + decoded = base64.b64decode(data[:100]) + for signature in self.get_media_signatures().keys(): + if decoded.startswith(signature): + return True + if decoded.startswith(b"RIFF") and b"WAVE" in decoded[:12]: + return True + if decoded[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]: + return True + except Exception: + return False + + return False + + def detect_extension(self, data: bytes) -> str: + for signature, ext in self.get_media_signatures().items(): + if data.startswith(signature): + return ext + if data.startswith(b"RIFF") and b"WAVE" in data[:12]: + return "wav" + if data[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]: + return "m4a" + return self.get_default_extension() + + +_handler = AudioHandler() + + +def is_base64_audio(data: str) -> bool: + return _handler.is_base64(data) + + +def save_base64_audio(base64_data: str, output_dir: str) -> str: + return _handler.save_base64(base64_data, output_dir) diff --git a/lightx2v/server/media/base.py b/lightx2v/server/media/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9145879f055b51df23c2f359bd92c6c113bdd2f0 --- /dev/null +++ b/lightx2v/server/media/base.py @@ -0,0 +1,86 @@ +import base64 +import os +import re +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, Optional, Tuple + +from loguru import logger + + +class MediaHandler(ABC): + @abstractmethod + def get_media_signatures(self) -> Dict[bytes, str]: + """Return the binary signatures of this media type and their corresponding file extensions.""" + pass + + @abstractmethod + def get_data_url_prefix(self) -> str: + """Return the data URL prefix, e.g. 'data:image/' or 'data:audio/'.""" + pass + + @abstractmethod + def get_data_url_pattern(self) -> str: + """Return the regex pattern for data URL.""" + pass + + @abstractmethod + def get_default_extension(self) -> str: + """Return the default extension for this media type.""" + pass + + def is_base64(self, data: str) -> bool: + if data.startswith(self.get_data_url_prefix()): + return True + + try: + if len(data) % 4 == 0: + base64.b64decode(data, validate=True) + decoded = base64.b64decode(data[:100]) + for signature in self.get_media_signatures().keys(): + if decoded.startswith(signature): + return True + except Exception as e: + logger.warning(f"Error checking base64 {self.__class__.__name__}: {e}") + return False + + return False + + def extract_base64_data(self, data: str) -> Tuple[str, Optional[str]]: + if data.startswith("data:"): + match = re.match(self.get_data_url_pattern(), data) + if match: + format_type = match.group(1) + base64_data = match.group(2) + return base64_data, format_type + + return data, None + + def detect_extension(self, data: bytes) -> str: + for signature, ext in self.get_media_signatures().items(): + if data.startswith(signature): + return ext + return self.get_default_extension() + + def save_base64(self, base64_data: str, output_dir: str) -> str: + Path(output_dir).mkdir(parents=True, exist_ok=True) + + data, format_type = self.extract_base64_data(base64_data) + file_id = str(uuid.uuid4()) + + try: + media_data = base64.b64decode(data) + except Exception as e: + raise ValueError(f"Invalid base64 data: {e}") + + if format_type: + ext = format_type + else: + ext = self.detect_extension(media_data) + + file_path = os.path.join(output_dir, f"{file_id}.{ext}") + with open(file_path, "wb") as f: + f.write(media_data) + + return file_path diff --git a/lightx2v/server/media/image.py b/lightx2v/server/media/image.py new file mode 100644 index 0000000000000000000000000000000000000000..3d50d6ba41edf64dbcb047666c5ecb0542a04ea6 --- /dev/null +++ b/lightx2v/server/media/image.py @@ -0,0 +1,68 @@ +from typing import Dict + +from .base import MediaHandler + + +class ImageHandler(MediaHandler): + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def get_media_signatures(self) -> Dict[bytes, str]: + return { + b"\x89PNG\r\n\x1a\n": "png", + b"\xff\xd8\xff": "jpg", + b"GIF87a": "gif", + b"GIF89a": "gif", + } + + def get_data_url_prefix(self) -> str: + return "data:image/" + + def get_data_url_pattern(self) -> str: + return r"data:image/(\w+);base64,(.+)" + + def get_default_extension(self) -> str: + return "png" + + def is_base64(self, data: str) -> bool: + if data.startswith(self.get_data_url_prefix()): + return True + + try: + import base64 + + if len(data) % 4 == 0: + base64.b64decode(data, validate=True) + decoded = base64.b64decode(data[:100]) + for signature in self.get_media_signatures().keys(): + if decoded.startswith(signature): + return True + if len(decoded) > 12 and decoded[8:12] == b"WEBP": + return True + except Exception: + return False + + return False + + def detect_extension(self, data: bytes) -> str: + for signature, ext in self.get_media_signatures().items(): + if data.startswith(signature): + return ext + if len(data) > 12 and data[8:12] == b"WEBP": + return "webp" + return self.get_default_extension() + + +_handler = ImageHandler() + + +def is_base64_image(data: str) -> bool: + return _handler.is_base64(data) + + +def save_base64_image(base64_data: str, output_dir: str) -> str: + return _handler.save_base64(base64_data, output_dir) diff --git a/lightx2v/server/metrics/__init__.py b/lightx2v/server/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..947bf3876cca3a22e147bf9e6fe10fa61f87663d --- /dev/null +++ b/lightx2v/server/metrics/__init__.py @@ -0,0 +1,6 @@ +# -*-coding=utf-8-*- + +from .metrics import server_process +from .monitor import Monitor + +monitor_cli = Monitor() diff --git a/lightx2v/server/metrics/metrics.py b/lightx2v/server/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..bba3c31f7fcc52f5974e41e98085999283f94474 --- /dev/null +++ b/lightx2v/server/metrics/metrics.py @@ -0,0 +1,361 @@ +# -*-coding=utf-8-*- +import threading +from typing import List, Tuple + +from loguru import logger +from prometheus_client import Counter, Gauge, Histogram, start_http_server +from pydantic import BaseModel + + +class MetricsConfig(BaseModel): + name: str + desc: str + type_: str + labels: List[str] = [] + buckets: Tuple[float, ...] = ( + 0.1, + 0.5, + 1.0, + 2.5, + 5.0, + 10.0, + 30.0, + 60.0, + 120.0, + 300.0, + 600.0, + ) + + +HYBRID_10_50MS_BUCKETS = ( + 0.001, # 1ms + 0.005, # 5ms + 0.008, # 8ms + 0.010, # 10ms + 0.012, # 12ms + 0.015, # 15ms + 0.020, # 20ms + 0.025, # 25ms + 0.030, # 30ms + 0.035, # 35ms + 0.040, # 40ms + 0.045, # 45ms + 0.050, # 50ms + 0.060, # 60ms + 0.075, # 75ms + 0.100, # 100ms + 0.150, # 150ms + 0.200, # 200ms + 0.500, # 500ms + 1.0, # 1s + 2.0, # 2s + 5.0, # 5s + 10.0, # 10s +) + +HYBRID_60_120MS_BUCKETS = ( + 0.010, # 10ms + 0.030, # 30ms + 0.050, # 50ms + 0.060, # 60ms + 0.065, # 65ms + 0.070, # 70ms + 0.075, # 75ms + 0.080, # 80ms + 0.085, # 85ms + 0.090, # 90ms + 0.095, # 95ms + 0.100, # 100ms + 0.110, # 110ms + 0.120, # 120ms + 0.150, # 150ms + 0.200, # 200ms + 0.300, # 200ms + 0.400, # 200ms + 0.500, # 500ms + 1.0, # 1s + 2.0, # 2s + 5.0, # 5s + 10.0, # 10s +) + +HYBRID_300MS_1600MS_BUCKETS = ( + 0.010, # 10ms + 0.050, # 50ms + 0.100, # 100ms + 0.150, # 150ms + 0.200, # 200ms + 0.250, # 250ms + 0.300, # 300ms + 0.350, # 350ms + 0.400, # 400ms + 0.450, # 450ms + 0.500, # 500ms + 0.550, # 550ms + 0.600, # 600ms + 0.650, # 650ms + 0.700, # 700ms + 0.750, # 750ms + 0.800, # 800ms + 0.850, # 850ms + 0.900, # 900ms + 0.950, # 950ms + 1.000, # 1s + 1.100, # 1.1s + 1.200, # 1.2s + 1.300, # 1.3s + 1.400, # 1.4s + 1.500, # 1.5s + 1.600, # 1.6s + 2.000, # 2s + 3.000, # 3s +) + +HYBRID_1_30S_BUCKETS = ( + 1.0, # 1s + 1.5, # 1.5s + 2.0, # 2s + 2.5, # 2.5s + 3.0, # 3s + 3.5, # 3.5s + 4.0, # 4s + 4.5, # 4.5s + 5.0, # 5s + 5.5, # 5.5s + 6.0, # 6s + 6.5, # 6.5s + 7.0, # 7s + 7.5, # 7.5s + 8.0, # 8s + 8.5, # 8.5s + 9.0, # 9s + 9.5, # 9.5s + 10.0, # 10s + 11.0, # 11s + 12.0, # 12s + 13.0, # 13s + 15.0, # 15s + 16.0, # 16s + 17.0, # 17s + 18.0, # 18s + 19.0, # 19s + 20.0, # 20s + 21.0, # 21s + 22.0, # 22s + 23.0, # 23s + 25.0, # 25s + 30.0, # 30s +) + +HYBRID_30_900S_BUCKETS = ( + 1.0, # 1s + 5.0, # 5s + 10.0, # 10s + 20.0, # 20s + 30.0, # 30s + 35.0, # 35s + 40.0, # 40s + 50.0, # 50s + 60.0, # 1min + 70.0, # 1min10s + 80.0, # 1min20s + 90.0, # 1min30s + 100.0, # 1min40s + 110.0, # 1min50s + 120.0, # 2min + 130.0, # 2min10s + 140.0, # 2min20s + 150.0, # 2min30s + 180.0, # 3min + 240.0, # 4min + 300.0, # 5min + 600.0, # 10min + 900.0, # 15min +) + + +METRICS_INFO = { + "lightx2v_worker_request_count": MetricsConfig( + name="lightx2v_worker_request_count", + desc="The total number of requests", + type_="counter", + ), + "lightx2v_worker_request_success": MetricsConfig( + name="lightx2v_worker_request_success", + desc="The number of successful requests", + type_="counter", + ), + "lightx2v_worker_request_failure": MetricsConfig( + name="lightx2v_worker_request_failure", + desc="The number of failed requests", + type_="counter", + labels=["error_type"], + ), + "lightx2v_worker_request_duration": MetricsConfig( + name="lightx2v_worker_request_duration", + desc="Duration of the request (s)", + type_="histogram", + labels=["model_cls"], + ), + "lightx2v_input_audio_len": MetricsConfig( + name="lightx2v_input_audio_len", + desc="Length of the input audio", + type_="histogram", + buckets=( + 1.0, + 2.0, + 3.0, + 5.0, + 7.0, + 10.0, + 20.0, + 30.0, + 45.0, + 60.0, + 75.0, + 90.0, + 105.0, + 120.0, + ), + ), + "lightx2v_input_image_len": MetricsConfig( + name="lightx2v_input_image_len", + desc="Length of the input image", + type_="histogram", + ), + "lightx2v_input_prompt_len": MetricsConfig( + name="lightx2v_input_prompt_len", + desc="Length of the input prompt", + type_="histogram", + ), + "lightx2v_load_model_duration": MetricsConfig( + name="lightx2v_load_model_duration", + desc="Duration of load model (s)", + type_="histogram", + ), + "lightx2v_run_per_step_dit_duration": MetricsConfig( + name="lightx2v_run_per_step_dit_duration", + desc="Duration of run per step Dit (s)", + type_="histogram", + labels=["step_no", "total_steps"], + buckets=HYBRID_30_900S_BUCKETS, + ), + "lightx2v_run_text_encode_duration": MetricsConfig( + name="lightx2v_run_text_encode_duration", + desc="Duration of run text encode (s)", + type_="histogram", + labels=["model_cls"], + buckets=HYBRID_1_30S_BUCKETS, + ), + "lightx2v_run_img_encode_duration": MetricsConfig( + name="lightx2v_run_img_encode_duration", + desc="Duration of run img encode (s)", + type_="histogram", + labels=["model_cls"], + buckets=HYBRID_10_50MS_BUCKETS, + ), + "lightx2v_run_vae_encoder_image_duration": MetricsConfig( + name="lightx2v_run_vae_encoder_image_duration", + desc="Duration of run vae encode for image (s)", + type_="histogram", + labels=["model_cls"], + buckets=HYBRID_60_120MS_BUCKETS, + ), + "lightx2v_run_vae_encoder_pre_latent_duration": MetricsConfig( + name="lightx2v_run_vae_encoder_pre_latent_duration", + desc="Duration of run vae encode for pre latents (s)", + type_="histogram", + labels=["model_cls"], + buckets=HYBRID_1_30S_BUCKETS, + ), + "lightx2v_run_vae_decode_duration": MetricsConfig( + name="lightx2v_run_vae_decode_duration", + desc="Duration of run vae decode (s)", + type_="histogram", + labels=["model_cls"], + buckets=HYBRID_1_30S_BUCKETS, + ), + "lightx2v_run_init_run_segment_duration": MetricsConfig( + name="lightx2v_run_init_run_segment_duration", + desc="Duration of run init_run_segment (s)", + type_="histogram", + labels=["model_cls"], + buckets=HYBRID_1_30S_BUCKETS, + ), + "lightx2v_run_end_run_segment_duration": MetricsConfig( + name="lightx2v_run_end_run_segment_duration", + desc="Duration of run end_run_segment (s)", + type_="histogram", + labels=["model_cls"], + buckets=HYBRID_300MS_1600MS_BUCKETS, + ), + "lightx2v_run_segments_end2end_duration": MetricsConfig( + name="lightx2v_run_segments_end2end_duration", + desc="Duration of run segments end2end (s)", + type_="histogram", + labels=["model_cls"], + ), +} + + +class MetricsClient: + def __init__(self): + self.init_metrics() + + def init_metrics(self): + for metric_name, config in METRICS_INFO.items(): + if config.type_ == "counter": + self.register_counter(config.name, config.desc, config.labels) + elif config.type_ == "histogram": + self.register_histogram(config.name, config.desc, config.labels, buckets=config.buckets) + elif config.type_ == "gauge": + self.register_gauge(config.name, config.desc, config.labels) + else: + logger.warning(f"Unsupported metric type: {config.type_} for {metric_name}") + + def register_counter(self, name, desc, labels): + metric_instance = Counter(name, desc, labels) + setattr(self, name, metric_instance) + + def register_histogram(self, name, desc, labels, buckets=None): + buckets = buckets or ( + 0.1, + 0.5, + 1.0, + 2.5, + 5.0, + 10.0, + 30.0, + 60.0, + 120.0, + 300.0, + 600.0, + ) + metric_instance = Histogram(name, desc, labels, buckets=buckets) + setattr(self, name, metric_instance) + + def register_gauge(self, name, desc, labels): + metric_instance = Gauge(name, desc, labels) + setattr(self, name, metric_instance) + + +class MetricsServer: + def __init__(self, port=8000): + self.port = port + self.server_thread = None + + def start_server(self): + def run_server(): + start_http_server(self.port) + logger.info(f"Metrics server started on port {self.port}") + + self.server_thread = threading.Thread(target=run_server) + self.server_thread.daemon = True + self.server_thread.start() + + +def server_process(metric_port=8001): + metrics = MetricsServer( + port=metric_port, + ) + metrics.start_server() diff --git a/lightx2v/server/metrics/monitor.py b/lightx2v/server/metrics/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..5636d27687e66e4bf4c299edfcf4c1980fcc77a8 --- /dev/null +++ b/lightx2v/server/metrics/monitor.py @@ -0,0 +1,22 @@ +# -*-coding=utf-8-*- +import threading + +from .metrics import MetricsClient + + +class Monitor(MetricsClient): + _instance = None + _lock = threading.Lock() + _initialized = False # 添加初始化标志 + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, *args, **kwargs): + if not self.__class__._initialized: + super().__init__(*args, **kwargs) + self.__class__._initialized = True diff --git a/lightx2v/server/schema.py b/lightx2v/server/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..8d420766a52bcf960d78f57f380e1dfcc370a27b --- /dev/null +++ b/lightx2v/server/schema.py @@ -0,0 +1,73 @@ +import random +from typing import Optional + +from pydantic import BaseModel, Field + +from ..utils.generate_task_id import generate_task_id + + +def generate_random_seed() -> int: + return random.randint(0, 2**32 - 1) + + +class TalkObject(BaseModel): + audio: str = Field(..., description="Audio path") + mask: str = Field(..., description="Mask path") + + +class BaseTaskRequest(BaseModel): + task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)") + prompt: str = Field("", description="Generation prompt") + use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer") + negative_prompt: str = Field("", description="Negative prompt") + image_path: str = Field("", description="Base64 encoded image or URL") + save_result_path: str = Field("", description="Save result path (optional, defaults to task_id, suffix auto-detected)") + infer_steps: int = Field(5, description="Inference steps") + seed: int = Field(default_factory=generate_random_seed, description="Random seed (auto-generated if not set)") + + def __init__(self, **data): + super().__init__(**data) + if not self.save_result_path: + self.save_result_path = f"{self.task_id}" + + def get(self, key, default=None): + return getattr(self, key, default) + + +class VideoTaskRequest(BaseTaskRequest): + num_fragments: int = Field(1, description="Number of fragments") + target_video_length: int = Field(81, description="Target video length") + audio_path: str = Field("", description="Input audio path (Wan-Audio)") + video_duration: int = Field(5, description="Video duration (Wan-Audio)") + talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)") + target_fps: Optional[int] = Field(16, description="Target FPS for video frame interpolation (overrides config)") + resize_mode: Optional[str] = Field("adaptive", description="Resize mode (adaptive, keep_ratio_fixed_area, fixed_min_area, fixed_max_area, fixed_shape, fixed_min_side)") + + +class ImageTaskRequest(BaseTaskRequest): + aspect_ratio: str = Field("16:9", description="Output aspect ratio") + + +class TaskRequest(BaseTaskRequest): + num_fragments: int = Field(1, description="Number of fragments") + target_video_length: int = Field(81, description="Target video length (video only)") + audio_path: str = Field("", description="Input audio path (Wan-Audio)") + video_duration: int = Field(5, description="Video duration (Wan-Audio)") + talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)") + aspect_ratio: str = Field("16:9", description="Output aspect ratio (T2I only)") + target_fps: Optional[int] = Field(16, description="Target FPS for video frame interpolation (overrides config)") + + +class TaskStatusMessage(BaseModel): + task_id: str = Field(..., description="Task ID") + + +class TaskResponse(BaseModel): + task_id: str + task_status: str + save_result_path: str + + +class StopTaskResponse(BaseModel): + stop_status: str + reason: str diff --git a/lightx2v/server/services/__init__.py b/lightx2v/server/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd07e18c86f5b3170fdac29d7abd821a6f873774 --- /dev/null +++ b/lightx2v/server/services/__init__.py @@ -0,0 +1,11 @@ +from .file_service import FileService +from .generation import ImageGenerationService, VideoGenerationService +from .inference import DistributedInferenceService, TorchrunInferenceWorker + +__all__ = [ + "FileService", + "DistributedInferenceService", + "TorchrunInferenceWorker", + "VideoGenerationService", + "ImageGenerationService", +] diff --git a/lightx2v/server/services/distributed_utils.py b/lightx2v/server/services/distributed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5d66e32cc62f4ee53201e168b550a675d5529d --- /dev/null +++ b/lightx2v/server/services/distributed_utils.py @@ -0,0 +1,141 @@ +import os +import pickle +from datetime import timedelta +from typing import Any, Optional + +import torch +import torch.distributed as dist +from loguru import logger + + +class DistributedManager: + def __init__(self): + self.is_initialized = False + self.rank = 0 + self.world_size = 1 + self.device = "cpu" + self.task_pg = None + + CHUNK_SIZE = 1024 * 1024 + + def init_process_group(self) -> bool: + try: + self.rank = int(os.environ.get("LOCAL_RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + + if self.world_size > 1: + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method="env://") + logger.info(f"Setup backend: {backend}") + + task_timeout = timedelta(days=30) + self.task_pg = dist.new_group(backend="gloo", timeout=task_timeout) + logger.info("Created gloo process group for task distribution with 30-day timeout") + + if torch.cuda.is_available(): + torch.cuda.set_device(self.rank) + self.device = f"cuda:{self.rank}" + else: + self.device = "cpu" + else: + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + + self.is_initialized = True + logger.info(f"Rank {self.rank}/{self.world_size - 1} distributed environment initialized successfully") + return True + + except Exception as e: + logger.error(f"Rank {self.rank} distributed environment initialization failed: {str(e)}") + return False + + def cleanup(self): + try: + if dist.is_initialized(): + dist.destroy_process_group() + logger.info(f"Rank {self.rank} distributed environment cleaned up") + except Exception as e: + logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}") + finally: + self.is_initialized = False + self.task_pg = None + + def barrier(self): + if self.is_initialized: + if torch.cuda.is_available() and dist.get_backend() == "nccl": + dist.barrier(device_ids=[torch.cuda.current_device()]) + else: + dist.barrier() + + def is_rank_zero(self) -> bool: + return self.rank == 0 + + def _broadcast_byte_chunks(self, data_bytes: bytes) -> None: + total_length = len(data_bytes) + num_full_chunks = total_length // self.CHUNK_SIZE + remaining = total_length % self.CHUNK_SIZE + + for i in range(num_full_chunks): + start_idx = i * self.CHUNK_SIZE + end_idx = start_idx + self.CHUNK_SIZE + chunk = data_bytes[start_idx:end_idx] + task_tensor = torch.tensor(list(chunk), dtype=torch.uint8) + dist.broadcast(task_tensor, src=0, group=self.task_pg) + + if remaining: + chunk = data_bytes[-remaining:] + task_tensor = torch.tensor(list(chunk), dtype=torch.uint8) + dist.broadcast(task_tensor, src=0, group=self.task_pg) + + def _receive_byte_chunks(self, total_length: int) -> bytes: + if total_length <= 0: + return b"" + + received = bytearray() + remaining = total_length + + while remaining > 0: + chunk_length = min(self.CHUNK_SIZE, remaining) + task_tensor = torch.empty(chunk_length, dtype=torch.uint8) + dist.broadcast(task_tensor, src=0, group=self.task_pg) + received.extend(task_tensor.numpy()) + remaining -= chunk_length + + return bytes(received) + + def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]: + if not self.is_initialized: + return None + + if self.is_rank_zero(): + if task_data is None: + stop_signal = torch.tensor([1], dtype=torch.int32) + else: + stop_signal = torch.tensor([0], dtype=torch.int32) + + dist.broadcast(stop_signal, src=0, group=self.task_pg) + + if task_data is not None: + task_bytes = pickle.dumps(task_data) + task_length = torch.tensor([len(task_bytes)], dtype=torch.int32) + + dist.broadcast(task_length, src=0, group=self.task_pg) + self._broadcast_byte_chunks(task_bytes) + + return task_data + else: + return None + else: + stop_signal = torch.tensor([0], dtype=torch.int32) + dist.broadcast(stop_signal, src=0, group=self.task_pg) + + if stop_signal.item() == 1: + return None + else: + task_length = torch.tensor([0], dtype=torch.int32) + + dist.broadcast(task_length, src=0, group=self.task_pg) + total_length = int(task_length.item()) + + task_bytes = self._receive_byte_chunks(total_length) + task_data = pickle.loads(task_bytes) + return task_data diff --git a/lightx2v/server/services/file_service.py b/lightx2v/server/services/file_service.py new file mode 100644 index 0000000000000000000000000000000000000000..95c7264acb5b8360cece52c797abb299b1de4227 --- /dev/null +++ b/lightx2v/server/services/file_service.py @@ -0,0 +1,153 @@ +import asyncio +import uuid +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +import httpx +from loguru import logger + + +class FileService: + def __init__(self, cache_dir: Path): + self.cache_dir = cache_dir + self.input_image_dir = cache_dir / "inputs" / "imgs" + self.input_audio_dir = cache_dir / "inputs" / "audios" + self.output_video_dir = cache_dir / "outputs" + + self._http_client = None + self._client_lock = asyncio.Lock() + + self.max_retries = 3 + self.retry_delay = 1.0 + self.max_retry_delay = 10.0 + + for directory in [ + self.input_image_dir, + self.output_video_dir, + self.input_audio_dir, + ]: + directory.mkdir(parents=True, exist_ok=True) + + async def _get_http_client(self) -> httpx.AsyncClient: + async with self._client_lock: + if self._http_client is None or self._http_client.is_closed: + timeout = httpx.Timeout( + connect=10.0, + read=30.0, + write=10.0, + pool=5.0, + ) + limits = httpx.Limits(max_keepalive_connections=5, max_connections=10, keepalive_expiry=30.0) + self._http_client = httpx.AsyncClient(verify=False, timeout=timeout, limits=limits, follow_redirects=True) + return self._http_client + + async def _download_with_retry(self, url: str, max_retries: Optional[int] = None) -> httpx.Response: + if max_retries is None: + max_retries = self.max_retries + + last_exception = None + retry_delay = self.retry_delay + + for attempt in range(max_retries): + try: + client = await self._get_http_client() + response = await client.get(url) + + if response.status_code == 200: + return response + elif response.status_code >= 500: + logger.warning(f"Server error {response.status_code} for {url}, attempt {attempt + 1}/{max_retries}") + last_exception = httpx.HTTPStatusError(f"Server returned {response.status_code}", request=response.request, response=response) + else: + raise httpx.HTTPStatusError(f"Client error {response.status_code}", request=response.request, response=response) + + except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError) as e: + logger.warning(f"Connection error for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}") + last_exception = e + except httpx.HTTPStatusError as e: + if e.response and e.response.status_code < 500: + raise + last_exception = e + except Exception as e: + logger.error(f"Unexpected error downloading {url}: {str(e)}") + last_exception = e + + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay) + retry_delay = min(retry_delay * 2, self.max_retry_delay) + + error_msg = f"All {max_retries} connection attempts failed for {url}" + if last_exception: + error_msg += f": {str(last_exception)}" + raise httpx.ConnectError(error_msg) + + async def download_media(self, url: str, media_type: str = "image") -> Path: + try: + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL format: {url}") + + response = await self._download_with_retry(url) + + media_name = Path(parsed_url.path).name + if not media_name: + default_ext = "jpg" if media_type == "image" else "mp3" + media_name = f"{uuid.uuid4()}.{default_ext}" + + if media_type == "image": + target_dir = self.input_image_dir + else: + target_dir = self.input_audio_dir + + media_path = target_dir / media_name + media_path.parent.mkdir(parents=True, exist_ok=True) + + with open(media_path, "wb") as f: + f.write(response.content) + + logger.info(f"Successfully downloaded {media_type} from {url} to {media_path}") + return media_path + + except httpx.ConnectError as e: + logger.error(f"Connection error downloading {media_type} from {url}: {str(e)}") + raise ValueError(f"Failed to connect to {url}: {str(e)}") + except httpx.TimeoutException as e: + logger.error(f"Timeout downloading {media_type} from {url}: {str(e)}") + raise ValueError(f"Download timeout for {url}: {str(e)}") + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error downloading {media_type} from {url}: {str(e)}") + raise ValueError(f"HTTP error for {url}: {str(e)}") + except ValueError: + raise + except Exception as e: + logger.error(f"Unexpected error downloading {media_type} from {url}: {str(e)}") + raise ValueError(f"Failed to download {media_type} from {url}: {str(e)}") + + async def download_image(self, image_url: str) -> Path: + return await self.download_media(image_url, "image") + + async def download_audio(self, audio_url: str) -> Path: + return await self.download_media(audio_url, "audio") + + def save_uploaded_file(self, file_content: bytes, filename: str) -> Path: + file_extension = Path(filename).suffix + unique_filename = f"{uuid.uuid4()}{file_extension}" + file_path = self.input_image_dir / unique_filename + + with open(file_path, "wb") as f: + f.write(file_content) + + return file_path + + def get_output_path(self, save_result_path: str) -> Path: + video_path = Path(save_result_path) + if not video_path.is_absolute(): + return self.output_video_dir / save_result_path + return video_path + + async def cleanup(self): + async with self._client_lock: + if self._http_client and not self._http_client.is_closed: + await self._http_client.aclose() + self._http_client = None diff --git a/lightx2v/server/services/generation/__init__.py b/lightx2v/server/services/generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b52108656fb2d4f227454e82250ddfaa8cfe9d4 --- /dev/null +++ b/lightx2v/server/services/generation/__init__.py @@ -0,0 +1,9 @@ +from .base import BaseGenerationService +from .image import ImageGenerationService +from .video import VideoGenerationService + +__all__ = [ + "BaseGenerationService", + "VideoGenerationService", + "ImageGenerationService", +] diff --git a/lightx2v/server/services/generation/base.py b/lightx2v/server/services/generation/base.py new file mode 100644 index 0000000000000000000000000000000000000000..12b9629a8f7c8a2f37e7e7f6834925f30e137d72 --- /dev/null +++ b/lightx2v/server/services/generation/base.py @@ -0,0 +1,146 @@ +import json +import uuid +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from loguru import logger + +from ...media import is_base64_audio, is_base64_image, save_base64_audio, save_base64_image +from ...schema import TaskResponse +from ..file_service import FileService +from ..inference import DistributedInferenceService + + +class BaseGenerationService(ABC): + def __init__(self, file_service: FileService, inference_service: DistributedInferenceService): + self.file_service = file_service + self.inference_service = inference_service + + @abstractmethod + def get_output_extension(self) -> str: + pass + + @abstractmethod + def get_task_type(self) -> str: + pass + + def _is_target_task_type(self) -> bool: + if self.inference_service.worker and self.inference_service.worker.runner: + task_type = self.inference_service.worker.runner.config.get("task", "t2v") + return task_type in self.get_task_type().split(",") + return False + + async def _process_image_path(self, image_path: str, task_data: Dict[str, Any]) -> None: + if not image_path: + return + + if image_path.startswith("http"): + downloaded_path = await self.file_service.download_image(image_path) + task_data["image_path"] = str(downloaded_path) + elif is_base64_image(image_path): + saved_path = save_base64_image(image_path, str(self.file_service.input_image_dir)) + task_data["image_path"] = str(saved_path) + else: + task_data["image_path"] = image_path + + async def _process_audio_path(self, audio_path: str, task_data: Dict[str, Any]) -> None: + if not audio_path: + return + + if audio_path.startswith("http"): + downloaded_path = await self.file_service.download_audio(audio_path) + task_data["audio_path"] = str(downloaded_path) + elif is_base64_audio(audio_path): + saved_path = save_base64_audio(audio_path, str(self.file_service.input_audio_dir)) + task_data["audio_path"] = str(saved_path) + else: + task_data["audio_path"] = audio_path + + async def _process_talk_objects(self, talk_objects: list, task_data: Dict[str, Any]) -> None: + if not talk_objects: + return + + task_data["talk_objects"] = [{} for _ in range(len(talk_objects))] + + for index, talk_object in enumerate(talk_objects): + if talk_object.audio.startswith("http"): + audio_path = await self.file_service.download_audio(talk_object.audio) + task_data["talk_objects"][index]["audio"] = str(audio_path) + elif is_base64_audio(talk_object.audio): + audio_path = save_base64_audio(talk_object.audio, str(self.file_service.input_audio_dir)) + task_data["talk_objects"][index]["audio"] = str(audio_path) + else: + task_data["talk_objects"][index]["audio"] = talk_object.audio + + if talk_object.mask.startswith("http"): + mask_path = await self.file_service.download_image(talk_object.mask) + task_data["talk_objects"][index]["mask"] = str(mask_path) + elif is_base64_image(talk_object.mask): + mask_path = save_base64_image(talk_object.mask, str(self.file_service.input_image_dir)) + task_data["talk_objects"][index]["mask"] = str(mask_path) + else: + task_data["talk_objects"][index]["mask"] = talk_object.mask + + temp_path = self.file_service.cache_dir / uuid.uuid4().hex[:8] + temp_path.mkdir(parents=True, exist_ok=True) + task_data["audio_path"] = str(temp_path) + + config_path = temp_path / "config.json" + with open(config_path, "w") as f: + json.dump({"talk_objects": task_data["talk_objects"]}, f) + + def _prepare_output_path(self, save_result_path: str, task_data: Dict[str, Any]) -> None: + actual_save_path = self.file_service.get_output_path(save_result_path) + if not actual_save_path.suffix: + actual_save_path = actual_save_path.with_suffix(self.get_output_extension()) + task_data["save_result_path"] = str(actual_save_path) + task_data["video_path"] = actual_save_path.name + + async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]: + try: + task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"} + task_data["task_id"] = message.task_id + + if stop_event.is_set(): + logger.info(f"Task {message.task_id} cancelled before processing") + return None + + if hasattr(message, "image_path") and message.image_path: + await self._process_image_path(message.image_path, task_data) + logger.info(f"Task {message.task_id} image path: {task_data.get('image_path')}") + + if hasattr(message, "audio_path") and message.audio_path: + await self._process_audio_path(message.audio_path, task_data) + logger.info(f"Task {message.task_id} audio path: {task_data.get('audio_path')}") + + if hasattr(message, "talk_objects") and message.talk_objects: + await self._process_talk_objects(message.talk_objects, task_data) + + self._prepare_output_path(message.save_result_path, task_data) + task_data["seed"] = message.seed + task_data["resize_mode"] = message.resize_mode + + result = await self.inference_service.submit_task_async(task_data) + + if result is None: + if stop_event.is_set(): + logger.info(f"Task {message.task_id} cancelled during processing") + return None + raise RuntimeError("Task processing failed") + + if result.get("status") == "success": + actual_save_path = self.file_service.get_output_path(message.save_result_path) + if not actual_save_path.suffix: + actual_save_path = actual_save_path.with_suffix(self.get_output_extension()) + return TaskResponse( + task_id=message.task_id, + task_status="completed", + save_result_path=actual_save_path.name, + ) + else: + error_msg = result.get("error", "Inference failed") + raise RuntimeError(error_msg) + + except Exception as e: + logger.exception(f"Task {message.task_id} processing failed: {str(e)}") + raise diff --git a/lightx2v/server/services/generation/image.py b/lightx2v/server/services/generation/image.py new file mode 100644 index 0000000000000000000000000000000000000000..badb6954f7a74535f2d0a50172a0f3a26e748a77 --- /dev/null +++ b/lightx2v/server/services/generation/image.py @@ -0,0 +1,66 @@ +from typing import Any, Optional + +from loguru import logger + +from ...schema import TaskResponse +from ..file_service import FileService +from ..inference import DistributedInferenceService +from .base import BaseGenerationService + + +class ImageGenerationService(BaseGenerationService): + def __init__(self, file_service: FileService, inference_service: DistributedInferenceService): + super().__init__(file_service, inference_service) + + def get_output_extension(self) -> str: + return ".png" + + def get_task_type(self) -> str: + return "t2i,i2i" + + async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]: + try: + task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"} + task_data["task_id"] = message.task_id + + if hasattr(message, "aspect_ratio"): + task_data["aspect_ratio"] = message.aspect_ratio + + if stop_event.is_set(): + logger.info(f"Task {message.task_id} cancelled before processing") + return None + + if hasattr(message, "image_path") and message.image_path: + await self._process_image_path(message.image_path, task_data) + logger.info(f"Task {message.task_id} image path: {task_data.get('image_path')}") + + self._prepare_output_path(message.save_result_path, task_data) + task_data["seed"] = message.seed + + result = await self.inference_service.submit_task_async(task_data) + + if result is None: + if stop_event.is_set(): + logger.info(f"Task {message.task_id} cancelled during processing") + return None + raise RuntimeError("Task processing failed") + + if result.get("status") == "success": + actual_save_path = self.file_service.get_output_path(message.save_result_path) + if not actual_save_path.suffix: + actual_save_path = actual_save_path.with_suffix(self.get_output_extension()) + return TaskResponse( + task_id=message.task_id, + task_status="completed", + save_result_path=actual_save_path.name, + ) + else: + error_msg = result.get("error", "Inference failed") + raise RuntimeError(error_msg) + + except Exception as e: + logger.exception(f"Task {message.task_id} processing failed: {str(e)}") + raise + + async def generate_image_with_stop_event(self, message: Any, stop_event) -> Optional[Any]: + return await self.generate_with_stop_event(message, stop_event) diff --git a/lightx2v/server/services/generation/video.py b/lightx2v/server/services/generation/video.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc4651eb66f22581d2cc9a48d7626378e1fa6fd --- /dev/null +++ b/lightx2v/server/services/generation/video.py @@ -0,0 +1,22 @@ +from typing import Any, Optional + +from ..file_service import FileService +from ..inference import DistributedInferenceService +from .base import BaseGenerationService + + +class VideoGenerationService(BaseGenerationService): + def __init__(self, file_service: FileService, inference_service: DistributedInferenceService): + super().__init__(file_service, inference_service) + + def get_output_extension(self) -> str: + return ".mp4" + + def get_task_type(self) -> str: + return "t2v,i2v,s2v" + + async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]: + return await super().generate_with_stop_event(message, stop_event) + + async def generate_video_with_stop_event(self, message: Any, stop_event) -> Optional[Any]: + return await self.generate_with_stop_event(message, stop_event) diff --git a/lightx2v/server/services/inference/__init__.py b/lightx2v/server/services/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8199414fa17c0176947c83e90a250b187615b270 --- /dev/null +++ b/lightx2v/server/services/inference/__init__.py @@ -0,0 +1,7 @@ +from .service import DistributedInferenceService +from .worker import TorchrunInferenceWorker + +__all__ = [ + "TorchrunInferenceWorker", + "DistributedInferenceService", +] diff --git a/lightx2v/server/services/inference/service.py b/lightx2v/server/services/inference/service.py new file mode 100644 index 0000000000000000000000000000000000000000..20cf0e7ee907054e1d9a8267083c8f6ea5f188b1 --- /dev/null +++ b/lightx2v/server/services/inference/service.py @@ -0,0 +1,81 @@ +from typing import Optional + +from loguru import logger + +from .worker import TorchrunInferenceWorker + + +class DistributedInferenceService: + def __init__(self): + self.worker = None + self.is_running = False + self.args = None + + def start_distributed_inference(self, args) -> bool: + self.args = args + if self.is_running: + logger.warning("Distributed inference service is already running") + return True + + try: + self.worker = TorchrunInferenceWorker() + + if not self.worker.init(args): + raise RuntimeError("Worker initialization failed") + + self.is_running = True + logger.info(f"Rank {self.worker.rank} inference service started successfully") + return True + + except Exception as e: + logger.error(f"Error starting inference service: {str(e)}") + self.stop_distributed_inference() + return False + + def stop_distributed_inference(self): + if not self.is_running: + return + + try: + if self.worker: + self.worker.cleanup() + logger.info("Inference service stopped") + except Exception as e: + logger.error(f"Error stopping inference service: {str(e)}") + finally: + self.worker = None + self.is_running = False + + async def submit_task_async(self, task_data: dict) -> Optional[dict]: + if not self.is_running or not self.worker: + logger.error("Inference service is not started") + return None + + if self.worker.rank != 0: + return None + + try: + if self.worker.processing: + logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}") + + self.worker.processing = True + result = await self.worker.process_request(task_data) + self.worker.processing = False + return result + except Exception as e: + self.worker.processing = False + logger.error(f"Failed to process task: {str(e)}") + return { + "task_id": task_data.get("task_id", "unknown"), + "status": "failed", + "error": str(e), + "message": f"Task processing failed: {str(e)}", + } + + def server_metadata(self): + assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first." + return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path} + + async def run_worker_loop(self): + if self.worker and self.worker.rank != 0: + await self.worker.worker_loop() diff --git a/lightx2v/server/services/inference/worker.py b/lightx2v/server/services/inference/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..f6591472451141cb8089bdca3c9e6bc7bb62fc16 --- /dev/null +++ b/lightx2v/server/services/inference/worker.py @@ -0,0 +1,128 @@ +import asyncio +import os +from typing import Any, Dict + +import torch +from easydict import EasyDict +from loguru import logger + +from lightx2v.infer import init_runner +from lightx2v.utils.input_info import set_input_info +from lightx2v.utils.set_config import set_config, set_parallel_config + +from ..distributed_utils import DistributedManager + + +class TorchrunInferenceWorker: + def __init__(self): + self.rank = int(os.environ.get("LOCAL_RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + self.runner = None + self.dist_manager = DistributedManager() + self.processing = False + + def init(self, args) -> bool: + try: + if self.world_size > 1: + if not self.dist_manager.init_process_group(): + raise RuntimeError("Failed to initialize distributed process group") + else: + self.dist_manager.rank = 0 + self.dist_manager.world_size = 1 + self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.dist_manager.is_initialized = False + + config = set_config(args) + + if config["parallel"]: + set_parallel_config(config) + + if self.rank == 0: + logger.info(f"Config:\n {config}") + + self.runner = init_runner(config) + logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed") + + return True + + except Exception as e: + logger.exception(f"Rank {self.rank} initialization failed: {str(e)}") + return False + + async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: + has_error = False + error_msg = "" + + try: + if self.world_size > 1 and self.rank == 0: + task_data = self.dist_manager.broadcast_task_data(task_data) + + task_data["task"] = self.runner.config["task"] + task_data["return_result_tensor"] = False + task_data["negative_prompt"] = task_data.get("negative_prompt", "") + + target_fps = task_data.pop("target_fps", None) + if target_fps is not None: + vfi_cfg = self.runner.config.get("video_frame_interpolation") + if vfi_cfg: + task_data["video_frame_interpolation"] = {**vfi_cfg, "target_fps": target_fps} + else: + logger.warning(f"Target FPS {target_fps} is set, but video frame interpolation is not configured") + + task_data = EasyDict(task_data) + input_info = set_input_info(task_data) + + self.runner.set_config(task_data) + self.runner.run_pipeline(input_info) + + await asyncio.sleep(0) + + except Exception as e: + has_error = True + error_msg = str(e) + logger.exception(f"Rank {self.rank} inference failed: {error_msg}") + + if self.world_size > 1: + self.dist_manager.barrier() + + if self.rank == 0: + if has_error: + return { + "task_id": task_data.get("task_id", "unknown"), + "status": "failed", + "error": error_msg, + "message": f"Inference failed: {error_msg}", + } + else: + return { + "task_id": task_data["task_id"], + "status": "success", + "save_result_path": task_data.get("video_path", task_data["save_result_path"]), + "message": "Inference completed", + } + else: + return None + + async def worker_loop(self): + while True: + task_data = None + try: + task_data = self.dist_manager.broadcast_task_data() + if task_data is None: + logger.info(f"Rank {self.rank} received stop signal") + break + + await self.process_request(task_data) + + except Exception as e: + logger.error(f"Rank {self.rank} worker loop error: {str(e)}") + if self.world_size > 1 and task_data is not None: + try: + self.dist_manager.barrier() + except Exception as barrier_error: + logger.error(f"Rank {self.rank} barrier failed after error: {barrier_error}") + break + continue + + def cleanup(self): + self.dist_manager.cleanup() diff --git a/lightx2v/server/task_manager.py b/lightx2v/server/task_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9ddbaca0df1247d9a431541501af458ab34363 --- /dev/null +++ b/lightx2v/server/task_manager.py @@ -0,0 +1,215 @@ +import threading +import uuid +from collections import OrderedDict +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, Optional + +from loguru import logger + + +class TaskStatus(Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class TaskInfo: + task_id: str + status: TaskStatus + message: Any + start_time: datetime = field(default_factory=datetime.now) + end_time: Optional[datetime] = None + error: Optional[str] = None + save_result_path: Optional[str] = None + stop_event: threading.Event = field(default_factory=threading.Event) + thread: Optional[threading.Thread] = None + + +class TaskManager: + def __init__(self, max_queue_size: int = 100): + self.max_queue_size = max_queue_size + + self._tasks: OrderedDict[str, TaskInfo] = OrderedDict() + self._lock = threading.RLock() + + self._processing_lock = threading.Lock() + self._current_processing_task: Optional[str] = None + + self.total_tasks = 0 + self.completed_tasks = 0 + self.failed_tasks = 0 + + def create_task(self, message: Any) -> str: + with self._lock: + if hasattr(message, "task_id") and message.task_id in self._tasks: + raise RuntimeError(f"Task ID {message.task_id} already exists") + + active_tasks = sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]) + if active_tasks >= self.max_queue_size: + raise RuntimeError(f"Task queue is full (max {self.max_queue_size} tasks)") + + task_id = getattr(message, "task_id", str(uuid.uuid4())) + task_info = TaskInfo(task_id=task_id, status=TaskStatus.PENDING, message=message, save_result_path=getattr(message, "save_result_path", None)) + + self._tasks[task_id] = task_info + self.total_tasks += 1 + + self._cleanup_old_tasks() + + return task_id + + def start_task(self, task_id: str) -> TaskInfo: + with self._lock: + if task_id not in self._tasks: + raise KeyError(f"Task {task_id} not found") + + task = self._tasks[task_id] + task.status = TaskStatus.PROCESSING + task.start_time = datetime.now() + + self._tasks.move_to_end(task_id) + + return task + + def complete_task(self, task_id: str, save_result_path: Optional[str] = None): + with self._lock: + if task_id not in self._tasks: + logger.warning(f"Task {task_id} not found for completion") + return + + task = self._tasks[task_id] + task.status = TaskStatus.COMPLETED + task.end_time = datetime.now() + if save_result_path: + task.save_result_path = save_result_path + + self.completed_tasks += 1 + + def fail_task(self, task_id: str, error: str): + with self._lock: + if task_id not in self._tasks: + logger.warning(f"Task {task_id} not found for failure") + return + + task = self._tasks[task_id] + task.status = TaskStatus.FAILED + task.end_time = datetime.now() + task.error = error + + self.failed_tasks += 1 + + def cancel_task(self, task_id: str) -> bool: + with self._lock: + if task_id not in self._tasks: + return False + + task = self._tasks[task_id] + + if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]: + return False + + task.stop_event.set() + task.status = TaskStatus.CANCELLED + task.end_time = datetime.now() + task.error = "Task cancelled by user" + + if task.thread and task.thread.is_alive(): + task.thread.join(timeout=5) + + return True + + def cancel_all_tasks(self): + with self._lock: + for task_id, task in list(self._tasks.items()): + if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]: + self.cancel_task(task_id) + + def get_task(self, task_id: str) -> Optional[TaskInfo]: + with self._lock: + return self._tasks.get(task_id) + + def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: + task = self.get_task(task_id) + if not task: + return None + + return {"task_id": task.task_id, "status": task.status.value, "start_time": task.start_time, "end_time": task.end_time, "error": task.error, "save_result_path": task.save_result_path} + + def get_all_tasks(self): + with self._lock: + return {task_id: self.get_task_status(task_id) for task_id in self._tasks} + + def get_active_task_count(self) -> int: + with self._lock: + return sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]) + + def get_pending_task_count(self) -> int: + with self._lock: + return sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING) + + def is_processing(self) -> bool: + with self._lock: + return self._current_processing_task is not None + + def acquire_processing_lock(self, task_id: str, timeout: Optional[float] = None) -> bool: + acquired = self._processing_lock.acquire(timeout=timeout if timeout else False) + if acquired: + with self._lock: + self._current_processing_task = task_id + logger.info(f"Task {task_id} acquired processing lock") + return acquired + + def release_processing_lock(self, task_id: str): + with self._lock: + if self._current_processing_task == task_id: + self._current_processing_task = None + try: + self._processing_lock.release() + logger.info(f"Task {task_id} released processing lock") + except RuntimeError as e: + logger.warning(f"Task {task_id} tried to release lock but failed: {e}") + + def get_next_pending_task(self) -> Optional[str]: + with self._lock: + for task_id, task in self._tasks.items(): + if task.status == TaskStatus.PENDING: + return task_id + return None + + def get_service_status(self) -> Dict[str, Any]: + with self._lock: + active_tasks = [task_id for task_id, task in self._tasks.items() if task.status == TaskStatus.PROCESSING] + + pending_count = sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING) + + return { + "service_status": "busy" if self._current_processing_task else "idle", + "current_task": self._current_processing_task, + "active_tasks": active_tasks, + "pending_tasks": pending_count, + "queue_size": self.max_queue_size, + "total_tasks": self.total_tasks, + "completed_tasks": self.completed_tasks, + "failed_tasks": self.failed_tasks, + } + + def _cleanup_old_tasks(self, keep_count: int = 1000): + if len(self._tasks) <= keep_count: + return + + completed_tasks = [(task_id, task) for task_id, task in self._tasks.items() if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]] + + completed_tasks.sort(key=lambda x: x[1].end_time or x[1].start_time) + + remove_count = len(self._tasks) - keep_count + for task_id, _ in completed_tasks[:remove_count]: + del self._tasks[task_id] + logger.debug(f"Cleaned up old task: {task_id}") + + +task_manager = TaskManager() diff --git a/lightx2v/utils/__init__.py b/lightx2v/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v/utils/async_io.py b/lightx2v/utils/async_io.py new file mode 100644 index 0000000000000000000000000000000000000000..9b41263f09277a027961d7ac8e6d6666aeb5fb60 --- /dev/null +++ b/lightx2v/utils/async_io.py @@ -0,0 +1,83 @@ +import asyncio +import io +from pathlib import Path +from typing import Union + +import aiofiles +from PIL import Image +from loguru import logger + + +async def load_image_async(path: Union[str, Path]) -> Image.Image: + try: + async with aiofiles.open(path, "rb") as f: + data = await f.read() + + return await asyncio.to_thread(lambda: Image.open(io.BytesIO(data)).convert("RGB")) + except Exception as e: + logger.error(f"Failed to load image from {path}: {e}") + raise + + +async def save_video_async(video_path: Union[str, Path], video_data: bytes): + try: + video_path = Path(video_path) + video_path.parent.mkdir(parents=True, exist_ok=True) + + async with aiofiles.open(video_path, "wb") as f: + await f.write(video_data) + + logger.info(f"Video saved to {video_path}") + except Exception as e: + logger.error(f"Failed to save video to {video_path}: {e}") + raise + + +async def read_text_async(path: Union[str, Path], encoding: str = "utf-8") -> str: + try: + async with aiofiles.open(path, "r", encoding=encoding) as f: + return await f.read() + except Exception as e: + logger.error(f"Failed to read text from {path}: {e}") + raise + + +async def write_text_async(path: Union[str, Path], content: str, encoding: str = "utf-8"): + try: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + async with aiofiles.open(path, "w", encoding=encoding) as f: + await f.write(content) + + logger.info(f"Text written to {path}") + except Exception as e: + logger.error(f"Failed to write text to {path}: {e}") + raise + + +async def exists_async(path: Union[str, Path]) -> bool: + return await asyncio.to_thread(lambda: Path(path).exists()) + + +async def read_bytes_async(path: Union[str, Path]) -> bytes: + try: + async with aiofiles.open(path, "rb") as f: + return await f.read() + except Exception as e: + logger.error(f"Failed to read bytes from {path}: {e}") + raise + + +async def write_bytes_async(path: Union[str, Path], data: bytes): + try: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + async with aiofiles.open(path, "wb") as f: + await f.write(data) + + logger.debug(f"Bytes written to {path}") + except Exception as e: + logger.error(f"Failed to write bytes to {path}: {e}") + raise diff --git a/lightx2v/utils/custom_compiler.py b/lightx2v/utils/custom_compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..851192b80f42a07b35a8b76dc0250416b7d3a1a9 --- /dev/null +++ b/lightx2v/utils/custom_compiler.py @@ -0,0 +1,187 @@ +import functools +from typing import Dict, List, Optional + +import torch +from loguru import logger + + +def compiled_method(compile_options: Optional[Dict] = None): + def decorator(func): + func_name = func.__name__ + compile_opts = compile_options or {} + + state = { + "original_func": func, + "compiled_graphs": {}, + "compile_mode": False, + "selected_graph": None, + "selected_compiled": None, + } + + @functools.wraps(func) + def wrapper(self, *args, graph_name: Optional[str] = None, **kwargs): + if state["compile_mode"]: + if graph_name is None: + graph_name = f"graph_{len(state['compiled_graphs']) + 1:02d}" + + if graph_name not in state["compiled_graphs"]: + logger.info(f"[Compile] Compiling {func_name} as '{graph_name}'...") + + compiled_func = torch.compile(state["original_func"], **compile_opts) + + try: + result = compiled_func(self, *args, **kwargs) + state["compiled_graphs"][graph_name] = compiled_func + logger.info(f"[Compile] Compiled {func_name} as '{graph_name}'") + return result + except Exception as e: + logger.info(f"[Compile] Failed to compile {func_name} as '{graph_name}': {e}") + return state["original_func"](self, *args, **kwargs) + else: + logger.info(f"[Compile] Using existing compiled graph '{graph_name}'") + return state["compiled_graphs"][graph_name](self, *args, **kwargs) + + elif state["selected_compiled"]: + return state["selected_compiled"](self, *args, **kwargs) + else: + return state["original_func"](self, *args, **kwargs) + + def _enable_compile_mode(): + logger.info(f"[Compile] Enabling compile mode for {func_name}") + state["compile_mode"] = True + + def _disable_compile_mode(): + logger.info(f"[Compile] Disabling compile mode for {func_name}") + state["compile_mode"] = False + + def _select_graph(graph_name: str): + if graph_name not in state["compiled_graphs"]: + logger.warning(f"[Compile] Graph '{graph_name}' not found. Available graphs: {list(state['compiled_graphs'].keys())}, returning to original function.") + state["selected_graph"] = None + state["selected_compiled"] = None + else: + logger.info(f"[Compile] Selecting graph '{graph_name}' for {func_name}") + state["selected_graph"] = graph_name + state["selected_compiled"] = state["compiled_graphs"][graph_name] + logger.info(f"[Compile] {func_name} will now use graph '{graph_name}' for inference") + + def _unselect_graph(): + logger.info(f"[Compile] Unselecting graph for {func_name}, returning to original function") + state["selected_graph"] = None + state["selected_compiled"] = None + + def _get_status(): + return { + "available_graphs": list(state["compiled_graphs"].keys()), + "compiled_count": len(state["compiled_graphs"]), + "selected_graph": state["selected_graph"], + "compile_mode": state["compile_mode"], + "mode": "compile" if state["compile_mode"] else ("inference" if state["selected_compiled"] else "original"), + } + + def _clear_graphs(): + state["compiled_graphs"].clear() + state["selected_graph"] = None + state["selected_compiled"] = None + state["compile_mode"] = False + logger.info(f"[Compile] Cleared all compiled graphs for {func_name}") + + def _remove_graph(graph_name: str): + if graph_name in state["compiled_graphs"]: + del state["compiled_graphs"][graph_name] + if state["selected_graph"] == graph_name: + state["selected_graph"] = None + state["selected_compiled"] = None + logger.info(f"[Compile] Removed graph '{graph_name}' for {func_name}") + else: + logger.info(f"[Compile] Graph '{graph_name}' not found") + + wrapper._enable_compile_mode = _enable_compile_mode + wrapper._disable_compile_mode = _disable_compile_mode + wrapper._select_graph = _select_graph + wrapper._unselect_graph = _unselect_graph + wrapper._get_status = _get_status + wrapper._clear_graphs = _clear_graphs + wrapper._remove_graph = _remove_graph + wrapper._func_name = func_name + + return wrapper + + return decorator + + +class CompiledMethodsMixin: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._compiled_methods = {} + self._discover_compiled_methods() + + def _discover_compiled_methods(self): + logger.info(f"[Compile] Discovering compiled methods for {self.__class__.__name__}...") + + for attr_name in dir(self): + attr = getattr(self, attr_name) + if hasattr(attr, "_enable_compile_mode"): + logger.info(f"[Compile] Found compiled method: {attr_name}") + self._compiled_methods[attr_name] = attr + + def enable_compile_mode(self, method_name: str = None): + if method_name: + if method_name not in self._compiled_methods: + raise ValueError(f"Method '{method_name}' is not a compiled method") + self._compiled_methods[method_name]._enable_compile_mode() + else: + for name, method in self._compiled_methods.items(): + method._enable_compile_mode() + logger.info("[Compile] Enabled compile mode for all methods") + + def disable_compile_mode(self, method_name: str = None): + if method_name: + if method_name not in self._compiled_methods: + raise ValueError(f"Method '{method_name}' is not a compiled method") + self._compiled_methods[method_name]._disable_compile_mode() + else: + for name, method in self._compiled_methods.items(): + method._disable_compile_mode() + logger.info("[Compile] Disabled compile mode for all methods") + + def select_graph(self, method_name: str, graph_name: str): + if method_name not in self._compiled_methods: + raise ValueError(f"Method '{method_name}' is not a compiled method") + + method = self._compiled_methods[method_name] + method._select_graph(graph_name) + + def unselect_graph(self, method_name: str): + if method_name not in self._compiled_methods: + raise ValueError(f"Method '{method_name}' is not a compiled method") + + method = self._compiled_methods[method_name] + method._unselect_graph() + + def get_compile_status(self) -> Dict: + status = {} + for method_name, method in self._compiled_methods.items(): + status[method_name] = method._get_status() + return status + + def get_compiled_methods(self) -> List[str]: + return list(self._compiled_methods.keys()) + + def clear_compiled_graphs(self, method_name: str = None): + if method_name: + if method_name in self._compiled_methods: + self._compiled_methods[method_name]._clear_graphs() + else: + logger.info(f"Method '{method_name}' not found") + else: + for method_name, method in self._compiled_methods.items(): + method._clear_graphs() + logger.info("[Compile] Cleared all compiled graphs") + + def remove_graph(self, method_name: str, graph_name: str): + if method_name not in self._compiled_methods: + raise ValueError(f"Method '{method_name}' is not a compiled method") + + method = self._compiled_methods[method_name] + method._remove_graph(graph_name) diff --git a/lightx2v/utils/envs.py b/lightx2v/utils/envs.py new file mode 100644 index 0000000000000000000000000000000000000000..585224f40ac8b123fee06a56d96ad55965e58ef1 --- /dev/null +++ b/lightx2v/utils/envs.py @@ -0,0 +1,49 @@ +import os +from functools import lru_cache + +import torch + +DTYPE_MAP = { + "BF16": torch.bfloat16, + "FP16": torch.float16, + "FP32": torch.float32, + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "torch.bfloat16": torch.bfloat16, + "torch.float16": torch.float16, + "torch.float32": torch.float32, +} + + +@lru_cache(maxsize=None) +def CHECK_PROFILING_DEBUG_LEVEL(target_level): + current_level = int(os.getenv("PROFILING_DEBUG_LEVEL", "0")) + return current_level >= target_level + + +@lru_cache(maxsize=None) +def GET_RUNNING_FLAG(): + RUNNING_FLAG = os.getenv("RUNNING_FLAG", "infer") + return RUNNING_FLAG + + +@lru_cache(maxsize=None) +def GET_DTYPE(): + RUNNING_FLAG = os.getenv("DTYPE", "BF16") + assert RUNNING_FLAG in ["BF16", "FP16"] + return DTYPE_MAP[RUNNING_FLAG] + + +@lru_cache(maxsize=None) +def GET_SENSITIVE_DTYPE(): + RUNNING_FLAG = os.getenv("SENSITIVE_LAYER_DTYPE", "None") + if RUNNING_FLAG == "None": + return GET_DTYPE() + return DTYPE_MAP[RUNNING_FLAG] + + +@lru_cache(maxsize=None) +def GET_RECORDER_MODE(): + RECORDER_MODE = int(os.getenv("RECORDER_MODE", "0")) + return RECORDER_MODE diff --git a/lightx2v/utils/generate_task_id.py b/lightx2v/utils/generate_task_id.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9d80802be87747816fe3b817f5e1970dba006e --- /dev/null +++ b/lightx2v/utils/generate_task_id.py @@ -0,0 +1,49 @@ +import random +import string +import time +from datetime import datetime + + +def generate_task_id(): + """ + Generate a random task ID in the format XXXX-XXXX-XXXX-XXXX-XXXX. + Features: + 1. Does not modify the global random state. + 2. Each X is an uppercase letter or digit (0-9). + 3. Combines time factors to ensure high randomness. + """ + # Save the current random state (does not affect external randomness) + original_state = random.getstate() + + try: + # Define character set (uppercase letters + digits) + characters = string.ascii_uppercase + string.digits + + # Create an independent random instance + local_random = random.Random(time.perf_counter_ns()) + + # Generate 5 groups of 4-character random strings + groups = [] + for _ in range(5): + # Mix new time factor for each group + time_mix = int(datetime.now().timestamp()) + local_random.seed(time_mix + local_random.getstate()[1][0] + time.perf_counter_ns()) + + groups.append("".join(local_random.choices(characters, k=4))) + + return "-".join(groups) + + finally: + # Restore the original random state + random.setstate(original_state) + + +if __name__ == "__main__": + # Set global random seed + random.seed(42) + + # Test that external randomness is not affected + print("External random number 1:", random.random()) # Always the same + print("Task ID 1:", generate_task_id()) # Different each time + print("External random number 1:", random.random()) # Always the same + print("Task ID 1:", generate_task_id()) # Different each time diff --git a/lightx2v/utils/ggml_tensor.py b/lightx2v/utils/ggml_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..d67ce68c9b9081e2a026b30e451a4405b10bc3dc --- /dev/null +++ b/lightx2v/utils/ggml_tensor.py @@ -0,0 +1,628 @@ +from __future__ import annotations + +import ctypes +import os +from pathlib import Path +from typing import Optional, Tuple, Union + +import gguf +import numpy as np +import torch +from loguru import logger + +c_float_p = ctypes.POINTER(ctypes.c_float) +TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16) + + +class GGMLTensor: + def __init__( + self, + data: Union[torch.Tensor, np.ndarray, None] = None, + orig_shape: Tuple[int, ...] = None, + dtype: torch.dtype = None, + gguf_type: gguf.GGMLQuantizationType = None, + requires_grad: bool = False, + aligned: bool = True, + pin_memory: bool = False, + preallocated: bool = False, + ): + super().__init__() + + assert orig_shape is not None + assert gguf_type is not None + + if isinstance(data, np.ndarray): + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The given NumPy array is not writable") + torch_data = torch.from_numpy(data) + else: + torch_data = data + + if dtype is not None and torch_data.dtype != dtype: + torch_data = torch_data.to(dtype) + + self.data = torch_data + + self.gguf_type = gguf_type + self._orig_shape = orig_shape + self._aligned = aligned + self._pinned_memory = pin_memory + self._requires_grad = requires_grad + self._preallocated = preallocated + + self._quantized = self._is_quantized_type(gguf_type) + self._q_type = self._get_quant_type_str(gguf_type) + + if aligned: + self._make_aligned() + if pin_memory: + self._pin_memory() + + def _is_quantized_type(self, gguf_type: gguf.GGMLQuantizationType) -> bool: + return gguf_type not in TORCH_COMPATIBLE_QTYPES + + def _get_quant_type_str(self, gguf_type: gguf.GGMLQuantizationType) -> str: + type_mapping = { + gguf.GGMLQuantizationType.F32: "ggml_f32", + gguf.GGMLQuantizationType.F16: "ggml_f16", + gguf.GGMLQuantizationType.Q4_0: "ggml_q4_0", + gguf.GGMLQuantizationType.Q4_1: "ggml_q4_1", + gguf.GGMLQuantizationType.Q5_0: "ggml_q5_0", + gguf.GGMLQuantizationType.Q5_1: "ggml_q5_1", + gguf.GGMLQuantizationType.Q8_0: "ggml_q8_0", + gguf.GGMLQuantizationType.Q8_1: "ggml_q8_1", + gguf.GGMLQuantizationType.Q2_K: "ggml_q2_k", + gguf.GGMLQuantizationType.Q3_K: "ggml_q3_k", + gguf.GGMLQuantizationType.Q4_K: "ggml_q4_k", + gguf.GGMLQuantizationType.Q5_K: "ggml_q5_k", + gguf.GGMLQuantizationType.Q6_K: "ggml_q6_k", + gguf.GGMLQuantizationType.Q8_K: "ggml_q8_k", + } + return type_mapping.get(gguf_type, "unknown") + + @classmethod + def empty_pinned( + cls, shape: Tuple[int, ...], orig_shape: Tuple[int, ...] = None, dtype: torch.dtype = torch.float32, gguf_type: gguf.GGMLQuantizationType = None, aligned: bool = True + ) -> "GGMLTensor": + torch_data = torch.empty(shape, pin_memory=True, dtype=dtype) + return cls(data=torch_data, dtype=dtype, orig_shape=orig_shape, gguf_type=gguf_type, pin_memory=True, aligned=aligned, preallocated=True) + + @classmethod + def empty_aligned( + cls, shape: Tuple[int, ...], orig_shape: Tuple[int, ...] = None, dtype: torch.dtype = torch.float32, gguf_type: gguf.GGMLQuantizationType = None, pin_memory: bool = False + ) -> "GGMLTensor": + return cls(dtype=dtype, orig_shape=orig_shape, gguf_type=gguf_type, pin_memory=pin_memory, aligned=True, preallocated=True) + + def copy_from(self, source: Union[torch.Tensor, "GGMLTensor"], transpose: bool = False, non_blocking: bool = False) -> "GGMLTensor": + if not self._preallocated: + raise RuntimeError("copy_from can only be used with preallocated tensors") + + if transpose: + source_data = source.data.t().contiguous() + else: + source_data = source.data.contiguous() + + if self.shape != source_data.shape: + raise ValueError(f"Shape mismatch: target {self.shape} vs source {source_data.shape}") + + self.data.copy_(source_data) + + return self + + def copy_(self, target: Union[torch.Tensor, "GGMLTensor"], transpose: bool = False, non_blocking: bool = False) -> "GGMLTensor": + source_data = self.data + if transpose: + source_data = self.t().contiguous() + + if isinstance(target, GGMLTensor): + target.copy_from(source_data, non_blocking=non_blocking) + else: + target.copy_(source_data) + + return self + + def t(self): + self.data = self.data.t() + return self + + def _make_aligned(self, alignment: int = 32): + if not self.data.is_contiguous(): + self.data = self.data.contiguous().data + + ptr = self.data.data_ptr() + if ptr % alignment == 0: + return + + if self._pinned_memory: + aligned_data = torch.empty(self.data.shape, dtype=self.data.dtype, device=self.data.device, pin_memory=True) + else: + aligned_data = torch.empty(self.data.shape, dtype=self.data.dtype, device=self.data.device) + + aligned_data.copy_(self.data) + self.data = aligned_data.data + + def _pin_memory(self) -> "GGMLTensor": + if self._pinned_memory or self.device.type != "cpu": + return self + + pinned_data = self.data.pin_memory() + self.data = pinned_data.data + self._pinned_memory = True + return self + + def to_torch(self) -> torch.Tensor: + return torch.as_tensor(self.data) + + @property + def shape(self): + return self.data.shape + + @property + def dtype(self): + return self.data.dtype + + @property + def device(self): + return self.data.device + + @property + def tensor_type(self) -> gguf.GGMLQuantizationType: + return self.gguf_type + + @property + def quant_type(self) -> str: + return self._q_type + + @property + def is_quantized(self) -> bool: + return self._quantized + + @property + def orig_shape(self) -> Tuple[int, ...]: + return self._orig_shape + + @property + def blocksize(self) -> Optional[int]: + _blocksize, _ = gguf.GGML_QUANT_SIZES[self.qtype] + return _blocksize + + @property + def is_pinned(self) -> bool: + return self._pinned_memory + + def memory_footprint(self) -> int: + if self._quantized: + return self.data.numel() * self.element_size() + else: + return self.data.numel() * self.element_size() + + def __repr__(self) -> str: + return f"GGMLTensor(shape={self.data.shape}, orig_shape={self.orig_shape}, dtype={self.data.dtype}, quantized={self.is_quantized}, quant_type='{self.quant_type}', pinned={self.is_pinned})" + + def cuda(self, device: Optional[Union[int, torch.device]] = None, non_blocking: bool = False) -> "GGMLTensor": + if device is None: + self.data = self.data.cuda(non_blocking=non_blocking) + else: + self.data = self.data.cuda(device=device, non_blocking=non_blocking) + return self + + def cpu(self, pin_memory: bool = False) -> "GGMLTensor": + self.data = self.data.cpu() + return self + + def to(self, *args, **kwargs) -> "GGMLTensor": + self.data = self.data.to(*args, **kwargs) + return self + + +def load_gguf_sd_ckpt(gguf_path, return_arch=False, to_device: Optional[Union[int, torch.device]] = None): + import warnings + + logger.info(f"Loading gguf-quant dit model from {gguf_path}") + + reader = gguf.GGUFReader(gguf_path) + state_dict = {} + for tensor in reader.tensors: + tensor_name = tensor.name + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The given NumPy array is not writable") + torch_tensor = torch.from_numpy(tensor.data) # mmap + + shape = get_orig_shape(reader, tensor_name) + if shape is None: + shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) + + if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES: + state_dict[tensor.name] = torch_tensor.to(to_device) + else: + state_dict[tensor.name] = GGMLTensor( + data=torch_tensor, + gguf_type=tensor.tensor_type, + orig_shape=shape, + aligned=True, + pin_memory=False, + ).to(to_device) + + if return_arch: + arch = get_model_architecture(reader) + return state_dict, arch + + return state_dict + + +def get_orig_shape(reader, tensor_name: str) -> Optional[Tuple[int, ...]]: + # TODO 这里正式上线的时候,需要更换 + field_key = f"comfy.gguf.orig_shape.{tensor_name}" + field = reader.get_field(field_key) + if field is None: + return None + # Has original shape metadata, so we try to decode it. + if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32: + raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}") + return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) + + +def get_field(reader, field_name, field_type): + field = reader.get_field(field_name) + if field is None: + return None + elif isinstance(field_type, str): + # extra check here as this is used for checking arch string + if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING: + raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}") + return str(field.parts[field.data[-1]], encoding="utf-8") + elif field_type in [int, float, bool]: + return field_type(field.parts[field.data[-1]]) + else: + raise TypeError(f"Unknown field type {field_type}") + + +def get_model_architecture(reader) -> str: + arch_str = get_field(reader, "general.architecture", str) + return arch_str + + +class ggml_init_params(ctypes.Structure): + _fields_ = [ + ("mem_size", ctypes.c_size_t), + ("mem_buffer", ctypes.c_void_p), + ("no_alloc", ctypes.c_bool), + ] + + +class GGMLQuants: + libggml: ctypes.CDLL + + def __init__(self, libggml: Path): + self.libggml = ctypes.CDLL(str(libggml)) + self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t + + self.libggml.ggml_quantize_chunk.argtypes = ( + ctypes.c_int, + ctypes.POINTER(ctypes.c_float), + ctypes.c_void_p, + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ctypes.POINTER(ctypes.c_float), + ) + + self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool + self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,) + + for t in ( + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_K", + "q3_K", + "q4_K", + "q5_K", + "q6_K", + ): + dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t) + dequant_func.restype = None + dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + + self.libggml.ggml_fp16_to_fp32_row.restype = None + self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + self.libggml.ggml_bf16_to_fp32_row.restype = None + self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + + self.libggml.ggml_init.argtypes = (ggml_init_params,) + + self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False)) + + def dequantize(self, tensor: np.ndarray, qtype: gguf.GGMLQuantizationType) -> np.ndarray: + result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C") + if qtype == gguf.GGMLQuantizationType.F32: + # no-op + result = tensor.view(np.float32) + elif qtype == gguf.GGMLQuantizationType.F16: + self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) + elif qtype == gguf.GGMLQuantizationType.BF16: + self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) + else: + lw_qname = qtype.name.lower() + if lw_qname[-1] == "k": + lw_qname = lw_qname[:-1] + "K" + dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname) + dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size) + return result + + +def to_uint32(x): + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + + +def split_block_dims(blocks, *args): + n_max = blocks.shape[1] + dims = list(args) + [n_max - sum(args)] + return torch.split(blocks, dims, dim=1) + + +def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): + return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) + + +def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None): + d, x = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + x = x.view(torch.int8) + return d * x + + +def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape((n_blocks, -1)) + + qs = ql | (qh << 4) + return (d * qs) + m + + +def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qh, qs = split_block_dims(blocks, 2, 4) + d = d.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape(n_blocks, -1) + + qs = (ql | (qh << 4)).to(torch.int8) - 16 + return d * qs + + +def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qs = split_block_dims(blocks, 2, 2) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + qs = (qs & 0x0F).reshape(n_blocks, -1) + + return (d * qs) + m + + +def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + return d * qs + + +# K Quants # +QK_K = 256 +K_SCALE_SIZE = 12 + + +def get_scale_min(scales): + n_blocks = scales.shape[0] + scales = scales.view(torch.uint8) + scales = scales.reshape((n_blocks, 3, 4)) + + d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) + + sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) + min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) + + return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) + + +def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + ( + ql, + qh, + scales, + d, + ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) + + scales = scales.view(torch.int8).to(dtype) + d = d.view(torch.float16).to(dtype) + d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) + + ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + qh = (qh & 0x03).reshape((n_blocks, -1, 32)) + q = (ql | (qh << 4)).to(torch.int8) - 32 + q = q.reshape((n_blocks, QK_K // 16, -1)) + + return (d * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) + + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1)) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = (qh & 0x01).reshape((n_blocks, -1, 32)) + q = ql | (qh << 4) + + return (d * q - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) + + return (d * qs - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) + d = d.view(torch.float16).to(dtype) + + lscales, hscales = scales[:, :8], scales[:, 8:] + lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1)) + lscales = lscales.reshape((n_blocks, 16)) + hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 4, 1)) + hscales = hscales.reshape((n_blocks, 16)) + scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) + scales = scales.to(torch.int8) - 32 + + dl = (d * scales).reshape((n_blocks, 16, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1)) + ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 + qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 + q = ql.to(torch.int8) - (qh << 2).to(torch.int8) + + return (dl * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + # (n_blocks, 16, 1) + dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) + ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) + + shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + + qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 + qs = qs.reshape((n_blocks, QK_K // 16, 16)) + qs = dl * qs - ml + + return qs.reshape((n_blocks, -1)) + + +dequantize_functions = { + gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, + gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, + gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, + gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, + gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, + gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, + gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, + gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, + gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, + gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, +} + + +try: + import platform + + import llama_cpp + + lib_name = "libggml.so" + if platform.system() == "Darwin": + lib_name = "libggml.dylib" + elif platform.system() == "Windows": + lib_name = "ggml.dll" # Or libggml.dll + + llama_lib_path = os.path.join(os.path.dirname(os.path.abspath(llama_cpp.__file__)), "lib", lib_name) + ggml_quants = GGMLQuants(llama_lib_path) + + def dequantize_c(tensor): + return torch.from_numpy(ggml_quants.dequantize(s.data.numpy(), s.gguf_type)) +except ImportError: + dequantize_c = None + + +def dequantize_tensor(tensor, dtype=None): + qtype = getattr(tensor, "gguf_type", None) + oshape = getattr(tensor, "orig_shape", tensor.data.shape) + + if qtype in TORCH_COMPATIBLE_QTYPES: + return tensor.to(dtype) + else: + if dequantize_c is not None: + return dequantize_c(tensor).to(dtype) + elif qtype in dequantize_functions: + return dequantize(tensor.to_torch().data, qtype, oshape, dtype=dtype).to(dtype) + else: + # this is incredibly slow + logger.warning(f"Falling back to numpy dequant for qtype: {qtype}") + new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) + return torch.from_numpy(new).to(tensor.device, dtype=dtype) + + +def dequantize(data, qtype, oshape, dtype=None): + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + dequantize_blocks = dequantize_functions[qtype] + + rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = rows.reshape((n_blocks, type_size)) + blocks = dequantize_blocks(blocks, block_size, type_size, dtype) + return blocks.reshape(oshape) diff --git a/lightx2v/utils/global_paras.py b/lightx2v/utils/global_paras.py new file mode 100644 index 0000000000000000000000000000000000000000..cef34ea546b66f440928c576762bd788c4f457a2 --- /dev/null +++ b/lightx2v/utils/global_paras.py @@ -0,0 +1 @@ +CALIB = {"absmax": {}} diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fe63a6b8933d27621514e4add18a35b5324915 --- /dev/null +++ b/lightx2v/utils/input_info.py @@ -0,0 +1,230 @@ +import inspect +from dataclasses import dataclass, field + + +@dataclass +class T2VInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + prompt_enhanced: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) + # shape related + latent_shape: list = field(default_factory=list) + target_shape: int = field(default_factory=int) + + +@dataclass +class I2VInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + prompt_enhanced: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + image_path: str = field(default_factory=str) + save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) + # shape related + original_shape: list = field(default_factory=list) + resized_shape: list = field(default_factory=list) + latent_shape: list = field(default_factory=list) + target_shape: int = field(default_factory=int) + + +@dataclass +class Flf2vInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + prompt_enhanced: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + image_path: str = field(default_factory=str) + last_frame_path: str = field(default_factory=str) + save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) + # shape related + original_shape: list = field(default_factory=list) + resized_shape: list = field(default_factory=list) + latent_shape: list = field(default_factory=list) + target_shape: int = field(default_factory=int) + + +# Need Check +@dataclass +class VaceInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + prompt_enhanced: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + src_ref_images: str = field(default_factory=str) + src_video: str = field(default_factory=str) + src_mask: str = field(default_factory=str) + save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) + # shape related + original_shape: list = field(default_factory=list) + resized_shape: list = field(default_factory=list) + latent_shape: list = field(default_factory=list) + target_shape: int = field(default_factory=int) + + +@dataclass +class S2VInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + prompt_enhanced: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + image_path: str = field(default_factory=str) + audio_path: str = field(default_factory=str) + audio_num: int = field(default_factory=int) + with_mask: bool = field(default_factory=lambda: False) + save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) + # shape related + original_shape: list = field(default_factory=list) + resized_shape: list = field(default_factory=list) + latent_shape: list = field(default_factory=list) + target_shape: int = field(default_factory=int) + + +# Need Check +@dataclass +class AnimateInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + prompt_enhanced: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + image_path: str = field(default_factory=str) + src_pose_path: str = field(default_factory=str) + src_face_path: str = field(default_factory=str) + src_ref_images: str = field(default_factory=str) + src_bg_path: str = field(default_factory=str) + src_mask_path: str = field(default_factory=str) + save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) + # shape related + original_shape: list = field(default_factory=list) + resized_shape: list = field(default_factory=list) + latent_shape: list = field(default_factory=list) + target_shape: int = field(default_factory=int) + + +@dataclass +class T2IInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + save_result_path: str = field(default_factory=str) + # shape related + target_shape: int = field(default_factory=int) + + +@dataclass +class I2IInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + image_path: str = field(default_factory=str) + save_result_path: str = field(default_factory=str) + # shape related + target_shape: int = field(default_factory=int) + processed_image_size: int = field(default_factory=list) + original_size: list = field(default_factory=list) + + +def set_input_info(args): + if args.task == "t2v": + input_info = T2VInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + save_result_path=args.save_result_path, + return_result_tensor=args.return_result_tensor, + ) + elif args.task == "i2v": + input_info = I2VInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + image_path=args.image_path, + save_result_path=args.save_result_path, + return_result_tensor=args.return_result_tensor, + ) + elif args.task == "flf2v": + input_info = Flf2vInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + image_path=args.image_path, + last_frame_path=args.last_frame_path, + save_result_path=args.save_result_path, + return_result_tensor=args.return_result_tensor, + ) + elif args.task == "vace": + input_info = VaceInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + src_ref_images=args.src_ref_images, + src_video=args.src_video, + src_mask=args.src_mask, + save_result_path=args.save_result_path, + return_result_tensor=args.return_result_tensor, + ) + elif args.task == "s2v": + input_info = S2VInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + image_path=args.image_path, + audio_path=args.audio_path, + save_result_path=args.save_result_path, + return_result_tensor=args.return_result_tensor, + ) + elif args.task == "animate": + input_info = AnimateInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + image_path=args.image_path, + src_pose_path=args.src_pose_path, + src_face_path=args.src_face_path, + src_ref_images=args.src_ref_images, + src_bg_path=args.src_bg_path, + src_mask_path=args.src_mask_path, + save_result_path=args.save_result_path, + return_result_tensor=args.return_result_tensor, + ) + elif args.task == "t2i": + input_info = T2IInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + save_result_path=args.save_result_path, + ) + elif args.task == "i2i": + input_info = I2IInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + image_path=args.image_path, + save_result_path=args.save_result_path, + ) + else: + raise ValueError(f"Unsupported task: {args.task}") + return input_info + + +def get_all_input_info_keys(): + all_keys = set() + + current_module = inspect.currentframe().f_globals + + for name, obj in current_module.items(): + if inspect.isclass(obj) and name.endswith("InputInfo") and hasattr(obj, "__dataclass_fields__"): + all_keys.update(obj.__dataclass_fields__.keys()) + + return all_keys + + +# 创建包含所有InputInfo字段的集合 +ALL_INPUT_INFO_KEYS = get_all_input_info_keys() diff --git a/lightx2v/utils/lockable_dict.py b/lightx2v/utils/lockable_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..6f00cc62a379b8f8b70ca87d633931076dd002e9 --- /dev/null +++ b/lightx2v/utils/lockable_dict.py @@ -0,0 +1,177 @@ +from contextlib import contextmanager +from typing import Any, Iterable, Mapping + + +class LockableDict(dict): + """ + A lockable/unlockable dictionary. After locking, any in-place modifications will raise TypeError. + By default auto_wrap=True, which recursively converts nested dict objects in dict/list/tuple/set + to LockableDict, so that recursive locking works consistently both internally and externally. + """ + + def __init__(self, *args, auto_wrap: bool = True, **kwargs): + self._locked: bool = False + self._auto_wrap: bool = auto_wrap + # Build with temporary dict, then wrap uniformly before writing to self, avoiding bypass of __setitem__ + tmp = dict(*args, **kwargs) + for k, v in tmp.items(): + dict.__setitem__(self, k, self._wrap(v)) + + # ========== Public API ========== + @property + def locked(self) -> bool: + return self._locked + + def lock(self, recursive: bool = True) -> None: + """Lock the dictionary. When recursive=True, also recursively locks nested LockableDict objects.""" + self._locked = True + if recursive: + for v in self.values(): + if isinstance(v, LockableDict): + v.lock(True) + + def unlock(self, recursive: bool = True) -> None: + """Unlock the dictionary. When recursive=True, also recursively unlocks nested LockableDict objects.""" + self._locked = False + if recursive: + for v in self.values(): + if isinstance(v, LockableDict): + v.unlock(True) + + @contextmanager + def temporarily_unlocked(self, recursive: bool = True): + """ + Temporarily unlock in context manager form, restoring original state on exit. + Typical usage: + with d.temporarily_unlocked(): + d["x"] = 1 + """ + prev = self._locked + if prev and recursive: + # First temporarily unlock all child nodes as well + stack: list[LockableDict] = [] + + def _collect(node: "LockableDict"): + for v in node.values(): + if isinstance(v, LockableDict): + stack.append(v) + _collect(v) + + _collect(self) + self._locked = False + for n in stack: + n._locked = False + try: + yield self + finally: + self._locked = prev + for n in stack: + n._locked = prev + else: + self._locked = False + try: + yield self + finally: + self._locked = prev + + def copy(self) -> "LockableDict": + new = LockableDict(auto_wrap=self._auto_wrap) + for k, v in self.items(): + dict.__setitem__(new, k, v) + new._locked = self._locked + return new + + # ========== In-place modification interception ========== + def __setitem__(self, key, value) -> None: + self._ensure_unlocked() + dict.__setitem__(self, key, self._wrap(value)) + + def __delitem__(self, key) -> None: + self._ensure_unlocked() + dict.__delitem__(self, key) + + def clear(self) -> None: + self._ensure_unlocked() + dict.clear(self) + + def pop(self, k, d: Any = ...): + self._ensure_unlocked() + if d is ...: + return dict.pop(self, k) + return dict.pop(self, k, d) + + def popitem(self): + self._ensure_unlocked() + return dict.popitem(self) + + def setdefault(self, key, default=None): + # If key doesn't exist, setdefault will write, need to check lock + if key not in self: + self._ensure_unlocked() + default = self._wrap(default) + return dict.setdefault(self, key, default) + + def update(self, other: Mapping | Iterable, **kwargs) -> None: + self._ensure_unlocked() + if isinstance(other, Mapping): + items = list(other.items()) + else: + items = list(other) + for k, v in items: + dict.__setitem__(self, k, self._wrap(v)) + for k, v in kwargs.items(): + dict.__setitem__(self, k, self._wrap(v)) + + # Python 3.9 in-place union: d |= x + def __ior__(self, other): + self.update(other) + return self + + # ========== Attribute-style access (EasyDict-like behavior) ========== + def __getattr__(self, key: str): + """Allow attribute-style access: d.key instead of d['key']""" + try: + return self[key] + except KeyError: + raise AttributeError(f"'LockableDict' object has no attribute '{key}'") + + # ========== Internal utilities ========== + def _ensure_unlocked(self) -> None: + if self._locked: + raise TypeError("Dictionary is locked, current operation not allowed.") + + def _wrap(self, value): + if not self._auto_wrap: + return value + if isinstance(value, LockableDict): + return value + if isinstance(value, dict): + return LockableDict(value, auto_wrap=True) + if isinstance(value, list): + return [self._wrap(v) for v in value] + if isinstance(value, tuple): + return tuple(self._wrap(v) for v in value) + if isinstance(value, set): + return {self._wrap(v) for v in value} + return value + + +if __name__ == "__main__": + d = LockableDict({"a": 1, "b": 2}) + d["b"] = 3 + print(d) + d.lock() + print(d) + + # d["a"] = 3 + # print(d) + + # d.unlock() + # print(d) + # d["a"] = 3 + # print(d) + + with d.temporarily_unlocked(): + d["a"] = 3 + print(d) + d["a"] = 4 diff --git a/lightx2v/utils/memory_profiler.py b/lightx2v/utils/memory_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..39d63b21214caa92e86465076c43e4b2388e4061 --- /dev/null +++ b/lightx2v/utils/memory_profiler.py @@ -0,0 +1,29 @@ +import torch +from loguru import logger + + +def peak_memory_decorator(func): + def wrapper(*args, **kwargs): + # 检查是否在分布式环境中 + rank_info = "" + if torch.distributed.is_available() and torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + rank_info = f"Rank {rank} - " + + # 如果使用GPU,重置显存统计 + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + # 执行目标函数 + result = func(*args, **kwargs) + + # 获取峰值显存 + if torch.cuda.is_available(): + peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB + logger.info(f"{rank_info}Function '{func.__qualname__}' Peak Memory: {peak_memory:.2f} GB") + else: + logger.info(f"{rank_info}Function '{func.__qualname__}' executed without GPU.") + + return result + + return wrapper diff --git a/lightx2v/utils/print_atten_score.py b/lightx2v/utils/print_atten_score.py new file mode 100644 index 0000000000000000000000000000000000000000..33f278fff136010e46fd5593c7e3bfcd4789c1ac --- /dev/null +++ b/lightx2v/utils/print_atten_score.py @@ -0,0 +1,76 @@ +import math + +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F + + +def scaled_dot_product_attention(Q, K, V, mask=None): + """ + Scaled dot-product attention + + Args: + Q: Query tensor [batch_size, num_heads, seq_len, d_k] + K: Key tensor [batch_size, num_heads, seq_len, d_k] + V: Value tensor [batch_size, num_heads, seq_len, d_k] + mask: Attention mask (0 indicates positions to mask, 1 indicates positions to keep) + + Returns: + output: Attention output + attention_weights: Attention weights + """ + d_k = Q.size(-1) + + scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) + + if mask is not None: + mask_value = torch.where(mask == 0, torch.tensor(-float("inf")), torch.tensor(0.0)) + scores = scores + mask_value + + attention_weights = F.softmax(scores, dim=-1) + + output = torch.matmul(attention_weights, V) + return output, scores, attention_weights + + +def draw_matrix(weights, save_path): + plt.imshow(weights, aspect="auto", cmap="viridis") + plt.colorbar() + plt.savefig(save_path) + plt.close() + + +def get_qkv_subset(x, head_index, token_start, token_end): + """ + x : [seq_len, num_heads, head_dim] + + return: [batch_size, num_heads, seq_len, head_dim] + batch_size = 1, num_heads = 1, seq_len = token_end - token_start + """ + x = x[token_start:token_end, head_index, :] # [seq_len, head_dim] + x = x.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim] + return x + + +def draw_attention_weights(q, k, v, head_index, token_start, token_end, save_path): + """ + q k v : [seq_len, num_heads, head_dim] + """ + q_vis = get_qkv_subset(q, head_index=head_index, token_start=token_start, token_end=token_end) + k_vis = get_qkv_subset(k, head_index=head_index, token_start=token_start, token_end=token_end) + v_vis = get_qkv_subset(v, head_index=head_index, token_start=token_start, token_end=token_end) + output, scores, attention_weights = scaled_dot_product_attention(q_vis, k_vis, v_vis, mask=None) + draw_matrix(scores[0][0].float().cpu().numpy(), save_path) + print(f"Saved to {save_path}") + + +if __name__ == "__main__": + seq_len = 10 + num_heads = 4 + head_dim = 8 + + q = torch.randn(seq_len, num_heads, head_dim) + k = torch.randn(seq_len, num_heads, head_dim) + v = torch.randn(seq_len, num_heads, head_dim) + + draw_attention_weights(q, k, v, head_index=0, token_start=0, token_end=10, save_path="scores.png") diff --git a/lightx2v/utils/profiler.py b/lightx2v/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..8c0126b1d8dd87bb2d8726cfe8bc2f25fe9341f7 --- /dev/null +++ b/lightx2v/utils/profiler.py @@ -0,0 +1,200 @@ +import asyncio +import threading +import time +from functools import wraps + +import torch +import torch.distributed as dist +from loguru import logger + +from lightx2v.utils.envs import * +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) +_excluded_time_local = threading.local() + + +def _get_excluded_time_stack(): + if not hasattr(_excluded_time_local, "stack"): + _excluded_time_local.stack = [] + return _excluded_time_local.stack + + +class _ProfilingContext: + def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None): + """ + recorder_mode = 0: disable recorder + recorder_mode = 1: enable recorder + recorder_mode = 2: enable recorder and force disable logger + """ + self.name = name + if dist.is_initialized(): + self.rank_info = f"Rank {dist.get_rank()}" + else: + self.rank_info = "Single GPU" + self.enable_recorder = recorder_mode > 0 + self.enable_logger = recorder_mode <= 1 + self.metrics_func = metrics_func + self.metrics_labels = metrics_labels + + def __enter__(self): + torch_device_module.synchronize() + self.start_time = time.perf_counter() + _get_excluded_time_stack().append(0.0) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch_device_module.synchronize() + total_elapsed = time.perf_counter() - self.start_time + excluded = _get_excluded_time_stack().pop() + elapsed = total_elapsed - excluded + if self.enable_recorder and self.metrics_func: + if self.metrics_labels: + self.metrics_func.labels(*self.metrics_labels).observe(elapsed) + else: + self.metrics_func.observe(elapsed) + if self.enable_logger: + logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds") + return False + + async def __aenter__(self): + torch_device_module.synchronize() + self.start_time = time.perf_counter() + _get_excluded_time_stack().append(0.0) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + torch_device_module.synchronize() + total_elapsed = time.perf_counter() - self.start_time + excluded = _get_excluded_time_stack().pop() + elapsed = total_elapsed - excluded + if self.enable_recorder and self.metrics_func: + if self.metrics_labels: + self.metrics_func.labels(*self.metrics_labels).observe(elapsed) + else: + self.metrics_func.observe(elapsed) + if self.enable_logger: + logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds") + return False + + def __call__(self, func): + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + async with self: + return await func(*args, **kwargs) + + return async_wrapper + else: + + @wraps(func) + def sync_wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return sync_wrapper + + +class _NullContext: + # Context manager without decision branch logic overhead + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + + def __call__(self, func): + return func + + +class _ExcludedProfilingContext: + """用于标记应该从外层 profiling 中排除的时间段""" + + def __init__(self, name=None): + self.name = name + if dist.is_initialized(): + self.rank_info = f"Rank {dist.get_rank()}" + else: + self.rank_info = "Single GPU" + + def __enter__(self): + torch_device_module.synchronize() + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch_device_module.synchronize() + elapsed = time.perf_counter() - self.start_time + stack = _get_excluded_time_stack() + for i in range(len(stack)): + stack[i] += elapsed + if self.name and CHECK_PROFILING_DEBUG_LEVEL(1): + logger.info(f"[Profile-Excluded] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds (excluded from outer profiling)") + return False + + async def __aenter__(self): + torch_device_module.synchronize() + self.start_time = time.perf_counter() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + torch_device_module.synchronize() + elapsed = time.perf_counter() - self.start_time + stack = _get_excluded_time_stack() + for i in range(len(stack)): + stack[i] += elapsed + if self.name and CHECK_PROFILING_DEBUG_LEVEL(1): + logger.info(f"[Profile-Excluded] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds (excluded from outer profiling)") + return False + + def __call__(self, func): + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + async with self: + return await func(*args, **kwargs) + + return async_wrapper + else: + + @wraps(func) + def sync_wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return sync_wrapper + + +class _ProfilingContextL1(_ProfilingContext): + """Level 1 profiling context with Level1_Log prefix.""" + + def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None): + super().__init__(f"Level1_Log {name}", recorder_mode, metrics_func, metrics_labels) + + +class _ProfilingContextL2(_ProfilingContext): + """Level 2 profiling context with Level2_Log prefix.""" + + def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None): + super().__init__(f"Level2_Log {name}", recorder_mode, metrics_func, metrics_labels) + + +""" +PROFILING_DEBUG_LEVEL=0: [Default] disable all profiling +PROFILING_DEBUG_LEVEL=1: enable ProfilingContext4DebugL1 +PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4DebugL2 +""" +ProfilingContext4DebugL1 = _ProfilingContextL1 if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext # if user >= 1, enable profiling +ProfilingContext4DebugL2 = _ProfilingContextL2 if CHECK_PROFILING_DEBUG_LEVEL(2) else _NullContext # if user >= 2, enable profiling +ExcludedProfilingContext = _ExcludedProfilingContext if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext diff --git a/lightx2v/utils/prompt_enhancer.py b/lightx2v/utils/prompt_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5401da334790d8b19ebc286c53a81a00f5f053 --- /dev/null +++ b/lightx2v/utils/prompt_enhancer.py @@ -0,0 +1,78 @@ +import argparse + +import torch +from loguru import logger +from transformers import AutoModelForCausalLM, AutoTokenizer + +from lightx2v.utils.profiler import * + +sys_prompt = """ +Transform the short prompt into a detailed video-generation caption using this structure: +​​Opening shot type​​ (long/medium/close-up/extreme close-up/full shot) +​​Primary subject(s)​​ with vivid attributes (colors, textures, actions, interactions) +​​Dynamic elements​​ (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...') +​​Scene composition​​ (background, environment, spatial relationships) +​​Lighting/atmosphere​​ (natural/artificial, time of day, mood) +​​Camera motion​​ (zooms, pans, static/handheld shots) if applicable. + +Pattern Summary from Examples: +[Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement] + +​One case: +Short prompt: a person is playing football +Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan. + +Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic. + +​​Now expand this short prompt:​​ [{}]. Please only output the final long prompt in English. +""" + + +class PromptEnhancer: + def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct", device_map="cuda:0"): + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map=device_map, + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + def to_device(self, device): + self.model = self.model.to(device) + + @ProfilingContext4DebugL1("Run prompt enhancer") + @torch.no_grad() + def __call__(self, prompt): + prompt = prompt.strip() + prompt = sys_prompt.format(prompt) + messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + generated_ids = self.model.generate( + **model_inputs, + max_new_tokens=8192, + ) + output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() + + think_id = self.tokenizer.encode("") + if len(think_id) == 1: + index = len(output_ids) - output_ids[::-1].index(think_id[0]) + else: + index = 0 + + thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n") + logger.info(f"[Enhanced] thinking content: {thinking_content}") + rewritten_prompt = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") + logger.info(f"[Enhanced] rewritten prompt: {rewritten_prompt}") + return rewritten_prompt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign") + args = parser.parse_args() + + prompt_enhancer = PromptEnhancer() + enhanced_prompt = prompt_enhancer(args.prompt) + logger.info(f"Original prompt: {args.prompt}") + logger.info(f"Enhanced prompt: {enhanced_prompt}") diff --git a/lightx2v/utils/quant_utils.py b/lightx2v/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ece66fa29797e6065a644ff4f1df5399720d3c9 --- /dev/null +++ b/lightx2v/utils/quant_utils.py @@ -0,0 +1,219 @@ +import torch +from loguru import logger + +try: + from qtorch.quant import float_quantize +except Exception: + logger.warning("qtorch not found, please install qtorch.Please install qtorch (pip install qtorch).") + float_quantize = None + + +class BaseQuantizer(object): + def __init__(self, bit, symmetric, granularity, **kwargs): + self.bit = bit + self.sym = symmetric + self.granularity = granularity + self.kwargs = kwargs + if self.granularity == "per_group": + self.group_size = self.kwargs["group_size"] + self.calib_algo = self.kwargs.get("calib_algo", "minmax") + + def get_tensor_range(self, tensor): + if self.calib_algo == "minmax": + return self.get_minmax_range(tensor) + elif self.calib_algo == "mse": + return self.get_mse_range(tensor) + else: + raise ValueError(f"Unsupported calibration algorithm: {self.calib_algo}") + + def get_minmax_range(self, tensor): + if self.granularity == "per_tensor": + max_val = torch.max(tensor) + min_val = torch.min(tensor) + else: + max_val = tensor.amax(dim=-1, keepdim=True) + min_val = tensor.amin(dim=-1, keepdim=True) + return (min_val, max_val) + + def get_mse_range(self, tensor): + raise NotImplementedError + + def get_qparams(self, tensor_range, device): + min_val, max_val = tensor_range[0], tensor_range[1] + qmin = self.qmin.to(device) + qmax = self.qmax.to(device) + if self.sym: + abs_max = torch.max(max_val.abs(), min_val.abs()) + abs_max = abs_max.clamp(min=1e-5) + scales = abs_max / qmax + zeros = torch.tensor(0.0) + else: + scales = (max_val - min_val).clamp(min=1e-5) / (qmax - qmin) + zeros = (qmin - torch.round(min_val / scales)).clamp(qmin, qmax) + return scales, zeros, qmax, qmin + + def reshape_tensor(self, tensor, allow_padding=False): + if self.granularity == "per_group": + t = tensor.reshape(-1, self.group_size) + else: + t = tensor + return t + + def restore_tensor(self, tensor, shape): + if tensor.shape == shape: + t = tensor + else: + t = tensor.reshape(shape) + return t + + def get_tensor_qparams(self, tensor): + tensor = self.reshape_tensor(tensor) + tensor_range = self.get_tensor_range(tensor) + scales, zeros, qmax, qmin = self.get_qparams(tensor_range, tensor.device) + return tensor, scales, zeros, qmax, qmin + + def fake_quant_tensor(self, tensor): + org_shape = tensor.shape + org_dtype = tensor.dtype + tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor) + tensor = self.quant_dequant(tensor, scales, zeros, qmax, qmin) + tensor = self.restore_tensor(tensor, org_shape).to(org_dtype) + return tensor + + def real_quant_tensor(self, tensor): + org_shape = tensor.shape + tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor) + tensor = self.quant(tensor, scales, zeros, qmax, qmin) + tensor = self.restore_tensor(tensor, org_shape) + if self.sym: + zeros = None + return tensor, scales, zeros + + +class IntegerQuantizer(BaseQuantizer): + def __init__(self, bit, symmetric, granularity, **kwargs): + super().__init__(bit, symmetric, granularity, **kwargs) + if "int_range" in self.kwargs: + self.qmin = self.kwargs["int_range"][0] + self.qmax = self.kwargs["int_range"][1] + else: + if self.sym: + self.qmin = -(2 ** (self.bit - 1)) + self.qmax = 2 ** (self.bit - 1) - 1 + else: + self.qmin = 0.0 + self.qmax = 2**self.bit - 1 + + self.qmin = torch.tensor(self.qmin) + self.qmax = torch.tensor(self.qmax) + self.dst_nbins = 2**bit + + def quant(self, tensor, scales, zeros, qmax, qmin): + tensor = torch.clamp(torch.round(tensor / scales) + zeros, qmin, qmax) + return tensor + + def dequant(self, tensor, scales, zeros): + tensor = (tensor - zeros) * scales + return tensor + + def quant_dequant( + self, + tensor, + scales, + zeros, + qmax, + qmin, + ): + tensor = self.quant(tensor, scales, zeros, qmax, qmin) + tensor = self.dequant(tensor, scales, zeros) + return tensor + + +class FloatQuantizer(BaseQuantizer): + def __init__(self, bit, symmetric, granularity, **kwargs): + super().__init__(bit, symmetric, granularity, **kwargs) + assert self.bit in ["e4m3", "e5m2"], f"Unsupported bit configuration: {self.bit}" + assert self.sym + + if self.bit == "e4m3": + self.e_bits = 4 + self.m_bits = 3 + self.fp_dtype = torch.float8_e4m3fn + elif self.bit == "e5m2": + self.e_bits = 5 + self.m_bits = 2 + self.fp_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported bit configuration: {self.bit}") + + finfo = torch.finfo(self.fp_dtype) + self.qmin, self.qmax = finfo.min, finfo.max + + self.qmax = torch.tensor(self.qmax) + self.qmin = torch.tensor(self.qmin) + + def quant(self, tensor, scales, zeros, qmax, qmin): + scaled_tensor = tensor / scales + zeros + scaled_tensor = torch.clip(scaled_tensor, self.qmin.cuda(), self.qmax.cuda()) + org_dtype = scaled_tensor.dtype + q_tensor = float_quantize(scaled_tensor.float(), self.e_bits, self.m_bits, rounding="nearest") + q_tensor.to(org_dtype) + return q_tensor + + def dequant(self, tensor, scales, zeros): + tensor = (tensor - zeros) * scales + return tensor + + def quant_dequant(self, tensor, scales, zeros, qmax, qmin): + tensor = self.quant(tensor, scales, zeros, qmax, qmin) + tensor = self.dequant(tensor, scales, zeros) + return tensor + + +# 导入 VLLM 的量化函数 +try: + from vllm import _custom_ops as ops +except ImportError: + ops = None + + +def quant_fp8_vllm(input_tensor): + input_tensor_fp8, input_tensor_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) + return input_tensor_fp8, input_tensor_scale + + +def dequant_fp8_vllm(input_tensor_fp8, input_tensor_scale, dtype): + return input_tensor_fp8.to(dtype) * input_tensor_scale.to(dtype) + + +if __name__ == "__main__": + weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda() + quantizer = IntegerQuantizer(4, False, "per_group", group_size=128) + q_weight = quantizer.fake_quant_tensor(weight) + logger.info(weight) + logger.info(q_weight) + logger.info(f"cosine = {torch.cosine_similarity(weight.view(1, -1).to(torch.float64), q_weight.view(1, -1).to(torch.float64))}") + + realq_weight, scales, zeros = quantizer.real_quant_tensor(weight) + logger.info(f"realq_weight = {realq_weight}, {realq_weight.shape}") + logger.info(f"scales = {scales}, {scales.shape}") + logger.info(f"zeros = {zeros}, {zeros.shape}") + + weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda() + quantizer = FloatQuantizer("e4m3", True, "per_channel") + q_weight = quantizer.fake_quant_tensor(weight) + logger.info(weight) + logger.info(q_weight) + logger.info(f"cosine = {torch.cosine_similarity(weight.view(1, -1).to(torch.float64), q_weight.view(1, -1).to(torch.float64))}") + + realq_weight, scales, zeros = quantizer.real_quant_tensor(weight) + logger.info(f"realq_weight = {realq_weight}, {realq_weight.shape}") + logger.info(f"scales = {scales}, {scales.shape}") + logger.info(f"zeros = {zeros}") + + input_tensor = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda() + input_tensor_fp8, input_tensor_scale = quant_fp8_vllm(input_tensor) + dequant_tensor = dequant_fp8_vllm(input_tensor_fp8, input_tensor_scale, input_tensor.dtype) + logger.info(input_tensor) + logger.info(dequant_tensor) + logger.info(f"cosine vllm fp8 quant/dequant = {torch.cosine_similarity(input_tensor.view(1, -1).to(torch.float64), dequant_tensor.view(1, -1).to(torch.float64))}") diff --git a/lightx2v/utils/registry_factory.py b/lightx2v/utils/registry_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5c8e09d921e359cf3650e8543dcd3dba140eca --- /dev/null +++ b/lightx2v/utils/registry_factory.py @@ -0,0 +1,71 @@ +from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER, PLATFORM_MM_WEIGHT_REGISTER + + +class Register(dict): + def __init__(self, *args, **kwargs): + super(Register, self).__init__(*args, **kwargs) + self._dict = {} + + def __call__(self, target_or_name): + if callable(target_or_name): + return self.register(target_or_name) + else: + return lambda x: self.register(x, key=target_or_name) + + def register(self, target, key=None): + if not callable(target): + raise Exception(f"Error: {target} must be callable!") + + if key is None: + key = target.__name__ + + if key in self._dict: + raise Exception(f"{key} already exists.") + + self[key] = target + return target + + def __setitem__(self, key, value): + self._dict[key] = value + + def __getitem__(self, key): + return self._dict[key] + + def __contains__(self, key): + return key in self._dict + + def __str__(self): + return str(self._dict) + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + def items(self): + return self._dict.items() + + def get(self, key, default=None): + return self._dict.get(key, default) + + def merge(self, other_register): + for key, value in other_register.items(): + if key in self._dict: + raise Exception(f"{key} already exists in target register.") + self[key] = value + + +MM_WEIGHT_REGISTER = Register() +ATTN_WEIGHT_REGISTER = Register() +RMS_WEIGHT_REGISTER = Register() +LN_WEIGHT_REGISTER = Register() +CONV3D_WEIGHT_REGISTER = Register() +CONV2D_WEIGHT_REGISTER = Register() +TENSOR_REGISTER = Register() +CONVERT_WEIGHT_REGISTER = Register() +EMBEDDING_WEIGHT_REGISTER = Register() +RUNNER_REGISTER = Register() + +ATTN_WEIGHT_REGISTER.merge(PLATFORM_ATTN_WEIGHT_REGISTER) +MM_WEIGHT_REGISTER.merge(PLATFORM_MM_WEIGHT_REGISTER) diff --git a/lightx2v/utils/service_utils.py b/lightx2v/utils/service_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b984050b8dfed62cbfae3c752e5db96f72019acd --- /dev/null +++ b/lightx2v/utils/service_utils.py @@ -0,0 +1,146 @@ +import base64 +import io +import signal +import sys +import threading +from datetime import datetime +from typing import Optional + +import psutil +import torch +from PIL import Image +from loguru import logger +from pydantic import BaseModel + + +class ProcessManager: + @staticmethod + def kill_all_related_processes(): + """Kill the current process and all its child processes""" + current_process = psutil.Process() + children = current_process.children(recursive=True) + for child in children: + try: + child.kill() + except Exception as e: + logger.info(f"Failed to kill child process {child.pid}: {e}") + try: + current_process.kill() + except Exception as e: + logger.info(f"Failed to kill main process: {e}") + + @staticmethod + def signal_handler(sig, frame): + logger.info("\nReceived Ctrl+C, shutting down all related processes...") + ProcessManager.kill_all_related_processes() + sys.exit(0) + + @staticmethod + def register_signal_handler(): + """Register the signal handler for SIGINT""" + signal.signal(signal.SIGINT, ProcessManager.signal_handler) + + +class TaskStatusMessage(BaseModel): + task_id: str + + +class BaseServiceStatus: + _lock = threading.Lock() + _current_task = None + _result_store = {} + + @classmethod + def start_task(cls, message): + with cls._lock: + if cls._current_task is not None: + raise RuntimeError("Service busy") + if message.task_id_must_unique and message.task_id in cls._result_store: + raise RuntimeError(f"Task ID {message.task_id} already exists") + cls._current_task = {"message": message, "start_time": datetime.now()} + return message.task_id + + @classmethod + def complete_task(cls, message): + with cls._lock: + cls._result_store[message.task_id] = {"success": True, "message": message, "start_time": cls._current_task["start_time"], "completion_time": datetime.now()} + cls._current_task = None + + @classmethod + def record_failed_task(cls, message, error: Optional[str] = None): + """Record a failed task with an error message.""" + with cls._lock: + cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error} + cls._current_task = None + + @classmethod + def clean_stopped_task(cls): + with cls._lock: + if cls._current_task: + message = cls._current_task["message"] + error = "Task stopped by user" + cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error} + cls._current_task = None + + @classmethod + def get_status_task_id(cls, task_id: str): + with cls._lock: + if cls._current_task and cls._current_task["message"].task_id == task_id: + return {"task_status": "processing"} + if task_id in cls._result_store: + return {"task_status": "completed", **cls._result_store[task_id]} + return {"task_status": "not_found"} + + @classmethod + def get_status_service(cls): + with cls._lock: + if cls._current_task: + return {"service_status": "busy", "task_id": cls._current_task["message"].task_id} + return {"service_status": "idle"} + + @classmethod + def get_all_tasks(cls): + with cls._lock: + return cls._result_store + + +class TensorTransporter: + def __init__(self): + self.buffer = io.BytesIO() + + def to_device(self, data, device): + if isinstance(data, dict): + return {key: self.to_device(value, device) for key, value in data.items()} + elif isinstance(data, list): + return [self.to_device(item, device) for item in data] + elif isinstance(data, torch.Tensor): + return data.to(device) + else: + return data + + def prepare_tensor(self, data) -> bytes: + self.buffer.seek(0) + self.buffer.truncate() + torch.save(self.to_device(data, "cpu"), self.buffer) + return base64.b64encode(self.buffer.getvalue()).decode("utf-8") + + def load_tensor(self, tensor_base64: str, device="cuda") -> torch.Tensor: + tensor_bytes = base64.b64decode(tensor_base64) + with io.BytesIO(tensor_bytes) as buffer: + return self.to_device(torch.load(buffer), device) + + +class ImageTransporter: + def __init__(self): + self.buffer = io.BytesIO() + + def prepare_image(self, image: Image.Image) -> bytes: + self.buffer.seek(0) + self.buffer.truncate() + image.save(self.buffer, format="PNG") + return base64.b64encode(self.buffer.getvalue()).decode("utf-8") + + def load_image(self, image_base64: bytes) -> Image.Image: + image_bytes = base64.b64decode(image_base64) + with io.BytesIO(image_bytes) as buffer: + return Image.open(buffer).convert("RGB") diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py new file mode 100644 index 0000000000000000000000000000000000000000..13f94df6f5df14880b5f865d8945809d86f8d2b2 --- /dev/null +++ b/lightx2v/utils/set_config.py @@ -0,0 +1,115 @@ +import json +import os + +import torch +import torch.distributed as dist +from loguru import logger +from torch.distributed.tensor.device_mesh import init_device_mesh + +from lightx2v.utils.input_info import ALL_INPUT_INFO_KEYS +from lightx2v.utils.lockable_dict import LockableDict +from lightx2v_platform.base.global_var import AI_DEVICE + + +def get_default_config(): + default_config = { + "do_mm_calib": False, + "cpu_offload": False, + "max_area": False, + "vae_stride": (4, 8, 8), + "patch_size": (1, 2, 2), + "feature_caching": "NoCaching", # ["NoCaching", "TaylorSeer", "Tea"] + "teacache_thresh": 0.26, + "use_ret_steps": False, + "use_bfloat16": True, + "lora_configs": None, # List of dicts with 'path' and 'strength' keys + "use_prompt_enhancer": False, + "parallel": False, + "seq_parallel": False, + "cfg_parallel": False, + "enable_cfg": False, + "use_image_encoder": True, + } + default_config = LockableDict(default_config) + return default_config + + +def set_config(args): + config = get_default_config() + config.update({k: v for k, v in vars(args).items() if k not in ALL_INPUT_INFO_KEYS}) + + if config.get("config_json", None) is not None: + logger.info(f"Loading some config from {config['config_json']}") + with open(config["config_json"], "r") as f: + config_json = json.load(f) + config.update(config_json) + + if config["model_cls"] in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: # Special config for hunyuan video 1.5 model folder structure + config["transformer_model_path"] = os.path.join(config["model_path"], "transformer", config["transformer_model_name"]) # transformer_model_name: [480p_t2v, 480p_i2v, 720p_t2v, 720p_i2v] + if os.path.exists(os.path.join(config["transformer_model_path"], "config.json")): + with open(os.path.join(config["transformer_model_path"], "config.json"), "r") as f: + model_config = json.load(f) + config.update(model_config) + else: + if os.path.exists(os.path.join(config["model_path"], "config.json")): + with open(os.path.join(config["model_path"], "config.json"), "r") as f: + model_config = json.load(f) + config.update(model_config) + elif os.path.exists(os.path.join(config["model_path"], "low_noise_model", "config.json")): # 需要一个更优雅的update方法 + with open(os.path.join(config["model_path"], "low_noise_model", "config.json"), "r") as f: + model_config = json.load(f) + config.update(model_config) + elif os.path.exists(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json")): # 需要一个更优雅的update方法 + with open(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json"), "r") as f: + model_config = json.load(f) + config.update(model_config) + elif os.path.exists(os.path.join(config["model_path"], "original", "config.json")): + with open(os.path.join(config["model_path"], "original", "config.json"), "r") as f: + model_config = json.load(f) + config.update(model_config) + # load quantized config + if config.get("dit_quantized_ckpt", None) is not None: + config_path = os.path.join(config["dit_quantized_ckpt"], "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + model_config = json.load(f) + config.update(model_config) + + if config["task"] in ["i2v", "s2v"]: + if config["target_video_length"] % config["vae_stride"][0] != 1: + logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.") + config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1 + + if config["task"] not in ["t2i", "i2i"] and config["model_cls"] not in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: + config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0] + if config["model_cls"] == "seko_talk": + config["attnmap_frame_num"] += 1 + + return config + + +def set_parallel_config(config): + if config["parallel"]: + cfg_p_size = config["parallel"].get("cfg_p_size", 1) + seq_p_size = config["parallel"].get("seq_p_size", 1) + assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size" + config["device_mesh"] = init_device_mesh(AI_DEVICE, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p")) + + if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1: + config["seq_parallel"] = True + + if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1: + config["cfg_parallel"] = True + # warmup dist + _a = torch.zeros([1]).to(f"{AI_DEVICE}:{dist.get_rank()}") + dist.all_reduce(_a) + + +def print_config(config): + config_to_print = config.copy() + config_to_print.pop("device_mesh", None) + if config["parallel"]: + if dist.get_rank() == 0: + logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}") + else: + logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}") diff --git a/lightx2v/utils/utils.py b/lightx2v/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b63a933ad0f6d3fdd5edb9b90d9ec3a49936856 --- /dev/null +++ b/lightx2v/utils/utils.py @@ -0,0 +1,486 @@ +import os +import random +import subprocess +from typing import Optional + +import imageio +import imageio_ffmpeg as ffmpeg +import numpy as np +import safetensors +import torch +import torch.distributed as dist +import torchvision +from einops import rearrange +from loguru import logger + +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +def seed_all(seed): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch_device_module.manual_seed(seed) + torch_device_module.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): + """save videos by video tensor + copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 + + Args: + videos (torch.Tensor): video tensor predicted by the model + path (str): path to save video + rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False. + n_rows (int, optional): Defaults to 1. + fps (int, optional): video save fps. Defaults to 8. + """ + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x, 0, 1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps) + + +def cache_video( + tensor, + save_file: str, + fps=30, + suffix=".mp4", + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5, +): + save_dir = os.path.dirname(save_file) + try: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + except Exception as e: + logger.error(f"Failed to create directory: {save_dir}, error: {e}") + return None + + cache_file = save_file + + # save to cache + error = None + for _ in range(retry): + try: + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) # type: ignore + tensor = torch.stack( + [torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2)], + dim=1, + ).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + del tensor + torch.cuda.empty_cache() + return cache_file + except Exception as e: + error = e + continue + else: + logger.info(f"cache_video failed, error: {error}", flush=True) + return None + + +def vae_to_comfyui_image(vae_output: torch.Tensor) -> torch.Tensor: + """ + Convert VAE decoder output to ComfyUI Image format + + Args: + vae_output: VAE decoder output tensor, typically in range [-1, 1] + Shape: [B, C, T, H, W] or [B, C, H, W] + + Returns: + ComfyUI Image tensor in range [0, 1] + Shape: [B, H, W, C] for single frame or [B*T, H, W, C] for video + """ + # Handle video tensor (5D) vs image tensor (4D) + if vae_output.dim() == 5: + # Video tensor: [B, C, T, H, W] + B, C, T, H, W = vae_output.shape + # Reshape to [B*T, C, H, W] for processing + vae_output = vae_output.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) + + # Normalize from [-1, 1] to [0, 1] + images = (vae_output + 1) / 2 + + # Clamp values to [0, 1] + images = torch.clamp(images, 0, 1) + + # Convert from [B, C, H, W] to [B, H, W, C] + images = images.permute(0, 2, 3, 1).cpu() + + return images + + +def vae_to_comfyui_image_inplace(vae_output: torch.Tensor) -> torch.Tensor: + """ + Convert VAE decoder output to ComfyUI Image format (inplace operation) + + Args: + vae_output: VAE decoder output tensor, typically in range [-1, 1] + Shape: [B, C, T, H, W] or [B, C, H, W] + WARNING: This tensor will be modified in-place! + + Returns: + ComfyUI Image tensor in range [0, 1] + Shape: [B, H, W, C] for single frame or [B*T, H, W, C] for video + Note: The returned tensor is the same object as input (modified in-place) + """ + # Handle video tensor (5D) vs image tensor (4D) + if vae_output.dim() == 5: + # Video tensor: [B, C, T, H, W] + B, C, T, H, W = vae_output.shape + # Reshape to [B*T, C, H, W] for processing (inplace view) + vae_output = vae_output.permute(0, 2, 1, 3, 4).contiguous().view(B * T, C, H, W) + + # Normalize from [-1, 1] to [0, 1] (inplace) + vae_output.add_(1).div_(2) + + # Clamp values to [0, 1] (inplace) + vae_output.clamp_(0, 1) + + # Convert from [B, C, H, W] to [B, H, W, C] and move to CPU + vae_output = vae_output.permute(0, 2, 3, 1).cpu() + + return vae_output + + +def save_to_video( + images: torch.Tensor, + output_path: str, + fps: float = 24.0, + method: str = "imageio", + lossless: bool = False, + output_pix_fmt: Optional[str] = "yuv420p", +) -> None: + """ + Save ComfyUI Image tensor to video file + + Args: + images: ComfyUI Image tensor [N, H, W, C] in range [0, 1] + output_path: Path to save the video + fps: Frames per second + method: Save method - "imageio" or "ffmpeg" + lossless: Whether to use lossless encoding (ffmpeg method only) + output_pix_fmt: Pixel format for output (ffmpeg method only) + """ + assert images.dim() == 4 and images.shape[-1] == 3, "Input must be [N, H, W, C] with C=3" + + # Ensure output directory exists + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + + if method == "imageio": + # Convert to uint8 + # frames = (images * 255).cpu().numpy().astype(np.uint8) + frames = (images * 255).to(torch.uint8).cpu().numpy() + imageio.mimsave(output_path, frames, fps=fps) # type: ignore + + elif method == "ffmpeg": + # Convert to numpy and scale to [0, 255] + # frames = (images * 255).cpu().numpy().clip(0, 255).astype(np.uint8) + frames = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + + # Convert RGB to BGR for OpenCV/FFmpeg + frames = frames[..., ::-1].copy() + + N, height, width, _ = frames.shape + + # Ensure even dimensions for x264 + width += width % 2 + height += height % 2 + + # Get ffmpeg executable from imageio_ffmpeg + ffmpeg_exe = ffmpeg.get_ffmpeg_exe() + + if lossless: + command = [ + ffmpeg_exe, + "-y", # Overwrite output file if it exists + "-f", + "rawvideo", + "-s", + f"{int(width)}x{int(height)}", + "-pix_fmt", + "bgr24", + "-r", + f"{fps}", + "-loglevel", + "error", + "-threads", + "4", + "-i", + "-", # Input from pipe + "-vcodec", + "libx264rgb", + "-crf", + "0", + "-an", # No audio + output_path, + ] + else: + command = [ + ffmpeg_exe, + "-y", # Overwrite output file if it exists + "-f", + "rawvideo", + "-s", + f"{int(width)}x{int(height)}", + "-pix_fmt", + "bgr24", + "-r", + f"{fps}", + "-loglevel", + "error", + "-threads", + "4", + "-i", + "-", # Input from pipe + "-vcodec", + "libx264", + "-pix_fmt", + output_pix_fmt, + "-an", # No audio + output_path, + ] + + # Run FFmpeg + process = subprocess.Popen( + command, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + if process.stdin is None: + raise BrokenPipeError("No stdin buffer received.") + + # Write frames to FFmpeg + for frame in frames: + # Pad frame if needed + if frame.shape[0] < height or frame.shape[1] < width: + padded = np.zeros((height, width, 3), dtype=np.uint8) + padded[: frame.shape[0], : frame.shape[1]] = frame + frame = padded + process.stdin.write(frame.tobytes()) + + process.stdin.close() + process.wait() + + if process.returncode != 0: + error_output = process.stderr.read().decode() if process.stderr else "Unknown error" + raise RuntimeError(f"FFmpeg failed with error: {error_output}") + + else: + raise ValueError(f"Unknown save method: {method}") + + +def remove_substrings_from_keys(original_dict, substr): + new_dict = {} + for key, value in original_dict.items(): + new_dict[key.replace(substr, "")] = value + return new_dict + + +def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["original", "fp8", "int8", "distill_models", "distill_fp8", "distill_int8"]): + if ckpt_config_key and config.get(ckpt_config_key, None) is not None: + return config.get(ckpt_config_key) + + paths_to_check = [ + os.path.join(config["model_path"], filename), + ] + if isinstance(subdir, list): + for sub in subdir: + paths_to_check.insert(0, os.path.join(config["model_path"], sub, filename)) + else: + paths_to_check.insert(0, os.path.join(config["model_path"], subdir, filename)) + + for path in paths_to_check: + if os.path.exists(path): + return path + raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") + + +def load_safetensors(in_path, remove_key=None, include_keys=None): + """加载safetensors文件或目录,支持按key包含筛选或排除""" + include_keys = include_keys or [] + if os.path.isdir(in_path): + return load_safetensors_from_dir(in_path, remove_key, include_keys) + elif os.path.isfile(in_path): + return load_safetensors_from_path(in_path, remove_key, include_keys) + else: + raise ValueError(f"{in_path} does not exist") + + +def load_safetensors_from_path(in_path, remove_key=None, include_keys=None): + include_keys = include_keys or [] + tensors = {} + with safetensors.safe_open(in_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if include_keys: + if any(inc_key in key for inc_key in include_keys): + tensors[key] = f.get_tensor(key) + else: + if not (remove_key and remove_key in key): + tensors[key] = f.get_tensor(key) + return tensors + + +def load_safetensors_from_dir(in_dir, remove_key=None, include_keys=None): + """从目录加载所有safetensors文件,支持按key筛选""" + include_keys = include_keys or [] + tensors = {} + safetensors_files = os.listdir(in_dir) + safetensors_files = [f for f in safetensors_files if f.endswith(".safetensors")] + for f in safetensors_files: + tensors.update(load_safetensors_from_path(os.path.join(in_dir, f), remove_key, include_keys)) + return tensors + + +def load_pt_safetensors(in_path, remove_key=None, include_keys=None): + """加载pt/pth或safetensors权重,支持按key筛选""" + include_keys = include_keys or [] + ext = os.path.splitext(in_path)[-1] + if ext in (".pt", ".pth", ".tar"): + state_dict = torch.load(in_path, map_location="cpu", weights_only=True) + # 处理筛选逻辑 + keys_to_keep = [] + for key in state_dict.keys(): + if include_keys: + if any(inc_key in key for inc_key in include_keys): + keys_to_keep.append(key) + else: + if not (remove_key and remove_key in key): + keys_to_keep.append(key) + # 只保留符合条件的key + state_dict = {k: state_dict[k] for k in keys_to_keep} + else: + state_dict = load_safetensors(in_path, remove_key, include_keys) + return state_dict + + +def load_weights(checkpoint_path, cpu_offload=False, remove_key=None, load_from_rank0=False, include_keys=None): + if not dist.is_initialized() or not load_from_rank0: + # Single GPU mode + logger.info(f"Loading weights from {checkpoint_path}") + cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key, include_keys) + return cpu_weight_dict + + # Multi-GPU mode + is_weight_loader = False + current_rank = dist.get_rank() + if current_rank == 0: + is_weight_loader = True + + cpu_weight_dict = {} + if is_weight_loader: + logger.info(f"Loading weights from {checkpoint_path}") + cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key) + + meta_dict = {} + if is_weight_loader: + for key, tensor in cpu_weight_dict.items(): + meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} + + obj_list = [meta_dict] if is_weight_loader else [None] + + src_global_rank = 0 + dist.broadcast_object_list(obj_list, src=src_global_rank) + synced_meta_dict = obj_list[0] + + if cpu_offload: + target_device = "cpu" + distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()} + dist.barrier() + else: + target_device = torch.device(f"cuda:{current_rank}") + distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()} + dist.barrier(device_ids=[torch.cuda.current_device()]) + + for key in sorted(synced_meta_dict.keys()): + tensor_to_broadcast = distributed_weight_dict[key] + if is_weight_loader: + tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True) + + if cpu_offload: + if is_weight_loader: + gpu_tensor = tensor_to_broadcast.cuda() + dist.broadcast(gpu_tensor, src=src_global_rank) + tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True) + del gpu_tensor + torch.cuda.empty_cache() + else: + gpu_tensor = torch.empty_like(tensor_to_broadcast, device="cuda") + dist.broadcast(gpu_tensor, src=src_global_rank) + tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True) + del gpu_tensor + torch.cuda.empty_cache() + else: + dist.broadcast(tensor_to_broadcast, src=src_global_rank) + + if is_weight_loader: + del cpu_weight_dict + + if cpu_offload: + torch.cuda.empty_cache() + + logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") + return distributed_weight_dict + + +def masks_like(tensor, zero=False, generator=None, p=0.2, prev_len=1): + assert isinstance(tensor, torch.Tensor) + out = torch.ones_like(tensor) + if zero: + if generator is not None: + random_num = torch.rand(1, generator=generator, device=generator.device).item() + if random_num < p: + out[:, :prev_len] = torch.zeros_like(out[:, :prev_len]) + else: + out[:, :prev_len] = torch.zeros_like(out[:, :prev_len]) + return out + + +def best_output_size(w, h, dw, dh, expected_area): + # float output size + ratio = w / h + ow = (expected_area * ratio) ** 0.5 + oh = expected_area / ow + + # process width first + ow1 = int(ow // dw * dw) + oh1 = int(expected_area / ow1 // dh * dh) + assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area + ratio1 = ow1 / oh1 + + # process height first + oh2 = int(oh // dh * dh) + ow2 = int(expected_area / oh2 // dw * dw) + assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area + ratio2 = ow2 / oh2 + + # compare ratios + if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio): + return ow1, oh1 + else: + return ow2, oh2 diff --git a/lightx2v_kernel/CMakeLists.txt b/lightx2v_kernel/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7b274606da400e745f308a4871a664319a9d914b --- /dev/null +++ b/lightx2v_kernel/CMakeLists.txt @@ -0,0 +1,103 @@ +cmake_minimum_required(VERSION 3.22 FATAL_ERROR) +project(lightx2v-kernel LANGUAGES CXX CUDA) + +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + +# Python +find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED) + +# CXX +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + +# CUDA +enable_language(CUDA) +find_package(CUDAToolkit REQUIRED) +set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON) + +# Torch +find_package(Torch REQUIRED) +# clean Torch Flag +clear_cuda_arches(CMAKE_FLAG) + + +# cutlass +if(CUTLASS_PATH) + set(repo-cutlass_SOURCE_DIR ${CUTLASS_PATH}) + message(STATUS "Using local CUTLASS from: ${CUTLASS_PATH}") +else() + message(FATAL_ERROR "CUTLASS_PATH is not set. Please manually download CUTLASS first.") +endif() + + +# ccache option +option(ENABLE_CCACHE "Whether to use ccache" ON) +find_program(CCACHE_FOUND ccache) +if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR}) + message(STATUS "Building with CCACHE enabled") + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache") + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache") +endif() + + +include_directories( + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/csrc + ${repo-cutlass_SOURCE_DIR}/include + ${repo-cutlass_SOURCE_DIR}/tools/util/include +) + +set(LIGHTX2V_KERNEL_CUDA_FLAGS + "-DNDEBUG" + "-DOPERATOR_NAMESPACE=lightx2v-kernel" + "-O3" + "-Xcompiler" + "-fPIC" + "-std=c++17" + "-DCUTE_USE_PACKED_TUPLE=1" + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" + "-DCUTLASS_VERSIONS_GENERATED" + "-DCUTLASS_TEST_LEVEL=0" + "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1" + "-DCUTLASS_DEBUG_TRACE_LEVEL=0" + "--expt-relaxed-constexpr" + "--expt-extended-lambda" + "--threads=32" + + # Suppress warnings + "-Xcompiler=-Wconversion" + "-Xcompiler=-fno-strict-aliasing" + +) + + +list(APPEND LIGHTX2V_KERNEL_CUDA_FLAGS + # "-gencode=arch=compute_90,code=sm_90" + # "-gencode=arch=compute_90a,code=sm_90a" + # "-gencode=arch=compute_100,code=sm_100" + # "-gencode=arch=compute_100a,code=sm_100a" + # "-gencode=arch=compute_120,code=sm_120" + "-gencode=arch=compute_120a,code=sm_120a" +) + + +set(SOURCES + "csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu" + "csrc/gemm/nvfp4_quant_kernels_sm120.cu" + "csrc/gemm/mxfp4_quant_kernels_sm120.cu" + "csrc/gemm/mxfp8_quant_kernels_sm120.cu" + "csrc/gemm/mxfp6_quant_kernels_sm120.cu" + "csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu" + "csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu" + "csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu" + "csrc/common_extension.cc" +) + +Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) + +message(STATUS "LIGHTX2V_KERNEL_CUDA_FLAGS: ${LIGHTX2V_KERNEL_CUDA_FLAGS}") + +target_compile_options(common_ops PRIVATE $<$:${LIGHTX2V_KERNEL_CUDA_FLAGS}>) +target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt) + +install(TARGETS common_ops LIBRARY DESTINATION lightx2v_kernel) diff --git a/lightx2v_kernel/LICENSE b/lightx2v_kernel/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9c422689c8f5c317c7c65153b1209349ec57007e --- /dev/null +++ b/lightx2v_kernel/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023-2024 SGLang Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/lightx2v_kernel/README.md b/lightx2v_kernel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e16130fee425417ffea851cfffc61f019e730e44 --- /dev/null +++ b/lightx2v_kernel/README.md @@ -0,0 +1,56 @@ +# lightx2v_kernel + +### Preparation +``` +# Install torch, at least version 2.7 + +pip install scikit_build_core uv +``` + +### Build whl + +``` +git clone https://github.com/NVIDIA/cutlass.git + +git clone https://github.com/ModelTC/LightX2V.git + +cd LightX2V/lightx2v_kernel + +# Set the /path/to/cutlass below to the absolute path of cutlass you download. + +MAX_JOBS=$(nproc) && CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) \ +uv build --wheel \ + -Cbuild-dir=build . \ + -Ccmake.define.CUTLASS_PATH=/path/to/cutlass \ + --verbose \ + --color=always \ + --no-build-isolation +``` + + +### Install whl +``` +pip install dist/*whl --force-reinstall --no-deps +``` + +### Test + +##### cos and speed test, mm without bias +``` +python test/nvfp4_nvfp4/test_bench2.py +``` + +##### cos and speed test, mm with bias +``` +python test/nvfp4_nvfp4/test_bench3_bias.py +``` + +##### Bandwidth utilization test for quant +``` +python test/nvfp4_nvfp4/test_quant_mem_utils.py +``` + +##### tflops test for mm +``` +python test/nvfp4_nvfp4/test_mm_tflops.py +``` diff --git a/lightx2v_kernel/cmake/utils.cmake b/lightx2v_kernel/cmake/utils.cmake new file mode 100644 index 0000000000000000000000000000000000000000..0eaa7a61acfa0bdaca34a9b306a06318b5b39c6c --- /dev/null +++ b/lightx2v_kernel/cmake/utils.cmake @@ -0,0 +1,21 @@ +# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake +# +# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in +# `CUDA_ARCH_FLAGS`. +# +# Example: +# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" +# clear_cuda_arches(CUDA_ARCH_FLAGS) +# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75" +# CMAKE_CUDA_FLAGS="-Wall" +# +macro(clear_cuda_arches CUDA_ARCH_FLAGS) + # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` + string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS + ${CMAKE_CUDA_FLAGS}) + + # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified + # and passed back via the `CUDA_ARCHITECTURES` property. + string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS + ${CMAKE_CUDA_FLAGS}) +endmacro() diff --git a/lightx2v_kernel/csrc/common_extension.cc b/lightx2v_kernel/csrc/common_extension.cc new file mode 100644 index 0000000000000000000000000000000000000000..18b15679a93a4cc51e38eed865bdc7fda7d14523 --- /dev/null +++ b/lightx2v_kernel/csrc/common_extension.cc @@ -0,0 +1,51 @@ +#include +#include +#include + +#include "lightx2v_kernel_ops.h" + +TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) { + + m.def( + "cutlass_scaled_nvfp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor " + "alpha, Tensor? bias) -> ()"); + m.impl("cutlass_scaled_nvfp4_mm_sm120", torch::kCUDA, &cutlass_scaled_nvfp4_mm_sm120); + + m.def( + "scaled_nvfp4_quant_sm120(Tensor! output, Tensor! input," + " Tensor! output_scale, Tensor! input_scale) -> ()"); + m.impl("scaled_nvfp4_quant_sm120", torch::kCUDA, &scaled_nvfp4_quant_sm120); + + m.def( + "scaled_mxfp4_quant_sm120(Tensor! output, Tensor! input," + " Tensor! output_scale) -> ()"); + m.impl("scaled_mxfp4_quant_sm120", torch::kCUDA, &scaled_mxfp4_quant_sm120); + + m.def( + "scaled_mxfp8_quant_sm120(Tensor! output, Tensor! input," + " Tensor! output_scale) -> ()"); + m.impl("scaled_mxfp8_quant_sm120", torch::kCUDA, &scaled_mxfp8_quant_sm120); + + m.def( + "scaled_mxfp6_quant_sm120(Tensor! output, Tensor! input," + " Tensor! output_scale) -> ()"); + m.impl("scaled_mxfp6_quant_sm120", torch::kCUDA, &scaled_mxfp6_quant_sm120); + + m.def( + "cutlass_scaled_mxfp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor " + "alpha, Tensor? bias) -> ()"); + m.impl("cutlass_scaled_mxfp4_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp4_mm_sm120); + + m.def( + "cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor " + "alpha, Tensor? bias) -> ()"); + m.impl("cutlass_scaled_mxfp6_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp6_mxfp8_mm_sm120); + + m.def( + "cutlass_scaled_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor " + "alpha, Tensor? bias) -> ()"); + m.impl("cutlass_scaled_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp8_mm_sm120); + +} + +REGISTER_EXTENSION(common_ops) diff --git a/lightx2v_kernel/csrc/gemm/mxfp4_quant_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp4_quant_kernels_sm120.cu new file mode 100644 index 0000000000000000000000000000000000000000..d89560301a026c43f69b486f71dbff492fda25c2 --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/mxfp4_quant_kernels_sm120.cu @@ -0,0 +1,324 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 32; + + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + // PTX instructions used here requires sm100a. +// #if CUDA_VERSION >= 12080 +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + return val; +// #else +// return 0; +// #endif +// #endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* get_sf_out_address(int rowIdx, int colIdx, int numCols, SFType* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 4); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +// #endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, uint8_t* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 32 values (four threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = vecMax * 0.16666666666666666f; + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmp); + fp8SFVal = tmp.__x; + + // Get the output scale. + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +// #else +// return 0; +// #endif +} + +// Use UE4M3 by default. +template +__global__ void +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(256, 6) cvt_fp16_to_fp4( +// #else +// cvt_fp16_to_fp4( +// #endif + int32_t numRows, int32_t numCols, Type const* in, uint32_t* out, uint32_t* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + get_sf_out_address(rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, sf_out); + } + } +// #endif +} + +template +void invokeFP4Quantization( + int m, + int n, + T const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream) { + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 256)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 1536 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // Launch the cvt kernel. + cvt_fp16_to_fp4<<>>( + m, n, input, reinterpret_cast(output), reinterpret_cast(SFOuput)); +} + +// Instantiate the function. +template void invokeFP4Quantization( + int m, + int n, + half const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeFP4Quantization( + int m, + int n, + __nv_bfloat16 const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); + +inline int getMultiProcessorCount() { + static int multi_processor_count = []() { + int device_id = 0; + int count = 0; + + // Get the current CUDA device ID + CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id)); + + // Get the number of multiprocessors for the current device + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id)); + + return count; // Initialize the static variable + }(); + + return multi_processor_count; // Return the cached value on subsequent calls +} + +void scaled_mxfp4_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + + int multiProcessorCount = getMultiProcessorCount(); + + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + default: { + std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid"; + throw std::runtime_error("Unsupported input data type for quantize_to_fp4."); + } + } +} diff --git a/lightx2v_kernel/csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu new file mode 100644 index 0000000000000000000000000000000000000000..18d03912aa6fb86d3e0dcb9dd88260d038da71c1 --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu @@ -0,0 +1,323 @@ +#include +#include +#include + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + + +using namespace cute; + + +struct Mxfp4GemmSm120 { + ///////////////////////////////////////////////////////////////////////////////////////////////// + /// GEMM kernel configurations + ///////////////////////////////////////////////////////////////////////////////////////////////// + + // A matrix configuration + using ElementA = cutlass::mx_float4_t; // Element type for A matrix operand + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = 128; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = cutlass::mx_float4_t; // Element type for B matrix operand + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand + using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + // Kernel functional config + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + + // Kernel Perf config + using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size + using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + + // use per-column bias, i.e. every column has different bias + using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + EVTOp + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +}; + + +// Populates a Gemm::Arguments structure from the given commandline options +typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + int64_t M, + int64_t N, + int64_t K) { + using Sm1xxBlkScaledConfig = typename Mxfp4GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + if (bias){ + using StrideBias = Stride; + + typename Mxfp4GemmSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + static const float beta_zero = 0.0f; + fusion_args.beta_ptr = &beta_zero; + fusion_args.bias_ptr = static_cast(bias->data_ptr()); + fusion_args.dBias = StrideBias{}; + return arguments; + } else { + typename Mxfp4GemmSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + static const float beta_zero = 0.0f; + fusion_args.beta_ptr = &beta_zero; + return arguments; + } +} + + +void runGemmMxfp4Sm120( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename Mxfp4GemmSm120::Gemm gemm; + + auto arguments = args_from_options_mxp4_mxfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); + size_t workspace_size = Mxfp4GemmSm120::Gemm::get_workspace_size(arguments); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} + + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e8m0fnu; + +void cutlass_scaled_mxfp4_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias) { + + CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); + + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + + + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + TORCH_CHECK( + A.sizes()[1] == B.sizes()[1], + "a and b shapes cannot be multiplied (", + A.sizes()[0], + "x", + A.sizes()[1], + " and ", + B.sizes()[0], + "x", + B.sizes()[1], + ")"); + + auto const m = A.sizes()[0]; + auto const n = B.sizes()[0]; + auto const k = A.sizes()[1] * 2; + + constexpr int alignment = 128; + TORCH_CHECK( + k % alignment == 0, + "Expected k to be divisible by ", + alignment, + ", but got a shape: (", + A.sizes()[0], + "x", + A.sizes()[1], + "), k: ", + k, + "."); + TORCH_CHECK( + n % alignment == 0, + "Expected n to be divisible by ", + alignment, + ", but got b shape: (", + B.sizes()[0], + "x", + B.sizes()[1], + ")."); + + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; + int rounded_m = round_up(m, 128); + int rounded_n = round_up(n, 128); + // Since k is divisible by 128 (alignment), k / 32 is guaranteed to be an + // integer. + int rounded_k = round_up(k / 32, 4); + + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + TORCH_CHECK( + A_sf.sizes()[1] == B_sf.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + " and ", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + TORCH_CHECK( + A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", + rounded_m, + "x", + rounded_k, + "), but got a shape (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + ")"); + TORCH_CHECK( + B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", + rounded_n, + "x", + rounded_k, + "), but got a shape (", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + + auto out_dtype = D.dtype(); + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + + runGemmMxfp4Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream); +} diff --git a/lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu new file mode 100644 index 0000000000000000000000000000000000000000..7b2e178ecc9c918ab043e5bf7e73ea31521ef5b0 --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu @@ -0,0 +1,324 @@ +#include +#include +#include + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + + +using namespace cute; + + +struct Mxfp6Mxfp8GemmSm120 { + ///////////////////////////////////////////////////////////////////////////////////////////////// + /// GEMM kernel configurations + ///////////////////////////////////////////////////////////////////////////////////////////////// + + // A matrix configuration + using ElementA = cutlass::mx_float8_t; // Element type for A matrix operand + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = cutlass::mx_float6_t; // Element type for B matrix operand + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand + using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + // Kernel functional config + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + + // Kernel Perf config + using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size + using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + + // use per-column bias, i.e. every column has different bias + using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + EVTOp + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +}; + + +// Populates a Gemm::Arguments structure from the given commandline options +typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + int64_t M, + int64_t N, + int64_t K) { + using Sm1xxBlkScaledConfig = typename Mxfp6Mxfp8GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(Mxfp6Mxfp8GemmSm120::StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(Mxfp6Mxfp8GemmSm120::StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(Mxfp6Mxfp8GemmSm120::StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + if (bias){ + using StrideBias = Stride; + + typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + static const float beta_zero = 0.0f; + fusion_args.beta_ptr = &beta_zero; + fusion_args.bias_ptr = static_cast(bias->data_ptr()); + fusion_args.dBias = StrideBias{}; + return arguments; + } else { + typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + static const float beta_zero = 0.0f; + fusion_args.beta_ptr = &beta_zero; + return arguments; + } +} + + +void runGemmMxfp6Mxfp8Sm120( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename Mxfp6Mxfp8GemmSm120::Gemm gemm; + + auto arguments = args_from_options_mxfp6_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); + size_t workspace_size = Mxfp6Mxfp8GemmSm120::Gemm::get_workspace_size(arguments); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} + + +constexpr auto FP6_FP8_TYPE = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e8m0fnu; + +void cutlass_scaled_mxfp6_mxfp8_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias) { + + CHECK_INPUT(A, FP6_FP8_TYPE, "a"); + CHECK_INPUT(B, FP6_FP8_TYPE, "b"); + + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + + + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); +// TORCH_CHECK( +// A.sizes()[1] == B.sizes()[1], +// "a and b shapes cannot be multiplied (", +// A.sizes()[0], +// "x", +// A.sizes()[1], +// " and ", +// B.sizes()[0], +// "x", +// B.sizes()[1], +// ")"); + + auto const m = A.sizes()[0]; + auto const n = B.sizes()[0]; + auto const k = A.sizes()[1]; + + constexpr int alignment_a = 16; + constexpr int alignment_b = 128; + TORCH_CHECK( + k % alignment_a == 0, + "Expected k to be divisible by ", + alignment_a, + ", but got a shape: (", + A.sizes()[0], + "x", + A.sizes()[1], + "), k: ", + k, + "."); + TORCH_CHECK( + n % alignment_b == 0, + "Expected n to be divisible by ", + alignment_b, + ", but got b shape: (", + B.sizes()[0], + "x", + B.sizes()[1], + ")."); + + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; + int rounded_m = round_up(m, 128); + int rounded_n = round_up(n, 128); + // Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an + // integer. + int rounded_k = round_up(k / 32, 4); + + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + TORCH_CHECK( + A_sf.sizes()[1] == B_sf.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + " and ", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + TORCH_CHECK( + A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", + rounded_m, + "x", + rounded_k, + "), but got a shape (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + ")"); + TORCH_CHECK( + B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", + rounded_n, + "x", + rounded_k, + "), but got a shape (", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + + auto out_dtype = D.dtype(); + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + + runGemmMxfp6Mxfp8Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream); +} diff --git a/lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu new file mode 100644 index 0000000000000000000000000000000000000000..99b684bdfb6065bcdd5c9132a13103bdc58bc397 --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu @@ -0,0 +1,348 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP6_ELTS_PER_THREAD = 8; +constexpr int CVT_FP6_SF_VEC_SIZE = 32; + +struct uint8x6_t { + uint8_t elts[6]; +}; + +// Convert 4 float2 values into 8 e3m2 values (represented as one uint8x6_t). +inline __device__ uint8x6_t fp32_vec_to_e3m2(float2 (&array)[4]) { + uint64_t val; + asm volatile( + "{\n" + ".reg .b16 pack0;\n" + ".reg .b16 pack1;\n" + ".reg .b16 pack2;\n" + ".reg .b16 pack3;\n" + "cvt.rn.satfinite.e3m2x2.f32 pack0, %2, %1;\n" + "cvt.rn.satfinite.e3m2x2.f32 pack1, %4, %3;\n" + "cvt.rn.satfinite.e3m2x2.f32 pack2, %6, %5;\n" + "cvt.rn.satfinite.e3m2x2.f32 pack3, %8, %7;\n" + "mov.b64 %0, {pack0, pack1, pack2, pack3};\n" + "}" + : "=l"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + + uint8x6_t result; + + // pack 8 uint8_t into 6 uint8_t + // here is how to pack: + // 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d... + // 3个unint8 pack0 pack1 pack2 + // packed0: [b1 b0][a5 a4 a3 a2 a1 a0] + // packed1: [c3 c2 c1 c0][b5 b4 b3 b2] + // packed2: [d5 d4 d3 d2 d1 d0][c5 c4] + + // lower 4 uint8_t + uint8_t l_val_0 = val & 0xFF; + uint8_t l_val_1 = (val >> 8) & 0xFF; + uint8_t l_val_2 = (val >> 16) & 0xFF; + uint8_t l_val_3 = (val >> 24) & 0xFF; + // higher 4 uint8_t + uint8_t h_val_0 = (val >> 32) & 0xFF; + uint8_t h_val_1 = (val >> 40) & 0xFF; + uint8_t h_val_2 = (val >> 48) & 0xFF; + uint8_t h_val_3 = (val >> 56) & 0xFF; + + // pack result + result.elts[0] = (l_val_1 << 6) | l_val_0; + result.elts[1] = (l_val_2 << 4) | (l_val_1 >> 2); + result.elts[2] = (l_val_3 << 2) | (l_val_2 >> 4); + result.elts[3] = (h_val_1 << 6) | h_val_0; + result.elts[4] = (h_val_2 << 4) | (h_val_1 >> 2); + result.elts[5] = (h_val_3 << 2) | (h_val_2 >> 4); + + return result; +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* get_sf_out_address(int rowIdx, int colIdx, int numCols, SFType* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP6_NUM_THREADS_PER_SF == 4); + + // one of 4 threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP6_NUM_THREADS_PER_SF == 0) { + // SF vector index (32 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP6_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 32. + int factor = CVT_FP6_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); // same as (mIdx % 128) % 32 + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } else { + // Other threads do not write to SFout. + return nullptr; + } +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +// template <> +// struct PackedVec<__nv_fp8_e4m3> { +// __nv_fp8x2_e4m3 elts[8]; +// }; + +template // Type can be half or bfloat16 +__device__ uint8x6_t cvt_warp_fp16_to_fp6(PackedVec& vec, uint8_t* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP6_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 32 values (four threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e3m2). + // maximum value of e3m2 = 28.0. + // TODO: use half as compute data type. + float SFValue = (vecMax / 28.0f); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmp); + fp8SFVal = tmp.__x; + + + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP6_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP6_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e3m2 values. + uint8x6_t e3m2Vec = fp32_vec_to_e3m2(fp2Vals); + + return e3m2Vec; +} + + +template // Type can be half or bfloat16 +__global__ void +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(256, 6) cvt_fp16_to_fp6( +// #else +// cvt_fp16_to_fp6( +// #endif + int32_t numRows, int32_t numCols, Type const* in, uint8x6_t* out, uint32_t* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP6_NUM_THREADS_PER_SF = (CVT_FP6_SF_VEC_SIZE / CVT_FP6_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP6_ELTS_PER_THREAD, "Vec size is not matched."); + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP6_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP6_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements(E3M2) are packed into one uint8x6_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + get_sf_out_address(rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_fp16_to_fp6(in_vec, sf_out); + } + } +// #endif +} + +template +void invokeFP6Quantization( + int m, + int n, + T const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream) { + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 256)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 1536 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // Launch the cvt kernel. + cvt_fp16_to_fp6 + <<>>( + m, n, input, reinterpret_cast(output), reinterpret_cast(SFOuput)); +} + +// Instantiate the function. +template void invokeFP6Quantization( + int m, + int n, + half const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeFP6Quantization( + int m, + int n, + __nv_bfloat16 const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); + +inline int getMultiProcessorCount() { + static int multi_processor_count = []() { + int device_id = 0; + int count = 0; + + // Get the current CUDA device ID + CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id)); + + // Get the number of multiprocessors for the current device + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id)); + + return count; // Initialize the static variable + }(); + + return multi_processor_count; // Return the cached value on subsequent calls +} + +void scaled_mxfp6_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 32 == 0, "The N dimension must be multiple of 32."); + + int multiProcessorCount = getMultiProcessorCount(); + + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeFP6Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeFP6Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + default: { + std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid"; + throw std::runtime_error("Unsupported input data type for quantize_to_fp6."); + } + } +} diff --git a/lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu new file mode 100644 index 0000000000000000000000000000000000000000..2c8cd810691adec9187b6c39c66c5dd524a6fb20 --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu @@ -0,0 +1,315 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP8_ELTS_PER_THREAD = 8; +constexpr int CVT_FP8_SF_VEC_SIZE = 32; + + +// Convert 4 float2 values into 8 e4m3 values (represented as one uint64_t). +inline __device__ uint64_t fp32_vec_to_e4m3(float2 (&array)[4]) { + uint64_t val; + asm volatile( + "{\n" + ".reg .b16 pack0;\n" + ".reg .b16 pack1;\n" + ".reg .b16 pack2;\n" + ".reg .b16 pack3;\n" + "cvt.rn.satfinite.e4m3x2.f32 pack0, %2, %1;\n" + "cvt.rn.satfinite.e4m3x2.f32 pack1, %4, %3;\n" + "cvt.rn.satfinite.e4m3x2.f32 pack2, %6, %5;\n" + "cvt.rn.satfinite.e4m3x2.f32 pack3, %8, %7;\n" + "mov.b64 %0, {pack0, pack1, pack2, pack3};\n" + "}" + : "=l"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + return val; +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* get_sf_out_address(int rowIdx, int colIdx, int numCols, SFType* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP8_NUM_THREADS_PER_SF == 4); + + // one of 4 threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP8_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP8_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 32. + int factor = CVT_FP8_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); // same as (mIdx % 128) % 32 + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } else { + // Other threads do not write to SFout. + return nullptr; + } +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint64_t output +template // Type can be half or bfloat16 +__device__ uint64_t cvt_warp_fp16_to_fp8(PackedVec& vec, uint8_t* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 32 values (four threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e4m3). + // maximum value of e4m3 = 448.0. + // TODO: use half as compute data type. + float SFValue = (vecMax / 448.0f); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmp); + fp8SFVal = tmp.__x; + + + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP8_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e4m3 values. + uint64_t e4m3Vec = fp32_vec_to_e4m3(fp2Vals); + + return e4m3Vec; +} + + +template // Type can be half or bfloat16 +__global__ void +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(256, 6) cvt_fp16_to_fp8( +// #else +// cvt_fp16_to_fp8( +// #endif + int32_t numRows, int32_t numCols, Type const* in, uint64_t* out, uint32_t* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP8_NUM_THREADS_PER_SF = (CVT_FP8_SF_VEC_SIZE / CVT_FP8_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP8_ELTS_PER_THREAD, "Vec size is not matched."); + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP8_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP8_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements(E4M3) are packed into one uint64_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + get_sf_out_address(rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_fp16_to_fp8(in_vec, sf_out); + } + } +// #endif +} + +template +void invokeFP8Quantization( + int m, + int n, + T const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream) { + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 256)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 1536 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // Launch the cvt kernel. + cvt_fp16_to_fp8 + <<>>( + m, n, input, reinterpret_cast(output), reinterpret_cast(SFOuput)); +} + +// Instantiate the function. +template void invokeFP8Quantization( + int m, + int n, + half const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeFP8Quantization( + int m, + int n, + __nv_bfloat16 const* input, + int64_t* output, + int32_t* SFOuput, + int multiProcessorCount, + cudaStream_t stream); + +inline int getMultiProcessorCount() { + static int multi_processor_count = []() { + int device_id = 0; + int count = 0; + + // Get the current CUDA device ID + CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id)); + + // Get the number of multiprocessors for the current device + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id)); + + return count; // Initialize the static variable + }(); + + return multi_processor_count; // Return the cached value on subsequent calls +} + +void scaled_mxfp8_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 32 == 0, "The N dimension must be multiple of 32."); + + int multiProcessorCount = getMultiProcessorCount(); + + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeFP8Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeFP8Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream); + break; + } + default: { + std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid"; + throw std::runtime_error("Unsupported input data type for quantize_to_fp8."); + } + } +} diff --git a/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu new file mode 100644 index 0000000000000000000000000000000000000000..f3a1558b680429d3a64220ddc38c79c447977b82 --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu @@ -0,0 +1,325 @@ +#include +#include +#include + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + + +using namespace cute; + + +struct Mxfp8GemmSm120 { + ///////////////////////////////////////////////////////////////////////////////////////////////// + /// GEMM kernel configurations + ///////////////////////////////////////////////////////////////////////////////////////////////// + + // A matrix configuration + using ElementA = cutlass::mx_float8_t; // Element type for A matrix operand + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = cutlass::mx_float8_t; // Element type for B matrix operand + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand + using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + // Kernel functional config + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + + // Kernel Perf config + using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size + using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + + // use per-column bias, i.e. every column has different bias + using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + EVTOp + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +}; + + +// Populates a Gemm::Arguments structure from the given commandline options +typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + int64_t M, + int64_t N, + int64_t K) { + using Sm1xxBlkScaledConfig = typename Mxfp8GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(Mxfp8GemmSm120::StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(Mxfp8GemmSm120::StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(Mxfp8GemmSm120::StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + if (bias){ + using StrideBias = Stride; + + typename Mxfp8GemmSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + static const float beta_zero = 0.0f; + fusion_args.beta_ptr = &beta_zero; + fusion_args.bias_ptr = static_cast(bias->data_ptr()); + fusion_args.dBias = StrideBias{}; + return arguments; + } else { + typename Mxfp8GemmSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + static const float beta_zero = 0.0f; + fusion_args.beta_ptr = &beta_zero; + return arguments; + } +} + + +void runGemmMxfp8Sm120( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename Mxfp8GemmSm120::Gemm gemm; + + auto arguments = args_from_options_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); + size_t workspace_size = Mxfp8GemmSm120::Gemm::get_workspace_size(arguments); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} + + +constexpr auto FP6_FP8_TYPE = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e8m0fnu; + +void cutlass_scaled_mxfp8_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias) { + + CHECK_INPUT(A, FP6_FP8_TYPE, "a"); + CHECK_INPUT(B, FP6_FP8_TYPE, "b"); + + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + + + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + + TORCH_CHECK( + A.sizes()[1] == B.sizes()[1], + "a and b shapes cannot be multiplied (", + A.sizes()[0], + "x", + A.sizes()[1], + " and ", + B.sizes()[0], + "x", + B.sizes()[1], + ")"); + + auto const m = A.sizes()[0]; + auto const n = B.sizes()[0]; + auto const k = A.sizes()[1]; + + constexpr int alignment_a = 16; + constexpr int alignment_b = 128; + TORCH_CHECK( + k % alignment_a == 0, + "Expected k to be divisible by ", + alignment_a, + ", but got a shape: (", + A.sizes()[0], + "x", + A.sizes()[1], + "), k: ", + k, + "."); + TORCH_CHECK( + n % alignment_b == 0, + "Expected n to be divisible by ", + alignment_b, + ", but got b shape: (", + B.sizes()[0], + "x", + B.sizes()[1], + ")."); + + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; + int rounded_m = round_up(m, 128); + int rounded_n = round_up(n, 128); + // Since k is divisible by 32 (alignment), k / 32 is guaranteed to be an + // integer. + int rounded_k = round_up(k / 32, 4); + + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + TORCH_CHECK( + A_sf.sizes()[1] == B_sf.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + " and ", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + TORCH_CHECK( + A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", + rounded_m, + "x", + rounded_k, + "), but got a shape (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + ")"); + TORCH_CHECK( + B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", + rounded_n, + "x", + rounded_k, + "), but got a shape (", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + + auto out_dtype = D.dtype(); + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + + runGemmMxfp8Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream); +} diff --git a/lightx2v_kernel/csrc/gemm/nvfp4_quant_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/nvfp4_quant_kernels_sm120.cu new file mode 100644 index 0000000000000000000000000000000000000000..ad50950baa407ccf5efd3bc0b5fff010f49368a5 --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/nvfp4_quant_kernels_sm120.cu @@ -0,0 +1,387 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + // PTX instructions used here requires sm100a. +// #if CUDA_VERSION >= 12080 +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), + "f"(array[1]), + "f"(array[2]), + "f"(array[3]), + "f"(array[4]), + "f"(array[5]), + "f"(array[6]), + "f"(array[7])); + return val; +// #else +// return 0; +// #endif +// #endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + // PTX instructions used here requires sm100a. +// #if CUDA_VERSION >= 12080 +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + return val; +// #else +// return 0; +// #endif +// #endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +// #endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * 0.16666666666666666f); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmp); + fp8SFVal = tmp.__x; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + fp8SFVal = tmp.__x; + SFValue = static_cast(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) +// float outputScale = +// SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; + + float outputScale = + SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValue) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +// #else +// return 0; +// #endif +} + +// Use UE4M3 by default. +template +__global__ void +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(256, 6) cvt_fp16_to_fp4( +// #else +// cvt_fp16_to_fp4( +// #endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) { +// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset(rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +// #endif +} + +template +void invokeFP4Quantization( + int m, + int n, + T const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream) { + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 256)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 1536 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // Launch the cvt kernel. + if (useUE8M0) { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput)); + } else { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput)); + } +} + +// Instantiate the function. +template void invokeFP4Quantization( + int m, + int n, + half const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeFP4Quantization( + int m, + int n, + __nv_bfloat16 const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +inline int getMultiProcessorCount() { + static int multi_processor_count = []() { + int device_id = 0; + int count = 0; + + // Get the current CUDA device ID + CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id)); + + // Get the number of multiprocessors for the current device + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id)); + + return count; // Initialize the static variable + }(); + + return multi_processor_count; // Return the cached value on subsequent calls +} + +void scaled_nvfp4_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + + int multiProcessorCount = getMultiProcessorCount(); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + // We don't support e8m0 scales at this moment. + bool useUE8M0 = false; + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream); + break; + } + default: { + std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid"; + throw std::runtime_error("Unsupported input data type for quantize_to_fp4."); + } + } +} diff --git a/lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu new file mode 100644 index 0000000000000000000000000000000000000000..8dd2838a2e5c84855ae6226f1d96048c9a6d1772 --- /dev/null +++ b/lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu @@ -0,0 +1,323 @@ +#include +#include +#include + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + + +using namespace cute; + + +struct Fp4GemmSm120 { + ///////////////////////////////////////////////////////////////////////////////////////////////// + /// GEMM kernel configurations + ///////////////////////////////////////////////////////////////////////////////////////////////// + + // A matrix configuration + using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand + using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + // Kernel functional config + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + + // Kernel Perf config + using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size + using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + + // use per-column bias, i.e. every column has different bias + using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + EVTOp + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +}; + + +// Populates a Gemm::Arguments structure from the given commandline options +typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + int64_t M, + int64_t N, + int64_t K) { + using Sm1xxBlkScaledConfig = typename Fp4GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(Fp4GemmSm120::StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(Fp4GemmSm120::StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(Fp4GemmSm120::StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + if (bias){ + using StrideBias = Stride; + + typename Fp4GemmSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + static const float beta_zero = 0.0f; + fusion_args.beta_ptr = &beta_zero; + fusion_args.bias_ptr = static_cast(bias->data_ptr()); + fusion_args.dBias = StrideBias{}; + return arguments; + } else { + typename Fp4GemmSm120::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + static const float beta_zero = 0.0f; + fusion_args.beta_ptr = &beta_zero; + return arguments; + } +} + + +void runGemmNvfp4Sm120( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + c10::optional const& bias, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename Fp4GemmSm120::Gemm gemm; + + auto arguments = args_from_options_nvfp4_nvfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); + size_t workspace_size = Fp4GemmSm120::Gemm::get_workspace_size(arguments); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} + + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +void cutlass_scaled_nvfp4_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias) { + + CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); + + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + + + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + TORCH_CHECK( + A.sizes()[1] == B.sizes()[1], + "a and b shapes cannot be multiplied (", + A.sizes()[0], + "x", + A.sizes()[1], + " and ", + B.sizes()[0], + "x", + B.sizes()[1], + ")"); + + auto const m = A.sizes()[0]; + auto const n = B.sizes()[0]; + auto const k = A.sizes()[1] * 2; + + constexpr int alignment = 32; + TORCH_CHECK( + k % alignment == 0, + "Expected k to be divisible by ", + alignment, + ", but got a shape: (", + A.sizes()[0], + "x", + A.sizes()[1], + "), k: ", + k, + "."); + TORCH_CHECK( + n % alignment == 0, + "Expected n to be divisible by ", + alignment, + ", but got b shape: (", + B.sizes()[0], + "x", + B.sizes()[1], + ")."); + + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; + int rounded_m = round_up(m, 128); + int rounded_n = round_up(n, 128); + // Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an + // integer. + int rounded_k = round_up(k / 16, 4); + + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + TORCH_CHECK( + A_sf.sizes()[1] == B_sf.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + " and ", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + TORCH_CHECK( + A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", + rounded_m, + "x", + rounded_k, + "), but got a shape (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + ")"); + TORCH_CHECK( + B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", + rounded_n, + "x", + rounded_k, + "), but got a shape (", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + + auto out_dtype = D.dtype(); + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + + runGemmNvfp4Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream); +} diff --git a/lightx2v_kernel/docs/en_US/mx_formats_quantization_basics.md b/lightx2v_kernel/docs/en_US/mx_formats_quantization_basics.md new file mode 100644 index 0000000000000000000000000000000000000000..8241eacde4f845dd2b39743780705b25a9083375 --- /dev/null +++ b/lightx2v_kernel/docs/en_US/mx_formats_quantization_basics.md @@ -0,0 +1,35 @@ +# MX-Formats Quantization Basics + +**Note: The following focuses on sharing the differences between MX-Formats quantization and Per-Row/Per-Column quantization, as well as the layout requirements for compatibility with Cutlass Block Scaled GEMMs.** + +### Data Formats and Quantization Factors +Target data format reference: [MX-Formats](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). Note that we do not need to pack raw data and scale factors together here. + +Source data format: fp16/bf16 + +Target data format: mxfp4/6/8 + +Quantization factor data format: E8M0, Per-Row/Per-Column quantization typically stores quantization factors in fp32, whereas E8M0 has the same numerical range as fp32. After rounding, the quantization factors can be stored directly, though the loss of mantissa bits may affect precision. + +Quantization granularity: \[1X32\] + +Quantization dimension: Following Cutlass GEMM conventions, where M, N, K represent the three dimensions of matrix multiplication, we should quantize along K dimension. + +### Rounding and Clamp +Unlike software emulation, CUDA can efficiently handle complex rounding and clamping operations using PTX or built-in functions. +For example, `cvt.rn.satfinite.e2m1x2.f32` can convert two fp32 inputs into two fp4 outputs. +Rounding mode: `rn` (round-to-nearest-even) +Clamp mode: `satfinite` (clamped to the maximum finite value within the target range, excluding infinities and NaN) +For more data types and modes, refer to: [PTX cvt Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt) + +### Data Layout and Quantization Factor Layout +**Data Layout** +- mxfp4 requires packing two values into a uint8. +- mxfp6 requires packing every four values into three uint8s. For the format, refer to: [mxfp6 cutlass mm format packing](https://github.com/ModelTC/LightX2V/blob/main/lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu#L74). + +**Quantization Factor Layout** +Cutlass Block Scaled GEMMs impose special swizzle requirements on quantization factor layouts to optimize matrix operations. +Reference: [Scale Factor Layouts](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts) + +### Quantization Method +After understanding the above, the calculation of the target data and quantization factor values can refer to [nvfp4 Quantization Basics](https://github.com/theNiemand/lightx2v/blob/main/lightx2v_kernel/docs/zh_CN/nvfp4%E9%87%8F%E5%8C%96%E5%9F%BA%E7%A1%80.md). Note that MX-Formats do not require quantizing the scale itself. diff --git a/lightx2v_kernel/docs/en_US/nvfp4_quantization_basics.md b/lightx2v_kernel/docs/en_US/nvfp4_quantization_basics.md new file mode 100644 index 0000000000000000000000000000000000000000..80a4ba0ffd4d3a18cdd0953725fc508b3efc4766 --- /dev/null +++ b/lightx2v_kernel/docs/en_US/nvfp4_quantization_basics.md @@ -0,0 +1,80 @@ +# nvfp4 Quantization Basics + +### Data Format + +The calculation method for fp is: + +`ans = (-1)^s * 2^(p-b) * (1 + d1/2 + d2/4 + d3/8 + ...)` + +Where `b = 2^(e-1) - 1`, p represents the value of the exponent bits, d1, d2, d3 represent the values of the mantissa bits + +For fp4, the format is E2M1, and the above formula is simplified to: + +`b = 2^(e-1) - 1 = 2^(2-1) - 1 = 1` + +`ans = (-1)^s * 2^(p-1) * (1 + d1/2)` + +Example: 0101 + +`s=0, p=(10)=2, d1=1` + +`ans = 2^0 * 2^(2-1) * (1 + 1/2) = 3` + +In normal fp data format, some data represents inf and nan, with a maximum representation of ±3. Specialized for nvfp4, inf and nan are removed, allowing a maximum representation of ±6. + +Specifically, 0000 represents +0, 1000 represents -0, 0001 represents 0.5, and 1001 represents -0.5. + +In summary: + +| E2M1 | 0000 | 0001 | 0010 | 0011 | 0100 | 0101 | 0110 | 0111 | 1000 | 1001 | 1010 | 1011 | 1100 | 1101 | 1110 | 1111 | +|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------| +| ans | +0 | 0.5 | 1.0 | 1.5 | 2.0 | 3.0 | 4.0 | 6.0 | +0 | -0.5 | -1.0 | -1.5 | -2.0 | -3.0 | -4.0 | -6.0 | + + +### Quantization Process + +**Both weight and activation use per-group quantization, with a group size of 16, and quantization scales are stored in fp8(e4m3) format** + +Since the quantization scale needs to be stored in fp8, the scale also needs to be rescaled, so the fp4 quantization process differs somewhat from the common w8a8-int8 process. + +The quantization process is as follows: + +Given a set of numbers, denoted as `X` + +#### Calculate scale + +`scale1 = max(abs(Xg)) / 6.0` + +Where Xg represents a group of numbers, and 6.0 represents the maximum value of nvfp4 + +#### Quantize scale + +`global_scale = 6.0 * 448.0 / max(abs(X))` + +`scale2 = global_scale * scale1` + +That is `scale2 = 6.0 * 448.0 / max(abs(X)) * max(abs(Xg)) / 6.0` + +That is `scale2 = max(abs(Xg)) / max(abs(X)) * 448.0` + +At this point, scale2 is rescaled to the range of fp8(e4m3), then scale2 is quantized to fp8 + +`scale2_fp8 = quant_fp8(scale2)` + +`scale2_fp8` serves as the final quantization scale parameter required for matrix multiplication + +#### Quantize X + +`scale2_fp32 = cvt2fp32(scale2_fp8)` + +`Xquant = quant_fp4(X * global_scale / scale2_fp32)` + +Then `Xquant ≈ quant_fp4(X / scale1)` + +#### fp4 Matrix Multiplication + +`ans = Aquant * Bquant * Ascale2 * Bscale2 / Aglobal_scale / Bglobal_scale` + +That is `ans ≈ Aquant * Bquant * Aglobal_scale * Ascale1 * Bglobal_scale * Bscale1 / Aglobal_scale / Bglobal_scale` + +That is `ans ≈ Aquant * Bquant * Ascale1 * Bscale1` diff --git "a/lightx2v_kernel/docs/zh_CN/mx_formats\351\207\217\345\214\226\345\237\272\347\241\200.md" "b/lightx2v_kernel/docs/zh_CN/mx_formats\351\207\217\345\214\226\345\237\272\347\241\200.md" new file mode 100644 index 0000000000000000000000000000000000000000..cdacde35342a9259a64fc1f87a888475b6ec4a8e --- /dev/null +++ "b/lightx2v_kernel/docs/zh_CN/mx_formats\351\207\217\345\214\226\345\237\272\347\241\200.md" @@ -0,0 +1,35 @@ +# MX-Formats量化基础 + +**注:下文关注于分享MX-Formats量化相对于Per-Row/Per-Column量化的区别,以及与Cutlass Block Scaled GEMMs配合使用需要满足的一些布局要求。** + +### 数据格式与量化因子 +目标数据格式参考:[MX-Formats](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf),需要注意的是,我们这里不需要将raw data和scale factor打包在一起 + +源数据格式:fp16/bf16 + +目标数据格式:mxfp4/6/8 + +量化因子数据格式:E8M0, Per-Row/Per-Column量化的量化因子一般以fp32进行存储,而E8M0与fp32数值范围一致,经过rounding后可直接存储量化因子,缺点是尾数的丢失会影响精度。 + +量化粒度:\[1X32\] + +量化维度:以Cutlass GEMM的规范,M N K表示矩阵乘的三个维度,需要沿着K维度量化 + +### Rounding与Clamp +不同于软件模拟,CUDA可以通过PTX或者内置函数高性能地便捷地来完成繁琐的Rouding和Clamp操作。 +例如,`cvt.rn.satfinite.e2m1x2.f32` 可以将两个fp32类型的输入,转换为​两个fp4类型的输出 +Rounding模式为:`rn`,​round-to-nearest-even​ +Clamp模式为:`satfinite`,钳制到目标范围内的最大有限值,​排除无穷和 NaN +更多数据类型和模式参考:[PTX cvt指令](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt) + +### 数据布局与量化因子布局 +数据布局 +- mxfp4需要两两打包为uint8 +- mxfp6需要每4个打包为3个uint8,格式参考:[mxfp6 cutlass mm 格式打包](https://github.com/ModelTC/LightX2V/blob/main/lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu#L74) + +量化因子布局 +Cutlass Block Scaled GEMMs为了满足矩阵运算加速,对量化因子布局有特殊的swizzle要求 +参考:[Scale Factor Layouts](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts) + +### 量化方法 +了解完上述后,目标数据和量化因子两者自身数值的求解,可参考[nvfp4量化基础](https://github.com/theNiemand/lightx2v/blob/main/lightx2v_kernel/docs/zh_CN/nvfp4%E9%87%8F%E5%8C%96%E5%9F%BA%E7%A1%80.md),注意MX-Formats无需量化scale本身 diff --git "a/lightx2v_kernel/docs/zh_CN/nvfp4\351\207\217\345\214\226\345\237\272\347\241\200.md" "b/lightx2v_kernel/docs/zh_CN/nvfp4\351\207\217\345\214\226\345\237\272\347\241\200.md" new file mode 100644 index 0000000000000000000000000000000000000000..be9b17559d77e6c6835f54042602ac57f7f12262 --- /dev/null +++ "b/lightx2v_kernel/docs/zh_CN/nvfp4\351\207\217\345\214\226\345\237\272\347\241\200.md" @@ -0,0 +1,80 @@ +# nvfp4量化基础 + +### 数据格式 + +fp的计算方式是: + +`ans = (-1)^s * 2^(p-b) * (1 + d1/2 + d2/4 + d3/8 + ...)` + +其中,`b = 2^(e-1) - 1`,p表示指数位的值,d1, d2, d3表示尾数位的值 + +对于fp4,格式是E2M1,上述的式子简化为: + +`b = 2^(e-1) - 1 = 2^(2-1) - 1 = 1` + +`ans = (-1)^s * 2^(p-1) * (1 + d1/2)` + +举例:0101 + +`s=0, p=(10)=2, d1=1` + +`ans = 2^0 * 2^(2-1) * (1 + 1/2) = 3` + +正常的fp数据格式,还会有部分数据表示inf和nan,最大只能表示±3,特化到nvfp4,取消了inf和nan,最大可以表示±6 + +特殊的,其中0000表示+0,1000表示-0,0001表示0.5,1001表示-0.5 + +综上: + +| E2M1 | 0000 | 0001 | 0010 | 0011 | 0100 | 0101 | 0110 | 0111 | 1000 | 1001 | 1010 | 1011 | 1100 | 1101 | 1110 | 1111 | +|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------| +| ans | +0 | 0.5 | 1.0 | 1.5 | 2.0 | 3.0 | 4.0 | 6.0 | +0 | -0.5 | -1.0 | -1.5 | -2.0 | -3.0 | -4.0 | -6.0 | + + +### 量化过程 + +**weight和act都是per group量化,group size都是16,量化scale以fp8(e4m3)格式存储** + +由于量化scale要用fp8存储,需要对scale也进行放缩,所以fp4量化的过程和常见的w8a8-int8过程,有一些不同 + +量化过程如下: + +给定一组数,记作`X` + +#### 计算scale + +`scale1 = max(abs(Xg)) / 6.0` + +其中Xg表示一个group的数,6.0表示nvfp4的最大值 + +#### 量化scale + +`global_scale = 6.0 * 448.0 / max(abs(X))` + +`scale2 = global_scale * scale1` + +即 `scale2 = 6.0 * 448.0 / max(abs(X)) * max(abs(Xg)) / 6.0` + +即 `scale2 = max(abs(Xg)) / max(abs(X)) * 448.0` + +此时scale2被放缩到fp8(e4m3)的范围,然后对scale2进行量化到fp8 + +`scale2_fp8 = quant_fp8(scale2)` + +`scale2_fp8`则作为最终的矩阵乘法所需的量化scale参数 + +#### 量化X + +`scale2_fp32 = cvt2fp32(scale2_fp8)` + +`Xquant = quant_fp4(X * global_scale / scale2_fp32)` + +则 `Xquant ≈ quant_fp4(X / scale1)` + +#### fp4矩阵乘法 + +`ans = Aquant * Bquant * Ascale2 * Bscale2 / Aglobal_scale / Bglobal_scale` + +即 `ans ≈ Aquant * Bquant * Aglobal_scale * Ascale1 * Bglobal_scale * Bscale1 / Aglobal_scale / Bglobal_scale` + +即 `ans ≈ Aquant * Bquant * Ascale1 * Bscale1` diff --git a/lightx2v_kernel/include/lightx2v_kernel_ops.h b/lightx2v_kernel/include/lightx2v_kernel_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..b937971a8678567dfc4fb41e02d8ceeaa97afe5d --- /dev/null +++ b/lightx2v_kernel/include/lightx2v_kernel_ops.h @@ -0,0 +1,92 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + + +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } + + +/* + * From csrc/gemm + */ +void scaled_nvfp4_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf); + +void scaled_mxfp4_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf); + +void scaled_mxfp6_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf); + +void scaled_mxfp8_quant_sm120( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf); + +void cutlass_scaled_nvfp4_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias); + +void cutlass_scaled_mxfp4_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias); + +void cutlass_scaled_mxfp6_mxfp8_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias); + + +void cutlass_scaled_mxfp8_mm_sm120( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + c10::optional const& bias); diff --git a/lightx2v_kernel/include/utils.h b/lightx2v_kernel/include/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..229c6e9c4b8e351d35e8a12b8ef729d82468cb0d --- /dev/null +++ b/lightx2v_kernel/include/utils.h @@ -0,0 +1,348 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include + +#ifndef USE_ROCM +// Adapt from FlashInfer +#ifdef FLASHINFER_ENABLE_F16 +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_F16(c_type, ...) +#endif + +#ifdef FLASHINFER_ENABLE_BF16 +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_BF16(c_type, ...) +#endif + +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) +#endif + +#ifdef FLASHINFER_ENABLE_FP8_E5M2 +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) +#endif + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + [&]() -> bool { \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ + [&]() -> bool { \ + switch (pack_u16(cond1, cond2)) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \ + << int(cond2) << ")"; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \ + case pack_u16(case_expr1, case_expr2): { \ + constexpr auto case_var1 = case_expr1; \ + constexpr auto case_var2 = case_expr2; \ + return __VA_ARGS__(); \ + } + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + +inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ + TORCH_CHECK( \ + num_qo_heads % num_kv_heads == 0, \ + "num_qo_heads(", \ + num_qo_heads, \ + ") must be divisible by num_kv_heads(", \ + num_kv_heads, \ + ")") + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) + +inline bool is_float8_tensor(const at::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; +} +#endif + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + std::stringstream _message; \ + auto s = cudaGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) + +inline int getSMVersion() { + int device{-1}; + CHECK_CUDA_SUCCESS(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28 +#ifndef USE_ROCM +#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask)) +#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width)) +#else +#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask)) +#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width)) +#endif + +#ifndef USE_ROCM +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#endif + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) +#define WARP_SIZE 32 + +#ifndef USE_ROCM +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#include + +using FP8_TYPE = c10::Float8_e4m3fnuz; +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif + +#ifndef USE_ROCM +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +} + +__device__ __forceinline__ float warpReduceMax(float max_value) { + max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16)); + max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8)); + max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4)); + max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2)); + max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1)); + return max_value; +} + +__device__ __forceinline__ float blockReduceMax(float max_value) { + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + max_value = warpReduceMax(max_value); + + if (laneId == 0) warpLevelMaxs[warpId] = max_value; + __syncthreads(); + + max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) max_value = warpReduceMax(max_value); + + return max_value; +} +#endif + +// Pads to a multiple of `alignment` rows. +inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) { + int64_t rows = tensor.size(0); + int64_t cols = tensor.size(1); + int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size + + if (pad_rows == 0) { + return tensor; // Already aligned + } + + torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options()); + torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows + + // Ensure column-major layout + if (is_column_major) { + return tensor_padded.t().contiguous().t(); + } + return tensor_padded; +} diff --git a/lightx2v_kernel/pyproject.toml b/lightx2v_kernel/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..058ee7e6e64d5fca7c675a941a1087d790c32943 --- /dev/null +++ b/lightx2v_kernel/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = [ + "scikit-build-core>=0.10", + "torch>=2.7.0", + "wheel", +] +build-backend = "scikit_build_core.build" + +[project] +name = "lightx2v-kernel" +version = "0.0.1" +description = "Kernel Library for lightx2v" +readme = "README.md" +requires-python = ">=3.9" +license = { file = "LICENSE" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA" +] +dependencies = [] + +[project.urls] +"Homepage" = "" +"Bug Tracker" = "" + +[tool.wheel] +exclude = [ + "dist*", + "tests*", +] + +[tool.scikit-build] +cmake.build-type = "Release" +minimum-version = "build-system.requires" + +wheel.py-api = "cp39" +wheel.license-files = [] +wheel.packages = ["python/lightx2v_kernel"] diff --git a/lightx2v_kernel/python/lightx2v_kernel/__init__.py b/lightx2v_kernel/python/lightx2v_kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f78cf8863811e22ec16d4b3bb214e63063e44db --- /dev/null +++ b/lightx2v_kernel/python/lightx2v_kernel/__init__.py @@ -0,0 +1,15 @@ +import ctypes +import os +import platform +from lightx2v_kernel import common_ops # noqa: F401 +from lightx2v_kernel.version import __version__ + + +SYSTEM_ARCH = platform.machine() + +cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12" +if os.path.exists(cuda_path): + ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL) + + +build_tree_kernel = None diff --git a/lightx2v_kernel/python/lightx2v_kernel/gemm.py b/lightx2v_kernel/python/lightx2v_kernel/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..d79ca213f7a4b642d6bfc560abe457a33269ab41 --- /dev/null +++ b/lightx2v_kernel/python/lightx2v_kernel/gemm.py @@ -0,0 +1,115 @@ +import torch + + +def cutlass_scaled_nvfp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): + m, n = mat_a.shape[0], mat_b.shape[0] + out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) + torch.ops.lightx2v_kernel.cutlass_scaled_nvfp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias) + return out + + +def scaled_nvfp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor): + """ + Quantize input tensor to FP4 and return quantized tensor and scale. + + This function quantizes the last dimension of the given tensor `input`. For + every 16 consecutive elements, a single dynamically computed scaling factor + is shared. This scaling factor is quantized using the `input_global_scale` + and is stored in a swizzled layout (see + https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). + + Args: + input: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every + two values are packed into a uint8 and float8_e4m3 scaling factors + in a sizzled layout. + """ + # assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." + # other_dims = 1 if input.ndim == 1 else -1 + # input = input.reshape(other_dims, input.shape[-1]) + m, n = input.shape + block_size = 16 + device = input.device + + # assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." + # assert input.dtype in ( + # torch.float16, + # torch.bfloat16, + # ), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." + + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Then, the scaling + # factors in float8_e4m3fn are packed into an int32 for every 4 values. + # rounded_m = ((m + 128 - 1) // 128) * 128 + # scale_n = n // block_size + # rounded_n = ((scale_n + 4 - 1) // 4) * 4 + output_scale = torch.zeros((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) + + torch.ops.lightx2v_kernel.scaled_nvfp4_quant_sm120.default(output, input, output_scale, input_global_scale) + output_scale = output_scale.view(torch.float8_e4m3fn) + return output, output_scale + + +def scaled_mxfp4_quant(input: torch.Tensor): + m, n = input.shape + block_size = 32 + device = input.device + + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + output_scale = torch.zeros(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) + + torch.ops.lightx2v_kernel.scaled_mxfp4_quant_sm120.default(output, input, output_scale) + output_scale = output_scale.view(torch.float8_e8m0fnu) + return output, output_scale + + +def scaled_mxfp6_quant(input: torch.Tensor): + m, n = input.shape + block_size = 32 + device = input.device + + output = torch.empty((m, 3 * n // 4), device=device, dtype=torch.uint8) + output_scale = torch.zeros(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) + + torch.ops.lightx2v_kernel.scaled_mxfp6_quant_sm120.default(output, input, output_scale) + output_scale = output_scale.view(torch.float8_e8m0fnu) + return output, output_scale + + +def scaled_mxfp8_quant(input: torch.Tensor): + m, n = input.shape + block_size = 32 + device = input.device + + output = torch.empty((m, n), device=device, dtype=torch.uint8) + output_scale = torch.empty(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) + + torch.ops.lightx2v_kernel.scaled_mxfp8_quant_sm120.default(output, input, output_scale) + output_scale = output_scale.view(torch.float8_e8m0fnu) + return output, output_scale + + +def cutlass_scaled_mxfp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): + m, n = mat_a.shape[0], mat_b.shape[0] + out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) + torch.ops.lightx2v_kernel.cutlass_scaled_mxfp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias) + return out + + +def cutlass_scaled_mxfp6_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): + m, n = mat_a.shape[0], mat_b.shape[0] + out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) + torch.ops.lightx2v_kernel.cutlass_scaled_mxfp6_mxfp8_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias) + return out + + +def cutlass_scaled_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): + m, n = mat_a.shape[0], mat_b.shape[0] + out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) + torch.ops.lightx2v_kernel.cutlass_scaled_mxfp8_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias) + return out diff --git a/lightx2v_kernel/python/lightx2v_kernel/utils.py b/lightx2v_kernel/python/lightx2v_kernel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c01be93c47c2bb4e2bc8487cedf62de5fe3dba4b --- /dev/null +++ b/lightx2v_kernel/python/lightx2v_kernel/utils.py @@ -0,0 +1,157 @@ +import functools +from typing import Dict, Tuple, Callable, List + +import torch + + +def get_cuda_stream() -> int: + return torch.cuda.current_stream().cuda_stream + + +_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} + + +def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: + key = (name, device) + buf = _cache_buf.get(key) + if buf is None: + buf = torch.empty(bytes, dtype=torch.uint8, device=device) + _cache_buf[key] = buf + return buf + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x) + + +@functools.lru_cache(maxsize=1) +def is_hopper_arch() -> bool: + # Hopper arch's compute capability == 9.0 + device = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(device) + return major == 9 + + +def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor: + """ + Compute SNR between y_pred(tensor) and y_real(tensor) + + SNR can be calcualted as following equation: + + SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 + + if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. + + SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) + + + Args: + y_pred (torch.Tensor): _description_ + y_real (torch.Tensor): _description_ + reduction (str, optional): _description_. Defaults to 'mean'. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + torch.Tensor: _description_ + """ + y_pred = torch.flatten(y_pred).float() + y_real = torch.flatten(y_real).float() + + if y_pred.shape != y_real.shape: + raise ValueError(f"Can not compute snr loss for tensors with different shape. ({y_pred.shape} and {y_real.shape})") + + noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) + signal_power = torch.pow(y_real, 2).sum(dim=-1) + snr = (noise_power) / (signal_power + 1e-7) + return snr.item() + + +def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs): + """ + A decorator function to assist in performance testing of CUDA operations. + + This function will: + 1. Automatically determine whether any parameters in the argument list, + or the output of the `func`, are of type `torch.Tensor`. + 2. If so, calculate the memory usage of the input and output tensors + on the GPU (based on their data type and `torch.numel()`). + 3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations. + 4. Record the execution time during these iterations. + 5. Use the information above to compute the compute performance (TFLOPS) and memory throughput. + + Args: + func (function): The function to benchmark. + shape (list of int): The problem shape. + tflops (float): The computational workload (in TFLOPS) per call of `func`. + steps (int): The number of times the function is executed during benchmarking. + *args: Positional arguments to be passed to the `func`. + **kwargs: Keyword arguments to be passed to the `func`. + + Returns: + function result + """ + + # Ensure CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for benchmarking.") + + # Check for torch.Tensor in inputs and outputs + input_tensors = [arg for arg in args if isinstance(arg, torch.Tensor)] + input_tensors += [value for value in kwargs.values() if isinstance(value, torch.Tensor)] + + def calculate_memory(tensor: torch.Tensor): + """Calculate memory usage in bytes for a tensor.""" + return tensor.numel() * tensor.element_size() + + input_memory = sum(calculate_memory(t) for t in input_tensors) + + # Execute the function to inspect outputs + with torch.no_grad(): + output = func(*args, **kwargs) + + output_memory = 0 + if isinstance(output, torch.Tensor): + output_memory = calculate_memory(output) + elif isinstance(output, (list, tuple)): + output_memory = sum(calculate_memory(o) for o in output if isinstance(o, torch.Tensor)) + + total_memory = input_memory + output_memory + + # Warm-up and CUDA graph creation + for _ in range(10): # Warm-up + func(*args, **kwargs) + + torch.cuda.synchronize() # Ensure no pending operations + + # Benchmark the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(steps): + func(*args, **kwargs) + end_event.record() + + torch.cuda.synchronize() # Ensure all operations are finished + elapsed_time_ms = start_event.elapsed_time(end_event) # Time in milliseconds + + # Calculate performance metrics + elapsed_time_s = elapsed_time_ms / 1000 # Convert to seconds + avg_time_per_step = elapsed_time_s / steps + compute_performance = tflops / avg_time_per_step # TFLOPS + memory_throughput = (total_memory * steps / (1024**3)) / elapsed_time_s # GB/s + + # Print performance metrics + print(f"Function: {func.__name__}{shape}") + # print(f"Function: {func.__ne__}{shape}") + print(f"Elapsed Time (total): {elapsed_time_s:.4f} seconds") + print(f"Average Time Per Step: {avg_time_per_step * 1000:.3f} ms") + print(f"Compute Performance: {compute_performance:.2f} TFLOPS") + print(f"Memory Throughput: {memory_throughput:.2f} GB/s") + print("") # print a blank line. diff --git a/lightx2v_kernel/python/lightx2v_kernel/version.py b/lightx2v_kernel/python/lightx2v_kernel/version.py new file mode 100644 index 0000000000000000000000000000000000000000..f102a9cadfa89ce554b3b26d2b90bfba2e05273c --- /dev/null +++ b/lightx2v_kernel/python/lightx2v_kernel/version.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/lightx2v_kernel/test/mxfp4_mxfp4/test_bench.py b/lightx2v_kernel/test/mxfp4_mxfp4/test_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..5495b6fdb3523657d6a7eaee486735ceebfe0455 --- /dev/null +++ b/lightx2v_kernel/test/mxfp4_mxfp4/test_bench.py @@ -0,0 +1,121 @@ +import torch +from lightx2v_kernel.gemm import scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm +import time + + +class MMWeightMxfp4ActMxfp4: + def __init__(self, weight, bias): + self.load_fp4_weight(weight, bias) + self.act_quant_func = self.act_quant_fp4 + self.set_alpha() + + @torch.no_grad() + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias) + return output_tensor + + @torch.no_grad() + def load_fp4_weight(self, weight, bias): + self.weight, self.weight_scale = scaled_mxfp4_quant(weight) + self.bias = bias + + def set_alpha(self): + self.alpha = torch.tensor(1.0, dtype=torch.float32, device=self.weight.device) + + @torch.no_grad() + def act_quant_fp4(self, x): + return scaled_mxfp4_quant(x) + + +def test_speed(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + bias = None + + mm = MMWeightMxfp4ActMxfp4(weight, bias) + + # warmup + output_tensor = mm.apply(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + output_tensor = mm.apply(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + lightx2v_kernel_time = (end_time - start_time) / 100 + print(f"lightx2v-kernel time: {lightx2v_kernel_time}") + + input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() + weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=False).cuda() + linear.weight.data = weight + # linear.bias.data = bias + + # warmup + ref_output_tensor = linear(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + ref_output_tensor = linear(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + ref_time = (end_time - start_time) / 100 + print(f"ref time: {ref_time}") + + print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") + + +def test_accuracy(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + bias = None + + linear = torch.nn.Linear(k, n, bias=False).cuda() + linear.weight.data = weight + # linear.bias.data = bias + + ref_output_tensor = linear(input_tensor) + + mm = MMWeightMxfp4ActMxfp4(weight, bias) + + output_tensor = mm.apply(input_tensor) + + # print(f"ref_output_tensor: {ref_output_tensor}") + # print(f"output_tensor: {output_tensor}") + + # cosine + cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) + print(f"cos : {cos}") + + +if __name__ == "__main__": + test_sizes = [ + (32130, 5120, 5120), + (512, 5120, 5120), + (257, 5120, 5120), + (32130, 5120, 13824), + (32130, 13824, 5120), + (75348, 5120, 5120), + (75348, 13824, 5120), + (32760, 1536, 1536), + (512, 1536, 1536), + (32760, 1536, 8960), + (32760, 8960, 1536), + ] + + for i, (m, k, n) in enumerate(test_sizes): + print("-" * 30) + print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})") + test_accuracy(m, k, n) + test_speed(m, k, n) diff --git a/lightx2v_kernel/test/mxfp4_mxfp4/test_bench3_bias.py b/lightx2v_kernel/test/mxfp4_mxfp4/test_bench3_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a2b16096d73befed5dc6e2ebb91f144a044d89 --- /dev/null +++ b/lightx2v_kernel/test/mxfp4_mxfp4/test_bench3_bias.py @@ -0,0 +1,94 @@ +import torch +import time +from test_bench import MMWeightMxfp4ActMxfp4 + + +def test_speed(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50 + + mm = MMWeightMxfp4ActMxfp4(weight, bias) + + # warmup + output_tensor = mm.apply(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + output_tensor = mm.apply(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + lightx2v_kernel_time = (end_time - start_time) / 100 + print(f"lightx2v-kernel time: {lightx2v_kernel_time}") + + input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() + weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=True).cuda() + linear.weight.data = weight + linear.bias.data = bias + + # warmup + ref_output_tensor = linear(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + ref_output_tensor = linear(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + ref_time = (end_time - start_time) / 100 + print(f"ref time: {ref_time}") + + print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") + + +def test_accuracy(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50 + + linear = torch.nn.Linear(k, n, bias=True).cuda() + linear.weight.data = weight + linear.bias.data = bias + + ref_output_tensor = linear(input_tensor) + + mm = MMWeightMxfp4ActMxfp4(weight, bias) + + output_tensor = mm.apply(input_tensor) + + # print(f"ref_output_tensor: {ref_output_tensor}") + # print(f"output_tensor: {output_tensor}") + + # cosine + cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) + print(f"cos : {cos}") + + +if __name__ == "__main__": + test_sizes = [ + (32130, 5120, 5120), + (512, 5120, 5120), + (257, 5120, 5120), + (32130, 5120, 13824), + (32130, 13824, 5120), + (75348, 5120, 5120), + (75348, 13824, 5120), + (32760, 1536, 1536), + (512, 1536, 1536), + (32760, 1536, 8960), + (32760, 8960, 1536), + ] + + for i, (m, k, n) in enumerate(test_sizes): + print("-" * 30) + print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})") + test_accuracy(m, k, n) + test_speed(m, k, n) diff --git a/lightx2v_kernel/test/mxfp4_mxfp4/test_mxfp4_quant.py b/lightx2v_kernel/test/mxfp4_mxfp4/test_mxfp4_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..34762b15d3ccca5bf52527ca0ba3b3746e132f53 --- /dev/null +++ b/lightx2v_kernel/test/mxfp4_mxfp4/test_mxfp4_quant.py @@ -0,0 +1,52 @@ +import unittest +import torch +from lightx2v_kernel.gemm import cutlass_scaled_mxfp4_mm +from lightx2v_kernel.gemm import scaled_mxfp4_quant +from torch.nn.functional import linear +from lightx2v_kernel.utils import error, benchmark + + +class TestQuantBF162MXFP4(unittest.TestCase): + def setUp(self): + self.tokens = [128, 257, 512, 1024, 13325, 32130, 32760] # , 75348 + self.channels = [128, 1536, 5120, 8960] # , 13824 + self.hiddenDims = [128, 1536, 3072, 5120, 8960, 12800] # , 13824 + + self.device = "cuda" + self.dtype = torch.bfloat16 + + def test_accuracy(self): + """Test the accuracy of quantization from BF16 to MXFP4.""" + for m in self.tokens: + for k in self.hiddenDims: + for n in self.channels: + with self.subTest(shape=[m, k, n]): + activation = torch.randn(m, k, dtype=self.dtype, device=self.device) + activation_quant_pred, activation_scale_pred = scaled_mxfp4_quant(activation) + + weight = torch.randn(n, k, dtype=self.dtype, device=self.device) + weight_quant_pred, weight_scale_pred = scaled_mxfp4_quant(weight) + + bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10 + + alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32) + mm_pred = cutlass_scaled_mxfp4_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias) + + mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16) + + # mxfp4_mxfp4 mm have very low accuracy, so we set the threshold to 3e-2. + self.assertTrue(error(mm_pred, mm_real) < 3e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.") + + def test_performance(self): + """Benchmark the performance of Activation quantization from BF16 to MXFP4.""" + for m in self.tokens: + for k in self.hiddenDims: + with self.subTest(shape=[m, k]): + input = torch.randn(m, k, dtype=self.dtype, device=self.device) + shape = [m, k] + tflops = 2 * (m * k / 1024**4) + benchmark(scaled_mxfp4_quant, shape, tflops, 100, input) + + +if __name__ == "__main__": + unittest.main() diff --git a/lightx2v_kernel/test/mxfp6_mxfp8/test.py b/lightx2v_kernel/test/mxfp6_mxfp8/test.py new file mode 100644 index 0000000000000000000000000000000000000000..62993f66a5a7b9326d889ef0219e4031ac1e17d4 --- /dev/null +++ b/lightx2v_kernel/test/mxfp6_mxfp8/test.py @@ -0,0 +1,29 @@ +import torch +from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm + + +def test_cutlass_scaled_mxfp6_mxfp8_mm_sm120(): + m, k, n = 1024, 2048, 4096 + + input_shape = (m, k) + weight_shape = (n, k) + + input_tensor_quant = (torch.rand((input_shape[0], input_shape[1]), device="cuda") * 10).to(torch.uint8) + weight = (torch.rand((weight_shape[0], weight_shape[1] * 3 // 4), device="cuda") * 10).to(torch.uint8) + + print(f"shape: {input_tensor_quant.shape}, {weight.shape}") + + input_tensor_scale = torch.rand((input_shape[0], input_shape[1] // 32), device="cuda").to(torch.float8_e8m0fnu) + weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 32, device="cuda").to(torch.float8_e8m0fnu) + + print(f"shape: {input_tensor_scale.shape}, {weight_scale.shape}") + + alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32) + bias = None + + out = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) + print(f"out: {out}, shape: {out.shape}") + + +if __name__ == "__main__": + test_cutlass_scaled_mxfp6_mxfp8_mm_sm120() diff --git a/lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py b/lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..77779ccb7d956c23f380660da6a17441981f2317 --- /dev/null +++ b/lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py @@ -0,0 +1,121 @@ +import torch +from lightx2v_kernel.gemm import scaled_mxfp8_quant, scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm +import time + + +class MMWeightMxfp6ActMxfp8: + def __init__(self, weight, bias): + self.load_fp6_weight(weight, bias) + self.act_quant_func = self.act_quant_fp8 + self.set_alpha() + + @torch.no_grad() + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias) + return output_tensor + + @torch.no_grad() + def load_fp6_weight(self, weight, bias): + self.weight, self.weight_scale = scaled_mxfp6_quant(weight) + self.bias = bias + + def set_alpha(self): + self.alpha = torch.tensor(1.0, dtype=torch.float32, device=self.weight.device) + + @torch.no_grad() + def act_quant_fp8(self, x): + return scaled_mxfp8_quant(x) + + +def test_speed(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + bias = None + + mm = MMWeightMxfp6ActMxfp8(weight, bias) + + # warmup + output_tensor = mm.apply(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + output_tensor = mm.apply(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + lightx2v_kernel_time = (end_time - start_time) / 100 + print(f"lightx2v-kernel time: {lightx2v_kernel_time}") + + input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() + weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=False).cuda() + linear.weight.data = weight + # linear.bias.data = bias + + # warmup + ref_output_tensor = linear(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + ref_output_tensor = linear(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + ref_time = (end_time - start_time) / 100 + print(f"ref time: {ref_time}") + + print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") + + +def test_accuracy(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + bias = None + + linear = torch.nn.Linear(k, n, bias=False).cuda() + linear.weight.data = weight + # linear.bias.data = bias + + ref_output_tensor = linear(input_tensor) + + mm = MMWeightMxfp6ActMxfp8(weight, bias) + + output_tensor = mm.apply(input_tensor) + + # print(f"ref_output_tensor: {ref_output_tensor}") + # print(f"output_tensor: {output_tensor}") + + # cosine + cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) + print(f"cos : {cos}") + + +if __name__ == "__main__": + test_sizes = [ + (32130, 5120, 5120), + (512, 5120, 5120), + (257, 5120, 5120), + (32130, 5120, 13824), + (32130, 13824, 5120), + (75348, 5120, 5120), + (75348, 13824, 5120), + (32760, 1536, 1536), + (512, 1536, 1536), + (32760, 1536, 8960), + (32760, 8960, 1536), + ] + + for i, (m, k, n) in enumerate(test_sizes): + print("-" * 30) + print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})") + test_accuracy(m, k, n) + test_speed(m, k, n) diff --git a/lightx2v_kernel/test/mxfp6_mxfp8/test_bench3_bias.py b/lightx2v_kernel/test/mxfp6_mxfp8/test_bench3_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..878c21e307710918d1c86b430da649016b3e25e1 --- /dev/null +++ b/lightx2v_kernel/test/mxfp6_mxfp8/test_bench3_bias.py @@ -0,0 +1,94 @@ +import torch +import time +from test_bench import MMWeightMxfp6ActMxfp8 + + +def test_speed(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + + mm = MMWeightMxfp6ActMxfp8(weight, bias) + + # warmup + output_tensor = mm.apply(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + output_tensor = mm.apply(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + lightx2v_kernel_time = (end_time - start_time) / 100 + print(f"lightx2v-kernel time: {lightx2v_kernel_time}") + + input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() + weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=True).cuda() + linear.weight.data = weight + linear.bias.data = bias + + # warmup + ref_output_tensor = linear(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + ref_output_tensor = linear(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + ref_time = (end_time - start_time) / 100 + print(f"ref time: {ref_time}") + + print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") + + +def test_accuracy(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50 + + linear = torch.nn.Linear(k, n, bias=True).cuda() + linear.weight.data = weight + linear.bias.data = bias + + ref_output_tensor = linear(input_tensor) + + mm = MMWeightMxfp6ActMxfp8(weight, bias) + + output_tensor = mm.apply(input_tensor) + + # print(f"ref_output_tensor: {ref_output_tensor}") + # print(f"output_tensor: {output_tensor}") + + # cosine + cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) + print(f"cos : {cos}") + + +if __name__ == "__main__": + test_sizes = [ + (32130, 5120, 5120), + (512, 5120, 5120), + (257, 5120, 5120), + (32130, 5120, 13824), + (32130, 13824, 5120), + (75348, 5120, 5120), + (75348, 13824, 5120), + (32760, 1536, 1536), + (512, 1536, 1536), + (32760, 1536, 8960), + (32760, 8960, 1536), + ] + + for i, (m, k, n) in enumerate(test_sizes): + print("-" * 30) + print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})") + test_accuracy(m, k, n) + test_speed(m, k, n) diff --git a/lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py b/lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..5b83e54152bce20280f24b57523ac6a40564d7d4 --- /dev/null +++ b/lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py @@ -0,0 +1,181 @@ +import torch +from torchao.prototype.mx_formats.constants import DTYPE_FP6_E3M2 +from torchao.prototype.mx_formats.mx_tensor import to_mx, pack_uint6 + + +def quant2mxfp8(x: torch.Tensor): + block_size = 32 + m, _ = x.shape + scale, output = to_mx(x, torch.float8_e4m3fn, block_size=block_size) + return scale.reshape(m, -1), output + + +def quant2mxfp6(x: torch.Tensor): + block_size = 32 + m, _ = x.shape + scale, output = to_mx(x, DTYPE_FP6_E3M2, block_size=block_size, pack_fp6=False) + return scale.reshape(m, -1), output + + +def scale_pad_and_swizzle(scale: torch.Tensor): + m, s = scale.shape + + # pad the m up to 128, s up to 4 + padded_m = (m + 127) // 128 * 128 + padded_s = (s + 3) // 4 * 4 + padded_scale = torch.empty(padded_m, padded_s, device=scale.device, dtype=scale.dtype) + padded_scale[:m, :s] = scale + + # swizzle the padded scale + swizzled_scale = padded_scale.reshape(padded_m // 128, 128, padded_s // 4, 4).reshape(padded_m // 128, 4, 32, padded_s // 4, 4).permute(0, 3, 2, 1, 4) + + return swizzled_scale.reshape(padded_m, padded_s) + + +############################################################### +# Packing kernel and func +############################################################### + +import triton # noqa: E402 +import triton.language as tl # noqa: E402 + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_IN": 2}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 4}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 8}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 16}, num_warps=1), + ], + key=["n_mx_blocks"], +) +@triton.jit +def triton_pack_uint6_kernel( + input_ptr, + output_ptr, + n_mx_blocks, + MX_BLOCK_SIZE: tl.constexpr, + PACKED_MX_BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_IN: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_IN + + # input_ptr is shape [n_mx_blocks, MX_BLOCK_SIZE] + # Load BLOCK_SIZE rows of input_ptr + offsets_rows = block_start + tl.arange(0, BLOCK_SIZE_IN) + offsets_cols = tl.arange(0, MX_BLOCK_SIZE // 4) + offsets = offsets_rows[:, None] * MX_BLOCK_SIZE + (4 * offsets_cols[None, :]) + mask = (offsets_rows[:, None] < n_mx_blocks) & (offsets_cols[None, :] < MX_BLOCK_SIZE // 4) + + # x is shape [BLOCK_SIZE, MX_BLOCK_SIZE] + x_0 = tl.load(input_ptr + offsets, mask=mask) + x_1 = tl.load(input_ptr + offsets + 1, mask=mask) + x_2 = tl.load(input_ptr + offsets + 2, mask=mask) + x_3 = tl.load(input_ptr + offsets + 3, mask=mask) + + # 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d... + # 3个unint8 pack0 pack1 pack2 + # cutlass需要的: + # packed0: [b1 b0][a5 a4 a3 a2 a1 a0] + # packed1: [c3 c2 c1 c0][b5 b4 b3 b2] + # packed2: [d5 d4 d3 d2 d1 d0][c5 c4] + bits_packed0 = (x_1 << 6) | x_0 + bits_packed1 = (x_2 << 4) | (x_1 >> 2) + bits_packed2 = (x_3 << 2) | (x_2 >> 4) + + # Store values in a uint8 tensor of length `3 * MX_BLOCK_SIZE / 4` + offsets_out_4_a = offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + 3 * offsets_cols[None, :] + offsets_out_4_b = offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + 3 * offsets_cols[None, :] + 1 + offsets_out_2 = offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + 3 * offsets_cols[None, :] + 2 + + # Store into output tensor + tl.store( + output_ptr + offsets_out_4_a, + bits_packed0, + mask=mask, + ) + + tl.store( + output_ptr + offsets_out_4_b, + bits_packed1, + mask=mask, + ) + + tl.store( + output_ptr + offsets_out_2, + bits_packed2, + mask=mask, + ) + + +def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: + # ensure input data is contiguous before passing to kernel + assert uint8_data.is_contiguous() + + # tensor should already be of shape [..., mx_block_size] + mx_block_size = uint8_data.shape[-1] + assert mx_block_size % 4 == 0 + + # effective mx block size since we're packing 2 fp4 into 1 uint8 + packed_mx_block_size = 3 * mx_block_size // 4 + packed_shape = [uint8_data.shape[0], packed_mx_block_size] + n_mx_blocks = uint8_data.numel() // mx_block_size + + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) # noqa: E731 + + # contiguous uint8 container in which we can store the unpacked tensor + packed_uint8_data = torch.empty(packed_shape, dtype=torch.uint8, device=uint8_data.device) + + triton_pack_uint6_kernel[grid]( + uint8_data, + packed_uint8_data, + n_mx_blocks, + MX_BLOCK_SIZE=mx_block_size, + PACKED_MX_BLOCK_SIZE=packed_mx_block_size, + ) + + return packed_uint8_data + + +M = [257, 512, 1024, 13325, 32130, 32760] # , 75348 +N = [1536, 5120, 8960] # , 13824 +K = [128, 256, 512, 1024, 2048, 4096] # , 13824 + + +for m in M: + for n in N: + for k in K: + x = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) + w = torch.randn(n, k, device="cuda", dtype=torch.bfloat16) + # excute quant + x_scale, x_quant = quant2mxfp8(x) + w_scale, w_quant = quant2mxfp6(w) + + # pack fp6 for cutlass + w_quant_packed = pack_uint6(w_quant.reshape(-1, 32)) + + # pad and swizzle scale + padded_and_swizzled_x_scale = scale_pad_and_swizzle(x_scale) + padded_and_swizzled_w_scale = scale_pad_and_swizzle(w_scale) + + # ref mm result + ref_mm = torch.nn.functional.linear(x, w).to(torch.bfloat16) + + # custom scaled mm + from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm + + alpha = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = None + x_quant = x_quant.reshape(m, k).view(torch.uint8) + w_quant_packed = w_quant_packed.reshape(n, 3 * k // 4) + custom_mm = cutlass_scaled_mxfp6_mxfp8_mm(x_quant, w_quant_packed, padded_and_swizzled_x_scale, padded_and_swizzled_w_scale, alpha, bias) + + # cal snr + from lightx2v_kernel.utils import error + + print(f"m: {m}, n: {n}, k: {k}, error: {error(ref_mm, custom_mm)}") + + # cal cos + cos_sim = torch.nn.functional.cosine_similarity(ref_mm.flatten(), custom_mm.flatten(), dim=0) + print(f"m: {m}, n: {n}, k: {k}, cos_sim: {cos_sim}") diff --git a/lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py b/lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py new file mode 100644 index 0000000000000000000000000000000000000000..b62efe52a6a20082c7eea43132507ff26bc78ea1 --- /dev/null +++ b/lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py @@ -0,0 +1,115 @@ +import torch +from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm + + +""" +input_shape = (1024, 2048) +weight_shape = (4096, 2048) + +input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8) +weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8) +input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e8m0fnu) +weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e8m0fnu) +alpha = torch.tensor(1.0, device="cuda").to(torch.float32) +bias = None +""" + + +def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias): + output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias) + return output_tensor + + +def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100): + """ + 测试test_mm函数的TFLOPS性能 + """ + + # 创建输入数据 + input_tensor_quant = (torch.rand((input_shape[0], input_shape[1]), device="cuda") * 10).to(torch.uint8) + weight = (torch.rand((weight_shape[0], 3 * weight_shape[1] // 4), device="cuda") * 10).to(torch.uint8) + + input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 32 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e8m0fnu) + weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 32, device="cuda").to(torch.float8_e8m0fnu) + alpha = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = None + + # 预热GPU + for _ in range(num_warmup): + test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) + + # 同步GPU + torch.cuda.synchronize() + + # 创建GPU事件用于精确计时 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 测量时间 + start_event.record() + for _ in range(num_runs): + result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) + end_event.record() + + # 同步并计算时间 + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + elapsed_time_s = elapsed_time_ms / 1000.0 + + # 计算FLOPS + # 矩阵乘法 A(M x K) @ B(K x N) = C(M x N) + # M = batch_size, K = input_dim, N = output_dim + M = input_shape[0] + K = input_shape[1] + N = weight_shape[0] + + # 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法) + flops_per_run = 2 * M * N * K + total_flops = flops_per_run * num_runs + + # 计算TFLOPS (万亿次浮点运算每秒) + tflops = total_flops / (elapsed_time_s * 1e12) + + print(f"测试结果:") + print(f" 输入形状: {input_shape} (M={M}, K={K})") + print(f" 权重形状: {weight_shape} (N={N}, K={K})") + print(f" 输出形状: ({M}, {N})") + print(f" 运行次数: {num_runs}") + print(f" 总执行时间: {elapsed_time_ms:.2f} ms") + print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms") + print(f" 每次运行FLOPS: {flops_per_run / 1e9:.2f} GFLOPS") + print(f" 总FLOPS: {total_flops / 1e12:.2f} TFLOPS") + print(f" 计算性能: {tflops:.2f} TFLOPS") + + return tflops + + +if __name__ == "__main__": + # 测试不同大小的矩阵乘法 + # (m,k) (n,k) + test_cases = [ + ((32130, 5120), (5120, 5120)), + ((512, 1536), (1536, 1536)), + ((512, 5120), (5120, 5120)), + ((257, 5120), (5120, 5120)), + ((32130, 5120), (13824, 5120)), + ((32130, 13824), (5120, 13824)), + ((75348, 5120), (5120, 5120)), + ((75348, 5120), (13824, 5120)), + ((75348, 13824), (5120, 13824)), + ((32760, 1536), (1536, 1536)), + ((512, 1536), (1536, 1536)), + ((32760, 1536), (8960, 1536)), + ((32760, 8960), (1536, 8960)), + ] + + print("=== test_mm TFLOPS性能测试 ===\n") + + for i, (input_shape, weight_shape) in enumerate(test_cases): + print(f"测试 {i + 1}: 输入形状 {input_shape}, 权重形状 {weight_shape}") + print("-" * 60) + + tflops = test_tflops(input_shape, weight_shape) + print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n") + + print("=== 测试完成 ===") diff --git a/lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py b/lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..8daecedd247c54f7931f27bb183fe525011ea92f --- /dev/null +++ b/lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py @@ -0,0 +1,51 @@ +import unittest +import torch +from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm +from lightx2v_kernel.gemm import scaled_mxfp6_quant, scaled_mxfp8_quant +from torch.nn.functional import linear +from lightx2v_kernel.utils import error, benchmark + + +class TestQuantBF162MXFP6(unittest.TestCase): + def setUp(self): + self.tokens = [128, 257, 512, 1024, 13325, 32130, 32760] # , 75348 + self.channels = [128, 1536, 5120, 8960] # , 13824 + self.hiddenDims = [128, 1536, 3072, 5120, 8960, 12800] # , 13824 + + self.device = "cuda" + self.dtype = torch.bfloat16 + + def test_accuracy(self): + """Test the accuracy of quantization from BF16 to MXFP6.""" + for m in self.tokens: + for k in self.hiddenDims: + for n in self.channels: + with self.subTest(shape=[m, k, n]): + activation = torch.randn(m, k, dtype=self.dtype, device=self.device) + activation_quant_pred, activation_scale_pred = scaled_mxfp8_quant(activation) + + weight = torch.randn(n, k, dtype=self.dtype, device=self.device) + weight_quant_pred, weight_scale_pred = scaled_mxfp6_quant(weight) + + bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10 + + alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32) + mm_pred = cutlass_scaled_mxfp6_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias) + + mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16) + + self.assertTrue(error(mm_pred, mm_real) < 1e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.") + + def test_performance(self): + """Benchmark the performance of Activation quantization from BF16 to MXFP6.""" + for m in self.tokens: + for k in self.hiddenDims: + with self.subTest(shape=[m, k]): + input = torch.randn(m, k, dtype=self.dtype, device=self.device) + shape = [m, k] + tflops = 2 * (m * k / 1024**4) + benchmark(scaled_mxfp6_quant, shape, tflops, 100, input) + + +if __name__ == "__main__": + unittest.main() diff --git a/lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py b/lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77d60e6b23504bde5b22b3a97f9bf0abd5926f1e --- /dev/null +++ b/lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py @@ -0,0 +1,158 @@ +import torch +from lightx2v_kernel.gemm import scaled_mxfp6_quant + + +def quantize_fp6(x): + return scaled_mxfp6_quant(x) + + +def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100): + """ + 测试函数的显存带宽 + """ + # 预热GPU + for _ in range(num_warmup): + func(x) + + # 同步GPU + torch.cuda.synchronize() + + # 创建GPU事件用于精确计时 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 测量时间 + start_event.record() + for _ in range(num_runs): + result = func(x) + end_event.record() + + # 同步并计算时间 + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + elapsed_time_s = elapsed_time_ms / 1000.0 + + # 计算数据量 + input_bytes = x.numel() * x.element_size() # 输入数据字节数 + + # FP6量化后,每个元素占用 3/ 4字节 + output_bytes = x.numel() * (3 / 4) # FP6输出数据字节数 + + scale_bytes = x.numel() / 32 # group_size = 32 + + # 总数据传输量(读取输入 + 写入输出 + scale) + total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs + + # 计算带宽 + bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s + + print(f"测试结果:") + print(f" 输入张量形状: {x.shape}") + print(f" 输入数据类型: {x.dtype}") + print(f" 运行次数: {num_runs}") + print(f" 总执行时间: {elapsed_time_ms:.2f} ms") + print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms") + print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB") + print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB") + print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB") + print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s") + + return bandwidth_gbps + + +if __name__ == "__main__": + # 测试不同大小的张量 + test_sizes = [ + # (1, 1024), + # (1, 2048), + # (1, 4096), + # (1, 8192), + # (1, 16384), + # (1, 32768), + # (2, 1024), + # (2, 2048), + # (2, 4096), + # (2, 8192), + # (2, 16384), + # (2, 32768), + # (4, 1024), + # (4, 2048), + # (4, 4096), + # (4, 8192), + # (4, 16384), + # (4, 32768), + # (128, 1024), + # (128, 2048), + # (128, 4096), + # (128, 8192), + # (128, 16384), + # (128, 32768), + # (512, 1024), + # (512, 2048), + # (512, 4096), + # (512, 8192), + # (512, 16384), + # (512, 32768), + # (1024, 1024), + # (1024, 2048), + # (1024, 4096), + # (1024, 8192), + # (1024, 16384), + # (1024, 32768), + # (2048, 1024), + # (2048, 2048), + # (2048, 4096), + # (2048, 8192), + # (2048, 16384), + # (2048, 32768), + # (4096, 1024), + # (4096, 2048), + # (4096, 4096), + # (4096, 8192), + # (4096, 16384), + # (4096, 32768), + # (8192, 1024), + # (8192, 2048), + # (8192, 4096), + # (8192, 8192), + # (8192, 16384), + # (8192, 32768), + # (16384, 1024), + # (16384, 2048), + # (16384, 4096), + # (16384, 8192), + # (16384, 16384), + # (16384, 32768), + # (32768, 1024), + # (32768, 2048), + # (32768, 4096), + # (32768, 8192), + # (32768, 16384), + # (32768, 32768), + (32130, 5120), + (512, 5120), + (257, 5120), + (32130, 13824), + (75348, 5120), + (75348, 13824), + (32760, 1536), + (512, 3072), + (512, 1536), + (32760, 8960), + ] + + print("=== quantize_fp8 显存带宽测试 ===\n") + + for i, (h, w) in enumerate(test_sizes): + print(f"测试 {i + 1}: 张量大小 ({h}, {w})") + print("-" * 50) + + x = torch.randn(h, w, dtype=torch.bfloat16).cuda() + + try: + bandwidth = test_memory_bandwidth(quantize_fp6, x) + print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n") + except Exception as e: + print(f"✗ 测试失败: {e}\n") + + print("=== 测试完成 ===") diff --git a/lightx2v_kernel/test/mxfp8_mxfp8/test_bench.py b/lightx2v_kernel/test/mxfp8_mxfp8/test_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..085f75c31131e52ed7748f212fa03f38d5e739ac --- /dev/null +++ b/lightx2v_kernel/test/mxfp8_mxfp8/test_bench.py @@ -0,0 +1,121 @@ +import torch +from lightx2v_kernel.gemm import scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm +import time + + +class MMWeightMxfp8: + def __init__(self, weight, bias): + self.load_fp8_weight(weight, bias) + self.act_quant_func = self.act_quant_fp8 + self.set_alpha() + + @torch.no_grad() + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias) + return output_tensor + + @torch.no_grad() + def load_fp8_weight(self, weight, bias): + self.weight, self.weight_scale = scaled_mxfp8_quant(weight) + self.bias = bias + + def set_alpha(self): + self.alpha = torch.tensor(1.0, dtype=torch.float32, device=self.weight.device) + + @torch.no_grad() + def act_quant_fp8(self, x): + return scaled_mxfp8_quant(x) + + +def test_speed(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + bias = None + + mm = MMWeightMxfp8(weight, bias) + + # warmup + output_tensor = mm.apply(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + output_tensor = mm.apply(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + lightx2v_kernel_time = (end_time - start_time) / 100 + print(f"lightx2v-kernel time: {lightx2v_kernel_time}") + + input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() + weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=False).cuda() + linear.weight.data = weight + # linear.bias.data = bias + + # warmup + ref_output_tensor = linear(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + ref_output_tensor = linear(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + ref_time = (end_time - start_time) / 100 + print(f"ref time: {ref_time}") + + print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") + + +def test_accuracy(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + bias = None + + linear = torch.nn.Linear(k, n, bias=False).cuda() + linear.weight.data = weight + # linear.bias.data = bias + + ref_output_tensor = linear(input_tensor) + + mm = MMWeightMxfp8(weight, bias) + + output_tensor = mm.apply(input_tensor) + + # print(f"ref_output_tensor: {ref_output_tensor}") + # print(f"output_tensor: {output_tensor}") + + # cosine + cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) + print(f"cos : {cos}") + + +if __name__ == "__main__": + test_sizes = [ + (32130, 5120, 5120), + (512, 5120, 5120), + (257, 5120, 5120), + (32130, 5120, 13824), + (32130, 13824, 5120), + (75348, 5120, 5120), + (75348, 13824, 5120), + (32760, 1536, 1536), + (512, 1536, 1536), + (32760, 1536, 8960), + (32760, 8960, 1536), + ] + + for i, (m, k, n) in enumerate(test_sizes): + print("-" * 30) + print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})") + test_accuracy(m, k, n) + test_speed(m, k, n) diff --git a/lightx2v_kernel/test/mxfp8_mxfp8/test_bench3_bias.py b/lightx2v_kernel/test/mxfp8_mxfp8/test_bench3_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..19a4f95187b2654a3bbb053f83df6fe9a33a5fca --- /dev/null +++ b/lightx2v_kernel/test/mxfp8_mxfp8/test_bench3_bias.py @@ -0,0 +1,94 @@ +import torch +import time +from test_bench import MMWeightMxfp8 + + +def test_speed(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + + mm = MMWeightMxfp8(weight, bias) + + # warmup + output_tensor = mm.apply(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + output_tensor = mm.apply(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + lightx2v_kernel_time = (end_time - start_time) / 100 + print(f"lightx2v-kernel time: {lightx2v_kernel_time}") + + input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() + weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=True).cuda() + linear.weight.data = weight + linear.bias.data = bias + + # warmup + ref_output_tensor = linear(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + ref_output_tensor = linear(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + ref_time = (end_time - start_time) / 100 + print(f"ref time: {ref_time}") + + print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") + + +def test_accuracy(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=True).cuda() + linear.weight.data = weight + linear.bias.data = bias + + ref_output_tensor = linear(input_tensor) + + mm = MMWeightMxfp8(weight, bias) + + output_tensor = mm.apply(input_tensor) + + # print(f"ref_output_tensor: {ref_output_tensor}") + # print(f"output_tensor: {output_tensor}") + + # cosine + cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) + print(f"cos : {cos}") + + +if __name__ == "__main__": + test_sizes = [ + (32130, 5120, 5120), + (512, 5120, 5120), + (257, 5120, 5120), + (32130, 5120, 13824), + (32130, 13824, 5120), + (75348, 5120, 5120), + (75348, 13824, 5120), + (32760, 1536, 1536), + (512, 1536, 1536), + (32760, 1536, 8960), + (32760, 8960, 1536), + ] + + for i, (m, k, n) in enumerate(test_sizes): + print("-" * 30) + print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})") + test_accuracy(m, k, n) + test_speed(m, k, n) diff --git a/lightx2v_kernel/test/mxfp8_mxfp8/test_mm_tflops.py b/lightx2v_kernel/test/mxfp8_mxfp8/test_mm_tflops.py new file mode 100644 index 0000000000000000000000000000000000000000..627c5a8ae9447585d66583138132e73708293b0d --- /dev/null +++ b/lightx2v_kernel/test/mxfp8_mxfp8/test_mm_tflops.py @@ -0,0 +1,115 @@ +import torch +from lightx2v_kernel.gemm import cutlass_scaled_mxfp8_mm + + +""" +input_shape = (1024, 2048) +weight_shape = (4096, 2048) + +input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8) +weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8) +input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e8m0fnu) +weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e8m0fnu) +alpha = torch.tensor(1.0, device="cuda").to(torch.float32) +bias = None +""" + + +def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias): + output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias) + return output_tensor + + +def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100): + """ + 测试test_mm函数的TFLOPS性能 + """ + + # 创建输入数据 + input_tensor_quant = (torch.rand((input_shape[0], input_shape[1]), device="cuda") * 10).to(torch.uint8) + weight = (torch.rand((weight_shape[0], weight_shape[1]), device="cuda") * 10).to(torch.uint8) + + input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 32 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e8m0fnu) + weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 32, device="cuda").to(torch.float8_e8m0fnu) + alpha = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = None + + # 预热GPU + for _ in range(num_warmup): + test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) + + # 同步GPU + torch.cuda.synchronize() + + # 创建GPU事件用于精确计时 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 测量时间 + start_event.record() + for _ in range(num_runs): + result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) + end_event.record() + + # 同步并计算时间 + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + elapsed_time_s = elapsed_time_ms / 1000.0 + + # 计算FLOPS + # 矩阵乘法 A(M x K) @ B(K x N) = C(M x N) + # M = batch_size, K = input_dim, N = output_dim + M = input_shape[0] + K = input_shape[1] + N = weight_shape[0] + + # 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法) + flops_per_run = 2 * M * N * K + total_flops = flops_per_run * num_runs + + # 计算TFLOPS (万亿次浮点运算每秒) + tflops = total_flops / (elapsed_time_s * 1e12) + + print(f"测试结果:") + print(f" 输入形状: {input_shape} (M={M}, K={K})") + print(f" 权重形状: {weight_shape} (N={N}, K={K})") + print(f" 输出形状: ({M}, {N})") + print(f" 运行次数: {num_runs}") + print(f" 总执行时间: {elapsed_time_ms:.2f} ms") + print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms") + print(f" 每次运行FLOPS: {flops_per_run / 1e9:.2f} GFLOPS") + print(f" 总FLOPS: {total_flops / 1e12:.2f} TFLOPS") + print(f" 计算性能: {tflops:.2f} TFLOPS") + + return tflops + + +if __name__ == "__main__": + # 测试不同大小的矩阵乘法 + # (m,k) (n,k) + test_cases = [ + ((32130, 5120), (5120, 5120)), + ((512, 1536), (1536, 1536)), + ((512, 5120), (5120, 5120)), + ((257, 5120), (5120, 5120)), + ((32130, 5120), (13824, 5120)), + ((32130, 13824), (5120, 13824)), + ((75348, 5120), (5120, 5120)), + ((75348, 5120), (13824, 5120)), + ((75348, 13824), (5120, 13824)), + ((32760, 1536), (1536, 1536)), + ((512, 1536), (1536, 1536)), + ((32760, 1536), (8960, 1536)), + ((32760, 8960), (1536, 8960)), + ] + + print("=== test_mm TFLOPS性能测试 ===\n") + + for i, (input_shape, weight_shape) in enumerate(test_cases): + print(f"测试 {i + 1}: 输入形状 {input_shape}, 权重形状 {weight_shape}") + print("-" * 60) + + tflops = test_tflops(input_shape, weight_shape) + print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n") + + print("=== 测试完成 ===") diff --git a/lightx2v_kernel/test/mxfp8_mxfp8/test_mxfp8_quant.py b/lightx2v_kernel/test/mxfp8_mxfp8/test_mxfp8_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..41af9bda35c060b56e10a94b941d8dda0f0ad28f --- /dev/null +++ b/lightx2v_kernel/test/mxfp8_mxfp8/test_mxfp8_quant.py @@ -0,0 +1,51 @@ +import unittest +import torch +from lightx2v_kernel.gemm import cutlass_scaled_mxfp8_mm +from lightx2v_kernel.gemm import scaled_mxfp8_quant +from torch.nn.functional import linear +from lightx2v_kernel.utils import error, benchmark + + +class TestQuantBF162MXFP8(unittest.TestCase): + def setUp(self): + self.tokens = [257, 512, 1024, 13325, 32130, 32760] # , 75348 + self.channels = [1536, 5120, 8960] # , 13824 + self.hiddenDims = [1536, 3072, 5120, 8960, 12800] # , 13824 + + self.device = "cuda" + self.dtype = torch.bfloat16 + + def test_accuracy(self): + """Test the accuracy of quantization from BF16 to MXFP8.""" + for m in self.tokens: + for k in self.hiddenDims: + for n in self.channels: + with self.subTest(shape=[m, k, n]): + activation = torch.randn(m, k, dtype=self.dtype, device=self.device) + activation_quant_pred, activation_scale_pred = scaled_mxfp8_quant(activation) + + weight = torch.randn(n, k, dtype=self.dtype, device=self.device) + weight_quant_pred, weight_scale_pred = scaled_mxfp8_quant(weight) + + bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10 + + alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32) + mm_pred = cutlass_scaled_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias) + + mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16) + + self.assertTrue(error(mm_pred, mm_real) < 1e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.") + + def test_performance(self): + """Benchmark the performance of Activation quantization from BF16 to MXFP8.""" + for m in self.tokens: + for k in self.hiddenDims: + with self.subTest(shape=[m, k]): + input = torch.randn(m, k, dtype=self.dtype, device=self.device) + shape = [m, k] + tflops = 2 * (m * k / 1024**4) + benchmark(scaled_mxfp8_quant, shape, tflops, 100, input) + + +if __name__ == "__main__": + unittest.main() diff --git a/lightx2v_kernel/test/mxfp8_mxfp8/test_quant_mem_utils.py b/lightx2v_kernel/test/mxfp8_mxfp8/test_quant_mem_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..13400d46c693c1cdfa21ee48a0dc9638c6f1fd23 --- /dev/null +++ b/lightx2v_kernel/test/mxfp8_mxfp8/test_quant_mem_utils.py @@ -0,0 +1,158 @@ +import torch +from lightx2v_kernel.gemm import scaled_mxfp8_quant + + +def quantize_fp8(x): + return scaled_mxfp8_quant(x) + + +def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100): + """ + 测试函数的显存带宽 + """ + # 预热GPU + for _ in range(num_warmup): + func(x) + + # 同步GPU + torch.cuda.synchronize() + + # 创建GPU事件用于精确计时 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 测量时间 + start_event.record() + for _ in range(num_runs): + result = func(x) + end_event.record() + + # 同步并计算时间 + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + elapsed_time_s = elapsed_time_ms / 1000.0 + + # 计算数据量 + input_bytes = x.numel() * x.element_size() # 输入数据字节数 + + # FP8量化后,每个元素占用1字节 + output_bytes = x.numel() * 1 # FP8输出数据字节数 + + scale_bytes = x.numel() / 32 # group_size = 32 + + # 总数据传输量(读取输入 + 写入输出 + scale) + total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs + + # 计算带宽 + bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s + + print(f"测试结果:") + print(f" 输入张量形状: {x.shape}") + print(f" 输入数据类型: {x.dtype}") + print(f" 运行次数: {num_runs}") + print(f" 总执行时间: {elapsed_time_ms:.2f} ms") + print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms") + print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB") + print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB") + print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB") + print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s") + + return bandwidth_gbps + + +if __name__ == "__main__": + # 测试不同大小的张量 + test_sizes = [ + # (1, 1024), + # (1, 2048), + # (1, 4096), + # (1, 8192), + # (1, 16384), + # (1, 32768), + # (2, 1024), + # (2, 2048), + # (2, 4096), + # (2, 8192), + # (2, 16384), + # (2, 32768), + # (4, 1024), + # (4, 2048), + # (4, 4096), + # (4, 8192), + # (4, 16384), + # (4, 32768), + # (128, 1024), + # (128, 2048), + # (128, 4096), + # (128, 8192), + # (128, 16384), + # (128, 32768), + # (512, 1024), + # (512, 2048), + # (512, 4096), + # (512, 8192), + # (512, 16384), + # (512, 32768), + # (1024, 1024), + # (1024, 2048), + # (1024, 4096), + # (1024, 8192), + # (1024, 16384), + # (1024, 32768), + # (2048, 1024), + # (2048, 2048), + # (2048, 4096), + # (2048, 8192), + # (2048, 16384), + # (2048, 32768), + # (4096, 1024), + # (4096, 2048), + # (4096, 4096), + # (4096, 8192), + # (4096, 16384), + # (4096, 32768), + # (8192, 1024), + # (8192, 2048), + # (8192, 4096), + # (8192, 8192), + # (8192, 16384), + # (8192, 32768), + # (16384, 1024), + # (16384, 2048), + # (16384, 4096), + # (16384, 8192), + # (16384, 16384), + # (16384, 32768), + # (32768, 1024), + # (32768, 2048), + # (32768, 4096), + # (32768, 8192), + # (32768, 16384), + # (32768, 32768), + (32130, 5120), + (512, 5120), + (257, 5120), + (32130, 13824), + (75348, 5120), + (75348, 13824), + (32760, 1536), + (512, 3072), + (512, 1536), + (32760, 8960), + ] + + print("=== quantize_fp8 显存带宽测试 ===\n") + + for i, (h, w) in enumerate(test_sizes): + print(f"测试 {i + 1}: 张量大小 ({h}, {w})") + print("-" * 50) + + x = torch.randn(h, w, dtype=torch.bfloat16).cuda() + + try: + bandwidth = test_memory_bandwidth(quantize_fp8, x) + print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n") + except Exception as e: + print(f"✗ 测试失败: {e}\n") + + print("=== 测试完成 ===") diff --git a/lightx2v_kernel/test/nvfp4_nvfp4/fake_quant.py b/lightx2v_kernel/test/nvfp4_nvfp4/fake_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..36f419c254c68e6bdadcfffe445ddb978e231cc3 --- /dev/null +++ b/lightx2v_kernel/test/nvfp4_nvfp4/fake_quant.py @@ -0,0 +1,55 @@ +import torch + + +BLOCK_SIZE = 16 + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def ref_nvfp4_quant(x, global_scale): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE)) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + # output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + output_scale = global_scale * get_reciprocal(scale) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + return cast_to_fp4(clipped_x), scale.squeeze(-1) + + +if __name__ == "__main__": + x = torch.randn(1, 16, dtype=torch.bfloat16).cuda() + print(f"x: {x}, {x.shape}") + global_scale = (6.0 * 448.0 / torch.max(torch.abs(x))).to(torch.float32).cuda() + quant_x, scale = ref_nvfp4_quant(x, global_scale) + print(f"quant_x: {quant_x}, {quant_x.shape}") + print(f"scale: {scale}, {scale.shape}") diff --git a/lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py b/lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py new file mode 100644 index 0000000000000000000000000000000000000000..e5071f5b60abaf74d9f226104eecf83ceaac5d62 --- /dev/null +++ b/lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py @@ -0,0 +1,142 @@ +import torch +from lightx2v_kernel.gemm import scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm + + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +kE2M1ToFloatArray = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + sf_m, sf_k = a_sf_swizzled.shape + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def e2m1_to_fp32(int4_value): + signBit = int4_value & 0x8 + int4_absValue = int4_value & 0x7 + float_result = kE2M1ToFloatArray[int4_absValue] + if signBit: + float_result = -float_result + return float_result + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + a = a.flatten() + # Get upper 4 bits + highHalfByte = (a & 0xF0) >> 4 + # Get lower 4 bits + lowHalfByte = a & 0x0F + fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device) + fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device) + # [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC] + out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2) + return out + + +def dequantize_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out + + +def get_ref_results( + a_fp4, + b_fp4, + a_sf, + b_sf, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, +): + _, m_k = a_fp4.shape + _, n_k = b_fp4.shape + assert m_k == n_k + a_in_dtype = dequantize_to_dtype(a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size) + b_in_dtype = dequantize_to_dtype(b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@torch.inference_mode() +def test_nvfp4_gemm( + dtype: torch.dtype, + shape: tuple[int, int], +) -> None: + m, n, packed_k = shape + k = packed_k * 2 + block_size = 16 + a_dtype = torch.randn((m, k), dtype=dtype, device="cuda") + b_dtype = torch.randn((n, k), dtype=dtype, device="cuda") + bias = torch.randn((1, n), dtype=dtype, device="cuda") + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) + b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) + + print(f"a_global_scale : {a_global_scale}, {a_global_scale.shape}") + print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}") + + alpha = 1.0 / (a_global_scale * b_global_scale) + a_fp4, a_scale_interleaved = scaled_nvfp4_quant(a_dtype, a_global_scale) + b_fp4, b_scale_interleaved = scaled_nvfp4_quant(b_dtype, b_global_scale) + + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + "cuda", + ) + expected_out = expected_out + bias + + print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}") + + out = cutlass_scaled_nvfp4_mm(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias) + + print(f"out : {out}, {out.shape}, {out.dtype}") + print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}") + + torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + test_nvfp4_gemm(torch.bfloat16, (128, 512, 128)) diff --git a/lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py b/lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py new file mode 100644 index 0000000000000000000000000000000000000000..1af8296a43351f837583e77edea0b48cabccc33c --- /dev/null +++ b/lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py @@ -0,0 +1,126 @@ +import torch +from lightx2v_kernel.gemm import scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm +import time + + +class MMWeightFp4: + def __init__(self, weight, bias): + self.load_fp4_weight(weight, bias) + self.act_quant_func = self.act_quant_fp4 + + # calibrate x_max + self.calibrate_x_absmax() + + @torch.no_grad() + def apply(self, input_tensor): + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias) + return output_tensor + + @torch.no_grad() + def load_fp4_weight(self, weight, bias): + self.weight_global_scale = (2688.0 / torch.max(torch.abs(weight))).to(torch.float32) + self.weight, self.weight_scale = scaled_nvfp4_quant(weight, self.weight_global_scale) + self.bias = bias + + def calibrate_x_absmax(self): + self.x_absmax = torch.tensor(5.0, dtype=torch.float32, device=self.weight.device) # need to be calibrated + self.input_global_scale = (2688.0 / self.x_absmax).to(torch.float32) + self.alpha = 1.0 / (self.input_global_scale * self.weight_global_scale) + + @torch.no_grad() + def act_quant_fp4(self, x): + return scaled_nvfp4_quant(x, self.input_global_scale) + + +def test_speed(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + bias = None + + mm = MMWeightFp4(weight, bias) + + # warmup + output_tensor = mm.apply(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + output_tensor = mm.apply(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + lightx2v_kernel_time = (end_time - start_time) / 100 + print(f"lightx2v-kernel time: {lightx2v_kernel_time}") + + input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() + weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=False).cuda() + linear.weight.data = weight + # linear.bias.data = bias + + # warmup + ref_output_tensor = linear(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + ref_output_tensor = linear(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + ref_time = (end_time - start_time) / 100 + print(f"ref time: {ref_time}") + + print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") + + +def test_accuracy(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() + bias = None + + linear = torch.nn.Linear(k, n, bias=False).cuda() + linear.weight.data = weight + # linear.bias.data = bias + + ref_output_tensor = linear(input_tensor) + + mm = MMWeightFp4(weight, bias) + + output_tensor = mm.apply(input_tensor) + + # print(f"ref_output_tensor: {ref_output_tensor}") + # print(f"output_tensor: {output_tensor}") + + # cosine + cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) + print(f"cos : {cos}") + + +if __name__ == "__main__": + test_sizes = [ + (32130, 5120, 5120), + (512, 5120, 5120), + (257, 5120, 5120), + (32130, 5120, 13824), + (32130, 13824, 5120), + (75348, 5120, 5120), + (75348, 13824, 5120), + (32760, 1536, 1536), + (512, 1536, 1536), + (32760, 1536, 8960), + (32760, 8960, 1536), + ] + + for i, (m, k, n) in enumerate(test_sizes): + print("-" * 30) + print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})") + test_accuracy(m, k, n) + test_speed(m, k, n) diff --git a/lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py b/lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..5997b57ffc136080234510a1f92a38fa86bf4d9a --- /dev/null +++ b/lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py @@ -0,0 +1,94 @@ +import torch +import time +from test_bench2 import MMWeightFp4 + + +def test_speed(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50 + + mm = MMWeightFp4(weight, bias) + + # warmup + output_tensor = mm.apply(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + output_tensor = mm.apply(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + lightx2v_kernel_time = (end_time - start_time) / 100 + print(f"lightx2v-kernel time: {lightx2v_kernel_time}") + + input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() + weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() + + linear = torch.nn.Linear(k, n, bias=True).cuda() + linear.weight.data = weight + linear.bias.data = bias + + # warmup + ref_output_tensor = linear(input_tensor) + + torch.cuda.synchronize() + start_time = time.time() + for i in range(100): + ref_output_tensor = linear(input_tensor) + torch.cuda.synchronize() + end_time = time.time() + + ref_time = (end_time - start_time) / 100 + print(f"ref time: {ref_time}") + + print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") + + +def test_accuracy(m, k, n): + with torch.no_grad(): + input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() + weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50 + + linear = torch.nn.Linear(k, n, bias=True).cuda() + linear.weight.data = weight + linear.bias.data = bias + + ref_output_tensor = linear(input_tensor) + + mm = MMWeightFp4(weight, bias) + + output_tensor = mm.apply(input_tensor) + + # print(f"ref_output_tensor: {ref_output_tensor}") + # print(f"output_tensor: {output_tensor}") + + # cosine + cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) + print(f"cos : {cos}") + + +if __name__ == "__main__": + test_sizes = [ + (32130, 5120, 5120), + (512, 5120, 5120), + (257, 5120, 5120), + (32130, 5120, 13824), + (32130, 13824, 5120), + (75348, 5120, 5120), + (75348, 13824, 5120), + (32760, 1536, 1536), + (512, 1536, 1536), + (32760, 1536, 8960), + (32760, 8960, 1536), + ] + + for i, (m, k, n) in enumerate(test_sizes): + print("-" * 30) + print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})") + test_accuracy(m, k, n) + test_speed(m, k, n) diff --git a/lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py b/lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py new file mode 100644 index 0000000000000000000000000000000000000000..3fd9994ccd300cd436abbb095b4f0bf383a5e033 --- /dev/null +++ b/lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py @@ -0,0 +1,114 @@ +import torch +from lightx2v_kernel.gemm import cutlass_scaled_nvfp4_mm + + +""" +input_shape = (1024, 2048) +weight_shape = (4096, 2048) + +input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8) +weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8) +input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e4m3fn) +weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e4m3fn) +alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32) +bias = None +""" + + +def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias): + output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias) + return output_tensor + + +def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100): + """ + 测试test_mm函数的TFLOPS性能 + """ + + # 创建输入数据 + input_tensor_quant = (torch.rand((input_shape[0], input_shape[1] // 2), device="cuda") * 10).to(torch.uint8) + weight = (torch.rand((weight_shape[0], weight_shape[1] // 2), device="cuda") * 10).to(torch.uint8) + + input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 16 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e4m3fn) + weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 16, device="cuda").to(torch.float8_e4m3fn) + alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32) + bias = None + + # 预热GPU + for _ in range(num_warmup): + test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) + + # 同步GPU + torch.cuda.synchronize() + + # 创建GPU事件用于精确计时 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 测量时间 + start_event.record() + for _ in range(num_runs): + result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) + end_event.record() + + # 同步并计算时间 + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + elapsed_time_s = elapsed_time_ms / 1000.0 + + # 计算FLOPS + # 矩阵乘法 A(M x K) @ B(K x N) = C(M x N) + # M = batch_size, K = input_dim, N = output_dim + M = input_shape[0] + K = input_shape[1] + N = weight_shape[0] + + # 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法) + flops_per_run = 2 * M * N * K + total_flops = flops_per_run * num_runs + + # 计算TFLOPS (万亿次浮点运算每秒) + tflops = total_flops / (elapsed_time_s * 1e12) + + print(f"测试结果:") + print(f" 输入形状: {input_shape} (M={M}, K={K})") + print(f" 权重形状: {weight_shape} (N={N}, K={K})") + print(f" 输出形状: ({M}, {N})") + print(f" 运行次数: {num_runs}") + print(f" 总执行时间: {elapsed_time_ms:.2f} ms") + print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms") + print(f" 每次运行FLOPS: {flops_per_run / 1e9:.2f} GFLOPS") + print(f" 总FLOPS: {total_flops / 1e12:.2f} TFLOPS") + print(f" 计算性能: {tflops:.2f} TFLOPS") + + return tflops + + +if __name__ == "__main__": + # 测试不同大小的矩阵乘法 + # (m,k) (n,k) + test_cases = [ + ((32130, 5120), (5120, 5120)), + ((512, 5120), (5120, 5120)), + ((257, 5120), (5120, 5120)), + ((32130, 5120), (13824, 5120)), + ((32130, 13824), (5120, 13824)), + ((75348, 5120), (5120, 5120)), + ((75348, 5120), (13824, 5120)), + ((75348, 13824), (5120, 13824)), + ((32760, 1536), (1536, 1536)), + ((512, 1536), (1536, 1536)), + ((32760, 1536), (8960, 1536)), + ((32760, 8960), (1536, 8960)), + ] + + print("=== test_mm TFLOPS性能测试 ===\n") + + for i, (input_shape, weight_shape) in enumerate(test_cases): + print(f"测试 {i + 1}: 输入形状 {input_shape}, 权重形状 {weight_shape}") + print("-" * 60) + + tflops = test_tflops(input_shape, weight_shape) + print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n") + + print("=== 测试完成 ===") diff --git a/lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py b/lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d56f8e80801ac5d1f27bef12f6a7bdc0c03f3b --- /dev/null +++ b/lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py @@ -0,0 +1,160 @@ +import torch +from lightx2v_kernel.gemm import scaled_nvfp4_quant + + +input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda() + + +def quantize_fp4(x): + return scaled_nvfp4_quant(x, input_global_scale) + + +def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100): + """ + 测试函数的显存带宽 + """ + # 预热GPU + for _ in range(num_warmup): + func(x) + + # 同步GPU + torch.cuda.synchronize() + + # 创建GPU事件用于精确计时 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 测量时间 + start_event.record() + for _ in range(num_runs): + result = func(x) + end_event.record() + + # 同步并计算时间 + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + elapsed_time_s = elapsed_time_ms / 1000.0 + + # 计算数据量 + input_bytes = x.numel() * x.element_size() # 输入数据字节数 + + # FP4量化后,每个元素占用0.5字节 + output_bytes = x.numel() * 0.5 # FP4输出数据字节数 + + scale_bytes = x.numel() / 16 # group_size = 16 + + # 总数据传输量(读取输入 + 写入输出 + scale) + total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs + + # 计算带宽 + bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s + + print(f"测试结果:") + print(f" 输入张量形状: {x.shape}") + print(f" 输入数据类型: {x.dtype}") + print(f" 运行次数: {num_runs}") + print(f" 总执行时间: {elapsed_time_ms:.2f} ms") + print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms") + print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB") + print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB") + print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB") + print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s") + + return bandwidth_gbps + + +if __name__ == "__main__": + # 测试不同大小的张量 + test_sizes = [ + # (1, 1024), + # (1, 2048), + # (1, 4096), + # (1, 8192), + # (1, 16384), + # (1, 32768), + # (2, 1024), + # (2, 2048), + # (2, 4096), + # (2, 8192), + # (2, 16384), + # (2, 32768), + # (4, 1024), + # (4, 2048), + # (4, 4096), + # (4, 8192), + # (4, 16384), + # (4, 32768), + # (128, 1024), + # (128, 2048), + # (128, 4096), + # (128, 8192), + # (128, 16384), + # (128, 32768), + # (512, 1024), + # (512, 2048), + # (512, 4096), + # (512, 8192), + # (512, 16384), + # (512, 32768), + # (1024, 1024), + # (1024, 2048), + # (1024, 4096), + # (1024, 8192), + # (1024, 16384), + # (1024, 32768), + # (2048, 1024), + # (2048, 2048), + # (2048, 4096), + # (2048, 8192), + # (2048, 16384), + # (2048, 32768), + # (4096, 1024), + # (4096, 2048), + # (4096, 4096), + # (4096, 8192), + # (4096, 16384), + # (4096, 32768), + # (8192, 1024), + # (8192, 2048), + # (8192, 4096), + # (8192, 8192), + # (8192, 16384), + # (8192, 32768), + # (16384, 1024), + # (16384, 2048), + # (16384, 4096), + # (16384, 8192), + # (16384, 16384), + # (16384, 32768), + # (32768, 1024), + # (32768, 2048), + # (32768, 4096), + # (32768, 8192), + # (32768, 16384), + # (32768, 32768), + (32130, 5120), + (512, 5120), + (257, 5120), + (32130, 13824), + (75348, 5120), + (75348, 13824), + (32760, 1536), + (512, 1536), + (32760, 8960), + ] + + print("=== quantize_fp4 显存带宽测试 ===\n") + + for i, (h, w) in enumerate(test_sizes): + print(f"测试 {i + 1}: 张量大小 ({h}, {w})") + print("-" * 50) + + x = torch.randn(h, w, dtype=torch.bfloat16).cuda() + + try: + bandwidth = test_memory_bandwidth(quantize_fp4, x) + print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n") + except Exception as e: + print(f"✗ 测试失败: {e}\n") + + print("=== 测试完成 ===") diff --git a/lightx2v_platform/__init__.py b/lightx2v_platform/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5ed21c9e3f65438f966db7f1913bc34b7ffa97 --- /dev/null +++ b/lightx2v_platform/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/lightx2v_platform/base/__init__.py b/lightx2v_platform/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff25e3361172566a80cde6c9980a066a47b8e79 --- /dev/null +++ b/lightx2v_platform/base/__init__.py @@ -0,0 +1,7 @@ +from lightx2v_platform.base.base import check_ai_device, init_ai_device +from lightx2v_platform.base.cambricon_mlu import MluDevice +from lightx2v_platform.base.dcu import DcuDevice +from lightx2v_platform.base.metax import MetaxDevice +from lightx2v_platform.base.nvidia import CudaDevice + +__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice", "DcuDevice"] diff --git a/lightx2v_platform/base/base.py b/lightx2v_platform/base/base.py new file mode 100644 index 0000000000000000000000000000000000000000..84df1517208bd30b9769cd70b19816498f7eba30 --- /dev/null +++ b/lightx2v_platform/base/base.py @@ -0,0 +1,33 @@ +import os + +from loguru import logger + +from lightx2v_platform.base import global_var +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + + +def init_ai_device(platform="cuda"): + platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None) + if platform_device is None: + available_platforms = list(PLATFORM_DEVICE_REGISTER.keys()) + raise RuntimeError(f"Unsupported platform: {platform}. Available platforms: {available_platforms}") + global_var.AI_DEVICE = platform_device.get_device() + logger.info(f"Initialized AI_DEVICE: {global_var.AI_DEVICE}") + return global_var.AI_DEVICE + + +def check_ai_device(platform="cuda"): + platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None) + if platform_device is None: + available_platforms = list(PLATFORM_DEVICE_REGISTER.keys()) + raise RuntimeError(f"Unsupported platform: {platform}. Available platforms: {available_platforms}") + is_available = platform_device.is_available() + if not is_available: + skip_platform_check = os.getenv("SKIP_PLATFORM_CHECK", "False") in ["1", "True"] + error_msg = f"AI device for platform '{platform}' is not available. Please check your runtime environment." + if skip_platform_check: + logger.warning(error_msg) + return True + raise RuntimeError(error_msg) + logger.info(f"AI device for platform '{platform}' is available.") + return True diff --git a/lightx2v_platform/base/cambricon_mlu.py b/lightx2v_platform/base/cambricon_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..90c80879d604309b3a368fd7eb4e40025a0e69d1 --- /dev/null +++ b/lightx2v_platform/base/cambricon_mlu.py @@ -0,0 +1,27 @@ +import torch +import torch.distributed as dist + +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + + +@PLATFORM_DEVICE_REGISTER("mlu") +class MluDevice: + name = "mlu" + + @staticmethod + def is_available() -> bool: + try: + import torch_mlu + + return torch_mlu.mlu.is_available() + except ImportError: + return False + + @staticmethod + def get_device() -> str: + return "mlu" + + @staticmethod + def init_parallel_env(): + dist.init_process_group(backend="cncl") + torch.mlu.set_device(dist.get_rank()) diff --git a/lightx2v_platform/base/dcu.py b/lightx2v_platform/base/dcu.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a896e77867acb5965f1afcc878d3e5b17d40f0 --- /dev/null +++ b/lightx2v_platform/base/dcu.py @@ -0,0 +1,55 @@ +import torch +import torch.distributed as dist + +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + + +@PLATFORM_DEVICE_REGISTER("dcu") +class DcuDevice: + """ + DCU (AMD GPU) Device implementation for LightX2V. + + DCU uses ROCm which provides CUDA-compatible APIs through HIP. + Most PyTorch operations work transparently through the ROCm backend. + """ + + name = "dcu" + + @staticmethod + def is_available() -> bool: + """ + Check if DCU is available. + + DCU uses the standard CUDA API through ROCm's HIP compatibility layer. + Returns: + bool: True if DCU/CUDA is available + """ + try: + return torch.cuda.is_available() + except ImportError: + return False + + @staticmethod + def get_device() -> str: + """ + Get the device type string. + + Returns "cuda" because DCU uses CUDA-compatible APIs through ROCm. + This allows seamless integration with existing PyTorch code. + + Returns: + str: "cuda" for ROCm compatibility + """ + return "cuda" + + @staticmethod + def init_parallel_env(): + """ + Initialize distributed parallel environment for DCU. + + Uses RCCL (ROCm Collective Communications Library) which is + compatible with NCCL APIs for multi-GPU communication. + """ + # RCCL is compatible with NCCL backend + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank()) diff --git a/lightx2v_platform/base/global_var.py b/lightx2v_platform/base/global_var.py new file mode 100644 index 0000000000000000000000000000000000000000..01d393a46b40f26125b412c10263c0975b69bc07 --- /dev/null +++ b/lightx2v_platform/base/global_var.py @@ -0,0 +1 @@ +AI_DEVICE = None diff --git a/lightx2v_platform/base/metax.py b/lightx2v_platform/base/metax.py new file mode 100644 index 0000000000000000000000000000000000000000..c43161a5f1cd68ac902bf42bd345d24f66addb6a --- /dev/null +++ b/lightx2v_platform/base/metax.py @@ -0,0 +1,7 @@ +from lightx2v_platform.base.nvidia import CudaDevice +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + + +@PLATFORM_DEVICE_REGISTER("metax") +class MetaxDevice(CudaDevice): + name = "cuda" diff --git a/lightx2v_platform/base/nvidia.py b/lightx2v_platform/base/nvidia.py new file mode 100644 index 0000000000000000000000000000000000000000..75a625e72be7a80fb2c723ac1a425ac658d183df --- /dev/null +++ b/lightx2v_platform/base/nvidia.py @@ -0,0 +1,36 @@ +import torch +import torch.distributed as dist + +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + +try: + from torch.distributed import ProcessGroupNCCL +except ImportError: + ProcessGroupNCCL = None + + +@PLATFORM_DEVICE_REGISTER("cuda") +class CudaDevice: + name = "cuda" + + @staticmethod + def is_available() -> bool: + try: + import torch + + return torch.cuda.is_available() + except ImportError: + return False + + @staticmethod + def get_device() -> str: + return "cuda" + + @staticmethod + def init_parallel_env(): + if ProcessGroupNCCL is None: + raise RuntimeError("ProcessGroupNCCL is not available. Please check your runtime environment.") + pg_options = ProcessGroupNCCL.Options() + pg_options.is_high_priority_stream = True + dist.init_process_group(backend="nccl", pg_options=pg_options) + torch.cuda.set_device(dist.get_rank()) diff --git a/lightx2v_platform/ops/__init__.py b/lightx2v_platform/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a435af23f8703ec13a49bc7cc17cb35590ee10b8 --- /dev/null +++ b/lightx2v_platform/ops/__init__.py @@ -0,0 +1,12 @@ +import os + +from lightx2v_platform.base.global_var import AI_DEVICE + +if AI_DEVICE == "mlu": + from .attn.cambricon_mlu import * + from .mm.cambricon_mlu import * +elif AI_DEVICE == "cuda": + # Check if running on DCU platform + if os.getenv("PLATFORM") == "dcu": + from .attn.dcu import * + from .mm.dcu import * diff --git a/lightx2v_platform/ops/attn/__init__.py b/lightx2v_platform/ops/attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v_platform/ops/attn/cambricon_mlu/__init__.py b/lightx2v_platform/ops/attn/cambricon_mlu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..848e74353c431ca289c290cb6f6fe8a56dd6a1f0 --- /dev/null +++ b/lightx2v_platform/ops/attn/cambricon_mlu/__init__.py @@ -0,0 +1,2 @@ +from .flash_attn import * +from .sage_attn import * diff --git a/lightx2v_platform/ops/attn/cambricon_mlu/flash_attn.py b/lightx2v_platform/ops/attn/cambricon_mlu/flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..19b916b5bba11d03e9371016d2cc867cd43ac3ed --- /dev/null +++ b/lightx2v_platform/ops/attn/cambricon_mlu/flash_attn.py @@ -0,0 +1,42 @@ +import math + +from lightx2v_platform.ops.attn.template import AttnWeightTemplate +from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER + +try: + import torch_mlu_ops as tmo +except ImportError: + tmo = None + + +@PLATFORM_ATTN_WEIGHT_REGISTER("mlu_flash_attn") +class MluFlashAttnWeight(AttnWeightTemplate): + def __init__(self): + self.config = {} + assert tmo is not None, "torch_mlu_ops is not installed." + + def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): + if len(q.shape) == 3: + bs = 1 + q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) + elif len(q.shape) == 4: + bs = q.shape[0] + softmax_scale = 1 / math.sqrt(q.shape[-1]) + x = tmo.flash_attention( + q=q, + k=k, + v=v, + cu_seq_lens_q=cu_seqlens_q, + cu_seq_lens_kv=cu_seqlens_kv, + max_seq_len_q=max_seqlen_q, + max_seq_len_kv=max_seqlen_kv, + softmax_scale=softmax_scale, + return_lse=False, + out_dtype=q.dtype, + is_causal=False, + out=None, + alibi_slope=None, + attn_bias=None, + ) + x = x.reshape(bs * max_seqlen_q, -1) + return x diff --git a/lightx2v_platform/ops/attn/cambricon_mlu/sage_attn.py b/lightx2v_platform/ops/attn/cambricon_mlu/sage_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..687e65478c4493ef58335780ee3f04cf354053b8 --- /dev/null +++ b/lightx2v_platform/ops/attn/cambricon_mlu/sage_attn.py @@ -0,0 +1,31 @@ +import math + +import torch + +from lightx2v_platform.ops.attn.template import AttnWeightTemplate +from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER + +try: + import torch_mlu_ops as tmo +except ImportError: + tmo = None + + +@PLATFORM_ATTN_WEIGHT_REGISTER("mlu_sage_attn") +class MluSageAttnWeight(AttnWeightTemplate): + def __init__(self): + self.config = {} + assert tmo is not None, "torch_mlu_ops is not installed." + + def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): + if len(q.shape) == 3: + bs = 1 + q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) + elif len(q.shape) == 4: + bs = q.shape[0] + softmax_scale = 1 / math.sqrt(q.shape[-1]) + x = tmo.sage_attn( + q=q, k=k, v=v, cu_seq_lens_q=None, cu_seq_lens_kv=None, max_seq_len_kv=max_seqlen_kv, max_seq_len_q=max_seqlen_q, is_causal=False, compute_dtype=torch.bfloat16, softmax_scale=softmax_scale + ) + x = x.reshape(bs * max_seqlen_q, -1) + return x diff --git a/lightx2v_platform/ops/attn/dcu/__init__.py b/lightx2v_platform/ops/attn/dcu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ecd9913147f5fd2c090b4c7f9044e296b996856 --- /dev/null +++ b/lightx2v_platform/ops/attn/dcu/__init__.py @@ -0,0 +1 @@ +from .flash_attn import * diff --git a/lightx2v_platform/ops/attn/dcu/flash_attn.py b/lightx2v_platform/ops/attn/dcu/flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..b174bb7c6ffda9899ac5c74702191ee9ea9fc9bd --- /dev/null +++ b/lightx2v_platform/ops/attn/dcu/flash_attn.py @@ -0,0 +1,146 @@ +import torch +from loguru import logger + +from lightx2v_platform.ops.attn.template import AttnWeightTemplate +from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER + +# Try to import Flash Attention (ROCm version 2.6.1) +try: + from flash_attn import flash_attn_varlen_func + + FLASH_ATTN_AVAILABLE = True + logger.info(f"Flash Attention (ROCm) is available") +except ImportError: + logger.warning("Flash Attention not found. Will use PyTorch SDPA as fallback.") + flash_attn_varlen_func = None + FLASH_ATTN_AVAILABLE = False + + +@PLATFORM_ATTN_WEIGHT_REGISTER("flash_attn_dcu") +class FlashAttnDcu(AttnWeightTemplate): + """ + DCU Flash Attention implementation. + + Uses AMD ROCm version of Flash Attention 2.6.1 when available. + Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed. + + Tested Environment: + - PyTorch: 2.7.1 + - Python: 3.10 + - Flash Attention: 2.6.1 (ROCm) + Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py + """ + + def __init__(self, weight_name="flash_attn_dcu"): + super().__init__(weight_name) + self.use_flash_attn = FLASH_ATTN_AVAILABLE + + if self.use_flash_attn: + logger.info("Flash Attention 2.6.1 (ROCm) is available and will be used.") + else: + logger.warning("Flash Attention not available. Using PyTorch SDPA fallback.") + + def apply( + self, + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + ): + """ + Execute Flash Attention computation. + Args: + q: [B, Lq, Nq, C1] Query tensor + k: [B, Lk, Nk, C1] Key tensor + v: [B, Lk, Nk, C2] Value tensor + q_lens: [B] Optional sequence lengths for queries + k_lens: [B] Optional sequence lengths for keys + dropout_p: Dropout probability + softmax_scale: Scaling factor for QK^T before softmax + causal: Whether to apply causal mask + window_size: Sliding window size tuple (left, right) + deterministic: Whether to use deterministic algorithm + Returns: + Output tensor: [B, Lq, Nq, C2] + """ + if not self.use_flash_attn: + # Fallback to PyTorch SDPA + return self._sdpa_fallback(q, k, v, causal, dropout_p) + + # Ensure data types are half precision + half_dtypes = (torch.float16, torch.bfloat16) + dtype = q.dtype if q.dtype in half_dtypes else torch.bfloat16 + out_dtype = q.dtype + + b, lq, lk = q.size(0), q.size(1), k.size(1) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # Preprocess query + if q_lens is None: + q_flat = half(q.flatten(0, 1)) + q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device) + else: + q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + # Preprocess key/value + if k_lens is None: + k_flat = half(k.flatten(0, 1)) + v_flat = half(v.flatten(0, 1)) + k_lens = torch.tensor([lk] * b, dtype=torch.int32, device=k.device) + else: + k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + # Compute cumulative sequence lengths + cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) + cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) + + # Use Flash Attention 2.6.1 (ROCm version) + output = flash_attn_varlen_func( + q=q_flat, + k=k_flat, + v=v_flat, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=lq, + max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + ) + + # Reshape back to batch dimension + output = output.unflatten(0, (b, lq)) + return output.to(out_dtype) + + def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0): + """ + Fallback to PyTorch Scaled Dot Product Attention. + Args: + q: [B, Lq, Nq, C] Query tensor + k: [B, Lk, Nk, C] Key tensor + v: [B, Lk, Nk, C] Value tensor + causal: Whether to apply causal mask + dropout_p: Dropout probability + Returns: + Output tensor: [B, Lq, Nq, C] + """ + # Transpose to [B, Nq, Lq, C] for SDPA + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p) + + # Transpose back to [B, Lq, Nq, C] + return out.transpose(1, 2).contiguous() diff --git a/lightx2v_platform/ops/attn/template.py b/lightx2v_platform/ops/attn/template.py new file mode 100644 index 0000000000000000000000000000000000000000..10b236e4217475248fb2499f8d0288689356ea98 --- /dev/null +++ b/lightx2v_platform/ops/attn/template.py @@ -0,0 +1,32 @@ +from abc import ABCMeta, abstractmethod + + +class AttnWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name): + self.weight_name = weight_name + self.config = {} + + def load(self, weight_dict): + pass + + @abstractmethod + def apply(self, input_tensor): + pass + + def set_config(self, config=None): + if config is not None: + self.config = config + + def to_cpu(self, non_blocking=False): + pass + + def to_cuda(self, non_blocking=False): + pass + + def state_dict(self, destination=None): + if destination is None: + destination = {} + return destination + + def load_state_dict(self, destination, block_index, adapter_block_inde=None): + return {} diff --git a/lightx2v_platform/ops/mm/__init__.py b/lightx2v_platform/ops/mm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v_platform/ops/mm/cambricon_mlu/__init__.py b/lightx2v_platform/ops/mm/cambricon_mlu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9898101b00e91590a78d636e2066ff1234a4d8 --- /dev/null +++ b/lightx2v_platform/ops/mm/cambricon_mlu/__init__.py @@ -0,0 +1 @@ +from .mm_weight import * diff --git a/lightx2v_platform/ops/mm/cambricon_mlu/mm_weight.py b/lightx2v_platform/ops/mm/cambricon_mlu/mm_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..da2b9809b47be112d1d7102532994b63f38099ca --- /dev/null +++ b/lightx2v_platform/ops/mm/cambricon_mlu/mm_weight.py @@ -0,0 +1,37 @@ +from lightx2v_platform.ops.mm.template import MMWeightQuantTemplate +from lightx2v_platform.registry_factory import PLATFORM_MM_WEIGHT_REGISTER + +try: + import torch_mlu_ops as tmo +except ImportError: + tmo = None + + +@PLATFORM_MM_WEIGHT_REGISTER("int8-tmo") +class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate): + """ + Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu + + Quant MM: + Weight: int8 perchannel sym + Act: int8 perchannel dynamic sym + Kernel: mlu + """ + + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.load_func = self.load_int8_perchannel_sym + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo + + def act_quant_int8_perchannel_sym_tmo(self, x): + input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x) + return input_tensor_quant, input_tensor_scale + + def apply(self, input_tensor): + dtype = input_tensor.dtype + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = tmo.scaled_matmul( + input_tensor_quant, self.weight.contiguous(), input_tensor_scale, self.weight_scale.squeeze(-1), bias=self.bias if self.bias is not None else None, output_dtype=dtype, use_hp_active=True + ) + return output_tensor diff --git a/lightx2v_platform/ops/mm/cambricon_mlu/q_linear.py b/lightx2v_platform/ops/mm/cambricon_mlu/q_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..5037a89971c44ade0f85e094cede02818eeba1cb --- /dev/null +++ b/lightx2v_platform/ops/mm/cambricon_mlu/q_linear.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn + +try: + import torch_mlu_ops as tmo +except ImportError: + tmo = None + + +class MluQuantLinearInt8(nn.Module): + def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) + + if bias: + self.register_buffer("bias", torch.empty(out_features, dtype=dtype)) + else: + self.register_buffer("bias", None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x) + return input_tensor_quant, input_tensor_scale + + def forward(self, input_tensor): + input_tensor = input_tensor.squeeze(0) + dtype = input_tensor.dtype + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output_tensor = tmo.scaled_matmul(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.squeeze(-1), output_dtype=dtype) + return output_tensor.unsqueeze(0) + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + def maybe_cast(t): + if t is not None and t.device != fn(t).device: + return fn(t) + return t + + self.weight = maybe_cast(self.weight) + self.weight_scale = maybe_cast(self.weight_scale) + self.bias = maybe_cast(self.bias) + return self diff --git a/lightx2v_platform/ops/mm/template.py b/lightx2v_platform/ops/mm/template.py new file mode 100644 index 0000000000000000000000000000000000000000..f418ca0b825f921a1de34afe0db5d1441dc09d71 --- /dev/null +++ b/lightx2v_platform/ops/mm/template.py @@ -0,0 +1,466 @@ +from abc import ABCMeta, abstractmethod + +import torch + +from lightx2v_platform.base.global_var import AI_DEVICE + + +class MMWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + self.weight_name = weight_name + self.bias_name = bias_name + self.create_cuda_buffer = create_cuda_buffer + self.create_cpu_buffer = create_cpu_buffer + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.is_post_adapter = is_post_adapter + self.config = {} + + @abstractmethod + def load(self, weight_dict): + pass + + @abstractmethod + def apply(self): + pass + + def set_config(self, config={}): + self.config = config + + def to_cuda(self, non_blocking=False): + self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking) + if hasattr(self, "pin_weight_scale"): + self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking) + if hasattr(self, "pin_bias") and self.pin_bias is not None: + self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking) + + def to_cpu(self, non_blocking=False): + if hasattr(self, "pin_weight"): + self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu() + if hasattr(self, "weight_scale_name"): + self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu() + if self.bias is not None: + self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu() + else: + self.weight = self.weight.to("cpu", non_blocking=non_blocking) + if hasattr(self, "weight_scale"): + self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking) + if hasattr(self, "bias") and self.bias is not None: + self.bias = self.bias.to("cpu", non_blocking=non_blocking) + + +class MMWeightQuantTemplate(MMWeightTemplate): + def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter) + self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale" + self.load_func = None + self.weight_need_transpose = True + self.act_quant_func = None + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.infer_dtype = torch.bfloat16 # bias dtype + self.bias_force_fp32 = False + + # ========================= + # weight load functions + # ========================= + def load(self, weight_dict): + self.load_quantized(weight_dict) + if self.weight_need_transpose: + if hasattr(self, "weight") and self.weight is not None: + self.weight = self.weight.t() + if hasattr(self, "pin_weight") and self.pin_weight is not None: + self.pin_weight = self.pin_weight.t() + if hasattr(self, "weight_cuda_buffer") and self.weight_cuda_buffer is not None: + self.weight_cuda_buffer = self.weight_cuda_buffer.t() + + def load_quantized(self, weight_dict): + if self.create_cuda_buffer: + self._load_cuda_buffers(weight_dict) + elif self.create_cpu_buffer: + self._load_cpu_pin_buffers() + else: + self._load_default_tensors(weight_dict) + + def _load_cuda_buffers(self, weight_dict): + source = self.lazy_load_file if self.lazy_load else weight_dict + self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load) + self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load) + + def _get_cuda_tensor_pair(self, source, is_lazy): + if is_lazy: + weight = source.get_tensor(self.weight_name).to(AI_DEVICE) + scale = source.get_tensor(self.weight_scale_name).float().to(AI_DEVICE) + else: + weight = source[self.weight_name].to(AI_DEVICE) + scale = source[self.weight_scale_name].float().to(AI_DEVICE) + return weight, scale + + def _get_cuda_bias_tensor(self, source, is_lazy): + if self.bias_name is None: + return None + if is_lazy: + bias = source.get_tensor(self.bias_name) + dtype = self.infer_dtype + else: + bias = source[self.bias_name] + dtype = bias.dtype + if self.bias_force_fp32: + bias = bias.to(torch.float32) + else: + bias = bias.to(dtype) + return bias.to(AI_DEVICE) + + def _load_cpu_pin_buffers(self): + self.pin_weight, self.pin_weight_scale = self._get_cpu_pin_tensor_pair(self.lazy_load_file, is_lazy=True) + self.pin_bias = self._get_cpu_pin_bias_tensor(self.lazy_load_file, is_lazy=True) + self.bias = None + + def _get_cpu_pin_tensor_pair(self, source, is_lazy): + if is_lazy: + weight_tensor = source.get_tensor(self.weight_name) + scale_tensor = source.get_tensor(self.weight_scale_name) + scale_dtype = torch.float + else: + weight_tensor = source[self.weight_name] + scale_tensor = source[self.weight_scale_name] + scale_dtype = torch.float + + pin_weight = self._create_pin_tensor(weight_tensor) + pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype) + return pin_weight, pin_scale + + def _get_cpu_pin_bias_tensor(self, source, is_lazy): + if self.bias_name is None: + return None + if is_lazy: + bias_tensor = source.get_tensor(self.bias_name) + if not self.bias_force_fp32: + bias_tensor = bias_tensor.to(self.infer_dtype) + else: + bias_tensor = source[self.bias_name] + if self.bias_force_fp32: + bias_tensor = bias_tensor.to(torch.float32) + return self._create_pin_tensor(bias_tensor) + + def _create_pin_tensor(self, tensor, dtype=None): + dtype = dtype or tensor.dtype + pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype) + pin_tensor.copy_(tensor) + del tensor + return pin_tensor + + def _load_default_tensors(self, weight_dict): + if not self.lazy_load: + self.weight, self.weight_scale, self.pin_weight, self.pin_weight_scale = self._get_device_tensor_pair(weight_dict) + self._load_default_bias(weight_dict) + else: + self.bias = None + self.pin_bias = None + + def _get_device_tensor_pair(self, source): + device = source[self.weight_name].device + if device.type == "cpu": + pin_weight, pin_scale = self._get_cpu_pin_tensor_pair(source, is_lazy=False) + return None, None, pin_weight, pin_scale + else: + return source[self.weight_name], source[self.weight_scale_name].float(), None, None + + def _load_default_bias(self, source): + if self.bias_name is None: + self.bias = None + self.pin_bias = None + self.bias_cuda_buffer = None + return + + if self.create_cuda_buffer: + self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, is_lazy=False) + self.bias = None + self.pin_bias = None + else: + bias_tensor = source[self.bias_name].float() if self.bias_force_fp32 else source[self.bias_name] + device = bias_tensor.device + if device.type == "cpu": + self.pin_bias = self._get_cpu_pin_bias_tensor(source, is_lazy=False) + self.bias = None + else: + self.bias = bias_tensor + self.pin_bias = None + + def load_fp8_perchannel_sym(self, weight_dict): + if self.config.get("weight_auto_quant", False): + self.weight = weight_dict[self.weight_name].to(torch.float32) + w_quantizer = FloatQuantizer("e4m3", True, "per_channel") + self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) + self.weight = self.weight.to(torch.float8_e4m3fn) + self.weight_scale = self.weight_scale.to(torch.float32) + else: + self.load_quantized(weight_dict) + + def load_int8_perchannel_sym(self, weight_dict): + if self.config.get("weight_auto_quant", False): + self.weight = weight_dict[self.weight_name].to(torch.float32) + w_quantizer = IntegerQuantizer(8, True, "per_channel") + self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) + self.weight = self.weight.to(torch.int8) + self.weight_scale = self.weight_scale.to(torch.float32) + else: + self.load_quantized(weight_dict) + + def load_mxfp4(self, weight_dict): + if self.config.get("weight_auto_quant", False): + device = weight_dict[self.weight_name].device + self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16) + self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight) + self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) + else: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + weight_scale_shape = weight_dict[self.weight_scale_name].shape + weight_scale_dtype = weight_dict[self.weight_scale_name].dtype + self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) + self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.weight_scale = weight_dict[self.weight_scale_name] + + def load_mxfp6(self, weight_dict): + if self.config.get("weight_auto_quant", False): + device = weight_dict[self.weight_name].device + self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16) + self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight) + self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) + else: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + weight_scale_shape = weight_dict[self.weight_scale_name].shape + weight_scale_dtype = weight_dict[self.weight_scale_name].dtype + self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) + self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.weight_scale = weight_dict[self.weight_scale_name] + + def load_mxfp8(self, weight_dict): + if self.config.get("weight_auto_quant", False): + device = weight_dict[self.weight_name].device + self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16) + self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight) + self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) + else: + device = weight_dict[self.weight_name].device + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + weight_scale_shape = weight_dict[self.weight_scale_name].shape + weight_scale_dtype = weight_dict[self.weight_scale_name].dtype + self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) + self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.weight_scale = weight_dict[self.weight_scale_name] + + def load_nvfp4(self, weight_dict): + device = weight_dict[self.weight_name].device + + input_absmax = weight_dict[self.weight_name.replace(".weight", ".input_absmax")] + input_global_scale = (2688.0 / input_absmax).to(torch.float32) + weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"] + alpha = 1.0 / (input_global_scale * weight_global_scale) + + if device.type == "cpu": + weight_shape = weight_dict[self.weight_name].shape + weight_dtype = weight_dict[self.weight_name].dtype + self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) + self.pin_weight.copy_(weight_dict[self.weight_name]) + + weight_scale_shape = weight_dict[self.weight_scale_name].shape + weight_scale_dtype = weight_dict[self.weight_scale_name].dtype + self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) + self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) + + input_global_scale_shape = input_global_scale.shape + input_global_scale_dtype = input_global_scale.dtype + self.pin_input_global_scale = torch.empty(input_global_scale_shape, pin_memory=True, dtype=input_global_scale_dtype) + self.pin_input_global_scale.copy_(input_global_scale) + + alpha_shape = alpha.shape + alpha_dtype = alpha.dtype + self.pin_alpha = torch.empty(alpha_shape, pin_memory=True, dtype=alpha_dtype) + self.pin_alpha.copy_(alpha) + + del weight_dict[self.weight_name] + else: + self.weight = weight_dict[self.weight_name] + self.weight_scale = weight_dict[self.weight_scale_name] + self.input_global_scale = input_global_scale + self.alpha = alpha + + if self.bias_name is not None: + if self.create_cuda_buffer: + self.bias_cuda_buffer = weight_dict[self.bias_name].to(AI_DEVICE) + else: + device = weight_dict[self.bias_name].device + if device.type == "cpu": + bias_shape = weight_dict[self.bias_name].shape + bias_dtype = weight_dict[self.bias_name].dtype + self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) + self.pin_bias.copy_(weight_dict[self.bias_name]) + else: + self.bias = weight_dict[self.bias_name] + else: + self.bias = None + self.pin_bias = None + + def load_fp8_perblock128_sym(self, weight_dict): + if self.config.get("weight_auto_quant", False): + self.weight = weight_dict[self.weight_name] + self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight) + else: + self.load_quantized(weight_dict) + + def per_block_cast_to_fp8(self, x): + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128), + dtype=x.dtype, + device=x.device, + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + # ========================= + # act quant kernels + # ========================= + def act_quant_int8_perchannel_sym_torchao(self, x): + input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x) + return input_tensor_quant, input_tensor_scale + + def act_quant_fp8_perchannel_sym_vllm(self, x): + input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) + return input_tensor_quant, input_tensor_scale + + def act_quant_fp8_perchannel_sym_sgl(self, x): + m, k = x.shape + input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False) + input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False) + sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale) + return input_tensor_quant, input_tensor_scale + + def act_quant_int8_perchannel_sym_vllm(self, x): + input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) + return input_tensor_quant, input_tensor_scale + + def act_quant_nvfp4(self, x): + input_tensor_quant, input_tensor_scale = scaled_nvfp4_quant(x, self.input_global_scale) + return input_tensor_quant, input_tensor_scale + + def act_quant_mxfp4(self, x): + input_tensor_quant, input_tensor_scale = scaled_mxfp4_quant(x) + return input_tensor_quant, input_tensor_scale + + def act_quant_mxfp8(self, x): + input_tensor_quant, input_tensor_scale = scaled_mxfp8_quant(x) + return input_tensor_quant, input_tensor_scale + + def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x): + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + def act_quant_fp8_perchannelgroup128_sym_sgl(self, x): + m, k = x.shape + input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False) + input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False) + sgl_kernel.sgl_per_token_group_quant_fp8( + x, + input_tensor_quant, + input_tensor_scale, + group_size=128, + eps=1e-10, + fp8_min=-448.0, + fp8_max=448.0, + ) + return input_tensor_quant, input_tensor_scale + + def state_dict(self, destination=None): + if destination is None: + destination = {} + destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight + if self.bias_name is not None: + destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias + destination[self.weight_scale_name] = self.pin_weight_scale if hasattr(self, "pin_weight_scale") else self.weight_scale + return destination + + def load_state_dict(self, destination, block_index, adapter_block_index=None): + if self.is_post_adapter: + weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1) + else: + weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1) + + if weight_name not in destination: + self.weight = None + return + + self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True) + self.weight_scale = self.weight_scale_cuda_buffer.copy_(destination[weight_scale_name], non_blocking=True) + + if self.bias_name is not None: + bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True) + else: + self.bias = None + + def load_state_dict_from_disk(self, block_index, adapter_block_index=None): + if self.is_post_adapter: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) + self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1) + else: + self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) + self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1) + + if self.weight_need_transpose: + weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).t() + else: + weight_tensor = self.lazy_load_file.get_tensor(self.weight_name) + self.pin_weight = self.pin_weight.copy_(weight_tensor) + + weight_scale_tensor = self.lazy_load_file.get_tensor(self.weight_scale_name) + self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor) + + del weight_tensor + + if self.bias_name is not None: + if self.is_post_adapter: + assert adapter_block_index is not None + self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1) + else: + self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) + + bias_tensor = self.lazy_load_file.get_tensor(self.bias_name) + self.pin_bias.copy_(bias_tensor) + del bias_tensor diff --git a/lightx2v_platform/ops/norm/__init__.py b/lightx2v_platform/ops/norm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v_platform/ops/rope/__init__.py b/lightx2v_platform/ops/rope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lightx2v_platform/registry_factory.py b/lightx2v_platform/registry_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..777cf5270bcdb87e9e0e908718cbec0e30791809 --- /dev/null +++ b/lightx2v_platform/registry_factory.py @@ -0,0 +1,58 @@ +class Register(dict): + def __init__(self, *args, **kwargs): + super(Register, self).__init__(*args, **kwargs) + self._dict = {} + + def __call__(self, target_or_name): + if callable(target_or_name): + return self.register(target_or_name) + else: + return lambda x: self.register(x, key=target_or_name) + + def register(self, target, key=None): + if not callable(target): + raise Exception(f"Error: {target} must be callable!") + + if key is None: + key = target.__name__ + + if key in self._dict: + raise Exception(f"{key} already exists.") + + self[key] = target + return target + + def __setitem__(self, key, value): + self._dict[key] = value + + def __getitem__(self, key): + return self._dict[key] + + def __contains__(self, key): + return key in self._dict + + def __str__(self): + return str(self._dict) + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + def items(self): + return self._dict.items() + + def get(self, key, default=None): + return self._dict.get(key, default) + + def merge(self, other_register): + for key, value in other_register.items(): + if key in self._dict: + raise Exception(f"{key} already exists in target register.") + self[key] = value + + +PLATFORM_DEVICE_REGISTER = Register() +PLATFORM_ATTN_WEIGHT_REGISTER = Register() +PLATFORM_MM_WEIGHT_REGISTER = Register() diff --git a/lightx2v_platform/set_ai_device.py b/lightx2v_platform/set_ai_device.py new file mode 100644 index 0000000000000000000000000000000000000000..e52b4ddb69c59173d8db9a1baa509e51a258caec --- /dev/null +++ b/lightx2v_platform/set_ai_device.py @@ -0,0 +1,15 @@ +import os + +from lightx2v_platform import * + + +def set_ai_device(): + platform = os.getenv("PLATFORM", "cuda") + init_ai_device(platform) + from lightx2v_platform.base.global_var import AI_DEVICE + + check_ai_device(AI_DEVICE) + + +set_ai_device() +from lightx2v_platform.ops import * # noqa: E402 diff --git a/lightx2v_platform/test/test_device.py b/lightx2v_platform/test/test_device.py new file mode 100644 index 0000000000000000000000000000000000000000..c1047bf7b650eb3595acd6b233724f2cf325603e --- /dev/null +++ b/lightx2v_platform/test/test_device.py @@ -0,0 +1,11 @@ +""" +PYTHONPATH=/path-to-LightX2V PLATFORM=cuda python test_device.py +PYTHONPATH=/path-to-LightX2V PLATFORM=mlu python test_device.py +PYTHONPATH=/path-to-LightX2V PLATFORM=metax python test_device.py +""" + +# This import will initialize the AI device +import lightx2v_platform.set_ai_device # noqa: F401 +from lightx2v_platform.base.global_var import AI_DEVICE + +print(f"AI_DEVICE: {AI_DEVICE}") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..884766cd1ed3fda0842b128fb4531f9eb9db0da1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,114 @@ +[build-system] +requires = [ + "setuptools>=61.0", + "wheel", + "packaging", + "ninja", +] +build-backend = "setuptools.build_meta" + +[project] +name = "lightx2v" +version = "0.1.0" +authors = [ + {name = "LightX2V Contributors"}, +] +description = "LightX2V: Light Video Generation Inference Framework" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Multimedia :: Video", +] + +dependencies = [ + "numpy", + "scipy", + "torch<=2.8.0", + "torchvision<=0.23.0", + "torchaudio<=2.8.0", + "diffusers", + "transformers", + "tokenizers", + "tqdm", + "accelerate", + "safetensors", + "opencv-python", + "imageio", + "imageio-ffmpeg", + "einops", + "loguru", + "qtorch", + "ftfy", + "gradio", + "aiohttp", + "pydantic", + "prometheus-client", + "gguf", + "fastapi", + "uvicorn", + "PyJWT", + "requests", + "aio-pika", + "asyncpg>=0.27.0", + "aioboto3>=12.0.0", + "alibabacloud_dypnsapi20170525==1.2.2", + "redis==6.4.0", + "tos", + "decord", + "av", +] + +[project.urls] +Homepage = "https://github.com/ModelTC/LightX2V" +Documentation = "https://lightx2v-en.readthedocs.io/en/latest/" +Repository = "https://github.com/ModelTC/LightX2V" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +include = ["lightx2v*"] +exclude = ["lightx2v_kernel*"] + +[tool.ruff] +exclude = [ + ".git", + ".mypy_cache", + ".ruff_cache", + ".venv", + "dist", + "build", + "__pycache__", + "*.egg-info", + ".pytest_cache", + ".cluade", + ".cursor", + "lightx2v_kernel", +] +target-version = "py311" +line-length = 200 +indent-width = 4 + + +[tool.ruff.lint] +extend-select = ["I", "F401"] +ignore = ["F"] + +[tool.ruff.lint.per-file-ignores] +"**/__init__.py" = ["F401"] +"**/lightx2v_kernel/*" = ["F401"] +"**/{cookbook,docs}/*" = ["E402", "F401", "F811", "F841"] + +[tool.ruff.lint.isort] +known-first-party = ["lightx2v"] +case-sensitive = true diff --git a/requirements-docs.txt b/requirements-docs.txt new file mode 100644 index 0000000000000000000000000000000000000000..1c8eec42f870e441abe5a533083ab9ce7ec8b9c4 --- /dev/null +++ b/requirements-docs.txt @@ -0,0 +1,7 @@ +sphinx == 6.2.1 +sphinx-book-theme == 1.0.1 +sphinx-copybutton == 0.5.2 +myst-parser == 2.0.0 +sphinx-argparse +sphinxcontrib.redoc +sphinxcontrib.openapi diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9cb43688b81f02502badf9168a86794c944a572 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,40 @@ +packaging +ninja +numpy +scipy +torch<=2.8.0 +torchvision<=0.23.0 +torchaudio<=2.8.0 +diffusers +transformers +tokenizers +tqdm +accelerate +safetensors +opencv-python +imageio +imageio-ffmpeg +einops +loguru +sgl-kernel +qtorch +ftfy +gradio +aiohttp +pydantic +aio-pika +asyncpg>=0.27.0 +aioboto3>=12.0.0 +prometheus-client +gguf +fastapi +uvicorn +PyJWT +requests +alibabacloud_dypnsapi20170525==1.2.2 +redis==6.4.0 +tos +decord +zmq +jsonschema +pymongo diff --git a/requirements_animate.txt b/requirements_animate.txt new file mode 100644 index 0000000000000000000000000000000000000000..13d6093796cd9f135a6868de7cfd6b6f0f3cac85 --- /dev/null +++ b/requirements_animate.txt @@ -0,0 +1,8 @@ +decord +peft +onnxruntime +pandas +matplotlib +-e git+https://github.com/facebookresearch/sam2.git@0e78a118995e66bb27d78518c4bd9a3e95b4e266#egg=SAM-2 +loguru +sentencepiece diff --git a/requirements_win.txt b/requirements_win.txt new file mode 100644 index 0000000000000000000000000000000000000000..1cada8f7cb1d8e37716c76ca9fa2975869ee1770 --- /dev/null +++ b/requirements_win.txt @@ -0,0 +1,18 @@ +packaging +ninja +diffusers +transformers +tokenizers +accelerate +safetensors +opencv-python +numpy +imageio +imageio-ffmpeg +einops +loguru +qtorch +ftfy +gradio +aiohttp +pydantic diff --git a/save_results/.gitkeep b/save_results/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/base/base.sh b/scripts/base/base.sh new file mode 100644 index 0000000000000000000000000000000000000000..4ae01675e6d357c127047756a6d1b0614c83a796 --- /dev/null +++ b/scripts/base/base.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +# always set false to avoid some warnings +export TOKENIZERS_PARALLELISM=false +# set expandable_segments to True to avoid OOM +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# ===================================================================================== +# ⚠️ IMPORTANT CONFIGURATION PARAMETERS - READ CAREFULLY AND MODIFY WITH CAUTION ⚠️ +# ===================================================================================== + +# Model Inference Data Type Setting (IMPORTANT!) +# Key parameter affecting model accuracy and performance +# Available options: [BF16, FP16] +# If not set, default value: BF16 +export DTYPE=BF16 + +# Sensitive Layer Data Type Setting (IMPORTANT!) +# Used for layers requiring higher precision +# Available options: [FP32, None] +# If not set, default value: None (follows DTYPE setting) +export SENSITIVE_LAYER_DTYPE=None + +# Performance Profiling Debug Level (Debug Only) +# Enables detailed performance analysis output, such as time cost and memory usage +# Available options: [0, 1, 2] +# If not set, default value: 0 +# Note: This option can be set to 0 for production. +export PROFILING_DEBUG_LEVEL=2 + + +echo "===============================================================================" +echo "LightX2V Base Environment Variables Summary:" +echo "-------------------------------------------------------------------------------" +echo "lightx2v_path: ${lightx2v_path}" +echo "model_path: ${model_path}" +echo "-------------------------------------------------------------------------------" +echo "Model Inference Data Type: ${DTYPE}" +echo "Sensitive Layer Data Type: ${SENSITIVE_LAYER_DTYPE}" +echo "Performance Profiling Debug Level: ${PROFILING_DEBUG_LEVEL}" +echo "===============================================================================" diff --git a/scripts/bench/run_lightx2v_1.sh b/scripts/bench/run_lightx2v_1.sh new file mode 100644 index 0000000000000000000000000000000000000000..a44ca395a9208c1c4e59a40ec6e05002eeeb0976 --- /dev/null +++ b/scripts/bench/run_lightx2v_1.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH +export DTYPE=BF16 +export SENSITIVE_LAYER_DTYPE=FP32 +export PROFILING_DEBUG_LEVEL=2 + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_1.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_1.mp4 diff --git a/scripts/bench/run_lightx2v_2.sh b/scripts/bench/run_lightx2v_2.sh new file mode 100644 index 0000000000000000000000000000000000000000..5a15c63d51882a795125647929b1b020631e8ac3 --- /dev/null +++ b/scripts/bench/run_lightx2v_2.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +export PROFILING_DEBUG_LEVEL=2 +export DTYPE=BF16 + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_2.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_2.mp4 diff --git a/scripts/bench/run_lightx2v_3.sh b/scripts/bench/run_lightx2v_3.sh new file mode 100644 index 0000000000000000000000000000000000000000..30d6a8c4068cf5aa0de2ecad46e91997ec665c06 --- /dev/null +++ b/scripts/bench/run_lightx2v_3.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +export PROFILING_DEBUG_LEVEL=2 +export DTYPE=BF16 + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_3.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_3.mp4 diff --git a/scripts/bench/run_lightx2v_3_distill.sh b/scripts/bench/run_lightx2v_3_distill.sh new file mode 100644 index 0000000000000000000000000000000000000000..b99136f387b0fe8ddd3c7a9b4d586c41bffcc843 --- /dev/null +++ b/scripts/bench/run_lightx2v_3_distill.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +export PROFILING_DEBUG_LEVEL=2 +export DTYPE=BF16 + +python -m lightx2v.infer \ +--model_cls wan2.1_distill \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_3_distill.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_3_distill.mp4 diff --git a/scripts/bench/run_lightx2v_4.sh b/scripts/bench/run_lightx2v_4.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6c57c59a81cca02c636a21524b0003cd27e0a8a --- /dev/null +++ b/scripts/bench/run_lightx2v_4.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +export PROFILING_DEBUG_LEVEL=2 +export DTYPE=BF16 + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_4.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_4.mp4 diff --git a/scripts/bench/run_lightx2v_5.sh b/scripts/bench/run_lightx2v_5.sh new file mode 100644 index 0000000000000000000000000000000000000000..0be7a05ed17f65c03a16ec1f83c3ab194fef34e8 --- /dev/null +++ b/scripts/bench/run_lightx2v_5.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH +export DTYPE=BF16 +export SENSITIVE_LAYER_DTYPE=FP32 +export PROFILING_DEBUG_LEVEL=2 + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_5.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_5.mp4 diff --git a/scripts/bench/run_lightx2v_5_distill.sh b/scripts/bench/run_lightx2v_5_distill.sh new file mode 100644 index 0000000000000000000000000000000000000000..108c99029d1a353aff2c46b38a5b2af86ffad5ff --- /dev/null +++ b/scripts/bench/run_lightx2v_5_distill.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH +export DTYPE=BF16 +export SENSITIVE_LAYER_DTYPE=FP32 +export PROFILING_DEBUG_LEVEL=2 + +python -m lightx2v.infer \ +--model_cls wan2.1_distill \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_5_distill.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_5_distill.mp4 diff --git a/scripts/bench/run_lightx2v_6.sh b/scripts/bench/run_lightx2v_6.sh new file mode 100644 index 0000000000000000000000000000000000000000..9d08af121dcece66fa300df5f1fbc78f95806baf --- /dev/null +++ b/scripts/bench/run_lightx2v_6.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +export PROFILING_DEBUG_LEVEL=2 +export DTYPE=BF16 + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_6.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_6.mp4 diff --git a/scripts/bench/run_lightx2v_6_distill.sh b/scripts/bench/run_lightx2v_6_distill.sh new file mode 100644 index 0000000000000000000000000000000000000000..1f6f223f89de7ae1685238a1b4c48dfd9306eb05 --- /dev/null +++ b/scripts/bench/run_lightx2v_6_distill.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/lightx2v +model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v +# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +export PROFILING_DEBUG_LEVEL=2 +export DTYPE=BF16 + +python -m lightx2v.infer \ +--model_cls wan2.1_distill \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/bench/lightx2v_6_distill.json \ +--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \ +--save_result_path ${lightx2v_path}/save_results/lightx2v_6_distill.mp4 diff --git a/scripts/cache/readme.md b/scripts/cache/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..05c2756c8e72ddbf6ed2bffc2cf9e0ab50386e50 --- /dev/null +++ b/scripts/cache/readme.md @@ -0,0 +1,11 @@ +# Feature Caching + +The config files for feature caching are available [here](https://github.com/ModelTC/lightx2v/tree/main/configs/caching) + +By specifying --config_json to the specific config file, you can test different cache algorithms. + +Please refer our feature caching doc: + +[English doc: Feature Caching](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/cache.html) + +[中文文档: 特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html) diff --git a/scripts/cache/run_wan_i2v_dist_cfg_ulysses_mag.sh b/scripts/cache/run_wan_i2v_dist_cfg_ulysses_mag.sh new file mode 100644 index 0000000000000000000000000000000000000000..20cd5e807a76640b7f4d26c185543fad53f9be9e --- /dev/null +++ b/scripts/cache/run_wan_i2v_dist_cfg_ulysses_mag.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/caching/magcache/wan_i2v_dist_cfg_ulysses_mag_480p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_dist_cfg_ulysses_mag.mp4 diff --git a/scripts/cache/run_wan_i2v_mag.sh b/scripts/cache/run_wan_i2v_mag.sh new file mode 100644 index 0000000000000000000000000000000000000000..8cb0930edb73fa16eaec6d0e19a01e1c9cd95042 --- /dev/null +++ b/scripts/cache/run_wan_i2v_mag.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/caching/magcache/wan_i2v_mag_480p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_mag.mp4 diff --git a/scripts/cache/run_wan_i2v_mag_calibration.sh b/scripts/cache/run_wan_i2v_mag_calibration.sh new file mode 100644 index 0000000000000000000000000000000000000000..9ac834cc180d9ee8598602c75bd9cd4dd23482b1 --- /dev/null +++ b/scripts/cache/run_wan_i2v_mag_calibration.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/caching/magcache/wan_i2v_mag_calibration_480p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_mag.mp4 diff --git a/scripts/cache/run_wan_i2v_tea.sh b/scripts/cache/run_wan_i2v_tea.sh new file mode 100644 index 0000000000000000000000000000000000000000..85901d466adfe008469b15f942f68ad2ed69f7d5 --- /dev/null +++ b/scripts/cache/run_wan_i2v_tea.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/caching/teacache/wan_i2v_tea_480p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_tea.mp4 diff --git a/scripts/cache/run_wan_t2v_dist_cfg_ulysses_mag.sh b/scripts/cache/run_wan_t2v_dist_cfg_ulysses_mag.sh new file mode 100644 index 0000000000000000000000000000000000000000..f0fcffbce7de56a18378551984d892a04031313c --- /dev/null +++ b/scripts/cache/run_wan_t2v_dist_cfg_ulysses_mag.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/caching/magcache/wan_t2v_dist_cfg_ulysses_mag_1_3b.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_dist_cfg_ulysses_mag.mp4 diff --git a/scripts/cache/run_wan_t2v_mag.sh b/scripts/cache/run_wan_t2v_mag.sh new file mode 100644 index 0000000000000000000000000000000000000000..6bdae93ec22102f15f9e9917c5a0b181b03cbd71 --- /dev/null +++ b/scripts/cache/run_wan_t2v_mag.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/caching/magcache/wan_t2v_mag_1_3b.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_mag.mp4 diff --git a/scripts/cache/run_wan_t2v_mag_calibration.sh b/scripts/cache/run_wan_t2v_mag_calibration.sh new file mode 100644 index 0000000000000000000000000000000000000000..22ec2338935d7eb39873797085117d8b6b413011 --- /dev/null +++ b/scripts/cache/run_wan_t2v_mag_calibration.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/caching/magcache/wan_t2v_mag_calibration_1_3b.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_mag.mp4 diff --git a/scripts/cache/run_wan_t2v_tea.sh b/scripts/cache/run_wan_t2v_tea.sh new file mode 100644 index 0000000000000000000000000000000000000000..74c28e62f75e8701b5b7ae0e35e432d6a477e518 --- /dev/null +++ b/scripts/cache/run_wan_t2v_tea.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_tea.mp4 diff --git a/scripts/changing_resolution/run_wan_i2v_changing_resolution.sh b/scripts/changing_resolution/run_wan_i2v_changing_resolution.sh new file mode 100644 index 0000000000000000000000000000000000000000..010497f99b626adce7760e21fbff03bd8db1def5 --- /dev/null +++ b/scripts/changing_resolution/run_wan_i2v_changing_resolution.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/changing_resolution/wan_i2v_U.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_changing_resolution.mp4 diff --git a/scripts/changing_resolution/run_wan_t2v_changing_resolution.sh b/scripts/changing_resolution/run_wan_t2v_changing_resolution.sh new file mode 100644 index 0000000000000000000000000000000000000000..3e4881d86c64dd5b6a784ec4d8c3e88b753712c4 --- /dev/null +++ b/scripts/changing_resolution/run_wan_t2v_changing_resolution.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/changing_resolution/wan_t2v_U.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_changing_resolution.mp4 diff --git a/scripts/dist_infer/run_wan22_moe_i2v_cfg.sh b/scripts/dist_infer/run_wan22_moe_i2v_cfg.sh new file mode 100644 index 0000000000000000000000000000000000000000..3369b95054082282a2e6ffbf06384393b1c0c728 --- /dev/null +++ b/scripts/dist_infer/run_wan22_moe_i2v_cfg.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_cfg.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_cfg.mp4 diff --git a/scripts/dist_infer/run_wan22_moe_i2v_cfg_ulysses.sh b/scripts/dist_infer/run_wan22_moe_i2v_cfg_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..ccee8f4cc95718b7c4fa13ce0533a2d31226b331 --- /dev/null +++ b/scripts/dist_infer/run_wan22_moe_i2v_cfg_ulysses.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_cfg_ulysses.mp4 diff --git a/scripts/dist_infer/run_wan22_moe_i2v_ulysses.sh b/scripts/dist_infer/run_wan22_moe_i2v_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..614c08636498e76de4b5488cbcaefc14b0302376 --- /dev/null +++ b/scripts/dist_infer/run_wan22_moe_i2v_ulysses.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=4 -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_ulysses.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_ulysses.mp4 diff --git a/scripts/dist_infer/run_wan22_moe_t2v_cfg.sh b/scripts/dist_infer/run_wan22_moe_t2v_cfg.sh new file mode 100644 index 0000000000000000000000000000000000000000..b7e4f773843580702259bc876f160c80f2f66393 --- /dev/null +++ b/scripts/dist_infer/run_wan22_moe_t2v_cfg.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_t2v_cfg.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_t2v_parallel_cfg.mp4 diff --git a/scripts/dist_infer/run_wan22_moe_t2v_cfg_ulysses.sh b/scripts/dist_infer/run_wan22_moe_t2v_cfg_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..dd6b01c6a5f3d1e952546437048eaf41c664b984 --- /dev/null +++ b/scripts/dist_infer/run_wan22_moe_t2v_cfg_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_t2v_cfg_ulysses.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_t2v_parallel_cfg_ulysses.mp4 diff --git a/scripts/dist_infer/run_wan22_moe_t2v_ulysses.sh b/scripts/dist_infer/run_wan22_moe_t2v_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..90cbec06e94fb39afdde9a088d0ade365cfc892a --- /dev/null +++ b/scripts/dist_infer/run_wan22_moe_t2v_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=4 -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_t2v_ulysses.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_t2v_parallel_ulysses.mp4 diff --git a/scripts/dist_infer/run_wan22_ti2v_i2v_cfg.sh b/scripts/dist_infer/run_wan22_ti2v_i2v_cfg.sh new file mode 100644 index 0000000000000000000000000000000000000000..0f640eba71a47f2a0b517de990ce0b70d8fce49f --- /dev/null +++ b/scripts/dist_infer/run_wan22_ti2v_i2v_cfg.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls wan2.2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_ti2v_i2v_cfg.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_ti2v_i2v_parallel_cfg.mp4 diff --git a/scripts/dist_infer/run_wan22_ti2v_i2v_cfg_ulysses.sh b/scripts/dist_infer/run_wan22_ti2v_i2v_cfg_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..3111ada52accff59a1062a0c7ca9348f92a04396 --- /dev/null +++ b/scripts/dist_infer/run_wan22_ti2v_i2v_cfg_ulysses.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_ti2v_i2v_cfg_ulysses.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_ti2v_i2v_parallel_cfg_ulysses.mp4 diff --git a/scripts/dist_infer/run_wan22_ti2v_i2v_ulysses.sh b/scripts/dist_infer/run_wan22_ti2v_i2v_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..ef4ebc8ff6efe124ae67aebb13936f7c29952770 --- /dev/null +++ b/scripts/dist_infer/run_wan22_ti2v_i2v_ulysses.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=4 -m lightx2v.infer \ +--model_cls wan2.2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_ti2v_i2v_ulysses.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_ti2v_i2v_parallel_ulysses.mp4 diff --git a/scripts/dist_infer/run_wan22_ti2v_t2v_cfg.sh b/scripts/dist_infer/run_wan22_ti2v_t2v_cfg.sh new file mode 100644 index 0000000000000000000000000000000000000000..0e6d3d2f9ca29ed7afb30c9cb49b7369062cb4a1 --- /dev/null +++ b/scripts/dist_infer/run_wan22_ti2v_t2v_cfg.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls wan2.2 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_ti2v_t2v_cfg.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_ti2v_t2v_parallel_cfg.mp4 diff --git a/scripts/dist_infer/run_wan22_ti2v_t2v_cfg_ulysses.sh b/scripts/dist_infer/run_wan22_ti2v_t2v_cfg_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..d1aa07ba825a662657391fb1278a2ef3602a0534 --- /dev/null +++ b/scripts/dist_infer/run_wan22_ti2v_t2v_cfg_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.2 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_ti2v_t2v_cfg_ulysses.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_ti2v_t2v_parallel_cfg_ulysses.mp4 diff --git a/scripts/dist_infer/run_wan22_ti2v_t2v_ulysses.sh b/scripts/dist_infer/run_wan22_ti2v_t2v_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..82599c8a6f2d5048f00e99669a7eae01003bc220 --- /dev/null +++ b/scripts/dist_infer/run_wan22_ti2v_t2v_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=4 -m lightx2v.infer \ +--model_cls wan2.2 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan22_ti2v_t2v_ulysses.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_ti2v_t2v_parallel_ulysses.mp4 diff --git a/scripts/dist_infer/run_wan_i2v_dist_cfg_ulysses.sh b/scripts/dist_infer/run_wan_i2v_dist_cfg_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..843b120935a0f36c4e85ed2a2e08f39611504cbd --- /dev/null +++ b/scripts/dist_infer/run_wan_i2v_dist_cfg_ulysses.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan_i2v_dist_cfg_ulysses.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 diff --git a/scripts/dist_infer/run_wan_i2v_dist_ulysses.sh b/scripts/dist_infer/run_wan_i2v_dist_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..b6eed52a063e492a206144d3c504d0af450ba891 --- /dev/null +++ b/scripts/dist_infer/run_wan_i2v_dist_ulysses.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=4 -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan_i2v_dist_ulysses.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 diff --git a/scripts/dist_infer/run_wan_t2v_dist_cfg.sh b/scripts/dist_infer/run_wan_t2v_dist_cfg.sh new file mode 100644 index 0000000000000000000000000000000000000000..a38f57f00ebbac5f65331dc4ce8c8dc27a104306 --- /dev/null +++ b/scripts/dist_infer/run_wan_t2v_dist_cfg.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan_t2v_dist_cfg.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 diff --git a/scripts/dist_infer/run_wan_t2v_dist_cfg_ulysses.sh b/scripts/dist_infer/run_wan_t2v_dist_cfg_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..f849940be9bc6389d0912748713352105b0168fd --- /dev/null +++ b/scripts/dist_infer/run_wan_t2v_dist_cfg_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan_t2v_dist_cfg_ulysses.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 diff --git a/scripts/dist_infer/run_wan_t2v_dist_ulysses.sh b/scripts/dist_infer/run_wan_t2v_dist_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..898b7f7291046b12dc0c9c8ef72a0ed6267f2f94 --- /dev/null +++ b/scripts/dist_infer/run_wan_t2v_dist_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=4 -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan_t2v_dist_ulysses.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 diff --git a/scripts/hunyuan_video_15/README.md b/scripts/hunyuan_video_15/README.md new file mode 100644 index 0000000000000000000000000000000000000000..108b5d65866d05d721e166c06d11705baa5b2d4c --- /dev/null +++ b/scripts/hunyuan_video_15/README.md @@ -0,0 +1,103 @@ +# HunyuanVideo1.5 + +## Quick Start + +1. Prepare docker environment: + +```bash +docker pull lightx2v/lightx2v:25111101-cu128 +``` + +2. Run the container: +```bash +docker run --gpus all -itd --ipc=host --name [container_name] -v [mount_settings] --entrypoint /bin/bash [image_id] +``` + +3. Prepare the models + +Please follow the instructions in [HunyuanVideo1.5 Github](https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/checkpoints-download.md) to download and place the model files. + +4. Running + +Running using bash script +```bash +# enter the docker container + +git clone https://github.com/ModelTC/LightX2V.git +cd LightX2V/scripts/hunyuan_video_15 + +# set LightX2V path and model path in the script +bash run_hy15_t2v_480p.sh +``` + +Running using Python code +```python +""" +HunyuanVideo-1.5 text-to-video generation example. +This example demonstrates how to use LightX2V with HunyuanVideo-1.5 model for T2V generation. +""" + +from lightx2v import LightX2VPipeline + +# Initialize pipeline for HunyuanVideo-1.5 +pipe = LightX2VPipeline( + model_path="/path/to/ckpts/hunyuanvideo-1.5/", + model_cls="hunyuan_video_1.5", + transformer_model_name="720p_t2v", + task="t2v", +) + +# Alternative: create generator from config JSON file +# pipe.create_generator(config_json="configs/hunyuan_video_15/hunyuan_video_t2v_720p.json") + +# Enable offloading to significantly reduce VRAM usage with minimal speed impact +# Suitable for RTX 30/40/50 consumer GPUs +pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported + text_encoder_offload=True, + image_encoder_offload=False, + vae_offload=False, +) + +# Use lighttae +pipe.enable_lightvae( + use_tae=True, + tae_path="/path/to/lighttaehy1_5.safetensors", + use_lightvae=False, + vae_path=None, +) + +# Create generator with specified parameters +pipe.create_generator( + attn_mode="sage_attn2", + infer_steps=50, + num_frames=121, + guidance_scale=6.0, + sample_shift=9.0, + aspect_ratio="16:9", + fps=24, +) + +# Generation parameters +seed = 123 +prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." +negative_prompt = "" +save_result_path = "/path/to/save_results/output.mp4" + +# Generate video +pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) +``` + +5. Check results + +You can find the generated video files in the `save_results` folder. + +6. Modify detailed configurations + +You can refer to the config file pointed to by `--config_json` in the script and modify its parameters as needed. diff --git a/scripts/hunyuan_video_15/run_hy15_i2v_480p.sh b/scripts/hunyuan_video_15/run_hy15_i2v_480p.sh new file mode 100644 index 0000000000000000000000000000000000000000..add987c2ada605523b3c9f2b3f107a82be772bd9 --- /dev/null +++ b/scripts/hunyuan_video_15/run_hy15_i2v_480p.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--seed 123 \ +--model_cls hunyuan_video_1.5 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/hunyuan_video_15/hunyuan_video_i2v_480p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_video_15_i2v.mp4 diff --git a/scripts/hunyuan_video_15/run_hy15_i2v_480p_vsr.sh b/scripts/hunyuan_video_15/run_hy15_i2v_480p_vsr.sh new file mode 100644 index 0000000000000000000000000000000000000000..eaeb1c9cb98e688fbc333e9ef7fede4917d8e78b --- /dev/null +++ b/scripts/hunyuan_video_15/run_hy15_i2v_480p_vsr.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--seed 123 \ +--model_cls hunyuan_video_1.5 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/hunyuan_video_15/vsr/hy15_i2v_480p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_video_15_i2v.mp4 diff --git a/scripts/hunyuan_video_15/run_hy15_i2v_720p.sh b/scripts/hunyuan_video_15/run_hy15_i2v_720p.sh new file mode 100644 index 0000000000000000000000000000000000000000..e398174c8f545e9a51c448af34939919fad2539c --- /dev/null +++ b/scripts/hunyuan_video_15/run_hy15_i2v_720p.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--seed 123 \ +--model_cls hunyuan_video_1.5 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/hunyuan_video_15/hunyuan_video_i2v_720p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_video_15_i2v.mp4 diff --git a/scripts/hunyuan_video_15/run_hy15_t2v_480p.sh b/scripts/hunyuan_video_15/run_hy15_t2v_480p.sh new file mode 100644 index 0000000000000000000000000000000000000000..c533b69b43c1101153ecf80e1e1252f8a6261f7e --- /dev/null +++ b/scripts/hunyuan_video_15/run_hy15_t2v_480p.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--seed 123 \ +--model_cls hunyuan_video_1.5 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/hunyuan_video_15/hunyuan_video_t2v_480p.json \ +--prompt "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." \ +--negative_prompt "" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_video_15_t2v.mp4 diff --git a/scripts/hunyuan_video_15/run_hy15_t2v_480p_distill.sh b/scripts/hunyuan_video_15/run_hy15_t2v_480p_distill.sh new file mode 100644 index 0000000000000000000000000000000000000000..060978a864e15bb0b0dce314043c0d455d85eb7d --- /dev/null +++ b/scripts/hunyuan_video_15/run_hy15_t2v_480p_distill.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--seed 123 \ +--model_cls hunyuan_video_1.5_distill \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json \ +--prompt "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." \ +--negative_prompt "" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_video_15_t2v_distill.mp4 diff --git a/scripts/hunyuan_video_15/run_hy15_t2v_720p.sh b/scripts/hunyuan_video_15/run_hy15_t2v_720p.sh new file mode 100644 index 0000000000000000000000000000000000000000..16902baa8b9bc6bd84daabe888cb360a7b18e117 --- /dev/null +++ b/scripts/hunyuan_video_15/run_hy15_t2v_720p.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=2 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--seed 123 \ +--model_cls hunyuan_video_1.5 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/hunyuan_video_15/hunyuan_video_t2v_720p.json \ +--prompt "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." \ +--negative_prompt "" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_video_15_t2v.mp4 diff --git a/scripts/matrix_game2/run_matrix_game2_gta_drive.sh b/scripts/matrix_game2/run_matrix_game2_gta_drive.sh new file mode 100644 index 0000000000000000000000000000000000000000..43c41ffe83714a622b62f049ad1019a5c1e170fc --- /dev/null +++ b/scripts/matrix_game2/run_matrix_game2_gta_drive.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path=path to Lightx2v +model_path=path to Skywork/Matrix-Game-2.0 + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf_mtxg2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_gta_drive.json \ +--prompt '' \ +--image_path gta_drive/0003.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_gta_drive.mp4 \ +--seed 42 diff --git a/scripts/matrix_game2/run_matrix_game2_gta_drive_streaming.sh b/scripts/matrix_game2/run_matrix_game2_gta_drive_streaming.sh new file mode 100644 index 0000000000000000000000000000000000000000..5b5b130517e444a73300a69b224f787c328b30fc --- /dev/null +++ b/scripts/matrix_game2/run_matrix_game2_gta_drive_streaming.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/data/nvme2/wushuo/LightX2V +model_path=/data/nvme2/wushuo/hf_models/Skywork/Matrix-Game-2.0 + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf_mtxg2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_gta_drive_streaming.json \ +--prompt '' \ +--image_path gta_drive/0003.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_gta_drive_streaming.mp4 \ +--seed 42 diff --git a/scripts/matrix_game2/run_matrix_game2_templerun.sh b/scripts/matrix_game2/run_matrix_game2_templerun.sh new file mode 100644 index 0000000000000000000000000000000000000000..0768aa9207fb96acb26ce9b9d1a668bb5cad0789 --- /dev/null +++ b/scripts/matrix_game2/run_matrix_game2_templerun.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path=path to Lightx2v +model_path=path to Skywork/Matrix-Game-2.0 + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf_mtxg2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_templerun.json \ +--prompt '' \ +--image_path templerun/0005.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_templerun.mp4 \ +--seed 42 diff --git a/scripts/matrix_game2/run_matrix_game2_templerun_streaming.sh b/scripts/matrix_game2/run_matrix_game2_templerun_streaming.sh new file mode 100644 index 0000000000000000000000000000000000000000..6765ce431e8d5488bcc557c3289a68f0ca961f49 --- /dev/null +++ b/scripts/matrix_game2/run_matrix_game2_templerun_streaming.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path=path to Lightx2v +model_path=path to Skywork/Matrix-Game-2.0 + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf_mtxg2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_templerun_streaming.json \ +--prompt '' \ +--image_path templerun/0005.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_templerun_streaming.mp4 \ +--seed 42 diff --git a/scripts/matrix_game2/run_matrix_game2_universal.sh b/scripts/matrix_game2/run_matrix_game2_universal.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb464c4a9a71afcdcb5c7c4c2f353c758fb4db3e --- /dev/null +++ b/scripts/matrix_game2/run_matrix_game2_universal.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path=path to Lightx2v +model_path=path to Skywork/Matrix-Game-2.0 + +export CUDA_VISIBLE_DEVICES= + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf_mtxg2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_universal.json \ +--prompt '' \ +--image_path universal/0007.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_universal.mp4 \ +--seed 42 diff --git a/scripts/matrix_game2/run_matrix_game2_universal_streaming.sh b/scripts/matrix_game2/run_matrix_game2_universal_streaming.sh new file mode 100644 index 0000000000000000000000000000000000000000..30f9d0eb91e1827f325c417eaabd74505cc487b6 --- /dev/null +++ b/scripts/matrix_game2/run_matrix_game2_universal_streaming.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path=path to Lightx2v +model_path=path to Skywork/Matrix-Game-2.0 + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf_mtxg2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_universal_streaming.json \ +--prompt '' \ +--image_path universal/0007.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_universal_streaming.mp4 \ +--seed 42 diff --git a/scripts/quantization/gguf/run_wan_i2v_gguf_q4_k.sh b/scripts/quantization/gguf/run_wan_i2v_gguf_q4_k.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0799084708e97a0e24d8973b5b8ad17a4d600b8 --- /dev/null +++ b/scripts/quantization/gguf/run_wan_i2v_gguf_q4_k.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/quantization/gguf/wan_i2v_q4_k.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 diff --git a/scripts/quantization/readme.md b/scripts/quantization/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..2835da0f1890377dfcc87f85ec642b73e1e249cb --- /dev/null +++ b/scripts/quantization/readme.md @@ -0,0 +1,11 @@ +# Model Quantization + +The config files for model quantization are available [here](https://github.com/ModelTC/lightx2v/tree/main/configs/quantization) + +By specifying --config_json to the specific config file, you can test quantization inference. + +Please refer our model quantization doc: + +[English doc: Model Quantization](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/quantization.html) + +[中文文档: 模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) diff --git a/scripts/quantization/run_wan_i2v_quantization.sh b/scripts/quantization/run_wan_i2v_quantization.sh new file mode 100644 index 0000000000000000000000000000000000000000..33a8138e463940176183e5c81d460386d82c2297 --- /dev/null +++ b/scripts/quantization/run_wan_i2v_quantization.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/quantization/wan_i2v.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 diff --git a/scripts/qwen_image/qwen_image_i2i.sh b/scripts/qwen_image/qwen_image_i2i.sh new file mode 100644 index 0000000000000000000000000000000000000000..2893150a70683d159590a8832dbcff478fb1ae47 --- /dev/null +++ b/scripts/qwen_image/qwen_image_i2i.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls qwen_image \ + --task i2i \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \ + --prompt "turn the style of the photo to vintage comic book" \ + --negative_prompt " " \ + --image_path pie.png \ + --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png \ + --seed 0 diff --git a/scripts/qwen_image/qwen_image_i2i_2509.sh b/scripts/qwen_image/qwen_image_i2i_2509.sh new file mode 100644 index 0000000000000000000000000000000000000000..f127d2dac10b3f036ac6a954488e62a39f6179be --- /dev/null +++ b/scripts/qwen_image/qwen_image_i2i_2509.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls qwen_image \ + --task i2i \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2509.json \ + --prompt "Have the two characters swap clothes and stand in front of the castle." \ + --negative_prompt " " \ + --image_path 1.jpeg,2.jpeg \ + --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2509.png \ + --seed 0 diff --git a/scripts/qwen_image/qwen_image_i2i_2509_block.sh b/scripts/qwen_image/qwen_image_i2i_2509_block.sh new file mode 100644 index 0000000000000000000000000000000000000000..e437d77289084eb0f6eb5069a95e45be77ab35d3 --- /dev/null +++ b/scripts/qwen_image/qwen_image_i2i_2509_block.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES= + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls qwen_image \ + --task i2i \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/offload/block/qwen_image_i2i_2509_block.json \ + --prompt "Have the two characters swap clothes and stand in front of the castle." \ + --negative_prompt " " \ + --image_path 1.jpeg,2.jpeg \ + --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2509.png \ + --seed 0 diff --git a/scripts/qwen_image/qwen_image_i2i_block.sh b/scripts/qwen_image/qwen_image_i2i_block.sh new file mode 100644 index 0000000000000000000000000000000000000000..6b1c67511b80ad4bb4b9a1166745447261979bfa --- /dev/null +++ b/scripts/qwen_image/qwen_image_i2i_block.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls qwen_image \ + --task i2i \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/offload/block/qwen_image_i2i_block.json \ + --prompt "turn the style of the photo to vintage comic book" \ + --negative_prompt " " \ + --image_path pie.png \ + --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png \ + --seed 0 diff --git a/scripts/qwen_image/qwen_image_i2i_lora.sh b/scripts/qwen_image/qwen_image_i2i_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..e5312b05bbd5065dafe468057c392f63c32ac3e2 --- /dev/null +++ b/scripts/qwen_image/qwen_image_i2i_lora.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ + --model_cls qwen_image \ + --task i2i \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_lora.json \ + --prompt "Change the person to a standing position, bending over to hold the dog's front paws." \ + --negative_prompt " " \ + --image_path qwen_image_edit/qwen_edit1.webp \ + --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png \ + --seed 0 diff --git a/scripts/qwen_image/qwen_image_t2i.sh b/scripts/qwen_image/qwen_image_t2i.sh new file mode 100644 index 0000000000000000000000000000000000000000..cb0b75e95132931bd84dda5d464f8b6ce09bdfb0 --- /dev/null +++ b/scripts/qwen_image/qwen_image_t2i.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls qwen_image \ +--task t2i \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \ +--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition, Ultra HD, 4K, cinematic composition.' \ +--negative_prompt " " \ +--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png \ +--seed 42 diff --git a/scripts/qwen_image/qwen_image_t2i_block.sh b/scripts/qwen_image/qwen_image_t2i_block.sh new file mode 100644 index 0000000000000000000000000000000000000000..15f64934f2dd2a96481617a6dda9cbd39c7be923 --- /dev/null +++ b/scripts/qwen_image/qwen_image_t2i_block.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls qwen_image \ +--task t2i \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/offload/block/qwen_image_t2i_block.json \ +--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition, Ultra HD, 4K, cinematic composition.' \ +--negative_prompt " " \ +--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png \ +--seed 42 diff --git a/scripts/seko_talk/multi_person/01_base.sh b/scripts/seko_talk/multi_person/01_base.sh new file mode 100644 index 0000000000000000000000000000000000000000..9f8e6654b4239617fca9e19d62fe9a9cbd37c86d --- /dev/null +++ b/scripts/seko_talk/multi_person/01_base.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/multi_person/01_base.json \ +--prompt "The video features a man and a woman standing by a bench in the park, their expressions tense and voices raised as they argue. The man gestures with both hands, his arms swinging slightly as if to emphasize each heated word, while the woman stands with her hands on her waist, her brows furrowed in frustration. The background is a wide expanse of sunlit grass, the golden light contrasting with the sharp energy of their quarrel. Their voices seem to clash in the air, and the rhythm of their hand movements and body postures interweaves with the rising tension, creating a vivid scene of confrontation." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/multi_person/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/multi_person \ +--save_result_path ${lightx2v_path}/save_results/seko_talk_multi_person.mp4 diff --git a/scripts/seko_talk/multi_person/03_dist.sh b/scripts/seko_talk/multi_person/03_dist.sh new file mode 100644 index 0000000000000000000000000000000000000000..faab432ae6da33cd9c4da8fc66ff0f182960bd3d --- /dev/null +++ b/scripts/seko_talk/multi_person/03_dist.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 4 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/multi_person/03_dist.json \ +--prompt "The video features a man and a woman standing by a bench in the park, their expressions tense and voices raised as they argue. The man gestures with both hands, his arms swinging slightly as if to emphasize each heated word, while the woman stands with her hands on her waist, her brows furrowed in frustration. The background is a wide expanse of sunlit grass, the golden light contrasting with the sharp energy of their quarrel. Their voices seem to clash in the air, and the rhythm of their hand movements and body postures interweaves with the rising tension, creating a vivid scene of confrontation." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/multi_person/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/multi_person \ +--save_result_path ${lightx2v_path}/save_results/seko_talk_multi_person_dist.mp4 diff --git a/scripts/seko_talk/run_seko_talk_01_base.sh b/scripts/seko_talk/run_seko_talk_01_base.sh new file mode 100644 index 0000000000000000000000000000000000000000..de3970ef4fbde4017b5ddf10fd8ef19fa4069713 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_01_base.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_01_base.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_02_fp8.sh b/scripts/seko_talk/run_seko_talk_02_fp8.sh new file mode 100644 index 0000000000000000000000000000000000000000..dc32d3ae81e7475eeec86419107dc83a30d8a520 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_02_fp8.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_02_fp8.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_03_dist.sh b/scripts/seko_talk/run_seko_talk_03_dist.sh new file mode 100644 index 0000000000000000000000000000000000000000..b4e8c6e2cdf9261ab7e389d0f8ca288ba00fadbe --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_03_dist.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_03_dist.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_04_fp8_dist.sh b/scripts/seko_talk/run_seko_talk_04_fp8_dist.sh new file mode 100644 index 0000000000000000000000000000000000000000..69a768baabaea81788c840babcb3606a2409ea86 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_04_fp8_dist.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_04_fp8_dist.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_05_offload_fp8_4090.sh b/scripts/seko_talk/run_seko_talk_05_offload_fp8_4090.sh new file mode 100644 index 0000000000000000000000000000000000000000..6a2a1ccbd36947c58e3d2469e3e79b09a4c11516 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_05_offload_fp8_4090.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/LightX2V +model_path=/path/to/SekoTalk-Distill-fp8/ + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_05_offload_fp8_4090.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_06_offload_fp8_H100.sh b/scripts/seko_talk/run_seko_talk_06_offload_fp8_H100.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f46fa6cf10adfed14729d0bfe8debd0650ff4fa --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_06_offload_fp8_H100.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/LightX2V +model_path=/path/to/SekoTalk-Distill-fp8/ + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_06_offload_fp8_H100.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_07_dist_offload.sh b/scripts/seko_talk/run_seko_talk_07_dist_offload.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1d7c5963afdabeed922770012afc16fbae7ac2f --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_07_dist_offload.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 4 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_07_dist_offload.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_08_5B_base.sh b/scripts/seko_talk/run_seko_talk_08_5B_base.sh new file mode 100644 index 0000000000000000000000000000000000000000..2f348603d8aeb00919e222cc0582281b908d24dc --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_08_5B_base.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-5B + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_08_5B_base.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_09_base_fixed_min_area.sh b/scripts/seko_talk/run_seko_talk_09_base_fixed_min_area.sh new file mode 100644 index 0000000000000000000000000000000000000000..36616d0adbc23c0f223f1d383a9bd2609b9c770b --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_09_base_fixed_min_area.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_09_base_fixed_min_area.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_10_fp8_dist_fixed_min_area.sh b/scripts/seko_talk/run_seko_talk_10_fp8_dist_fixed_min_area.sh new file mode 100644 index 0000000000000000000000000000000000000000..48829fe68eef1d3fb11a7b8358ebbe2e5eb0c524 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_10_fp8_dist_fixed_min_area.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 4 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_10_fp8_dist_fixed_min_area.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_11_fp8_dist_fixed_shape.sh b/scripts/seko_talk/run_seko_talk_11_fp8_dist_fixed_shape.sh new file mode 100644 index 0000000000000000000000000000000000000000..252d1e9f88ec758253f321fd61bab39ee0138a4b --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_11_fp8_dist_fixed_shape.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 4 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_11_fp8_dist_fixed_shape.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.sh b/scripts/seko_talk/run_seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.sh new file mode 100644 index 0000000000000000000000000000000000000000..fcf61b1202786672007c26716000842087796958 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_12_fp8_dist_fixed_shape_8gpus_1s.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.sh b/scripts/seko_talk/run_seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.sh new file mode 100644 index 0000000000000000000000000000000000000000..e2c7be6fb05d2a10d78957027ffdbccbd0235304 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_13_fp8_dist_bucket_shape_8gpus_5s_realtime.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.sh b/scripts/seko_talk/run_seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.sh new file mode 100644 index 0000000000000000000000000000000000000000..115bb70242ea47ec172bf8b9399c5a5255e1322a --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_14_fp8_dist_bucket_shape_8gpus_1s_realtime.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_15_base_compile.sh b/scripts/seko_talk/run_seko_talk_15_base_compile.sh new file mode 100644 index 0000000000000000000000000000000000000000..4c2210ad9d2c34ddba5e0b074228cd3d7cd81d5e --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_15_base_compile.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_15_base_compile.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_16_fp8_dist_compile.sh b/scripts/seko_talk/run_seko_talk_16_fp8_dist_compile.sh new file mode 100644 index 0000000000000000000000000000000000000000..b7d662c17ef703c8e05598ddccc05b5e96bf41e8 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_16_fp8_dist_compile.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_16_fp8_dist_compile.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_17_vsr.sh b/scripts/seko_talk/run_seko_talk_17_vsr.sh new file mode 100644 index 0000000000000000000000000000000000000000..91922530c9bf96830997aa469e4e2f0eaa02b8b5 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_17_vsr.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path="" +model_path="" + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_17_base_vsr.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_18_5090_base.sh b/scripts/seko_talk/run_seko_talk_18_5090_base.sh new file mode 100644 index 0000000000000000000000000000000000000000..d3a9bde906e9b5ae4ade0c88099c90fae771396c --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_18_5090_base.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/LightX2V +model_path=/path/to/SekoTalk-Distill/ + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/5090/seko_talk_5090_bf16.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_19_A800_int8_dist.sh b/scripts/seko_talk/run_seko_talk_19_A800_int8_dist.sh new file mode 100644 index 0000000000000000000000000000000000000000..6a36475feaa6ad04dbcce7525fd370e3f45bfeb6 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_19_A800_int8_dist.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-int8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/A800/seko_talk_A800_int8_dist_8gpu.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_20_A800_int8.sh b/scripts/seko_talk/run_seko_talk_20_A800_int8.sh new file mode 100644 index 0000000000000000000000000000000000000000..c52a18ca649604275ba95d2270d43e667251aa06 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_20_A800_int8.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-int8 + + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/A800/seko_talk_A800_int8.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_21_5090_int8.sh b/scripts/seko_talk/run_seko_talk_21_5090_int8.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b06a1fc4b09168747bb9f72df5f9d01907f1ccb --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_21_5090_int8.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-int8 + + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/5090/seko_talk_5090_int8.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_21_5090_int8_dist.sh b/scripts/seko_talk/run_seko_talk_21_5090_int8_dist.sh new file mode 100644 index 0000000000000000000000000000000000000000..e0c6c7d1a6fa0b97f55db0e0e04e85b0a2858335 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_21_5090_int8_dist.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-int8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/5090/seko_talk_5090_int8_8gpu.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_22_nbhd_attn.sh b/scripts/seko_talk/run_seko_talk_22_nbhd_attn.sh new file mode 100644 index 0000000000000000000000000000000000000000..fef9ba5d8810597aba7f59a6bec4aa1c387e5167 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_22_nbhd_attn.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_22_nbhd_attn.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_23_fp8_dist_nbhd_attn.sh b/scripts/seko_talk/run_seko_talk_23_fp8_dist_nbhd_attn.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7a32a40ee10f43ab1e202197339a6a174e3ce65 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_23_fp8_dist_nbhd_attn.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_23_fp8_dist_nbhd_attn.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_24_fp8_dist_compile_nbhd_attn.sh b/scripts/seko_talk/run_seko_talk_24_fp8_dist_compile_nbhd_attn.sh new file mode 100644 index 0000000000000000000000000000000000000000..de447faf55295740ad51c46a519524b702599f07 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_24_fp8_dist_compile_nbhd_attn.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-fp8 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_24_fp8_dist_compile_nbhd_attn.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_25_mlu_bf16.sh b/scripts/seko_talk/run_seko_talk_25_mlu_bf16.sh new file mode 100644 index 0000000000000000000000000000000000000000..4005fa44ca1e620779e2dade176ea1c7a049ab6f --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_25_mlu_bf16.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export PLATFORM=mlu +export MLU_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/mlu/seko_talk_bf16.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_26_mlu_int8.sh b/scripts/seko_talk/run_seko_talk_26_mlu_int8.sh new file mode 100644 index 0000000000000000000000000000000000000000..205942ef04a48ceb5b0429f15724fe7b1f0b3b54 --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_26_mlu_int8.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-int8 + +export PLATFORM=mlu +export MLU_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/mlu/seko_talk_int8.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_27_mlu_int8_dist.sh b/scripts/seko_talk/run_seko_talk_27_mlu_int8_dist.sh new file mode 100644 index 0000000000000000000000000000000000000000..935bf48eeb69b4b75ebef18fd3cc1e8bcf6a80cd --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_27_mlu_int8_dist.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill-int8 + +export PLATFORM=mlu +export MLU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +torchrun --nproc-per-node 8 -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/mlu/seko_talk_int8_dist.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/seko_talk/run_seko_talk_28_f2v.sh b/scripts/seko_talk/run_seko_talk_28_f2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..50e04a89fde41c77b8284cabac979fe4a998d33b --- /dev/null +++ b/scripts/seko_talk/run_seko_talk_28_f2v.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_28_f2v.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/scripts/self_forcing/run_wan_t2v_sf.sh b/scripts/self_forcing/run_wan_t2v_sf.sh new file mode 100644 index 0000000000000000000000000000000000000000..bd8c4e9b73b333e73785e93e87881025ff3d3539 --- /dev/null +++ b/scripts/self_forcing/run_wan_t2v_sf.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= # path to Wan2.1-T2V-1.3B +sf_model_path= # path to gdhe17/Self-Forcing + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf \ +--task t2v \ +--model_path $model_path \ +--sf_model_path $sf_model_path \ +--config_json ${lightx2v_path}/configs/self_forcing/wan_t2v_sf.json \ +--prompt 'A stylish woman strolls down a bustling Tokyo street, the warm glow of neon lights and animated city signs casting vibrant reflections. She wears a sleek black leather jacket paired with a flowing red dress and black boots, her black purse slung over her shoulder. Sunglasses perched on her nose and a bold red lipstick add to her confident, casual demeanor. The street is damp and reflective, creating a mirror-like effect that enhances the colorful lights and shadows. Pedestrians move about, adding to the lively atmosphere. The scene is captured in a dynamic medium shot with the woman walking slightly to one side, highlighting her graceful strides.' \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_sf.mp4 diff --git a/scripts/server/check_status.py b/scripts/server/check_status.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2166e69539c164d66658e9f6af13cdce9a1a24 --- /dev/null +++ b/scripts/server/check_status.py @@ -0,0 +1,13 @@ +import requests +from loguru import logger + +response = requests.get("http://localhost:8000/v1/service/status") +logger.info(response.json()) + + +response = requests.get("http://localhost:8000/v1/tasks/") +logger.info(response.json()) + + +response = requests.get("http://localhost:8000/v1/tasks/test_task_001/status") +logger.info(response.json()) diff --git a/scripts/server/post.py b/scripts/server/post.py new file mode 100644 index 0000000000000000000000000000000000000000..0abe32cbb4c5c055770994e21f72a28a2c230618 --- /dev/null +++ b/scripts/server/post.py @@ -0,0 +1,17 @@ +import requests +from loguru import logger + +if __name__ == "__main__": + url = "http://localhost:8000/v1/tasks/" + + message = { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "", + } + + logger.info(f"message: {message}") + + response = requests.post(url, json=message) + + logger.info(f"response: {response.json()}") diff --git a/scripts/server/post_enhancer.py b/scripts/server/post_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..49a350f4e99e7b307eaaacd35aa6828249d88a49 --- /dev/null +++ b/scripts/server/post_enhancer.py @@ -0,0 +1,17 @@ +import requests +from loguru import logger + +url = "http://localhost:8000/v1/tasks/" + +message = { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "", + "use_prompt_enhancer": True, +} + +logger.info(f"message: {message}") + +response = requests.post(url, json=message) + +logger.info(f"response: {response.json()}") diff --git a/scripts/server/post_i2v.py b/scripts/server/post_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..3500f2e7d6697eda265c06d83bb50c1d8ee7bcef --- /dev/null +++ b/scripts/server/post_i2v.py @@ -0,0 +1,27 @@ +import base64 + +import requests +from loguru import logger + + +def image_to_base64(image_path): + """Convert an image file to base64 string""" + with open(image_path, "rb") as f: + image_data = f.read() + return base64.b64encode(image_data).decode("utf-8") + + +if __name__ == "__main__": + url = "http://localhost:8000/v1/tasks/" + + message = { + "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": image_to_base64("assets/inputs/imgs/img_0.jpg"), # 图片地址 + } + + logger.info(f"message: {message}") + + response = requests.post(url, json=message) + + logger.info(f"response: {response.json()}") diff --git a/scripts/server/post_multi_servers.py b/scripts/server/post_multi_servers.py new file mode 100644 index 0000000000000000000000000000000000000000..eb18dbe6dc02316569de874157a0ce7e3585db25 --- /dev/null +++ b/scripts/server/post_multi_servers.py @@ -0,0 +1,148 @@ +import base64 +import os +import threading +import time +from typing import Any + +import requests +from loguru import logger +from tqdm import tqdm + + +def image_to_base64(image_path): + """Convert an image file to base64 string""" + with open(image_path, "rb") as f: + image_data = f.read() + return base64.b64encode(image_data).decode("utf-8") + + +def process_image_path(image_path) -> Any | str: + """Process image_path: convert to base64 if local path, keep unchanged if HTTP link""" + if not image_path: + return image_path + + if image_path.startswith(("http://", "https://")): + return image_path + + if os.path.exists(image_path): + return image_to_base64(image_path) + else: + logger.warning(f"Image path not found: {image_path}") + return image_path + + +def send_and_monitor_task(url, message, task_index, complete_bar, complete_lock): + """Send task to server and monitor until completion""" + try: + if "image_path" in message and message["image_path"]: + message["image_path"] = process_image_path(message["image_path"]) + + response = requests.post(f"{url}/v1/tasks/", json=message) + response_data = response.json() + task_id = response_data.get("task_id") + + if not task_id: + logger.error(f"No task_id received from {url}") + return False + + # Step 2: Monitor task status until completion + while True: + try: + status_response = requests.get(f"{url}/v1/tasks/{task_id}/status") + status_data = status_response.json() + task_status = status_data.get("status") + + if task_status == "completed": + # Update completion bar safely + if complete_bar and complete_lock: + with complete_lock: + complete_bar.update(1) + return True + elif task_status == "failed": + logger.error(f"Task {task_index + 1} (task_id: {task_id}) failed") + if complete_bar and complete_lock: + with complete_lock: + complete_bar.update(1) # Still update progress even if failed + return False + else: + time.sleep(0.5) + + except Exception as e: + logger.error(f"Failed to check status for task_id {task_id}: {e}") + time.sleep(0.5) + + except Exception as e: + logger.error(f"Failed to send task to {url}: {e}") + return False + + +def get_available_urls(urls): + """Check which URLs are available and return the list""" + available_urls = [] + for url in urls: + try: + _ = requests.get(f"{url}/v1/service/status").json() + available_urls.append(url) + except Exception as e: + continue + + if not available_urls: + logger.error("No available urls.") + return None + + logger.info(f"available_urls: {available_urls}") + return available_urls + + +def find_idle_server(available_urls): + """Find an idle server from available URLs""" + while True: + for url in available_urls: + try: + response = requests.get(f"{url}/v1/service/status").json() + if response["service_status"] == "idle": + return url + except Exception as e: + continue + time.sleep(3) + + +def process_tasks_async(messages, available_urls, show_progress=True): + """Process a list of tasks asynchronously across multiple servers""" + if not available_urls: + logger.error("No available servers to process tasks.") + return False + + active_threads = [] + + logger.info(f"Sending {len(messages)} tasks to available servers...") + + complete_bar = None + complete_lock = None + if show_progress: + complete_bar = tqdm(total=len(messages), desc="Completing tasks") + complete_lock = threading.Lock() # Thread-safe updates to completion bar + + for idx, message in enumerate(messages): + # Find an idle server + server_url = find_idle_server(available_urls) + + # Create and start thread for sending and monitoring task + thread = threading.Thread(target=send_and_monitor_task, args=(server_url, message, idx, complete_bar, complete_lock)) + thread.daemon = False + thread.start() + active_threads.append(thread) + + # Small delay to let thread start + time.sleep(0.5) + + # Wait for all threads to complete + for thread in active_threads: + thread.join() + + # Close completion bar + if complete_bar: + complete_bar.close() + + logger.info("All tasks processing completed!") + return True diff --git a/scripts/server/post_multi_servers_i2v.py b/scripts/server/post_multi_servers_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..7d13026f54aed3b24870529114d239f61a45b580 --- /dev/null +++ b/scripts/server/post_multi_servers_i2v.py @@ -0,0 +1,40 @@ +import base64 + +from loguru import logger +from post_multi_servers import get_available_urls, process_tasks_async + + +def image_to_base64(image_path): + """Convert an image file to base64 string""" + with open(image_path, "rb") as f: + image_data = f.read() + return base64.b64encode(image_data).decode("utf-8") + + +if __name__ == "__main__": + urls = [f"http://localhost:{port}" for port in range(8000, 8008)] + img_prompts = { + "assets/inputs/imgs/img_0.jpg": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + "assets/inputs/imgs/img_2.jpg": "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting.", + } + negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + messages = [] + for i, (image_path, prompt) in enumerate(img_prompts.items()): + messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_path, "save_result_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"}) + + logger.info(f"urls: {urls}") + + # Get available servers + available_urls = get_available_urls(urls) + if not available_urls: + exit(1) + + # Process tasks asynchronously + success = process_tasks_async(messages, available_urls, show_progress=True) + + if success: + logger.info("All tasks completed successfully!") + else: + logger.error("Some tasks failed.") + exit(1) diff --git a/scripts/server/post_multi_servers_t2v.py b/scripts/server/post_multi_servers_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..4f9f1e5c8f80524df262573ab58105c3636071ba --- /dev/null +++ b/scripts/server/post_multi_servers_t2v.py @@ -0,0 +1,166 @@ +import argparse +from pathlib import Path + +from loguru import logger +from post_multi_servers import get_available_urls, process_tasks_async + + +def load_prompts_from_folder(folder_path): + """Load prompts from all files in the specified folder. + + Returns: + tuple: (prompts, filenames) where prompts is a list of prompt strings + and filenames is a list of corresponding filenames + """ + prompts = [] + filenames = [] + folder = Path(folder_path) + + if not folder.exists() or not folder.is_dir(): + logger.error(f"Prompt folder does not exist or is not a directory: {folder_path}") + return prompts, filenames + + # Get all files in the folder and sort them + files = sorted(folder.glob("*")) + files = [f for f in files if f.is_file()] + + for file_path in files: + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read().strip() + if content: # Only add non-empty prompts + prompts.append(content) + filenames.append(file_path.name) + # logger.info(f"Loaded prompt from {file_path.name}") + except Exception as e: + logger.warning(f"Failed to read file {file_path}: {e}") + + return prompts, filenames + + +def load_prompts_from_file(file_path): + """Load prompts from a file, one prompt per line. + + Returns: + list: prompts, where each element is a prompt string + """ + prompts = [] + file = Path(file_path) + + if not file.exists() or not file.is_file(): + logger.error(f"Prompt file does not exist or is not a file: {file_path}") + return prompts + + try: + with open(file, "r", encoding="utf-8") as f: + for line in f: + prompt = line.strip() + if prompt: # Only add non-empty prompts + prompts.append(prompt) + except Exception as e: + logger.error(f"Failed to read prompt file {file_path}: {e}") + + return prompts + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Post prompts to multiple T2V servers") + parser.add_argument("--prompt-folder", type=str, default=None, help="Folder containing prompt files. If not specified, use default prompts.") + parser.add_argument("--prompt-file", type=str, default=None, help="File containing prompts, one prompt per line. Cannot be used together with --prompt-folder.") + parser.add_argument("--save-folder", type=str, default="./", help="Folder to save output videos. Default is current directory.") + args = parser.parse_args() + + # Check that --prompt-folder and --prompt-file are not used together + if args.prompt_folder and args.prompt_file: + logger.error("Cannot use --prompt-folder and --prompt-file together. Please choose one.") + exit(1) + + # Generate URLs from IPs (each IP has 8 ports: 8000-8007) + ips = ["localhost"] + urls = [f"http://{ip}:{port}" for ip in ips for port in range(8000, 8008)] + # urls = ["http://localhost:8007"] + + logger.info(f"urls: {urls}") + + # Get available servers + available_urls = get_available_urls(urls) + if not available_urls: + exit(1) + + logger.info(f"Total {len(available_urls)} available servers.") + + # Load prompts from folder, file, or use default prompts + prompt_filenames = None + if args.prompt_folder: + logger.info(f"Loading prompts from folder: {args.prompt_folder}") + prompts, prompt_filenames = load_prompts_from_folder(args.prompt_folder) + if not prompts: + logger.error("No valid prompts loaded from folder.") + exit(1) + elif args.prompt_file: + logger.info(f"Loading prompts from file: {args.prompt_file}") + prompts = load_prompts_from_file(args.prompt_file) + if not prompts: + logger.error("No valid prompts loaded from file.") + exit(1) + else: + logger.info("Using default prompts") + prompts = [ + "A cat walks on the grass, realistic style.", + "A person is riding a bike. Realistic, Natural lighting, Casual.", + "A car turns a corner. Realistic, Natural lighting, Casual.", + "An astronaut is flying in space, Van Gogh style. Dark, Mysterious.", + "A beautiful coastal beach in spring, waves gently lapping on the sand, the camera movement is Zoom In. Realistic, Natural lighting, Peaceful.", + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + ] + + negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + # Prepare save folder + save_folder = Path(args.save_folder) + save_folder.mkdir(parents=True, exist_ok=True) + + messages = [] + total_count = len(prompts) + skipped_count = 0 + + for i, prompt in enumerate(prompts): + # Generate output filename + if prompt_filenames: + # Use prompt filename, replace extension with .mp4 + filename = Path(prompt_filenames[i]).stem + ".mp4" + else: + # Use default naming + filename = f"output_lightx2v_wan_t2v_{i + 1}.mp4" + + save_path = save_folder / filename + + # Skip if file already exists (only when using prompt_filenames) + if prompt_filenames and save_path.exists(): + logger.info(f"Skipping {filename} - file already exists") + skipped_count += 1 + continue + + messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": "", "save_result_path": str(save_path)}) + + # Log statistics + to_process_count = len(messages) + logger.info("=" * 80) + logger.info("Task Statistics:") + logger.info(f" Total prompts: {total_count}") + logger.info(f" Skipped (already exists): {skipped_count}") + logger.info(f" To process: {to_process_count}") + logger.info("=" * 80) + + if to_process_count == 0: + logger.info("No tasks to process. All files already exist.") + exit(0) + + # Process tasks asynchronously + success = process_tasks_async(messages, available_urls, show_progress=True) + + if success: + logger.info("All tasks completed successfully!") + else: + logger.error("Some tasks failed.") + exit(1) diff --git a/scripts/server/post_vbench_i2v.py b/scripts/server/post_vbench_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..c971b7ecffb2f4c1de2568be21b4c5390b06fab8 --- /dev/null +++ b/scripts/server/post_vbench_i2v.py @@ -0,0 +1,68 @@ +import argparse +import glob +import os + +from loguru import logger +from post_multi_servers import get_available_urls, process_tasks_async + + +def create_i2v_messages(img_files, output_path): + """Create messages for image-to-video tasks""" + messages = [] + negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + for img_path in img_files: + file_name = os.path.basename(img_path) + prompt = os.path.splitext(file_name)[0] + save_result_path = os.path.join(output_path, f"{prompt}.mp4") + + message = { + "seed": 42, + "prompt": prompt, + "negative_prompt": negative_prompt, + "image_path": img_path, + "save_result_path": save_result_path, + } + messages.append(message) + + return messages + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", type=str, required=True, help="path to img files.") + parser.add_argument("--output_path", type=str, default="./vbench_i2v", help="output video path.") + args = parser.parse_args() + + # Create server URLs + urls = [f"http://localhost:{port}" for port in range(8000, 8008)] + + # Get available servers + available_urls = get_available_urls(urls) + if not available_urls: + exit(1) + + # Find image files + if os.path.exists(args.data_path): + img_files = glob.glob(os.path.join(args.data_path, "*.jpg")) + logger.info(f"Found {len(img_files)} image files.") + + if not img_files: + logger.error("No image files found.") + exit(1) + + # Create messages for all images + messages = create_i2v_messages(img_files, args.output_path) + logger.info(f"Created {len(messages)} tasks.") + + # Process tasks asynchronously + success = process_tasks_async(messages, available_urls, show_progress=True) + + if success: + logger.info("All image-to-video tasks completed successfully!") + else: + logger.error("Some tasks failed.") + exit(1) + else: + logger.error(f"Data path does not exist: {args.data_path}") + exit(1) diff --git a/scripts/server/readme.md b/scripts/server/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..8a6fb13f07ecd75fc8d4f4f8f497c98ea9a9e05e --- /dev/null +++ b/scripts/server/readme.md @@ -0,0 +1 @@ +## todo diff --git a/scripts/server/start_multi_servers.sh b/scripts/server/start_multi_servers.sh new file mode 100644 index 0000000000000000000000000000000000000000..809e32cb3d48524f3210b5ec51f8440be00e1646 --- /dev/null +++ b/scripts/server/start_multi_servers.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v +model_path=/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8 + + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +# Start multiple servers +torchrun --nproc_per_node 4 -m lightx2v.server \ + --model_cls seko_talk \ + --task i2v \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/seko_talk/xxx_dist.json \ + --port 8000 diff --git a/scripts/server/start_server.sh b/scripts/server/start_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..3b1644a5aa81658e6790510f69945bb2011738a9 --- /dev/null +++ b/scripts/server/start_server.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +# Start API server with distributed inference service +python -m lightx2v.server \ +--model_cls hunyuan_video_1.5_distill \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json \ +--port 8000 + +echo "Service stopped" diff --git a/scripts/server/start_server_i2i.sh b/scripts/server/start_server_i2i.sh new file mode 100644 index 0000000000000000000000000000000000000000..ba814b2b4cb171173bfdab7722be32aa9c0359e9 --- /dev/null +++ b/scripts/server/start_server_i2i.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +# Start API server with distributed inference service +python -m lightx2v.server \ +--model_cls qwen_image \ +--task i2i \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \ +--port 8000 + +echo "Service stopped" + +# { +# "prompt": "turn the style of the photo to vintage comic book", +# "image_path": "assets/inputs/imgs/snake.png", +# "infer_steps": 50 +# } diff --git a/scripts/server/start_server_t2i.sh b/scripts/server/start_server_t2i.sh new file mode 100644 index 0000000000000000000000000000000000000000..d740778f96259eccd67b76efcb6558743f8e2d3b --- /dev/null +++ b/scripts/server/start_server_t2i.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +# Start API server with distributed inference service +python -m lightx2v.server \ +--model_cls qwen_image \ +--task t2i \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \ +--port 8000 + +echo "Service stopped" + + +# { +# "prompt": "a beautiful sunset over the ocean", +# "aspect_ratio": "16:9", +# "infer_steps": 50 +# } diff --git a/scripts/server/stop_running_task.py b/scripts/server/stop_running_task.py new file mode 100644 index 0000000000000000000000000000000000000000..65f643c966a985704703d1733c5f344106aaf094 --- /dev/null +++ b/scripts/server/stop_running_task.py @@ -0,0 +1,5 @@ +import requests +from loguru import logger + +response = requests.get("http://localhost:8000/v1/local/video/generate/stop_running_task") +logger.info(response.json()) diff --git a/scripts/sparse_attn/spas_sage_attn/run_wan_i2v.sh b/scripts/sparse_attn/spas_sage_attn/run_wan_i2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..67e5a714b600abe206ca4204e7c2dfdffd68b02e --- /dev/null +++ b/scripts/sparse_attn/spas_sage_attn/run_wan_i2v.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/sparse_attn/spas_sage_attn/wan_i2v.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_spas_sage_attn.mp4 diff --git a/scripts/sparse_attn/spas_sage_attn/run_wan_t2v.sh b/scripts/sparse_attn/spas_sage_attn/run_wan_t2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..13ad8bf1624ff738f1401e63ecf3740baa4e04f6 --- /dev/null +++ b/scripts/sparse_attn/spas_sage_attn/run_wan_t2v.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/sparse_attn/spas_sage_attn/wan_t2v.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_spas_sage_attn.mp4 diff --git a/scripts/video_frame_interpolation/run_wan_t2v_video_frame_interpolation.sh b/scripts/video_frame_interpolation/run_wan_t2v_video_frame_interpolation.sh new file mode 100644 index 0000000000000000000000000000000000000000..3ce5131b725866c7665222f16fc5307721e9a8ee --- /dev/null +++ b/scripts/video_frame_interpolation/run_wan_t2v_video_frame_interpolation.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +# Run inference with VFI enabled through config file +# The wan_t2v.json config contains video_frame_interpolation settings +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/video_frame_interpolation/wan_t2v.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_video_frame_interpolation.mp4 diff --git a/scripts/wan/run_wan_flf2v.sh b/scripts/wan/run_wan_flf2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..35ff6400c6e2ca73fb975674247342d7288e2599 --- /dev/null +++ b/scripts/wan/run_wan_flf2v.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task flf2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan/wan_flf2v.json \ +--prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/flf2v_input_first_frame-fs8.png \ +--last_frame_path ${lightx2v_path}/assets/inputs/imgs/flf2v_input_last_frame-fs8.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_flf2v.mp4 diff --git a/scripts/wan/run_wan_i2v.sh b/scripts/wan/run_wan_i2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..f8384121192af73b7f07dd3d6f70ed55ab54ef65 --- /dev/null +++ b/scripts/wan/run_wan_i2v.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan/wan_i2v.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 diff --git a/scripts/wan/run_wan_i2v_causvid.sh b/scripts/wan/run_wan_i2v_causvid.sh new file mode 100644 index 0000000000000000000000000000000000000000..0b93f19ac2addd8bc8e924d5002e885e78b99e09 --- /dev/null +++ b/scripts/wan/run_wan_i2v_causvid.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_causvid \ +--task i2v \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/causvid/wan_i2v_causvid.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_causvid.mp4 diff --git a/scripts/wan/run_wan_i2v_distill_4step_cfg.sh b/scripts/wan/run_wan_i2v_distill_4step_cfg.sh new file mode 100644 index 0000000000000000000000000000000000000000..e6cbb781516bf460e73b3d08fa01f422a0fd8eef --- /dev/null +++ b/scripts/wan/run_wan_i2v_distill_4step_cfg.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_distill \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_distill.mp4 diff --git a/scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh b/scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..8e54a8879edb646b8f6e7968037b1d4c03f35feb --- /dev/null +++ b/scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_distill \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg_lora.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_distill.mp4 diff --git a/scripts/wan/run_wan_i2v_lazy_load.sh b/scripts/wan/run_wan_i2v_lazy_load.sh new file mode 100644 index 0000000000000000000000000000000000000000..07ac97107567dbfb78838c3cf54dff4c14d1bb8e --- /dev/null +++ b/scripts/wan/run_wan_i2v_lazy_load.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh +export DTYPE=FP16 +export SENSITIVE_LAYER_DTYPE=FP16 +export PROFILING_DEBUG_LEVEL=2 + +echo "===============================================================================" +echo "LightX2V Lazyload Environment Variables Summary:" +echo "-------------------------------------------------------------------------------" +echo "lightx2v_path: ${lightx2v_path}" +echo "model_path: ${model_path}" +echo "-------------------------------------------------------------------------------" +echo "Model Inference Data Type: ${DTYPE}" +echo "Sensitive Layer Data Type: ${SENSITIVE_LAYER_DTYPE}" +echo "Performance Profiling Debug Level: ${PROFILING_DEBUG_LEVEL}" +echo "===============================================================================" + + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 diff --git a/scripts/wan/run_wan_i2v_nbhd_attn_480p.sh b/scripts/wan/run_wan_i2v_nbhd_attn_480p.sh new file mode 100644 index 0000000000000000000000000000000000000000..c91dd8e764fe91059afbdf293651c9e412798f1c --- /dev/null +++ b/scripts/wan/run_wan_i2v_nbhd_attn_480p.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/attentions/wan_i2v_nbhd_480p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_nbhd_attn_480p.mp4 diff --git a/scripts/wan/run_wan_i2v_nbhd_attn_720p.sh b/scripts/wan/run_wan_i2v_nbhd_attn_720p.sh new file mode 100644 index 0000000000000000000000000000000000000000..1fb35149aac021b6ec62f3c81585056ebf81fc76 --- /dev/null +++ b/scripts/wan/run_wan_i2v_nbhd_attn_720p.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/attentions/wan_i2v_nbhd_720p.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_nbhd_attn_720p.mp4 diff --git a/scripts/wan/run_wan_t2v.sh b/scripts/wan/run_wan_t2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..a539ebef79ebf7dae8cd47a2642862e83dc338f9 --- /dev/null +++ b/scripts/wan/run_wan_t2v.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan/wan_t2v.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 diff --git a/scripts/wan/run_wan_t2v_causvid.sh b/scripts/wan/run_wan_t2v_causvid.sh new file mode 100644 index 0000000000000000000000000000000000000000..a4672582a1af92d1d801f9e9fc9fee0075cc5b7e --- /dev/null +++ b/scripts/wan/run_wan_t2v_causvid.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_causvid \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/causvid/wan_t2v_causvid.json \ +--prompt "Two anthropomorphic cats fight intensely on a spotlighted stage; the left cat wearing blue boxing gear with matching gloves, the right cat in bright red boxing attire and gloves." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_causvid.mp4 diff --git a/scripts/wan/run_wan_t2v_distill_4step_cfg.sh b/scripts/wan/run_wan_t2v_distill_4step_cfg.sh new file mode 100644 index 0000000000000000000000000000000000000000..586cc8234c0024cfe7672344ad4e09edfe62ad73 --- /dev/null +++ b/scripts/wan/run_wan_t2v_distill_4step_cfg.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_distill \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--use_prompt_enhancer \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 diff --git a/scripts/wan/run_wan_t2v_distill_4step_cfg_dynamic.sh b/scripts/wan/run_wan_t2v_distill_4step_cfg_dynamic.sh new file mode 100644 index 0000000000000000000000000000000000000000..f15cfd0bb2635dd8dc022331d9aae2d45c9c156c --- /dev/null +++ b/scripts/wan/run_wan_t2v_distill_4step_cfg_dynamic.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_distill \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_dynamic.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--use_prompt_enhancer \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_cfg_4.mp4 diff --git a/scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh b/scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..09bb104a6fd8a3ef3166782563a5c38590b42a3f --- /dev/null +++ b/scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_distill \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--use_prompt_enhancer \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 diff --git a/scripts/wan/run_wan_vace.sh b/scripts/wan/run_wan_vace.sh new file mode 100644 index 0000000000000000000000000000000000000000..8d6a733f2ad5182e818d338d12356dfb9ae3f6a7 --- /dev/null +++ b/scripts/wan/run_wan_vace.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_vace \ +--task vace \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan/wan_vace.json \ +--prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--src_ref_images ${lightx2v_path}/assets/inputs/imgs/girl.png,${lightx2v_path}/assets/inputs/imgs/snake.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_vace.mp4\ diff --git a/scripts/wan22/run_wan22_animate.sh b/scripts/wan22/run_wan22_animate.sh new file mode 100644 index 0000000000000000000000000000000000000000..89e2f03c53cbbba08da97c38877c1242935824b1 --- /dev/null +++ b/scripts/wan22/run_wan22_animate.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= +video_path= +refer_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +# process +python ${lightx2v_path}/tools/preprocess/preprocess_data.py \ + --ckpt_path ${model_path}/process_checkpoint \ + --video_path $video_path \ + --refer_path $refer_path \ + --save_path ${lightx2v_path}/save_results/animate/process_results \ + --resolution_area 1280 720 \ + --retarget_flag \ + +python -m lightx2v.infer \ +--model_cls wan2.2_animate \ +--task animate \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_animate.json \ +--src_pose_path ${lightx2v_path}/save_results/animate/process_results/src_pose.mp4 \ +--src_face_path ${lightx2v_path}/save_results/animate/process_results/src_face.mp4 \ +--src_ref_images ${lightx2v_path}/save_results/animate/process_results/src_ref.png \ +--prompt "视频中的人在做动作" \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_animate.mp4 diff --git a/scripts/wan22/run_wan22_animate_lora.sh b/scripts/wan22/run_wan22_animate_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..8919630778ea83de70eaa5c79a82487254ca6301 --- /dev/null +++ b/scripts/wan22/run_wan22_animate_lora.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= +video_path= +refer_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +# process +python ${lightx2v_path}/tools/preprocess/preprocess_data.py \ + --ckpt_path ${model_path}/process_checkpoint \ + --video_path $video_path \ + --refer_path $refer_path \ + --save_path ${lightx2v_path}/save_results/animate/process_results \ + --resolution_area 1280 720 \ + --retarget_flag \ + +python -m lightx2v.infer \ +--model_cls wan2.2_animate \ +--task animate \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_animate_lora.json \ +--src_pose_path ${lightx2v_path}/save_results/animate/process_results/src_pose.mp4 \ +--src_face_path ${lightx2v_path}/save_results/animate/process_results/src_face.mp4 \ +--src_ref_images ${lightx2v_path}/save_results/animate/process_results/src_ref.png \ +--prompt "视频中的人在做动作" \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_animate_lora.mp4 diff --git a/scripts/wan22/run_wan22_animate_replace.sh b/scripts/wan22/run_wan22_animate_replace.sh new file mode 100644 index 0000000000000000000000000000000000000000..8932c48e28777a8168610cc2c06445d93929321c --- /dev/null +++ b/scripts/wan22/run_wan22_animate_replace.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= +video_path= +refer_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +# process +python ${lightx2v_path}/tools/preprocess/preprocess_data.py \ + --ckpt_path ${model_path}/process_checkpoint \ + --video_path $video_path \ + --refer_path $refer_path \ + --save_path ${lightx2v_path}/save_results/replace/process_results \ + --resolution_area 1280 720 \ + --iterations 3 \ + --k 7 \ + --w_len 1 \ + --h_len 1 \ + --replace_flag + +python -m lightx2v.infer \ +--model_cls wan2.2_animate \ +--task animate \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_animate_replace_4090.json \ +--src_pose_path ${lightx2v_path}/save_results/animate/process_results/src_pose.mp4 \ +--src_face_path ${lightx2v_path}/save_results/animate/process_results/src_face.mp4 \ +--src_ref_images ${lightx2v_path}/save_results/animate/process_results/src_ref.png \ +--src_bg_path ${lightx2v_path}/save_results/animate/process_results/src_bg.mp4 \ +--src_mask_path ${lightx2v_path}/save_results/animate/process_results/src_mask.mp4 \ +--prompt "视频中的人在做动作" \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_replace.mp4 diff --git a/scripts/wan22/run_wan22_distill_moe_flf2v.sh b/scripts/wan22/run_wan22_distill_moe_flf2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..1741ce29c4870232bb274755e21aa9f457d8a5a6 --- /dev/null +++ b/scripts/wan22/run_wan22_distill_moe_flf2v.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + + +python -m lightx2v.infer \ +--model_cls wan2.2_moe_distill \ +--task flf2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_distill_moe_flf2v.json \ +--prompt "A bearded man with red facial hair wearing a yellow straw hat and dark coat in Van Gogh's self-portrait style, slowly and continuously transforms into a space astronaut. The transformation flows like liquid paint - his beard fades away strand by strand, the yellow hat melts and reforms smoothly into a silver space helmet, dark coat gradually lightens and restructures into a white spacesuit. The background swirling brushstrokes slowly organize and clarify into realistic stars and space, with Earth appearing gradually in the distance. Every change happens in seamless waves, maintaining visual continuity throughout the metamorphosis.\n\nConsistent soft lighting throughout, medium close-up maintaining same framing, central composition stays fixed, gentle color temperature shift from warm to cool, gradual contrast increase, smooth style transition from painterly to photorealistic. Static camera with subtle slow zoom, emphasizing the flowing transformation process without abrupt changes." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path /mtc/gushiqiao/llmc_workspace/wan22_14B_flf2v_start_image.png \ +--last_frame_path /mtc/gushiqiao/llmc_workspace/wan22_14B_flf2v_end_image.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_flf2v.mp4 diff --git a/scripts/wan22/run_wan22_moe_flf2v.sh b/scripts/wan22/run_wan22_moe_flf2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..24419ac7fe9c6350073445e88f0534b6f542f648 --- /dev/null +++ b/scripts/wan22/run_wan22_moe_flf2v.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task flf2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_moe_flf2v.json \ +--prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/flf2v_input_first_frame-fs8.png \ +--last_frame_path ${lightx2v_path}/assets/inputs/imgs/flf2v_input_last_frame-fs8.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_flf2v.mp4 diff --git a/scripts/wan22/run_wan22_moe_i2v.sh b/scripts/wan22/run_wan22_moe_i2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf974d12d58f27c9d2a7da050424c5de99d53600 --- /dev/null +++ b/scripts/wan22/run_wan22_moe_i2v.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_moe_i2v.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v.mp4 diff --git a/scripts/wan22/run_wan22_moe_i2v_distill.sh b/scripts/wan22/run_wan22_moe_i2v_distill.sh new file mode 100644 index 0000000000000000000000000000000000000000..edd1c6f8adcbbf186d31d6fd64cef85ec5117261 --- /dev/null +++ b/scripts/wan22/run_wan22_moe_i2v_distill.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.2_moe_distill \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_moe_i2v_distill.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_distill.mp4 diff --git a/scripts/wan22/run_wan22_moe_t2v.sh b/scripts/wan22/run_wan22_moe_t2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..9b1494f6e2eabf67c79fda7ae86de9aafd7132c3 --- /dev/null +++ b/scripts/wan22/run_wan22_moe_t2v.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.2_moe \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_moe_t2v.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_t2v.mp4 diff --git a/scripts/wan22/run_wan22_moe_t2v_distill.sh b/scripts/wan22/run_wan22_moe_t2v_distill.sh new file mode 100644 index 0000000000000000000000000000000000000000..f412ab4c74fdaad529a964585946dbb82e88bdb5 --- /dev/null +++ b/scripts/wan22/run_wan22_moe_t2v_distill.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.2_moe_distill \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_moe_t2v_distill.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_t2v_distill.mp4 diff --git a/scripts/wan22/run_wan22_ti2v_i2v.sh b/scripts/wan22/run_wan22_ti2v_i2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..045d3b0c64ad9d52a57161df14b13b74176142bf --- /dev/null +++ b/scripts/wan22/run_wan22_ti2v_i2v.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_ti2v_i2v.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_ti2v_i2v.mp4 diff --git a/scripts/wan22/run_wan22_ti2v_t2v.sh b/scripts/wan22/run_wan22_ti2v_t2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..13d9b5fc2f66e444628a6db250e00acc971db563 --- /dev/null +++ b/scripts/wan22/run_wan22_ti2v_t2v.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.2 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan22/wan_ti2v_t2v.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" \ +--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_ti2v_t2v.mp4 diff --git a/scripts/win/run_wan_i2v.bat b/scripts/win/run_wan_i2v.bat new file mode 100644 index 0000000000000000000000000000000000000000..bc998a645ba614dc81496ae5fb81610982f0dad5 --- /dev/null +++ b/scripts/win/run_wan_i2v.bat @@ -0,0 +1,53 @@ +@echo off +chcp 65001 >nul +echo 启动LightX2V I2V推理... + +:: 设置路径 +set lightx2v_path=D:\LightX2V +set model_path=D:\models\Wan2.1-I2V-14B-480P-Lightx2v + +:: 检查CUDA_VISIBLE_DEVICES +if "%CUDA_VISIBLE_DEVICES%"=="" ( + set cuda_devices=0 + echo Warn: CUDA_VISIBLE_DEVICES is not set, using default value: %cuda_devices%, change at shell script or set env variable. + set CUDA_VISIBLE_DEVICES=%cuda_devices% +) + +:: 检查路径 +if "%lightx2v_path%"=="" ( + echo Error: lightx2v_path is not set. Please set this variable first. + exit /b 1 +) + +if "%model_path%"=="" ( + echo Error: model_path is not set. Please set this variable first. + exit /b 1 +) + +:: 设置环境变量 +set TOKENIZERS_PARALLELISM=false +set PYTHONPATH=%lightx2v_path%;%PYTHONPATH% +set PROFILING_DEBUG_LEVEL=2 +set DTYPE=BF16 + +echo 环境变量设置完成! +echo PYTHONPATH: %PYTHONPATH% +echo CUDA_VISIBLE_DEVICES: %CUDA_VISIBLE_DEVICES% +echo 模型路径: %model_path% + +:: 切换到项目目录 +cd /d %lightx2v_path% + +:: 运行推理 +python -m lightx2v.infer ^ +--model_cls wan2.1 ^ +--task i2v ^ +--model_path %model_path% ^ +--config_json %lightx2v_path%/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json ^ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." ^ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" ^ +--image_path %lightx2v_path%/assets/inputs/imgs/img_0.jpg ^ +--save_result_path %lightx2v_path%/save_results/output_lightx2v_wan_i2v.mp4 + +echo 推理完成! +pause diff --git a/scripts/win/run_wan_t2v.bat b/scripts/win/run_wan_t2v.bat new file mode 100644 index 0000000000000000000000000000000000000000..52caedecf6db283978eeaa65003caacdc44fad75 --- /dev/null +++ b/scripts/win/run_wan_t2v.bat @@ -0,0 +1,52 @@ +@echo off +chcp 65001 >nul +echo 启动LightX2V T2V推理... + +:: 设置路径 +set lightx2v_path=D:\LightX2V +set model_path=D:\models\Wan2.1-T2V-1.3B-Lightx2v + +:: 检查CUDA_VISIBLE_DEVICES +if "%CUDA_VISIBLE_DEVICES%"=="" ( + set cuda_devices=0 + echo Warn: CUDA_VISIBLE_DEVICES is not set, using default value: %cuda_devices%, change at shell script or set env variable. + set CUDA_VISIBLE_DEVICES=%cuda_devices% +) + +:: 检查路径 +if "%lightx2v_path%"=="" ( + echo Error: lightx2v_path is not set. Please set this variable first. + exit /b 1 +) + +if "%model_path%"=="" ( + echo Error: model_path is not set. Please set this variable first. + exit /b 1 +) + +:: 设置环境变量 +set TOKENIZERS_PARALLELISM=false +set PYTHONPATH=%lightx2v_path%;%PYTHONPATH% +set PROFILING_DEBUG_LEVEL=2 +set DTYPE=BF16 + +echo 环境变量设置完成! +echo PYTHONPATH: %PYTHONPATH% +echo CUDA_VISIBLE_DEVICES: %CUDA_VISIBLE_DEVICES% +echo 模型路径: %model_path% + +:: 切换到项目目录 +cd /d %lightx2v_path% + +:: 运行推理 +python -m lightx2v.infer ^ +--model_cls wan2.1 ^ +--task t2v ^ +--model_path %model_path% ^ +--config_json %lightx2v_path%/configs/offload/block/wan_t2v_1_3b.json ^ +--prompt "A beautiful sunset over a calm ocean, with golden rays of light reflecting on the water surface. The sky is painted with vibrant orange and pink clouds. A peaceful and serene atmosphere." ^ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" ^ +--save_result_path %lightx2v_path%/save_results/output_lightx2v_wan_t2v.mp4 + +echo 推理完成! +pause diff --git a/setup_vae.py b/setup_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..3814252d2e4ca3cb4905b7ad8ed2599f37aefb4d --- /dev/null +++ b/setup_vae.py @@ -0,0 +1,86 @@ +""" +LightX2V Setup Script +Minimal installation for VAE models only +""" + +import os + +from setuptools import find_packages, setup + + +# Read the README file +def read_readme(): + readme_path = os.path.join(os.path.dirname(__file__), "README.md") + if os.path.exists(readme_path): + with open(readme_path, "r", encoding="utf-8") as f: + return f.read() + return "" + + +# Core dependencies for VAE models +vae_dependencies = [ + "torch>=2.0.0", + "numpy>=1.20.0", + "einops>=0.6.0", + "loguru>=0.6.0", +] + +# Full dependencies for complete LightX2V +full_dependencies = [ + "packaging", + "ninja", + "torch", + "torchvision", + "diffusers", + "transformers", + "tokenizers", + "tqdm", + "accelerate", + "safetensors", + "opencv-python", + "numpy", + "imageio", + "imageio-ffmpeg", + "einops", + "loguru", + "ftfy", + "gradio", + "aiohttp", + "pydantic", + "fastapi", + "uvicorn", + "requests", + "decord", +] + +setup( + name="lightx2v", + version="1.0.0", + author="LightX2V Team", + author_email="", + description="LightX2V: High-performance video generation models with optimized VAE", + long_description=read_readme(), + long_description_content_type="text/markdown", + url="https://github.com/ModelTC/LightX2V", + packages=find_packages(include=["lightx2v", "lightx2v.*"]), + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + python_requires=">=3.8", + install_requires=vae_dependencies, + extras_require={ + "full": full_dependencies, + "vae": vae_dependencies, + }, + include_package_data=True, + zip_safe=False, +) diff --git a/test_cases/run_matrix_game2_gta_drive.sh b/test_cases/run_matrix_game2_gta_drive.sh new file mode 100644 index 0000000000000000000000000000000000000000..43c41ffe83714a622b62f049ad1019a5c1e170fc --- /dev/null +++ b/test_cases/run_matrix_game2_gta_drive.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path=path to Lightx2v +model_path=path to Skywork/Matrix-Game-2.0 + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf_mtxg2 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_gta_drive.json \ +--prompt '' \ +--image_path gta_drive/0003.png \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_gta_drive.mp4 \ +--seed 42 diff --git a/test_cases/run_qwen_image_i2i_2509.sh b/test_cases/run_qwen_image_i2i_2509.sh new file mode 100644 index 0000000000000000000000000000000000000000..f127d2dac10b3f036ac6a954488e62a39f6179be --- /dev/null +++ b/test_cases/run_qwen_image_i2i_2509.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls qwen_image \ + --task i2i \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2509.json \ + --prompt "Have the two characters swap clothes and stand in front of the castle." \ + --negative_prompt " " \ + --image_path 1.jpeg,2.jpeg \ + --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2509.png \ + --seed 0 diff --git a/test_cases/run_seko_talk_01_base.sh b/test_cases/run_seko_talk_01_base.sh new file mode 100644 index 0000000000000000000000000000000000000000..cd9258f42bfb2188be71b6e05d9597e64f51435a --- /dev/null +++ b/test_cases/run_seko_talk_01_base.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +lightx2v_path=/path/to/Lightx2v +model_path=/path/to/SekoTalk-Distill + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +export SENSITIVE_LAYER_DTYPE=None + +python -m lightx2v.infer \ +--model_cls seko_talk \ +--task s2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_01_base.json \ +--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \ +--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4 diff --git a/test_cases/run_wan_i2v.sh b/test_cases/run_wan_i2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..f8384121192af73b7f07dd3d6f70ed55ab54ef65 --- /dev/null +++ b/test_cases/run_wan_i2v.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan/wan_i2v.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 diff --git a/test_cases/run_wan_i2v_offload.sh b/test_cases/run_wan_i2v_offload.sh new file mode 100644 index 0000000000000000000000000000000000000000..c657dfa550947f88fd1b6972a3b8295bc924ac42 --- /dev/null +++ b/test_cases/run_wan_i2v_offload.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/offload/phase/wan_i2v_phase.json \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 diff --git a/test_cases/run_wan_t2v.sh b/test_cases/run_wan_t2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..a539ebef79ebf7dae8cd47a2642862e83dc338f9 --- /dev/null +++ b/test_cases/run_wan_t2v.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan/wan_t2v.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 diff --git a/test_cases/run_wan_t2v_dist_cfg_ulysses.sh b/test_cases/run_wan_t2v_dist_cfg_ulysses.sh new file mode 100644 index 0000000000000000000000000000000000000000..f849940be9bc6389d0912748713352105b0168fd --- /dev/null +++ b/test_cases/run_wan_t2v_dist_cfg_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=8 -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/dist_infer/wan_t2v_dist_cfg_ulysses.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 diff --git a/test_cases/run_wan_t2v_sf.sh b/test_cases/run_wan_t2v_sf.sh new file mode 100644 index 0000000000000000000000000000000000000000..bd8c4e9b73b333e73785e93e87881025ff3d3539 --- /dev/null +++ b/test_cases/run_wan_t2v_sf.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= # path to Wan2.1-T2V-1.3B +sf_model_path= # path to gdhe17/Self-Forcing + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls wan2.1_sf \ +--task t2v \ +--model_path $model_path \ +--sf_model_path $sf_model_path \ +--config_json ${lightx2v_path}/configs/self_forcing/wan_t2v_sf.json \ +--prompt 'A stylish woman strolls down a bustling Tokyo street, the warm glow of neon lights and animated city signs casting vibrant reflections. She wears a sleek black leather jacket paired with a flowing red dress and black boots, her black purse slung over her shoulder. Sunglasses perched on her nose and a bold red lipstick add to her confident, casual demeanor. The street is damp and reflective, creating a mirror-like effect that enhances the colorful lights and shadows. Pedestrians move about, adding to the lively atmosphere. The scene is captured in a dynamic medium shot with the woman walking slightly to one side, highlighting her graceful strides.' \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_sf.mp4 diff --git a/tools/convert/converter.py b/tools/convert/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5faaf67bf43ab584e3b6014bd893437b7a55d2d --- /dev/null +++ b/tools/convert/converter.py @@ -0,0 +1,883 @@ +import argparse +import gc +import glob +import json +import multiprocessing +import os +import re +import shutil +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed + +import torch +from loguru import logger + +try: + from lora_loader import LoRALoader +except ImportError: + pass +import sys +from pathlib import Path + +from safetensors import safe_open +from safetensors import torch as st +from tqdm import tqdm + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER +from tools.convert.quant import * + +dtype_mapping = { + "int8": torch.int8, + "fp8": torch.float8_e4m3fn, +} + + +def get_key_mapping_rules(direction, model_type): + if model_type == "wan_dit": + unified_rules = [ + { + "forward": (r"^head\.head$", "proj_out"), + "backward": (r"^proj_out$", "head.head"), + }, + { + "forward": (r"^head\.modulation$", "scale_shift_table"), + "backward": (r"^scale_shift_table$", "head.modulation"), + }, + { + "forward": ( + r"^text_embedding\.0\.", + "condition_embedder.text_embedder.linear_1.", + ), + "backward": ( + r"^condition_embedder.text_embedder.linear_1\.", + "text_embedding.0.", + ), + }, + { + "forward": ( + r"^text_embedding\.2\.", + "condition_embedder.text_embedder.linear_2.", + ), + "backward": ( + r"^condition_embedder.text_embedder.linear_2\.", + "text_embedding.2.", + ), + }, + { + "forward": ( + r"^time_embedding\.0\.", + "condition_embedder.time_embedder.linear_1.", + ), + "backward": ( + r"^condition_embedder.time_embedder.linear_1\.", + "time_embedding.0.", + ), + }, + { + "forward": ( + r"^time_embedding\.2\.", + "condition_embedder.time_embedder.linear_2.", + ), + "backward": ( + r"^condition_embedder.time_embedder.linear_2\.", + "time_embedding.2.", + ), + }, + { + "forward": (r"^time_projection\.1\.", "condition_embedder.time_proj."), + "backward": (r"^condition_embedder.time_proj\.", "time_projection.1."), + }, + { + "forward": (r"blocks\.(\d+)\.self_attn\.q\.", r"blocks.\1.attn1.to_q."), + "backward": ( + r"blocks\.(\d+)\.attn1\.to_q\.", + r"blocks.\1.self_attn.q.", + ), + }, + { + "forward": (r"blocks\.(\d+)\.self_attn\.k\.", r"blocks.\1.attn1.to_k."), + "backward": ( + r"blocks\.(\d+)\.attn1\.to_k\.", + r"blocks.\1.self_attn.k.", + ), + }, + { + "forward": (r"blocks\.(\d+)\.self_attn\.v\.", r"blocks.\1.attn1.to_v."), + "backward": ( + r"blocks\.(\d+)\.attn1\.to_v\.", + r"blocks.\1.self_attn.v.", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.self_attn\.o\.", + r"blocks.\1.attn1.to_out.0.", + ), + "backward": ( + r"blocks\.(\d+)\.attn1\.to_out\.0\.", + r"blocks.\1.self_attn.o.", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.q\.", + r"blocks.\1.attn2.to_q.", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.to_q\.", + r"blocks.\1.cross_attn.q.", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.k\.", + r"blocks.\1.attn2.to_k.", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.to_k\.", + r"blocks.\1.cross_attn.k.", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.v\.", + r"blocks.\1.attn2.to_v.", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.to_v\.", + r"blocks.\1.cross_attn.v.", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.o\.", + r"blocks.\1.attn2.to_out.0.", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.to_out\.0\.", + r"blocks.\1.cross_attn.o.", + ), + }, + { + "forward": (r"blocks\.(\d+)\.norm3\.", r"blocks.\1.norm2."), + "backward": (r"blocks\.(\d+)\.norm2\.", r"blocks.\1.norm3."), + }, + { + "forward": (r"blocks\.(\d+)\.ffn\.0\.", r"blocks.\1.ffn.net.0.proj."), + "backward": ( + r"blocks\.(\d+)\.ffn\.net\.0\.proj\.", + r"blocks.\1.ffn.0.", + ), + }, + { + "forward": (r"blocks\.(\d+)\.ffn\.2\.", r"blocks.\1.ffn.net.2."), + "backward": (r"blocks\.(\d+)\.ffn\.net\.2\.", r"blocks.\1.ffn.2."), + }, + { + "forward": ( + r"blocks\.(\d+)\.modulation\.", + r"blocks.\1.scale_shift_table.", + ), + "backward": ( + r"blocks\.(\d+)\.scale_shift_table(?=\.|$)", + r"blocks.\1.modulation", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.k_img\.", + r"blocks.\1.attn2.add_k_proj.", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.add_k_proj\.", + r"blocks.\1.cross_attn.k_img.", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.v_img\.", + r"blocks.\1.attn2.add_v_proj.", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.add_v_proj\.", + r"blocks.\1.cross_attn.v_img.", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.norm_k_img\.weight", + r"blocks.\1.attn2.norm_added_k.weight", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.norm_added_k\.weight", + r"blocks.\1.cross_attn.norm_k_img.weight", + ), + }, + { + "forward": ( + r"img_emb\.proj\.0\.", + r"condition_embedder.image_embedder.norm1.", + ), + "backward": ( + r"condition_embedder\.image_embedder\.norm1\.", + r"img_emb.proj.0.", + ), + }, + { + "forward": ( + r"img_emb\.proj\.1\.", + r"condition_embedder.image_embedder.ff.net.0.proj.", + ), + "backward": ( + r"condition_embedder\.image_embedder\.ff\.net\.0\.proj\.", + r"img_emb.proj.1.", + ), + }, + { + "forward": ( + r"img_emb\.proj\.3\.", + r"condition_embedder.image_embedder.ff.net.2.", + ), + "backward": ( + r"condition_embedder\.image_embedder\.ff\.net\.2\.", + r"img_emb.proj.3.", + ), + }, + { + "forward": ( + r"img_emb\.proj\.4\.", + r"condition_embedder.image_embedder.norm2.", + ), + "backward": ( + r"condition_embedder\.image_embedder\.norm2\.", + r"img_emb.proj.4.", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.self_attn\.norm_q\.weight", + r"blocks.\1.attn1.norm_q.weight", + ), + "backward": ( + r"blocks\.(\d+)\.attn1\.norm_q\.weight", + r"blocks.\1.self_attn.norm_q.weight", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.self_attn\.norm_k\.weight", + r"blocks.\1.attn1.norm_k.weight", + ), + "backward": ( + r"blocks\.(\d+)\.attn1\.norm_k\.weight", + r"blocks.\1.self_attn.norm_k.weight", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.norm_q\.weight", + r"blocks.\1.attn2.norm_q.weight", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.norm_q\.weight", + r"blocks.\1.cross_attn.norm_q.weight", + ), + }, + { + "forward": ( + r"blocks\.(\d+)\.cross_attn\.norm_k\.weight", + r"blocks.\1.attn2.norm_k.weight", + ), + "backward": ( + r"blocks\.(\d+)\.attn2\.norm_k\.weight", + r"blocks.\1.cross_attn.norm_k.weight", + ), + }, + # head projection mapping + { + "forward": (r"^head\.head\.", "proj_out."), + "backward": (r"^proj_out\.", "head.head."), + }, + ] + + if direction == "forward": + return [rule["forward"] for rule in unified_rules] + elif direction == "backward": + return [rule["backward"] for rule in unified_rules] + else: + raise ValueError(f"Invalid direction: {direction}") + else: + raise ValueError(f"Unsupported model type: {model_type}") + + +def quantize_model( + weights, + w_bit=8, + target_keys=["attn", "ffn"], + adapter_keys=None, + key_idx=2, + ignore_key=None, + linear_type="int8", + non_linear_dtype=torch.float, + comfyui_mode=False, + comfyui_keys=[], +): + """ + Quantize model weights in-place + + Args: + weights: Model state dictionary + w_bit: Quantization bit width + target_keys: List of module names to quantize + + Returns: + Modified state dictionary with quantized weights and scales + """ + total_quantized = 0 + original_size = 0 + quantized_size = 0 + non_quantized_size = 0 + keys = list(weights.keys()) + + with tqdm(keys, desc="Quantizing weights") as pbar: + for key in pbar: + pbar.set_postfix(current_key=key, refresh=False) + + if ignore_key is not None and any(ig_key in key for ig_key in ignore_key): + del weights[key] + continue + + tensor = weights[key] + + # Skip non-tensors and non-2D tensors + if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: + if tensor.dtype != non_linear_dtype: + weights[key] = tensor.to(non_linear_dtype) + non_quantized_size += weights[key].numel() * weights[key].element_size() + else: + non_quantized_size += tensor.numel() * tensor.element_size() + continue + + # Check if key matches target modules + parts = key.split(".") + + if comfyui_mode and (comfyui_keys is not None and key in comfyui_keys): + pass + elif len(parts) < key_idx + 1 or parts[key_idx] not in target_keys: + if adapter_keys is None: + if tensor.dtype != non_linear_dtype: + weights[key] = tensor.to(non_linear_dtype) + non_quantized_size += weights[key].numel() * weights[key].element_size() + else: + non_quantized_size += tensor.numel() * tensor.element_size() + elif not any(adapter_key in parts for adapter_key in adapter_keys): + if tensor.dtype != non_linear_dtype: + weights[key] = tensor.to(non_linear_dtype) + non_quantized_size += weights[key].numel() * weights[key].element_size() + else: + non_quantized_size += tensor.numel() * tensor.element_size() + else: + non_quantized_size += tensor.numel() * tensor.element_size() + continue + + # try: + original_tensor_size = tensor.numel() * tensor.element_size() + original_size += original_tensor_size + + # Quantize tensor and store results + quantizer = CONVERT_WEIGHT_REGISTER[linear_type](tensor) + w_q, scales, extra = quantizer.weight_quant_func(tensor, comfyui_mode) + weight_global_scale = extra.get("weight_global_scale", None) # For nvfp4 + + # Replace original tensor and store scales + weights[key] = w_q + if comfyui_mode: + weights[key.replace(".weight", ".scale_weight")] = scales + else: + weights[key + "_scale"] = scales + if weight_global_scale: + weights[key + "_global_scale"] = weight_global_scale + + quantized_tensor_size = w_q.numel() * w_q.element_size() + scale_size = scales.numel() * scales.element_size() + quantized_size += quantized_tensor_size + scale_size + + total_quantized += 1 + del w_q, scales + + # except Exception as e: + # logger.error(f"Error quantizing {key}: {str(e)}") + + gc.collect() + + original_size_mb = original_size / (1024**2) + quantized_size_mb = quantized_size / (1024**2) + non_quantized_size_mb = non_quantized_size / (1024**2) + total_final_size_mb = (quantized_size + non_quantized_size) / (1024**2) + size_reduction_mb = original_size_mb - quantized_size_mb + + logger.info(f"Quantized {total_quantized} tensors") + logger.info(f"Original quantized tensors size: {original_size_mb:.2f} MB") + logger.info(f"After quantization size: {quantized_size_mb:.2f} MB (includes scales)") + logger.info(f"Non-quantized tensors size: {non_quantized_size_mb:.2f} MB") + logger.info(f"Total final model size: {total_final_size_mb:.2f} MB") + logger.info(f"Size reduction in quantized tensors: {size_reduction_mb:.2f} MB ({size_reduction_mb / original_size_mb * 100:.1f}%)") + + if comfyui_mode: + weights["scaled_fp8"] = torch.zeros(2, dtype=torch.float8_e4m3fn) + + return weights + + +def load_loras(lora_path, weight_dict, alpha, key_mapping_rules=None, strength=1.0): + """ + Load and apply LoRA weights to model weights using the LoRALoader class. + + Args: + lora_path: Path to LoRA safetensors file + weight_dict: Model weights dictionary (will be modified in place) + alpha: Global alpha scaling factor + key_mapping_rules: Optional list of (pattern, replacement) regex rules for key mapping + strength: Additional strength factor for LoRA deltas + """ + logger.info(f"Loading LoRA from: {lora_path} with alpha={alpha}, strength={strength}") + + # Load LoRA weights from safetensors file + with safe_open(lora_path, framework="pt") as f: + lora_weights = {k: f.get_tensor(k) for k in f.keys()} + + # Create LoRA loader with key mapping rules + lora_loader = LoRALoader(key_mapping_rules=key_mapping_rules) + + # Apply LoRA weights to model + lora_loader.apply_lora( + weight_dict=weight_dict, + lora_weights=lora_weights, + alpha=alpha, + strength=strength, + ) + + +def convert_weights(args): + if os.path.isdir(args.source): + src_files = glob.glob(os.path.join(args.source, "*.safetensors"), recursive=True) + elif args.source.endswith((".pth", ".safetensors", "pt")): + src_files = [args.source] + else: + raise ValueError("Invalid input path") + + merged_weights = {} + logger.info(f"Processing source files: {src_files}") + + # Optimize loading for better memory usage + for file_path in tqdm(src_files, desc="Loading weights"): + logger.info(f"Loading weights from: {file_path}") + if file_path.endswith(".pt") or file_path.endswith(".pth"): + weights = torch.load(file_path, map_location=args.device, weights_only=True) + if args.model_type == "hunyuan_dit": + weights = weights["module"] + elif file_path.endswith(".safetensors"): + # Use lazy loading for safetensors to reduce memory usage + with safe_open(file_path, framework="pt") as f: + # Only load tensors when needed (lazy loading) + weights = {} + keys = f.keys() + + # For large files, show progress + if len(keys) > 100: + for k in tqdm(keys, desc=f"Loading {os.path.basename(file_path)}", leave=False): + weights[k] = f.get_tensor(k) + else: + weights = {k: f.get_tensor(k) for k in keys} + + duplicate_keys = set(weights.keys()) & set(merged_weights.keys()) + if duplicate_keys: + raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}") + + # Update weights more efficiently + merged_weights.update(weights) + + # Clear weights dict to free memory + del weights + if len(src_files) > 1: + gc.collect() # Force garbage collection between files + + if args.direction is not None: + rules = get_key_mapping_rules(args.direction, args.model_type) + converted_weights = {} + logger.info("Converting keys...") + + # Pre-compile regex patterns for better performance + compiled_rules = [(re.compile(pattern), replacement) for pattern, replacement in rules] + + def convert_key(key): + """Convert a single key using compiled rules""" + new_key = key + for pattern, replacement in compiled_rules: + new_key = pattern.sub(replacement, new_key) + return new_key + + # Batch convert keys using list comprehension (faster than loop) + keys_list = list(merged_weights.keys()) + + # Use parallel processing for large models + if len(keys_list) > 1000 and args.parallel: + logger.info(f"Using parallel processing for {len(keys_list)} keys") + # Use ThreadPoolExecutor for I/O bound regex operations + num_workers = min(8, multiprocessing.cpu_count()) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + # Submit all conversion tasks + future_to_key = {executor.submit(convert_key, key): key for key in keys_list} + + # Process results as they complete with progress bar + for future in tqdm(as_completed(future_to_key), total=len(keys_list), desc="Converting keys (parallel)"): + original_key = future_to_key[future] + new_key = future.result() + converted_weights[new_key] = merged_weights[original_key] + else: + # For smaller models, use simple loop with less overhead + for key in tqdm(keys_list, desc="Converting keys"): + new_key = convert_key(key) + converted_weights[new_key] = merged_weights[key] + else: + converted_weights = merged_weights + + # Apply LoRA AFTER key conversion to ensure proper key matching + if args.lora_path is not None: + # Handle alpha list - if single alpha, replicate for all LoRAs + if args.lora_alpha is not None: + if len(args.lora_alpha) == 1 and len(args.lora_path) > 1: + args.lora_alpha = args.lora_alpha * len(args.lora_path) + elif len(args.lora_alpha) != len(args.lora_path): + raise ValueError(f"Number of lora_alpha ({len(args.lora_alpha)}) must match number of lora_path ({len(args.lora_path)}) or be 1") + + # Normalize strength list + if args.lora_strength is not None: + if len(args.lora_strength) == 1 and len(args.lora_path) > 1: + args.lora_strength = args.lora_strength * len(args.lora_path) + elif len(args.lora_strength) != len(args.lora_path): + raise ValueError(f"Number of strength ({len(args.lora_strength)}) must match number of lora_path ({len(args.lora_path)}) or be 1") + + # Determine if we should apply key mapping rules to LoRA keys + key_mapping_rules = None + if args.lora_key_convert == "convert" and args.direction is not None: + # Apply same conversion as model + key_mapping_rules = get_key_mapping_rules(args.direction, args.model_type) + logger.info("Applying key conversion to LoRA weights") + elif args.lora_key_convert == "same": + # Don't convert LoRA keys + logger.info("Using original LoRA keys without conversion") + else: # auto + # Auto-detect: if model was converted, try with conversion first + if args.direction is not None: + key_mapping_rules = get_key_mapping_rules(args.direction, args.model_type) + logger.info("Auto mode: will try with key conversion first") + + for idx, path in enumerate(args.lora_path): + # Pass key mapping rules to handle converted keys properly + strength = args.lora_strength[idx] if args.lora_strength is not None else 1.0 + alpha = args.lora_alpha[idx] if args.lora_alpha is not None else None + load_loras(path, converted_weights, alpha, key_mapping_rules, strength=strength) + + if args.quantized: + if args.full_quantized and args.comfyui_mode: + logger.info("Quant all tensors...") + assert args.linear_dtype, f"Error: only support 'torch.int8' and 'torch.float8_e4m3fn'." + for k in converted_weights.keys(): + converted_weights[k] = converted_weights[k].float().to(args.linear_dtype) + else: + converted_weights = quantize_model( + converted_weights, + w_bit=args.bits, + target_keys=args.target_keys, + adapter_keys=args.adapter_keys, + key_idx=args.key_idx, + ignore_key=args.ignore_key, + linear_type=args.linear_type, + non_linear_dtype=args.non_linear_dtype, + comfyui_mode=args.comfyui_mode, + comfyui_keys=args.comfyui_keys, + ) + + os.makedirs(args.output, exist_ok=True) + + if args.output_ext == ".pth": + torch.save(converted_weights, os.path.join(args.output, args.output_name + ".pth")) + + else: + index = {"metadata": {"total_size": 0}, "weight_map": {}} + if args.single_file: + output_filename = f"{args.output_name}.safetensors" + output_path = os.path.join(args.output, output_filename) + logger.info(f"Saving model to single file: {output_path}") + + # For memory efficiency with large models + try: + # If model is very large (over threshold), consider warning + total_size = sum(tensor.numel() * tensor.element_size() for tensor in converted_weights.values()) + total_size_gb = total_size / (1024**3) + + if total_size_gb > 10: # Warn if model is larger than 10GB + logger.warning(f"Model size is {total_size_gb:.2f}GB. This will require significant memory to save as a single file.") + logger.warning("Consider using --save_by_block or default chunked saving for better memory efficiency.") + + # Save the entire model as a single file + st.save_file(converted_weights, output_path) + logger.info(f"Model saved successfully to: {output_path} ({total_size_gb:.2f}GB)") + + except MemoryError: + logger.error("Memory error while saving. The model is too large to save as a single file.") + logger.error("Please use --save_by_block or remove --single_file to use chunked saving.") + raise + except Exception as e: + logger.error(f"Error saving model: {e}") + raise + elif args.save_by_block: + logger.info("Backward conversion: grouping weights by block") + block_groups = defaultdict(dict) + non_block_weights = {} + block_pattern = re.compile(r"blocks\.(\d+)\.") + + for key, tensor in converted_weights.items(): + match = block_pattern.search(key) + if match: + block_idx = match.group(1) + if args.model_type == "wan_animate_dit" and "face_adapter" in key: + block_idx = str(int(block_idx) * 5) + block_groups[block_idx][key] = tensor + else: + non_block_weights[key] = tensor + + for block_idx, weights_dict in tqdm(block_groups.items(), desc="Saving block chunks"): + output_filename = f"block_{block_idx}.safetensors" + output_path = os.path.join(args.output, output_filename) + st.save_file(weights_dict, output_path) + for key in weights_dict: + index["weight_map"][key] = output_filename + index["metadata"]["total_size"] += os.path.getsize(output_path) + + if non_block_weights: + output_filename = f"non_block.safetensors" + output_path = os.path.join(args.output, output_filename) + st.save_file(non_block_weights, output_path) + for key in non_block_weights: + index["weight_map"][key] = output_filename + index["metadata"]["total_size"] += os.path.getsize(output_path) + + else: + chunk_idx = 0 + current_chunk = {} + for idx, (k, v) in tqdm(enumerate(converted_weights.items()), desc="Saving chunks"): + current_chunk[k] = v + if args.chunk_size > 0 and (idx + 1) % args.chunk_size == 0: + output_filename = f"{args.output_name}_part{chunk_idx}.safetensors" + output_path = os.path.join(args.output, output_filename) + logger.info(f"Saving chunk to: {output_path}") + st.save_file(current_chunk, output_path) + for key in current_chunk: + index["weight_map"][key] = output_filename + index["metadata"]["total_size"] += os.path.getsize(output_path) + current_chunk = {} + chunk_idx += 1 + + if current_chunk: + output_filename = f"{args.output_name}_part{chunk_idx}.safetensors" + output_path = os.path.join(args.output, output_filename) + logger.info(f"Saving final chunk to: {output_path}") + st.save_file(current_chunk, output_path) + for key in current_chunk: + index["weight_map"][key] = output_filename + index["metadata"]["total_size"] += os.path.getsize(output_path) + + # Save index file + if not args.single_file: + index_path = os.path.join(args.output, "diffusion_pytorch_model.safetensors.index.json") + with open(index_path, "w", encoding="utf-8") as f: + json.dump(index, f, indent=2) + logger.info(f"Index file written to: {index_path}") + + if os.path.isdir(args.source) and args.copy_no_weight_files: + copy_non_weight_files(args.source, args.output) + + +def copy_non_weight_files(source_dir, target_dir): + ignore_extensions = [".pth", ".pt", ".safetensors", ".index.json"] + + logger.info(f"Start copying non-weighted files and subdirectories...") + + for item in tqdm(os.listdir(source_dir), desc="copy non-weighted file"): + source_item = os.path.join(source_dir, item) + target_item = os.path.join(target_dir, item) + + try: + if os.path.isdir(source_item): + os.makedirs(target_item, exist_ok=True) + copy_non_weight_files(source_item, target_item) + elif os.path.isfile(source_item) and not any(source_item.endswith(ext) for ext in ignore_extensions): + shutil.copy2(source_item, target_item) + logger.debug(f"copy file: {source_item} -> {target_item}") + except Exception as e: + logger.error(f"copy {source_item} : {str(e)}") + + logger.info(f"Non-weight files and subdirectories copied") + + +def main(): + parser = argparse.ArgumentParser(description="Model weight format converter") + parser.add_argument("-s", "--source", required=True, help="Input path (file or directory)") + parser.add_argument("-o_e", "--output_ext", default=".safetensors", choices=[".pth", ".safetensors"]) + parser.add_argument("-o_n", "--output_name", type=str, default="converted", help="Output file name") + parser.add_argument("-o", "--output", required=True, help="Output directory path") + parser.add_argument( + "-d", + "--direction", + choices=[None, "forward", "backward"], + default=None, + help="Conversion direction: forward = 'lightx2v' -> 'Diffusers', backward = reverse", + ) + parser.add_argument( + "-c", + "--chunk-size", + type=int, + default=100, + help="Chunk size for saving (only applies to forward), 0 = no chunking", + ) + parser.add_argument( + "-t", + "--model_type", + choices=["wan_dit", "hunyuan_dit", "wan_t5", "wan_clip", "wan_animate_dit", "qwen_image_dit", "qwen25vl_llm"], + default="wan_dit", + help="Model type", + ) + parser.add_argument("-b", "--save_by_block", action="store_true") + + # Quantization + parser.add_argument("--comfyui_mode", action="store_true") + parser.add_argument("--full_quantized", action="store_true") + parser.add_argument("--quantized", action="store_true") + parser.add_argument("--bits", type=int, default=8, choices=[8], help="Quantization bit width") + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for quantization (cpu/cuda)", + ) + parser.add_argument( + "--linear_type", + type=str, + choices=["int8", "fp8", "nvfp4", "mxfp4", "mxfp6", "mxfp8"], + help="Quant type for linear", + ) + parser.add_argument( + "--non_linear_dtype", + type=str, + default="torch.float32", + choices=["torch.bfloat16", "torch.float16"], + help="Data type for non-linear", + ) + parser.add_argument("--lora_path", type=str, nargs="*", help="Path(s) to LoRA file(s). Can specify multiple paths separated by spaces.") + parser.add_argument( + "--lora_alpha", + type=float, + nargs="*", + default=None, + help="Alpha for LoRA weight scaling, Default non scaling. ", + ) + parser.add_argument( + "--lora_strength", + type=float, + nargs="*", + help="Additional strength factor(s) for LoRA deltas; default 1.0", + ) + parser.add_argument("--copy_no_weight_files", action="store_true") + parser.add_argument("--single_file", action="store_true", help="Save as a single safetensors file instead of chunking (warning: requires loading entire model in memory)") + parser.add_argument( + "--lora_key_convert", + choices=["auto", "same", "convert"], + default="auto", + help="How to handle LoRA key conversion: 'auto' (detect from LoRA), 'same' (use original keys), 'convert' (apply same conversion as model)", + ) + parser.add_argument("--parallel", action="store_true", default=True, help="Use parallel processing for faster conversion (default: True)") + parser.add_argument("--no-parallel", dest="parallel", action="store_false", help="Disable parallel processing") + args = parser.parse_args() + + # Validate conflicting arguments + if args.single_file and args.save_by_block: + parser.error("--single_file and --save_by_block cannot be used together. Choose one saving strategy.") + + if args.single_file and args.chunk_size > 0 and args.chunk_size != 100: + logger.warning("--chunk_size is ignored when using --single_file option.") + + if args.quantized: + args.linear_dtype = dtype_mapping.get(args.linear_type, None) + args.non_linear_dtype = eval(args.non_linear_dtype) + + model_type_keys_map = { + "qwen_image_dit": { + "key_idx": 2, + "target_keys": ["attn", "img_mlp", "txt_mlp", "txt_mod", "img_mod"], + "ignore_key": None, + "comfyui_keys": [ + "time_text_embed.timestep_embedder.linear_1.weight", + "time_text_embed.timestep_embedder.linear_2.weight", + "img_in.weight", + "txt_in.weight", + "norm_out.linear.weight", + "proj_out.weight", + ], + }, + "wan_dit": { + "key_idx": 2, + "target_keys": ["self_attn", "cross_attn", "ffn"], + "ignore_key": ["ca", "audio"], + }, + "wan_animate_dit": {"key_idx": 2, "target_keys": ["self_attn", "cross_attn", "ffn"], "adapter_keys": ["linear1_kv", "linear1_q", "linear2"], "ignore_key": None}, + "hunyuan_dit": { + "key_idx": 2, + "target_keys": [ + "img_mod", + "img_attn_q", + "img_attn_k", + "img_attn_v", + "img_attn_proj", + "img_mlp", + "txt_mod", + "txt_attn_q", + "txt_attn_k", + "txt_attn_v", + "txt_attn_proj", + "txt_mlp", + ], + "ignore_key": None, + }, + "wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None}, + "wan_clip": { + "key_idx": 3, + "target_keys": ["attn", "mlp"], + "ignore_key": ["textual"], + }, + "qwen25vl_llm": { + "key_idx": 3, + "target_keys": ["self_attn", "mlp"], + "ignore_key": ["visual"], + }, + } + + args.target_keys = model_type_keys_map[args.model_type]["target_keys"] + args.adapter_keys = model_type_keys_map[args.model_type]["adapter_keys"] if "adapter_keys" in model_type_keys_map[args.model_type] else None + args.key_idx = model_type_keys_map[args.model_type]["key_idx"] + args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"] + args.comfyui_keys = model_type_keys_map[args.model_type]["comfyui_keys"] if "comfyui_keys" in model_type_keys_map[args.model_type] else None + + if os.path.isfile(args.output): + raise ValueError("Output path must be a directory, not a file") + + logger.info("Starting model weight conversion...") + convert_weights(args) + logger.info(f"Conversion completed! Files saved to: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/tools/convert/lora_loader.py b/tools/convert/lora_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..61a61f852fefaf6af6bf80d2eb916d1c3f328f07 --- /dev/null +++ b/tools/convert/lora_loader.py @@ -0,0 +1,448 @@ +""" +LoRA (Low-Rank Adaptation) loader with support for multiple format patterns. + +Supported formats: +- Standard: {key}.lora_up.weight and {key}.lora_down.weight +- Diffusers: {key}_lora.up.weight and {key}_lora.down.weight +- Diffusers v2: {key}.lora_B.weight and {key}.lora_A.weight (B=up, A=down) +- Diffusers v3: {key}.lora.up.weight and {key}.lora.down.weight +- Mochi: {key}.lora_B and {key}.lora_A (no .weight suffix) +- Transformers: {key}.lora_linear_layer.up.weight and {key}.lora_linear_layer.down.weight +- Qwen: {key}.lora_B.default.weight and {key}.lora_A.default.weight +""" + +import re +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import torch +from loguru import logger + + +class LoRAFormat(Enum): + """Enum for different LoRA format patterns.""" + + STANDARD = "standard" + DIFFUSERS = "diffusers" + DIFFUSERS_V2 = "diffusers_v2" + DIFFUSERS_V3 = "diffusers_v3" + MOCHI = "mochi" + TRANSFORMERS = "transformers" + QWEN = "qwen" + + +class LoRAPatternDefinition: + """Defines a single LoRA format pattern and how to extract its components.""" + + def __init__( + self, + format_name: LoRAFormat, + up_suffix: str, + down_suffix: str, + has_weight_suffix: bool = True, + mid_suffix: Optional[str] = None, + ): + """ + Args: + format_name: The LoRA format type + up_suffix: Suffix for the up (B) weight matrix (e.g., ".lora_up.weight") + down_suffix: Suffix for the down (A) weight matrix (e.g., ".lora_down.weight") + has_weight_suffix: Whether the format includes .weight suffix + mid_suffix: Optional suffix for mid weight (only used in standard format) + """ + self.format_name = format_name + self.up_suffix = up_suffix + self.down_suffix = down_suffix + self.has_weight_suffix = has_weight_suffix + self.mid_suffix = mid_suffix + + def get_base_key(self, key: str, detected_suffix: str) -> Optional[str]: + """Extract base key by removing the detected suffix.""" + if key.endswith(detected_suffix): + return key[: -len(detected_suffix)] + return None + + +class LoRAPatternMatcher: + """Detects and matches LoRA format patterns in state dicts.""" + + def __init__(self): + """Initialize the pattern matcher with all supported formats.""" + self.patterns: Dict[LoRAFormat, LoRAPatternDefinition] = { + LoRAFormat.STANDARD: LoRAPatternDefinition( + LoRAFormat.STANDARD, + up_suffix=".lora_up.weight", + down_suffix=".lora_down.weight", + mid_suffix=".lora_mid.weight", + ), + LoRAFormat.DIFFUSERS: LoRAPatternDefinition( + LoRAFormat.DIFFUSERS, + up_suffix="_lora.up.weight", + down_suffix="_lora.down.weight", + ), + LoRAFormat.DIFFUSERS_V2: LoRAPatternDefinition( + LoRAFormat.DIFFUSERS_V2, + up_suffix=".lora_B.weight", + down_suffix=".lora_A.weight", + ), + LoRAFormat.DIFFUSERS_V3: LoRAPatternDefinition( + LoRAFormat.DIFFUSERS_V3, + up_suffix=".lora.up.weight", + down_suffix=".lora.down.weight", + ), + LoRAFormat.MOCHI: LoRAPatternDefinition( + LoRAFormat.MOCHI, + up_suffix=".lora_B", + down_suffix=".lora_A", + has_weight_suffix=False, + ), + LoRAFormat.TRANSFORMERS: LoRAPatternDefinition( + LoRAFormat.TRANSFORMERS, + up_suffix=".lora_linear_layer.up.weight", + down_suffix=".lora_linear_layer.down.weight", + ), + LoRAFormat.QWEN: LoRAPatternDefinition( + LoRAFormat.QWEN, + up_suffix=".lora_B.default.weight", + down_suffix=".lora_A.default.weight", + ), + } + + def detect_format(self, key: str, lora_weights: Dict) -> Optional[Tuple[LoRAFormat, str]]: + """ + Detect the LoRA format of a given key. + + Args: + key: The weight key to check + lora_weights: The full LoRA weights dictionary + + Returns: + Tuple of (LoRAFormat, detected_suffix) if format detected, None otherwise + """ + for format_type, pattern in self.patterns.items(): + if key.endswith(pattern.up_suffix): + return (format_type, pattern.up_suffix) + return None + + def extract_lora_pair( + self, + key: str, + lora_weights: Dict, + lora_alphas: Dict, + ) -> Optional[Dict]: + """ + Extract a complete LoRA pair (up and down weights) from the state dict. + + Args: + key: The up weight key + lora_weights: The full LoRA weights dictionary + lora_alphas: Dictionary of alpha values by base key + + Returns: + Dictionary with extracted LoRA information, or None if pair is incomplete + """ + format_detected = self.detect_format(key, lora_weights) + if format_detected is None: + return None + + format_type, up_suffix = format_detected + pattern = self.patterns[format_type] + + # Extract base key + base_key = pattern.get_base_key(key, up_suffix) + if base_key is None: + return None + + # Check if down weight exists + down_key = base_key + pattern.down_suffix + if down_key not in lora_weights: + return None + + # Check for mid weight (only for standard format) + mid_key = None + if pattern.mid_suffix: + mid_key = base_key + pattern.mid_suffix + if mid_key not in lora_weights: + mid_key = None + + # Get alpha value + alpha = lora_alphas.get(base_key, None) + + return { + "format": format_type, + "base_key": base_key, + "up_key": key, + "down_key": down_key, + "mid_key": mid_key, + "alpha": alpha, + } + + +class LoRALoader: + """Loads and applies LoRA weights to model weights using pattern matching.""" + + def __init__(self, key_mapping_rules: Optional[List[Tuple[str, str]]] = None): + """ + Args: + key_mapping_rules: Optional list of (pattern, replacement) regex rules for key mapping + """ + self.pattern_matcher = LoRAPatternMatcher() + self.key_mapping_rules = key_mapping_rules or [] + self._compile_rules() + + def _compile_rules(self): + """Pre-compile regex patterns for better performance.""" + self.compiled_rules = [(re.compile(pattern), replacement) for pattern, replacement in self.key_mapping_rules] + + def _apply_key_mapping(self, key: str) -> str: + """Apply key mapping rules to a key.""" + for pattern, replacement in self.compiled_rules: + key = pattern.sub(replacement, key) + return key + + def _get_model_key( + self, + lora_key: str, + base_key: str, + suffix_to_remove: str, + suffix_to_add: str = ".weight", + ) -> Optional[str]: + """ + Extract the model weight key from LoRA key with proper prefix handling. + + Args: + lora_key: The original LoRA key + base_key: The base key after removing LoRA suffix + suffix_to_remove: The suffix that was removed + suffix_to_add: The suffix to add for model key + + Returns: + The model key, or None if extraction fails + """ + # For Qwen models, keep transformer_blocks prefix + if base_key.startswith("transformer_blocks.") and len(base_key.split(".")) > 1: + if base_key.split(".")[1].isdigit(): + # Keep the full path for Qwen models + model_key = base_key + suffix_to_add + else: + # Remove common prefixes for other models + model_key = self._remove_prefixes(base_key) + suffix_to_add + else: + # Remove common prefixes for other models + model_key = self._remove_prefixes(base_key) + suffix_to_add + + # Apply key mapping rules if provided + if self.compiled_rules: + model_key = self._apply_key_mapping(model_key) + + return model_key + + @staticmethod + def _remove_prefixes(key: str) -> str: + """Remove common model prefixes from a key.""" + prefixes_to_remove = ["diffusion_model.", "model.", "unet."] + for prefix in prefixes_to_remove: + if key.startswith(prefix): + return key[len(prefix) :] + return key + + def extract_lora_alphas(self, lora_weights: Dict) -> Dict: + """Extract LoRA alpha values from the state dict.""" + lora_alphas = {} + for key in lora_weights.keys(): + if key.endswith(".alpha"): + base_key = key[:-6] # Remove .alpha + lora_alphas[base_key] = lora_weights[key].item() + return lora_alphas + + def extract_lora_pairs(self, lora_weights: Dict) -> Dict[str, Dict]: + """ + Extract all LoRA pairs from the state dict, mapping to model keys. + + Args: + lora_weights: The LoRA state dictionary + + Returns: + Dictionary mapping model keys to LoRA pair information + """ + lora_alphas = self.extract_lora_alphas(lora_weights) + lora_pairs = {} + + for key in lora_weights.keys(): + # Skip alpha parameters + if key.endswith(".alpha"): + continue + + # Try to extract LoRA pair + pair_info = self.pattern_matcher.extract_lora_pair(key, lora_weights, lora_alphas) + if pair_info is None: + continue + + # Determine the suffix to remove and add based on format + format_type = pair_info["format"] + pattern = self.pattern_matcher.patterns[format_type] + + # Get the model key + model_key = self._get_model_key( + pair_info["up_key"], + pair_info["base_key"], + pattern.up_suffix, + ".weight", + ) + + if model_key is None: + logger.warning(f"Failed to extract model key from LoRA key: {key}") + continue + + lora_pairs[model_key] = pair_info + + return lora_pairs + + def extract_lora_diffs(self, lora_weights: Dict) -> Dict[str, Dict]: + """ + Extract diff-style LoRA weights (direct addition, not matrix multiplication). + + Args: + lora_weights: The LoRA state dictionary + + Returns: + Dictionary mapping model keys to diff information + """ + lora_diffs = {} + + # Define diff patterns: (suffix_to_check, suffix_to_remove, suffix_to_add) + diff_patterns = [ + (".diff", ".diff", ".weight"), + (".diff_b", ".diff_b", ".bias"), + (".diff_m", ".diff_m", ".modulation"), + ] + + for key in lora_weights.keys(): + for check_suffix, remove_suffix, add_suffix in diff_patterns: + if key.endswith(check_suffix): + base_key = key[: -len(remove_suffix)] + model_key = self._get_model_key(key, base_key, remove_suffix, add_suffix) + + if model_key: + lora_diffs[model_key] = { + "diff_key": key, + "type": check_suffix, + } + break + + return lora_diffs + + def apply_lora( + self, + weight_dict: Dict[str, torch.Tensor], + lora_weights: Dict[str, torch.Tensor], + alpha: float = None, + strength: float = 1.0, + ) -> int: + """ + Apply LoRA weights to model weights. + + Args: + weight_dict: The model weights dictionary (will be modified in place) + lora_weights: The LoRA weights dictionary + alpha: Global alpha scaling factor + strength: Additional strength factor for LoRA deltas + + Returns: + Number of LoRA weights successfully applied + """ + # Extract LoRA pairs, diffs, and alphas + lora_pairs = self.extract_lora_pairs(lora_weights) + lora_diffs = self.extract_lora_diffs(lora_weights) + + applied_count = 0 + used_lora_keys = set() + + # Apply LoRA pairs (matrix multiplication) + for model_key, pair_info in lora_pairs.items(): + if model_key not in weight_dict: + logger.debug(f"Model key not found: {model_key}") + continue + + param = weight_dict[model_key] + up_key = pair_info["up_key"] + down_key = pair_info["down_key"] + + # Track used keys + used_lora_keys.add(up_key) + used_lora_keys.add(down_key) + if pair_info["mid_key"]: + used_lora_keys.add(pair_info["mid_key"]) + + try: + lora_up = lora_weights[up_key].to(param.device, param.dtype) + lora_down = lora_weights[down_key].to(param.device, param.dtype) + + # Get LoRA-specific alpha if available, otherwise use global alpha + # Apply LoRA: W' = W + (alpha/rank) * B @ A + # where B = up (out_features, rank), A = down (rank, in_features) + if pair_info["alpha"]: + lora_scale = pair_info["alpha"] / lora_down.shape[0] + elif alpha is not None: + lora_scale = alpha / lora_down.shape[0] + else: + lora_scale = 1 + + if len(lora_down.shape) == 2 and len(lora_up.shape) == 2: + lora_delta = torch.mm(lora_up, lora_down) * lora_scale + if strength is not None: + lora_delta = lora_delta * float(strength) + + param.data += lora_delta + applied_count += 1 + logger.debug(f"Applied LoRA to {model_key} with lora_scale={lora_scale}") + else: + logger.warning(f"Unexpected LoRA shape for {model_key}: down={lora_down.shape}, up={lora_up.shape}") + + except Exception as e: + logger.warning(f"Failed to apply LoRA pair for {model_key}: {e}") + logger.warning(f" Shapes - param: {param.shape}, down: {lora_weights[down_key].shape}, up: {lora_weights[up_key].shape}") + + # Apply diff weights (direct addition) + for model_key, diff_info in lora_diffs.items(): + if model_key not in weight_dict: + logger.debug(f"Model key not found for diff: {model_key}") + continue + + param = weight_dict[model_key] + diff_key = diff_info["diff_key"] + + # Track used keys + used_lora_keys.add(diff_key) + + try: + lora_diff = lora_weights[diff_key].to(param.device, param.dtype) + if alpha is not None: + param.data += lora_diff * alpha * (float(strength) if strength is not None else 1.0) + else: + param.data += lora_diff * (float(strength) if strength is not None else 1.0) + applied_count += 1 + logger.debug(f"Applied LoRA diff to {model_key} (type: {diff_info['type']})") + except Exception as e: + logger.warning(f"Failed to apply LoRA diff for {model_key}: {e}") + + # Warn about unused keys + all_lora_keys = set(k for k in lora_weights.keys() if not k.endswith(".alpha")) + unused_lora_keys = all_lora_keys - used_lora_keys + + if unused_lora_keys: + logger.warning(f"Found {len(unused_lora_keys)} unused LoRA weights - this may indicate key mismatch:") + for key in list(unused_lora_keys)[:10]: # Show first 10 + logger.warning(f" Unused: {key}") + if len(unused_lora_keys) > 10: + logger.warning(f" ... and {len(unused_lora_keys) - 10} more") + + logger.info(f"Applied {applied_count} LoRA weight adjustments out of {len(lora_pairs) + len(lora_diffs)} possible") + + if applied_count == 0 and (lora_pairs or lora_diffs): + logger.error("No LoRA weights were applied! Check for key name mismatches.") + logger.info("Model weight keys sample: " + str(list(weight_dict.keys())[:5])) + logger.info("LoRA pairs keys sample: " + str(list(lora_pairs.keys())[:5])) + logger.info("LoRA diffs keys sample: " + str(list(lora_diffs.keys())[:5])) + + return applied_count diff --git a/tools/convert/quant/__init__.py b/tools/convert/quant/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79b37f98f422125cd56ba051c3ddd088cf29ee98 --- /dev/null +++ b/tools/convert/quant/__init__.py @@ -0,0 +1 @@ +from .quant import * diff --git a/tools/convert/quant/quant.py b/tools/convert/quant/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a6c93927b56ba8cc974f7572d92b27729246a1 --- /dev/null +++ b/tools/convert/quant/quant.py @@ -0,0 +1,141 @@ +from abc import ABCMeta + +import torch +# from qtorch.quant import float_quantize + +from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER + +try: + from lightx2v_kernel.gemm import scaled_mxfp4_quant, scaled_mxfp6_quant, scaled_mxfp8_quant, scaled_nvfp4_quant +except ImportError: + pass + + +class QuantTemplate(metaclass=ABCMeta): + def __init__(self, weight): + if weight.dim() != 2: + raise ValueError(f"Only 2D tensors supported. Got {weight.dim()}D tensor") + if torch.isnan(weight).any(): + raise ValueError("Tensor contains NaN values") + + self.weight_quant_func = None + self.extra = {} + + +@CONVERT_WEIGHT_REGISTER("int8") +class QuantWeightINT8(QuantTemplate): + def __init__(self, weight): + super().__init__(weight) + self.weight_quant_func = self.load_int8_weight + + @torch.no_grad() + def load_int8_weight(self, w, comfyui_mode=False): + org_w_shape = w.shape + if not comfyui_mode: + max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) + else: + max_val = w.abs().max() + qmin, qmax = -128, 127 + scales = max_val / qmax + w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8) + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w_q).sum() == 0 + + if not comfyui_mode: + scales = scales.view(org_w_shape[0], -1) + w_q = w_q.reshape(org_w_shape) + + return w_q, scales, self.extra + + +@CONVERT_WEIGHT_REGISTER("fp8") +class QuantWeightFP8(QuantTemplate): + def __init__(self, weight): + super().__init__(weight) + self.weight_quant_func = self.load_fp8_weight + + @torch.no_grad() + def load_fp8_weight(self, w, comfyui_mode=False): + org_w_shape = w.shape + if not comfyui_mode: + max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) + else: + max_val = w.abs().max() + finfo = torch.finfo(torch.float8_e4m3fn) + qmin, qmax = finfo.min, finfo.max + scales = max_val / qmax + scaled_tensor = w / scales + scaled_tensor = torch.clip(scaled_tensor, qmin, qmax) + w_q = float_quantize(scaled_tensor.float(), 4, 3, rounding="nearest").to(torch.float8_e4m3fn) + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w_q).sum() == 0 + + if not comfyui_mode: + scales = scales.view(org_w_shape[0], -1) + w_q = w_q.reshape(org_w_shape) + + return w_q, scales, self.extra + + +@CONVERT_WEIGHT_REGISTER("mxfp4") +class QuantWeightMxFP4(QuantTemplate): + def __init__(self, weight): + super().__init__(weight) + self.weight_quant_func = self.load_mxfp4_weight + + @torch.no_grad() + def load_mxfp4_weight(self, w, comfyui_mode=False): + device = w.device + w = w.cuda().to(torch.bfloat16) + w_q, scales = scaled_mxfp4_quant(w) + w_q, scales = w_q.to(device), scales.to(device) + return w_q, scales, self.extra + + +@CONVERT_WEIGHT_REGISTER("mxfp6") +class QuantWeightMxFP6(QuantTemplate): + def __init__(self, weight): + super().__init__(weight) + self.weight_quant_func = self.load_mxfp6_weight + + @torch.no_grad() + def load_mxfp6_weight(self, w, comfyui_mode=False): + device = w.device + w = w.cuda().to(torch.bfloat16) + w_q, scales = scaled_mxfp6_quant(w) + w_q, scales = w_q.to(device), scales.to(device) + return w_q, scales, self.extra + + +@CONVERT_WEIGHT_REGISTER("mxfp8") +class QuantWeightMxFP8(QuantTemplate): + def __init__(self, weight): + super().__init__(weight) + self.weight_quant_func = self.load_mxfp8_weight + + @torch.no_grad() + def load_mxfp8_weight(self, w, comfyui_mode=False): + device = w.device + w = w.cuda().to(torch.bfloat16) + w_q, scales = scaled_mxfp8_quant(w) + w_q, scales = w_q.to(device), scales.to(device) + return w_q, scales, self.extra + + +@CONVERT_WEIGHT_REGISTER("nvfp4") +class QuantWeightNVFP4(QuantTemplate): + def __init__(self, weight): + super().__init__(weight) + self.weight_quant_func = self.load_fp4_weight + + @torch.no_grad() + def load_fp4_weight(self, w, comfyui_mode=False): + device = w.device + w = w.cuda().to(torch.bfloat16) + weight_global_scale = (2688.0 / torch.max(torch.abs(w))).to(torch.float32) + w_q, scales = scaled_nvfp4_quant(w, weight_global_scale) + w_q, scales = w_q.to(device), scales.to(device) + self.extra["weight_global_scale"] = weight_global_scale.to(device) + return w_q, scales, self.extra diff --git a/tools/convert/quant_adapter.py b/tools/convert/quant_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..835591918fc2604153ad5ef89bbcfebf8c05d3dc --- /dev/null +++ b/tools/convert/quant_adapter.py @@ -0,0 +1,75 @@ +import argparse +import sys +from pathlib import Path + +import safetensors +import torch +from safetensors.torch import save_file + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from lightx2v.utils.quant_utils import FloatQuantizer +from tools.convert.quant import * + + +def main(): + # 获取脚本所在目录 + script_dir = Path(__file__).parent + project_root = script_dir.parent.parent + + parser = argparse.ArgumentParser(description="Quantize audio adapter model to FP8") + parser.add_argument( + "--model_path", + type=str, + default=str(project_root / "models" / "SekoTalk-Distill" / "audio_adapter_model.safetensors"), + help="Path to input model file", + ) + parser.add_argument( + "--output_path", + type=str, + default=str(project_root / "models" / "SekoTalk-Distill-fp8" / "audio_adapter_model_fp8.safetensors"), + help="Path to output quantized model file", + ) + args = parser.parse_args() + + model_path = Path(args.model_path) + output_path = Path(args.output_path) + + output_path.parent.mkdir(parents=True, exist_ok=True) + + state_dict = {} + with safetensors.safe_open(model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + + new_state_dict = {} + + for key in state_dict.keys(): + if key.startswith("ca") and ".to" in key and "weight" in key: + print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}") + + ## fp8 + weight = state_dict[key].to(torch.float32).cuda() + w_quantizer = FloatQuantizer("e4m3", True, "per_channel") + weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight) + weight = weight.to(torch.float8_e4m3fn) + weight_scale = weight_scale.to(torch.float32) + + ## QuantWeightMxFP4, QuantWeightMxFP6, QuantWeightMxFP8 for mxfp4,mxfp6,mxfp8 + # weight = state_dict[key].to(torch.bfloat16).cuda() + # quantizer = QuantWeightMxFP4(weight) + # weight, weight_scale, _ = quantizer.weight_quant_func(weight) + + new_state_dict[key] = weight.cpu() + new_state_dict[key + "_scale"] = weight_scale.cpu() + else: + # 不匹配的权重转换为BF16 + print(f"Converting {key} to BF16, dtype: {state_dict[key].dtype}") + new_state_dict[key] = state_dict[key].to(torch.bfloat16) + + save_file(new_state_dict, str(output_path)) + print(f"Quantized model saved to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/convert/readme.md b/tools/convert/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..16f03c25938a67eecaa12c11b1fbd0e93c12351c --- /dev/null +++ b/tools/convert/readme.md @@ -0,0 +1,445 @@ +# Model Conversion Tool + +A powerful model weight conversion tool that supports format conversion, quantization, LoRA merging, and more. + +## Main Features + +- **Format Conversion**: Support PyTorch (.pth) and SafeTensors (.safetensors) format conversion +- **Model Quantization**: Support INT8, FP8, NVFP4, MXFP4, MXFP6 and MXFP8 quantization to significantly reduce model size +- **Architecture Conversion**: Support conversion between LightX2V and Diffusers architectures +- **LoRA Merging**: Support loading and merging multiple LoRA formats +- **Multi-Model Support**: Support Wan DiT, Qwen Image DiT, T5, CLIP, etc. +- **Flexible Saving**: Support single file, block-based, and chunked saving methods +- **Parallel Processing**: Support parallel acceleration for large model conversion + +## Supported Model Types + +- `hunyuan_dit`: hunyuan DiT 1.5 models +- `wan_dit`: Wan DiT series models (default) +- `wan_animate_dit`: Wan Animate DiT models +- `qwen_image_dit`: Qwen Image DiT models +- `wan_t5`: Wan T5 text encoder +- `wan_clip`: Wan CLIP vision encoder + +## Core Parameters + +### Basic Parameters + +- `-s, --source`: Input path (file or directory) +- `-o, --output`: Output directory path +- `-o_e, --output_ext`: Output format, `.pth` or `.safetensors` (default) +- `-o_n, --output_name`: Output file name (default: `converted`) +- `-t, --model_type`: Model type (default: `wan_dit`) + +### Architecture Conversion Parameters + +- `-d, --direction`: Conversion direction + - `None`: No architecture conversion (default) + - `forward`: LightX2V → Diffusers + - `backward`: Diffusers → LightX2V + +### Quantization Parameters + +- `--quantized`: Enable quantization +- `--bits`: Quantization bit width, currently only supports 8-bit +- `--linear_type`: Linear layer quantization type + - `int8`: INT8 quantization (torch.int8) + - `fp8`: FP8 quantization (torch.float8_e4m3fn) + - `nvfp4`: NVFP4 quantization + - `mxfp4`: MXFP4 quantization + - `mxfp6`: MXFP6 quantization + - `mxfp8`: MXFP8 quantization +- `--non_linear_dtype`: Non-linear layer data type + - `torch.bfloat16`: BF16 + - `torch.float16`: FP16 + - `torch.float32`: FP32 (default) +- `--device`: Device for quantization, `cpu` or `cuda` (default) +- `--comfyui_mode`: ComfyUI compatible mode (only int8 and fp8) +- `--full_quantized`: Full quantization mode (effective in ComfyUI mode) +For nvfp4, mxfp4, mxfp6 and mxfp8, please install them fllowing LightX2V/lightx2v_kernel/README.md. + +### LoRA Parameters + +- `--lora_path`: LoRA file path(s), supports multiple (separated by spaces) +- `--lora_strength`: LoRA strength coefficients, supports multiple (default: 1.0) +- `--alpha`: LoRA alpha parameters, supports multiple +- `--lora_key_convert`: LoRA key conversion mode + - `auto`: Auto-detect (default) + - `same`: Use original key names + - `convert`: Apply same conversion as model + +### Saving Parameters + +- `--single_file`: Save as single file (note: large models consume significant memory) +- `-b, --save_by_block`: Save by blocks (recommended for backward conversion) +- `-c, --chunk-size`: Chunk size (default: 100, 0 means no chunking) +- `--copy_no_weight_files`: Copy non-weight files from source directory + +### Performance Parameters + +- `--parallel`: Enable parallel processing (default: True) +- `--no-parallel`: Disable parallel processing + +## Supported LoRA Formats + +The tool automatically detects and supports the following LoRA formats: + +1. **Standard**: `{key}.lora_up.weight` and `{key}.lora_down.weight` +2. **Diffusers**: `{key}_lora.up.weight` and `{key}_lora.down.weight` +3. **Diffusers V2**: `{key}.lora_B.weight` and `{key}.lora_A.weight` +4. **Diffusers V3**: `{key}.lora.up.weight` and `{key}.lora.down.weight` +5. **Mochi**: `{key}.lora_B` and `{key}.lora_A` (no .weight suffix) +6. **Transformers**: `{key}.lora_linear_layer.up.weight` and `{key}.lora_linear_layer.down.weight` +7. **Qwen**: `{key}.lora_B.default.weight` and `{key}.lora_A.default.weight` + +Additionally supports diff formats: +- `.diff`: Weight diff +- `.diff_b`: Bias diff +- `.diff_m`: Modulation diff + +## Usage Examples + +### 1. Model Quantization + +#### 1.1 Wan DiT Quantization to INT8 + +**Multiple safetensors, saved by dit blocks** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan_int8 \ + --linear_type int8 \ + --model_type wan_dit \ + --quantized \ + --save_by_block +``` + +**Single safetensor file** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan2.1_i2v_480p_int8_lightx2v \ + --linear_type int8 \ + --model_type wan_dit \ + --quantized \ + --single_file +``` + +#### 1.2 Wan DiT Quantization to FP8 + +**Multiple safetensors, saved by dit blocks** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan_fp8 \ + --linear_type fp8 \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_dit \ + --quantized \ + --save_by_block +``` + +**Single safetensor file** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v \ + --linear_type fp8 \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_dit \ + --quantized \ + --single_file +``` + +**ComfyUI scaled_fp8 format** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \ + --linear_type fp8 \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_dit \ + --quantized \ + --single_file \ + --comfyui_mode +``` + +**ComfyUI full FP8 format** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \ + --linear_type fp8 \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_dit \ + --quantized \ + --single_file \ + --comfyui_mode \ + --full_quantized +``` + +> **Tip**: For other DIT models, simply switch the `--model_type` parameter + +#### 1.3 T5 Encoder Quantization + +**INT8 Quantization** +```bash +python converter.py \ + --source /path/to/models_t5_umt5-xxl-enc-bf16.pth \ + --output /path/to/output \ + --output_ext .pth \ + --output_name models_t5_umt5-xxl-enc-int8 \ + --linear_type int8 \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_t5 \ + --quantized +``` + +**FP8 Quantization** +```bash +python converter.py \ + --source /path/to/models_t5_umt5-xxl-enc-bf16.pth \ + --output /path/to/output \ + --output_ext .pth \ + --output_name models_t5_umt5-xxl-enc-fp8 \ + --linear_type fp8 \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_t5 \ + --quantized +``` + +#### 1.4 CLIP Encoder Quantization + +**INT8 Quantization** +```bash +python converter.py \ + --source /path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ + --output /path/to/output \ + --output_ext .pth \ + --output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8 \ + --linear_type int8 \ + --non_linear_dtype torch.float16 \ + --model_type wan_clip \ + --quantized +``` + +**FP8 Quantization** +```bash +python converter.py \ + --source /path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ + --output /path/to/output \ + --output_ext .pth \ + --output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8 \ + --linear_type fp8 \ + --non_linear_dtype torch.float16 \ + --model_type wan_clip \ + --quantized +``` + + + +#### 1.5 Qwen25_vl llm Quantization + +**INT8 Quantization** +```bash +python converter.py \ + --source /path/to/hunyuanvideo-1.5/text_encoder/llm \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name qwen25vl-llm-int8 \ + --linear_dtype torch.int8 \ + --non_linear_dtype torch.float16 \ + --model_type qwen25vl_llm \ + --quantized \ + --single_file +``` + +**FP8 Quantization** +```bash +python converter.py \ + --source /path/to/hunyuanvideo-1.5/text_encoder/llm \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name qwen25vl-llm-fp8 \ + --linear_dtype torch.float8_e4m3fn \ + --non_linear_dtype torch.float16 \ + --model_type qwen25vl_llm \ + --quantized \ + --single_file +``` + +### 2. LoRA Merging + +#### 2.1 Merge Single LoRA + +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_model \ + --model_type wan_dit \ + --lora_path /path/to/lora.safetensors \ + --lora_strength 1.0 \ + --single_file +``` + +#### 2.2 Merge Multiple LoRAs + +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_model \ + --model_type wan_dit \ + --lora_path /path/to/lora1.safetensors /path/to/lora2.safetensors \ + --lora_strength 1.0 0.8 \ + --single_file +``` + +#### 2.3 LoRA Merging with Quantization + +**LoRA Merge → FP8 Quantization** +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_quantized \ + --model_type wan_dit \ + --lora_path /path/to/lora.safetensors \ + --lora_strength 1.0 \ + --quantized \ + --linear_type fp8 \ + --single_file +``` + +**LoRA Merge → ComfyUI scaled_fp8** +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_quantized \ + --model_type wan_dit \ + --lora_path /path/to/lora.safetensors \ + --lora_strength 1.0 \ + --quantized \ + --linear_type fp8 \ + --single_file \ + --comfyui_mode +``` + +**LoRA Merge → ComfyUI Full FP8** +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_quantized \ + --model_type wan_dit \ + --lora_path /path/to/lora.safetensors \ + --lora_strength 1.0 \ + --quantized \ + --linear_type fp8 \ + --single_file \ + --comfyui_mode \ + --full_quantized +``` + +#### 2.4 LoRA Key Conversion Modes + +**Auto-detect mode (recommended)** +```bash +python converter.py \ + --source /path/to/model/ \ + --output /path/to/output \ + --lora_path /path/to/lora.safetensors \ + --lora_key_convert auto \ + --single_file +``` + +**Use original key names (LoRA already in target format)** +```bash +python converter.py \ + --source /path/to/model/ \ + --output /path/to/output \ + --direction forward \ + --lora_path /path/to/lora.safetensors \ + --lora_key_convert same \ + --single_file +``` + +**Apply conversion (LoRA in source format)** +```bash +python converter.py \ + --source /path/to/model/ \ + --output /path/to/output \ + --direction forward \ + --lora_path /path/to/lora.safetensors \ + --lora_key_convert convert \ + --single_file +``` + +### 3. Architecture Format Conversion + +#### 3.1 LightX2V → Diffusers + +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P \ + --output /path/to/Wan2.1-I2V-14B-480P-Diffusers \ + --output_ext .safetensors \ + --model_type wan_dit \ + --direction forward \ + --chunk-size 100 +``` + +#### 3.2 Diffusers → LightX2V + +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P-Diffusers \ + --output /path/to/Wan2.1-I2V-14B-480P \ + --output_ext .safetensors \ + --model_type wan_dit \ + --direction backward \ + --save_by_block +``` + +### 4. Format Conversion + +#### 4.1 .pth → .safetensors + +```bash +python converter.py \ + --source /path/to/model.pth \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name model \ + --single_file +``` + +#### 4.2 Multiple .safetensors → Single File + +```bash +python converter.py \ + --source /path/to/model_directory/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_model \ + --single_file +``` diff --git a/tools/convert/readme_zh.md b/tools/convert/readme_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..6c5d0d07121975411fdbbc5185d93de0220007e8 --- /dev/null +++ b/tools/convert/readme_zh.md @@ -0,0 +1,438 @@ +# 模型转换工具 + +这是一个功能强大的模型权重转换工具,支持格式转换、量化、LoRA融合等多种功能。 + +## 主要特性 + +- **格式转换**: 支持 PyTorch (.pth) 和 SafeTensors (.safetensors) 格式互转 +- **模型量化**: 支持 INT8 和 FP8 量化,显著减小模型体积 +- **架构转换**: 支持 LightX2V 和 Diffusers 架构互转 +- **LoRA 融合**: 支持多种 LoRA 格式的加载和融合 +- **多模型支持**: 支持 Wan DiT、Qwen Image DiT、T5、CLIP 等 +- **灵活保存**: 支持单文件、按块、分块等多种保存方式 +- **并行处理**: 大模型转换支持并行加速 + +## 支持的模型类型 + +- `hunyuan_dit`: hunyuan DiT 1.5模型 +- `wan_dit`: Wan DiT 系列模型(默认) +- `wan_animate_dit`: Wan Animate DiT 模型 +- `qwen_image_dit`: Qwen Image DiT 模型 +- `wan_t5`: Wan T5 文本编码器 +- `wan_clip`: Wan CLIP 视觉编码器 + +## 核心参数说明 + +### 基础参数 + +- `-s, --source`: 输入路径(文件或目录) +- `-o, --output`: 输出目录路径 +- `-o_e, --output_ext`: 输出格式,可选 `.pth` 或 `.safetensors`(默认) +- `-o_n, --output_name`: 输出文件名(默认: `converted`) +- `-t, --model_type`: 模型类型(默认: `wan_dit`) + +### 架构转换参数 + +- `-d, --direction`: 转换方向 + - `None`: 不进行架构转换(默认) + - `forward`: LightX2V → Diffusers + - `backward`: Diffusers → LightX2V + +### 量化参数 + +- `--quantized`: 启用量化 +- `--bits`: 量化位宽,当前仅支持 8 位 +- `--linear_dtype`: 线性层量化类型 + - `torch.int8`: INT8 量化 + - `torch.float8_e4m3fn`: FP8 量化 +- `--non_linear_dtype`: 非线性层数据类型 + - `torch.bfloat16`: BF16 + - `torch.float16`: FP16 + - `torch.float32`: FP32(默认) +- `--device`: 量化使用的设备,可选 `cpu` 或 `cuda`(默认) +- `--comfyui_mode`: ComfyUI 兼容模式 +- `--full_quantized`: 全量化模式(ComfyUI 模式下有效) + +### LoRA 参数 + +- `--lora_path`: LoRA 文件路径,支持多个(用空格分隔) +- `--lora_strength`: LoRA 强度系数,支持多个(默认: 1.0) +- `--alpha`: LoRA alpha 参数,支持多个 +- `--lora_key_convert`: LoRA 键转换模式 + - `auto`: 自动检测(默认) + - `same`: 使用原始键名 + - `convert`: 应用与模型相同的转换 + +### 保存参数 + +- `--single_file`: 保存为单个文件(注意: 大模型会消耗大量内存) +- `-b, --save_by_block`: 按块保存(推荐用于 backward 转换) +- `-c, --chunk-size`: 分块大小(默认: 100,0 表示不分块) +- `--copy_no_weight_files`: 复制源目录中的非权重文件 + +### 性能参数 + +- `--parallel`: 启用并行处理(默认: True) +- `--no-parallel`: 禁用并行处理 + +## 支持的 LoRA 格式 + +工具自动检测并支持以下 LoRA 格式: + +1. **Standard**: `{key}.lora_up.weight` 和 `{key}.lora_down.weight` +2. **Diffusers**: `{key}_lora.up.weight` 和 `{key}_lora.down.weight` +3. **Diffusers V2**: `{key}.lora_B.weight` 和 `{key}.lora_A.weight` +4. **Diffusers V3**: `{key}.lora.up.weight` 和 `{key}.lora.down.weight` +5. **Mochi**: `{key}.lora_B` 和 `{key}.lora_A`(无 .weight 后缀) +6. **Transformers**: `{key}.lora_linear_layer.up.weight` 和 `{key}.lora_linear_layer.down.weight` +7. **Qwen**: `{key}.lora_B.default.weight` 和 `{key}.lora_A.default.weight` + +此外还支持差值(diff)格式: +- `.diff`: 权重差值 +- `.diff_b`: bias 差值 +- `.diff_m`: modulation 差值 + +## 使用示例 + +### 1. 模型量化 + +#### 1.1 Wan DiT 量化为 INT8 + +**多个 safetensors,按 dit block 存储** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan_int8 \ + --linear_dtype torch.int8 \ + --model_type wan_dit \ + --quantized \ + --save_by_block +``` + +**单个 safetensor 文件** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan2.1_i2v_480p_int8_lightx2v \ + --linear_dtype torch.int8 \ + --model_type wan_dit \ + --quantized \ + --single_file +``` + +#### 1.2 Wan DiT 量化为 FP8 + +**多个 safetensors,按 dit block 存储** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan_fp8 \ + --linear_dtype torch.float8_e4m3fn \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_dit \ + --quantized \ + --save_by_block +``` + +**单个 safetensor 文件** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v \ + --linear_dtype torch.float8_e4m3fn \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_dit \ + --quantized \ + --single_file +``` + +**ComfyUI 的 scaled_fp8 格式** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \ + --linear_dtype torch.float8_e4m3fn \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_dit \ + --quantized \ + --single_file \ + --comfyui_mode +``` + +**ComfyUI 的全 FP8 格式** +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \ + --linear_dtype torch.float8_e4m3fn \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_dit \ + --quantized \ + --single_file \ + --comfyui_mode \ + --full_quantized +``` + +> **提示**: 对于其他 DIT 模型,切换 `--model_type` 参数即可 + +#### 1.3 T5 编码器量化 + +**INT8 量化** +```bash +python converter.py \ + --source /path/to/models_t5_umt5-xxl-enc-bf16.pth \ + --output /path/to/output \ + --output_ext .pth \ + --output_name models_t5_umt5-xxl-enc-int8 \ + --linear_dtype torch.int8 \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_t5 \ + --quantized +``` + +**FP8 量化** +```bash +python converter.py \ + --source /path/to/models_t5_umt5-xxl-enc-bf16.pth \ + --output /path/to/output \ + --output_ext .pth \ + --output_name models_t5_umt5-xxl-enc-fp8 \ + --linear_dtype torch.float8_e4m3fn \ + --non_linear_dtype torch.bfloat16 \ + --model_type wan_t5 \ + --quantized +``` + +#### 1.4 CLIP 编码器量化 + +**INT8 量化** +```bash +python converter.py \ + --source /path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ + --output /path/to/output \ + --output_ext .pth \ + --output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8 \ + --linear_dtype torch.int8 \ + --non_linear_dtype torch.float16 \ + --model_type wan_clip \ + --quantized +``` + +**FP8 量化** +```bash +python converter.py \ + --source /path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ + --output /path/to/output \ + --output_ext .pth \ + --output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8 \ + --linear_dtype torch.float8_e4m3fn \ + --non_linear_dtype torch.float16 \ + --model_type wan_clip \ + --quantized +``` + +#### 1.5 Qwen25_vl 語言部分量化 + +**INT8 量化** +```bash +python converter.py \ + --source /path/to/hunyuanvideo-1.5/text_encoder/llm \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name qwen25vl-llm-int8 \ + --linear_dtype torch.int8 \ + --non_linear_dtype torch.float16 \ + --model_type qwen25vl_llm \ + --quantized \ + --single_file +``` + +**FP8 量化** +```bash +python converter.py \ + --source /path/to/hunyuanvideo-1.5/text_encoder/llm \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name qwen25vl-llm-fp8 \ + --linear_dtype torch.float8_e4m3fn \ + --non_linear_dtype torch.float16 \ + --model_type qwen25vl_llm \ + --quantized \ + --single_file +``` + +### 2. LoRA 融合 + +#### 2.1 融合单个 LoRA + +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_model \ + --model_type wan_dit \ + --lora_path /path/to/lora.safetensors \ + --lora_strength 1.0 \ + --single_file +``` + +#### 2.2 融合多个 LoRA + +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_model \ + --model_type wan_dit \ + --lora_path /path/to/lora1.safetensors /path/to/lora2.safetensors \ + --lora_strength 1.0 0.8 \ + --single_file +``` + +#### 2.3 LoRA 融合后量化 + +**LoRA 融合 → FP8 量化** +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_quantized \ + --model_type wan_dit \ + --lora_path /path/to/lora.safetensors \ + --lora_strength 1.0 \ + --quantized \ + --linear_dtype torch.float8_e4m3fn \ + --single_file +``` + +**LoRA 融合 → ComfyUI scaled_fp8** +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_quantized \ + --model_type wan_dit \ + --lora_path /path/to/lora.safetensors \ + --lora_strength 1.0 \ + --quantized \ + --linear_dtype torch.float8_e4m3fn \ + --single_file \ + --comfyui_mode +``` + +**LoRA 融合 → ComfyUI 全 FP8** +```bash +python converter.py \ + --source /path/to/base_model/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_quantized \ + --model_type wan_dit \ + --lora_path /path/to/lora.safetensors \ + --lora_strength 1.0 \ + --quantized \ + --linear_dtype torch.float8_e4m3fn \ + --single_file \ + --comfyui_mode \ + --full_quantized +``` + +#### 2.4 LoRA 键转换模式 + +**自动检测模式(推荐)** +```bash +python converter.py \ + --source /path/to/model/ \ + --output /path/to/output \ + --lora_path /path/to/lora.safetensors \ + --lora_key_convert auto \ + --single_file +``` + +**使用原始键名(LoRA 已经是目标格式)** +```bash +python converter.py \ + --source /path/to/model/ \ + --output /path/to/output \ + --direction forward \ + --lora_path /path/to/lora.safetensors \ + --lora_key_convert same \ + --single_file +``` + +**应用转换(LoRA 使用源格式)** +```bash +python converter.py \ + --source /path/to/model/ \ + --output /path/to/output \ + --direction forward \ + --lora_path /path/to/lora.safetensors \ + --lora_key_convert convert \ + --single_file +``` + +### 3. 架构格式转换 + +#### 3.1 LightX2V → Diffusers + +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P \ + --output /path/to/Wan2.1-I2V-14B-480P-Diffusers \ + --output_ext .safetensors \ + --model_type wan_dit \ + --direction forward \ + --chunk-size 100 +``` + +#### 3.2 Diffusers → LightX2V + +```bash +python converter.py \ + --source /path/to/Wan2.1-I2V-14B-480P-Diffusers \ + --output /path/to/Wan2.1-I2V-14B-480P \ + --output_ext .safetensors \ + --model_type wan_dit \ + --direction backward \ + --save_by_block +``` + +### 4. 格式转换 + +#### 4.1 .pth → .safetensors + +```bash +python converter.py \ + --source /path/to/model.pth \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name model \ + --single_file +``` + +#### 4.2 多个 .safetensors → 单文件 + +```bash +python converter.py \ + --source /path/to/model_directory/ \ + --output /path/to/output \ + --output_ext .safetensors \ + --output_name merged_model \ + --single_file +``` diff --git a/tools/convert/seko_talk_converter.py b/tools/convert/seko_talk_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3c1dd19e884b25b00389823c6a3dad5cae0a15 --- /dev/null +++ b/tools/convert/seko_talk_converter.py @@ -0,0 +1,452 @@ +""" +Model Merge and Multi-Precision Conversion Script + +This script supports three conversion modes: +1. 'both' (default): Convert both R2V model and audio adapter +2. 'r2v': Only convert R2V model (R2V + distill via LoRA) +3. 'audio': Only convert audio adapter + +Pipeline: +- R2V model: R2V + distill via LoRA → merged.safetensors (FP32) → BF16/FP8 +- Audio adapter: (optional: + LoRA) → audio_adapter.pt → BF16 → FP8 + +Usage Examples: + # Convert both (default) + python tools/convert/seko_talk_converter.py \ + --r2v_model /path/to/model.pt \ + --distill_model /path/to/model_ema.pt \ + --audio_adapter /path/to/audio_adapter.pt \ + --output_dir /data/output + + # Only convert R2V model + python tools/convert/seko_talk_converter.py \ + --mode r2v \ + --r2v_model /path/to/model.pt \ + --distill_model /path/to/model_ema.pt \ + --output_dir /data/output + + # Only convert audio adapter + python tools/convert/seko_talk_converter.py \ + --mode audio \ + --audio_adapter /path/to/audio_adapter.pt \ + --output_dir /data/output + + # Convert audio adapter with LoRA merge + python tools/convert/seko_talk_converter.py \ + --mode audio \ + --audio_adapter /path/to/audio_adapter.pt \ + --audio_lora /path/to/audio_lora.pt \ + --output_dir /data/output + +Output files (depending on mode): + - merged.safetensors (FP32, R2V + distill merged) + - merged_bf16.safetensors (BF16) + - merged_fp8.safetensors (FP8) + - audio_adapter_merged.safetensors (FP32, audio + lora merged, optional) + - audio_adapter_model.safetensors (BF16) + - audio_adapter_model_fp8.safetensors (FP8) +""" + +import argparse +import subprocess +import sys +from pathlib import Path + +import torch +from loguru import logger +from safetensors.torch import load_file, save_file +from tqdm import tqdm + + +def run_command(cmd: list, description: str): + """Run a subprocess command and handle errors.""" + logger.info(f"\n{description}") + logger.info("Command: " + " \\\n ".join(cmd)) + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + logger.error(f"{description} FAILED!") + logger.error(f"STDOUT:\n{result.stdout}") + logger.error(f"STDERR:\n{result.stderr}") + raise RuntimeError(f"{description} failed") + + logger.info(f"✓ {description} completed!") + return result + + +def load_checkpoint(ckpt_path: Path) -> dict: + """Load checkpoint from .pt or .safetensors file.""" + logger.info(f"Loading: {ckpt_path.name}") + + if ckpt_path.suffix in [".pt", ".pth"]: + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + elif ckpt_path.suffix == ".safetensors": + checkpoint = load_file(str(ckpt_path)) + else: + raise ValueError(f"Unsupported format: {ckpt_path.suffix}") + + logger.info(f" Loaded {len(checkpoint)} keys") + return checkpoint + + +def convert_to_bf16(state_dict: dict) -> dict: + """Convert all tensors to bfloat16.""" + logger.info("Converting to BF16...") + bf16_dict = {} + for key, tensor in tqdm(state_dict.items(), desc="BF16 conversion"): + bf16_dict[key] = tensor.to(torch.bfloat16) + return bf16_dict + + +def step1_merge_via_lora(r2v_model_path: Path, distill_model_path: Path, output_dir: Path, lora_alpha: float, temp_dir: Path) -> Path: + """ + Step 1: Merge R2V + distillation model via LoRA using converter.py. + Both models in FP32, output merged.safetensors (FP32). + """ + logger.info("=" * 80) + logger.info("STEP 1: Merge R2V + Distillation via LoRA (FP32)") + logger.info("=" * 80) + + temp_dir.mkdir(parents=True, exist_ok=True) + + # Convert R2V to safetensors (keep FP32) + logger.info("\n[1.1] Converting R2V model to safetensors (FP32)...") + r2v_dict = load_checkpoint(r2v_model_path) + r2v_safetensors = temp_dir / "model.safetensors" + save_file(r2v_dict, str(r2v_safetensors)) + logger.info(f" Saved: {r2v_safetensors}") + + # Convert distill to safetensors (keep FP32 for LoRA merge) + logger.info("\n[1.2] Converting distillation model to safetensors (FP32)...") + distill_dict = load_checkpoint(distill_model_path) + distill_safetensors = temp_dir / "model_ema.safetensors" + save_file(distill_dict, str(distill_safetensors)) + logger.info(f" Saved: {distill_safetensors}") + + # Merge via LoRA using converter.py (FP32 + FP32 → FP32) + logger.info("\n[1.3] Merging via LoRA (converter.py)...") + cmd = [ + "python", + "tools/convert/converter.py", + "-s", + str(r2v_safetensors), + "-o", + str(output_dir), + "-o_n", + "merged", + "--lora_path", + str(distill_safetensors), + "--lora_alpha", + str(lora_alpha), + "--single_file", + ] + + run_command(cmd, "LoRA merge") + + merged_path = output_dir / "merged.safetensors" + if not merged_path.exists(): + raise FileNotFoundError(f"Merged file not found: {merged_path}") + + logger.info(f" ✓ Created: {merged_path} (FP32)") + return merged_path + + +def step2_convert_merged_to_bf16(merged_path: Path, output_dir: Path): + """ + Step 2: Convert merged.safetensors (FP32) to BF16. + """ + logger.info("=" * 80) + logger.info("STEP 2: Convert merged.safetensors (FP32) → BF16") + logger.info("=" * 80) + + merged_dict = load_file(str(merged_path)) + merged_bf16 = convert_to_bf16(merged_dict) + + bf16_path = output_dir / "merged_bf16.safetensors" + save_file(merged_bf16, str(bf16_path)) + logger.info(f" ✓ Created: {bf16_path}") + + +def step3_convert_merged_to_fp8(merged_path: Path, output_dir: Path, device: str = "cuda"): + """ + Step 3: Convert merged.safetensors (FP32) to FP8 using converter.py --quantized. + """ + logger.info("=" * 80) + logger.info("STEP 3: Convert merged.safetensors (FP32) → FP8") + logger.info("=" * 80) + + cmd = [ + "python", + "tools/convert/converter.py", + "-s", + str(merged_path), + "-o", + str(output_dir), + "-o_n", + "merged_fp8", + "--linear_type", + "fp8", + "--quantized", + "--device", + device, + "--single_file", + ] + + run_command(cmd, "Merged FP8 conversion") + + fp8_path = output_dir / "merged_fp8.safetensors" + logger.info(f" ✓ Created: {fp8_path}") + + +def step_audio_merge_lora(audio_adapter_path: Path, audio_lora_path: Path, output_dir: Path, lora_alpha: float, temp_dir: Path) -> Path: + """ + Merge audio adapter + LoRA using converter.py. + Both in FP32, output audio_adapter_merged.safetensors (FP32). + """ + logger.info("=" * 80) + logger.info("AUDIO STEP 1: Merge Audio Adapter + LoRA (FP32)") + logger.info("=" * 80) + + temp_dir.mkdir(parents=True, exist_ok=True) + + logger.info("\n[1.1] Converting audio adapter to safetensors (FP32)...") + audio_dict = load_checkpoint(audio_adapter_path) + audio_safetensors = temp_dir / "audio_adapter.safetensors" + save_file(audio_dict, str(audio_safetensors)) + logger.info(f" Saved: {audio_safetensors}") + + logger.info("\n[1.2] Converting audio LoRA to safetensors (FP32)...") + lora_dict = load_checkpoint(audio_lora_path) + lora_safetensors = temp_dir / "audio_lora.safetensors" + save_file(lora_dict, str(lora_safetensors)) + logger.info(f" Saved: {lora_safetensors}") + + logger.info("\n[1.3] Merging via LoRA (converter.py)...") + cmd = [ + "python", + "tools/convert/converter.py", + "-s", + str(audio_safetensors), + "-o", + str(output_dir), + "-o_n", + "audio_adapter_merged", + "--lora_path", + str(lora_safetensors), + "--lora_alpha", + str(lora_alpha), + "--single_file", + ] + + run_command(cmd, "Audio LoRA merge") + + merged_path = output_dir / "audio_adapter_merged.safetensors" + if not merged_path.exists(): + raise FileNotFoundError(f"Merged audio file not found: {merged_path}") + + logger.info(f" ✓ Created: {merged_path} (FP32)") + return merged_path + + +def step4_convert_audio_adapter_to_bf16(audio_adapter_path: Path, output_dir: Path): + """ + Step 4: Convert audio adapter to BF16. + """ + logger.info("=" * 80) + logger.info("AUDIO STEP 2: Convert audio adapter → BF16") + logger.info("=" * 80) + + audio_dict = load_checkpoint(audio_adapter_path) + audio_bf16 = convert_to_bf16(audio_dict) + + bf16_path = output_dir / "audio_adapter_model.safetensors" + save_file(audio_bf16, str(bf16_path)) + logger.info(f" ✓ Created: {bf16_path}") + + +def step5_convert_audio_adapter_to_fp8(output_dir: Path): + """ + Step 5: Convert audio adapter BF16 to FP8 using quant_adapter.py. + """ + logger.info("=" * 80) + logger.info("AUDIO STEP 3: Convert audio adapter → FP8") + logger.info("=" * 80) + + input_path = output_dir / "audio_adapter_model.safetensors" + output_path = output_dir / "audio_adapter_model_fp8.safetensors" + + cmd = ["python", "tools/convert/quant_adapter.py", "--model_path", str(input_path), "--output_path", str(output_path)] + + run_command(cmd, "Audio adapter FP8 conversion") + + logger.info(f" ✓ Created: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Merge R2V+distill via LoRA and convert to multiple formats") + + # Mode selection + parser.add_argument("--mode", type=str, choices=["both", "r2v", "audio"], default="both", help="Conversion mode: 'both' (default), 'r2v' (only R2V model), or 'audio' (only audio adapter)") + + # Inputs (conditionally required based on mode) + parser.add_argument("--r2v_model", type=str, help="Path to R2V model (.pt) [required for 'both' and 'r2v' modes]") + parser.add_argument("--distill_model", type=str, help="Path to distillation model (.pt) [required for 'both' and 'r2v' modes]") + parser.add_argument("--audio_adapter", type=str, help="Path to audio adapter (.pt) [required for 'both' and 'audio' modes]") + parser.add_argument("--audio_lora", type=str, help="Path to audio LoRA (.pt/.safetensors) [optional, for merging with audio adapter]") + parser.add_argument("--audio_lora_alpha", type=float, default=8.0, help="Alpha for audio LoRA merge (default: 8.0)") + + # Outputs + parser.add_argument("--output_dir", type=str, required=True, help="Output directory") + parser.add_argument("--temp_dir", type=str, default=None, help="Temp directory (default: output_dir/temp)") + + # Settings + parser.add_argument("--lora_alpha", type=float, default=8.0, help="Alpha for LoRA merge (default: 8.0)") + parser.add_argument("--device", type=str, default="cuda", help="Device for FP8 quantization (default: cuda)") + + # Options + parser.add_argument("--skip_merged_fp8", action="store_true", help="Skip merged FP8 conversion") + parser.add_argument("--skip_audio_fp8", action="store_true", help="Skip audio adapter FP8 conversion") + + args = parser.parse_args() + + # Validate required arguments based on mode + if args.mode in ["both", "r2v"]: + if not args.r2v_model or not args.distill_model: + parser.error("--r2v_model and --distill_model are required for 'both' and 'r2v' modes") + + if args.mode in ["both", "audio"]: + if not args.audio_adapter: + parser.error("--audio_adapter is required for 'both' and 'audio' modes") + + # Setup paths + output_dir = Path(args.output_dir) + temp_dir = Path(args.temp_dir) if args.temp_dir else output_dir / "temp" + + r2v_path = Path(args.r2v_model) if args.r2v_model else None + distill_path = Path(args.distill_model) if args.distill_model else None + audio_path = Path(args.audio_adapter) if args.audio_adapter else None + audio_lora_path = Path(args.audio_lora) if args.audio_lora else None + + # Validate file existence + if r2v_path and not r2v_path.exists(): + raise FileNotFoundError(f"R2V model not found: {r2v_path}") + if distill_path and not distill_path.exists(): + raise FileNotFoundError(f"Distill model not found: {distill_path}") + if audio_path and not audio_path.exists(): + raise FileNotFoundError(f"Audio adapter not found: {audio_path}") + if audio_lora_path and not audio_lora_path.exists(): + raise FileNotFoundError(f"Audio LoRA not found: {audio_lora_path}") + + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("=" * 80) + logger.info("MODEL CONVERSION PIPELINE") + logger.info("=" * 80) + logger.info(f"Mode: {args.mode}") + if r2v_path: + logger.info(f"R2V model: {r2v_path}") + if distill_path: + logger.info(f"Distill model: {distill_path}") + if audio_path: + logger.info(f"Audio adapter: {audio_path}") + if audio_lora_path: + logger.info(f"Audio LoRA: {audio_lora_path}") + logger.info(f"Output dir: {output_dir}") + if args.mode in ["both", "r2v"]: + logger.info(f"LoRA alpha: {args.lora_alpha}") + if audio_lora_path: + logger.info(f"Audio LoRA alpha: {args.audio_lora_alpha}") + logger.info(f"Device: {args.device}") + logger.info("=" * 80) + + # Execute pipeline based on mode + try: + merged_path = None + + # Process R2V model (modes: 'both', 'r2v') + if args.mode in ["both", "r2v"]: + logger.info("\n>>> Processing R2V MODEL") + + # Step 1: Merge R2V + Distill via LoRA + merged_path = step1_merge_via_lora(r2v_path, distill_path, output_dir, args.lora_alpha, temp_dir) + + # Step 2: Convert merged to BF16 + step2_convert_merged_to_bf16(merged_path, output_dir) + + # Step 3: Convert merged to FP8 + if not args.skip_merged_fp8: + step3_convert_merged_to_fp8(merged_path, output_dir, args.device) + + # Process audio adapter (modes: 'both', 'audio') + if args.mode in ["both", "audio"]: + logger.info("\n>>> Processing AUDIO ADAPTER") + + audio_source_path = audio_path + + # Optional: Merge audio adapter + LoRA + if audio_lora_path: + audio_source_path = step_audio_merge_lora(audio_path, audio_lora_path, output_dir, args.audio_lora_alpha, temp_dir) + + # Convert audio adapter to BF16 + step4_convert_audio_adapter_to_bf16(audio_source_path, output_dir) + + # Convert audio adapter to FP8 + if not args.skip_audio_fp8: + step5_convert_audio_adapter_to_fp8(output_dir) + + except Exception as e: + logger.error(f"\n{'=' * 80}") + logger.error("PIPELINE FAILED") + logger.error(f"{'=' * 80}") + logger.error(f"Error: {e}") + sys.exit(1) + + # Summary + logger.info("\n" + "=" * 80) + logger.info("✓ PIPELINE COMPLETED SUCCESSFULLY!") + logger.info("=" * 80) + logger.info(f"\nMode: {args.mode}") + logger.info(f"Output directory: {output_dir}\n") + logger.info("Generated files:") + + # Show files based on mode + if args.mode in ["both", "r2v"]: + logger.info(" ✓ merged.safetensors (FP32, R2V+distill merged)") + logger.info(" ✓ merged_bf16.safetensors (BF16)") + if not args.skip_merged_fp8: + logger.info(" ✓ merged_fp8.safetensors (FP8)") + + if args.mode in ["both", "audio"]: + if audio_lora_path: + logger.info(" ✓ audio_adapter_merged.safetensors (FP32, audio+lora merged)") + logger.info(" ✓ audio_adapter_model.safetensors (BF16)") + if not args.skip_audio_fp8: + logger.info(" ✓ audio_adapter_model_fp8.safetensors (FP8)") + + if args.mode in ["both", "r2v"]: + logger.info(f"\nTemp files: {temp_dir}") + + # Show conversion flow + logger.info("\nConversion flow:") + if args.mode in ["both", "r2v"]: + logger.info(" R2V model:") + logger.info(" 1. R2V (FP32) + Distill (FP32) --LoRA--> merged.safetensors (FP32)") + logger.info(" 2. merged.safetensors (FP32) --> merged_bf16.safetensors") + if not args.skip_merged_fp8: + logger.info(" 3. merged.safetensors (FP32) --> merged_fp8.safetensors") + + if args.mode in ["both", "audio"]: + logger.info(" Audio adapter:") + step_num = 1 + if audio_lora_path: + logger.info(f" {step_num}. audio_adapter.pt + audio_lora --LoRA--> audio_adapter_merged.safetensors (FP32)") + step_num += 1 + logger.info(f" {step_num}. audio_adapter --> audio_adapter_model.safetensors (BF16)") + step_num += 1 + if not args.skip_audio_fp8: + logger.info(f" {step_num}. audio_adapter_model.safetensors --> audio_adapter_model_fp8.safetensors") + + +if __name__ == "__main__": + main() diff --git a/tools/download_rife.py b/tools/download_rife.py new file mode 100644 index 0000000000000000000000000000000000000000..dd77c2fcc5c31f400e78681a7ad1f8d933554c25 --- /dev/null +++ b/tools/download_rife.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# coding: utf-8 + +import argparse +import os +import shutil +import sys +import zipfile +from pathlib import Path + +import requests + + +def get_base_dir(): + """Get project root directory""" + return Path(__file__).parent.parent + + +def download_file(url, save_path): + """Download file""" + print(f"Starting download: {url}") + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + downloaded_size = 0 + + with open(save_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + downloaded_size += len(chunk) + if total_size > 0: + progress = (downloaded_size / total_size) * 100 + print(f"\rDownload progress: {progress:.1f}%", end="", flush=True) + + print(f"\nDownload completed: {save_path}") + + +def extract_zip(zip_path, extract_to): + """Extract zip file""" + print(f"Starting extraction: {zip_path}") + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_to) + print(f"Extraction completed: {extract_to}") + + +def find_flownet_pkl(extract_dir): + """Find flownet.pkl file in extracted directory""" + for root, dirs, files in os.walk(extract_dir): + for file in files: + if file == "flownet.pkl": + return os.path.join(root, file) + return None + + +def main(): + parser = argparse.ArgumentParser(description="Download RIFE model to specified directory") + parser.add_argument("target_directory", help="Target directory path") + + args = parser.parse_args() + + target_dir = Path(args.target_directory) + if not target_dir.is_absolute(): + target_dir = Path.cwd() / target_dir + + base_dir = get_base_dir() + temp_dir = base_dir / "_temp" + + # Create temporary directory + temp_dir.mkdir(exist_ok=True) + + target_dir.mkdir(parents=True, exist_ok=True) + + zip_url = "https://huggingface.co/hzwer/RIFE/resolve/main/RIFEv4.26_0921.zip" + zip_path = temp_dir / "RIFEv4.26_0921.zip" + + try: + # Download zip file + download_file(zip_url, zip_path) + + # Extract file + extract_zip(zip_path, temp_dir) + + # Find flownet.pkl file + flownet_pkl = find_flownet_pkl(temp_dir) + if flownet_pkl: + # Copy flownet.pkl to target directory + target_file = target_dir / "flownet.pkl" + shutil.copy2(flownet_pkl, target_file) + print(f"flownet.pkl copied to: {target_file}") + else: + print("Error: flownet.pkl file not found") + return 1 + + print("RIFE model download and installation completed!") + return 0 + + except Exception as e: + print(f"Error: {e}") + return 1 + finally: + # Clean up temporary files + print("Cleaning up temporary files...") + + # Delete zip file if exists + if zip_path.exists(): + try: + zip_path.unlink() + print(f"Deleted: {zip_path}") + except Exception as e: + print(f"Error deleting zip file: {e}") + + # Delete extracted folders + for item in temp_dir.iterdir(): + if item.is_dir(): + try: + shutil.rmtree(item) + print(f"Deleted directory: {item}") + except Exception as e: + print(f"Error deleting directory {item}: {e}") + + # Delete the temp directory itself if empty + if temp_dir.exists() and not any(temp_dir.iterdir()): + try: + temp_dir.rmdir() + print(f"Deleted temp directory: {temp_dir}") + except Exception as e: + print(f"Error deleting temp directory: {e}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/extract/convert_vigen_to_x2v_lora.py b/tools/extract/convert_vigen_to_x2v_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..171850efecf70fac17d27d445911db6dd3f904c0 --- /dev/null +++ b/tools/extract/convert_vigen_to_x2v_lora.py @@ -0,0 +1,144 @@ +### Using this script to convert ViGen-DiT Lora Format to Lightx2v +### +### Cmd line:python convert_vigen_to_x2v_lora.py model_lora.pt model_lora_converted.safetensors +### +### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT +### +import os +import sys + +import torch +from safetensors.torch import load_file, save_file + +if len(sys.argv) != 3: + print("用法: python convert_lora.py <输入文件> <输出文件.safetensors>") + sys.exit(1) + +ckpt_path = sys.argv[1] +output_path = sys.argv[2] + +if not os.path.exists(ckpt_path): + print(f"❌ 输入文件不存在: {ckpt_path}") + sys.exit(1) + +if ckpt_path.endswith(".safetensors"): + state_dict = load_file(ckpt_path) +else: + state_dict = torch.load(ckpt_path, map_location="cpu") + +if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] +elif "model" in state_dict: + state_dict = state_dict["model"] + +mapped_dict = {} + +# 映射表定义 +attn_map = { + "attn1": "self_attn", + "attn2": "cross_attn", +} +proj_map = { + "to_q": "q", + "to_k": "k", + "to_v": "v", + "to_out": "o", + "add_k_proj": "k_img", + "add_v_proj": "v_img", +} +lora_map = { + "lora_A": "lora_down", + "lora_B": "lora_up", +} + +for k, v in state_dict.items(): + # 预处理:将 to_out.0 / to_out.1 统一替换为 to_out + k = k.replace("to_out.0", "to_out").replace("to_out.1", "to_out") + k = k.replace(".default", "") # 去除.default + + parts = k.split(".") + + # === Attention Blocks === + if k.startswith("blocks.") and len(parts) >= 5: + block_id = parts[1] + + if parts[2].startswith("attn"): + attn_raw = parts[2] + proj_raw = parts[3] + lora_raw = parts[4] + + if attn_raw in attn_map and proj_raw in proj_map and lora_raw in lora_map: + attn_name = attn_map[attn_raw] + proj_name = proj_map[proj_raw] + lora_name = lora_map[lora_raw] + new_k = f"diffusion_model.blocks.{block_id}.{attn_name}.{proj_name}.{lora_name}.weight" + mapped_dict[new_k] = v + continue + else: + print(f"无法映射 attention key: {k}") + continue + # === FFN Blocks === + elif parts[2] == "ffn": + if parts[3:6] == ["net", "0", "proj"]: + layer_id = "0" + lora_raw = parts[6] + elif parts[3:5] == ["net", "2"]: + layer_id = "2" + lora_raw = parts[5] + else: + print(f"无法解析 FFN key: {k}") + continue + + if lora_raw not in lora_map: + print(f"未知 FFN LoRA 类型: {k}") + continue + + lora_name = lora_map[lora_raw] + new_k = f"diffusion_model.blocks.{block_id}.ffn.{layer_id}.{lora_name}.weight" + mapped_dict[new_k] = v + continue + # === Text Embedding === + elif k.startswith("condition_embedder.text_embedder.linear_"): + layer_id = parts[2].split("_")[1] + lora_raw = parts[3] + if lora_raw in lora_map: + lora_name = lora_map[lora_raw] + new_k = f"diffusion_model.text_embedding.{layer_id}.{lora_name}.weight" + mapped_dict[new_k] = v + continue + else: + print(f"text_embedder 未知 LoRA 类型: {k}") + continue + """ + # === Time Embedding === + elif k.startswith("condition_embedder.time_embedder.linear_"): + layer_id = parts[2].split("_")[1] + lora_raw = parts[3] + if lora_raw in lora_map: + lora_name = lora_map[lora_raw] + new_k = f"diffusion_model.time_embedding.{layer_id}.{lora_name}.weight" + mapped_dict[new_k] = v + continue + else: + print(f"time_embedder 未知 LoRA 类型: {k}") + continue + + # === Time Projection === + elif k.startswith("condition_embedder.time_proj."): + lora_raw = parts[2] + if lora_raw in lora_map: + lora_name = lora_map[lora_raw] + new_k = f"diffusion_model.time_projection.1.{lora_name}.weight" + mapped_dict[new_k] = v + continue + else: + print(f"time_proj 未知 LoRA 类型: {k}") + continue + """ + # fallback + print(f"未识别结构 key: {k}") + +# 保存 +print(f"\n✅ 成功重命名 {len(mapped_dict)} 个 LoRA 参数") +save_file(mapped_dict, output_path) +print(f"💾 已保存为: {output_path}") diff --git a/tools/extract/lora_extractor.py b/tools/extract/lora_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..05ffa528f5df0e0871bfc871939dc159d4e6d5d8 --- /dev/null +++ b/tools/extract/lora_extractor.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +LoRA Extractor Script +Extract LoRA weights from the difference between two models +""" + +import argparse +import os +from typing import Dict, Optional + +import torch +from safetensors import safe_open +from safetensors import torch as st +from tqdm import tqdm + + +def _get_torch_dtype(dtype_str: str) -> torch.dtype: + """ + Convert string to torch data type + + Args: + dtype_str: Data type string + + Returns: + Torch data type + """ + dtype_mapping = { + "float32": torch.float32, + "fp32": torch.float32, + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + } + + if dtype_str not in dtype_mapping: + raise ValueError(f"Unsupported data type: {dtype_str}") + + return dtype_mapping[dtype_str] + + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="Extract LoRA weights from the difference between source and target models", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + # Source model parameters + parser.add_argument("--source-model", type=str, required=True, help="Path to source model") + parser.add_argument("--source-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Source model format type") + + # Target model parameters + parser.add_argument("--target-model", type=str, required=True, help="Path to target model (fine-tuned model)") + parser.add_argument("--target-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Target model format type") + + # Output parameters + parser.add_argument("--output", type=str, required=True, help="Path to output LoRA model") + parser.add_argument("--output-format", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Output LoRA model format") + + # LoRA related parameters + parser.add_argument("--rank", type=int, default=32, help="LoRA rank value") + + parser.add_argument("--output-dtype", type=str, choices=["float32", "fp32", "float16", "fp16", "bfloat16", "bf16"], default="bf16", help="Output weight data type") + parser.add_argument("--diff-only", action="store_true", help="Save all weights as direct diff without LoRA decomposition") + + return parser.parse_args() + + +def load_model_weights(model_path: str, model_type: str) -> Dict[str, torch.Tensor]: + """ + Load model weights (using fp32 precision) + + Args: + model_path: Model file path or directory path + model_type: Model type ("safetensors" or "pytorch") + + Returns: + Model weights dictionary (fp32 precision) + """ + print(f"Loading model: {model_path} (type: {model_type}, precision: fp32)") + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model path does not exist: {model_path}") + + weights = {} + + if model_type == "safetensors": + if os.path.isdir(model_path): + # If it's a directory, load all .safetensors files in the directory + safetensors_files = [] + for file in os.listdir(model_path): + if file.endswith(".safetensors"): + safetensors_files.append(os.path.join(model_path, file)) + + if not safetensors_files: + raise ValueError(f"No .safetensors files found in directory: {model_path}") + + print(f"Found {len(safetensors_files)} safetensors files") + + # Load all files and merge weights + for file_path in sorted(safetensors_files): + print(f" Loading file: {os.path.basename(file_path)}") + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key in weights: + print(f"Warning: weight key '{key}' is duplicated in multiple files, will be overwritten") + weights[key] = f.get_tensor(key) + + elif os.path.isfile(model_path): + # If it's a single file + if model_path.endswith(".safetensors"): + with safe_open(model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + weights[key] = f.get_tensor(key) + else: + raise ValueError(f"safetensors type file should end with .safetensors: {model_path}") + else: + raise ValueError(f"Invalid path type: {model_path}") + + elif model_type == "pytorch": + # Load pytorch format (.pt, .pth) + if model_path.endswith((".pt", ".pth")): + checkpoint = torch.load(model_path, map_location="cpu") + + # Handle possible nested structure + if isinstance(checkpoint, dict): + if "state_dict" in checkpoint: + weights = checkpoint["state_dict"] + elif "model" in checkpoint: + weights = checkpoint["model"] + else: + weights = checkpoint + else: + weights = checkpoint + else: + raise ValueError(f"pytorch type file should end with .pt or .pth: {model_path}") + else: + raise ValueError(f"Unsupported model type: {model_type}") + + # Convert all floating point weights to fp32 to ensure computational precision + print("Converting weights to fp32 to ensure computational precision...") + + converted_weights = {} + for key, tensor in weights.items(): + # Only convert floating point tensors, keep integer tensors unchanged + if tensor.dtype.is_floating_point: + converted_weights[key] = tensor.to(torch.float32) + else: + converted_weights[key] = tensor + + print(f"Successfully loaded model with {len(converted_weights)} weight tensors") + return converted_weights + + +def save_lora_weights(lora_weights: Dict[str, torch.Tensor], output_path: str, output_format: str, output_dtype: str = "bf16"): + """ + Save LoRA weights + + Args: + lora_weights: LoRA weights dictionary + output_path: Output path + output_format: Output format + output_dtype: Output data type + """ + print(f"Saving LoRA weights to: {output_path} (format: {output_format}, data type: {output_dtype})") + + # Ensure output directory exists + output_dir = os.path.dirname(output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + # Convert data type + target_dtype = _get_torch_dtype(output_dtype) + print(f"Converting LoRA weights to {output_dtype} type...") + + converted_weights = {} + with tqdm(lora_weights.items(), desc="Converting data type", unit="weights") as pbar: + for key, tensor in pbar: + # Only convert floating point tensors, keep integer tensors unchanged + if tensor.dtype.is_floating_point: + converted_weights[key] = tensor.to(target_dtype).contiguous() + else: + converted_weights[key] = tensor.contiguous() + + if output_format == "safetensors": + # Save as safetensors format + if not output_path.endswith(".safetensors"): + output_path += ".safetensors" + st.save_file(converted_weights, output_path) + + elif output_format == "pytorch": + # Save as pytorch format + if not output_path.endswith((".pt", ".pth")): + output_path += ".pt" + torch.save(converted_weights, output_path) + else: + raise ValueError(f"Unsupported output format: {output_format}") + + print(f"LoRA weights saved to: {output_path}") + + +def _compute_weight_diff(source_tensor: torch.Tensor, target_tensor: torch.Tensor, key: str) -> Optional[torch.Tensor]: + """ + Compute the difference between two weight tensors + + Args: + source_tensor: Source weight tensor + target_tensor: Target weight tensor + key: Weight key name (for logging) + + Returns: + Difference tensor, returns None if no change + """ + # Check if tensor shapes match + if source_tensor.shape != target_tensor.shape: + return None + + # Check if tensor data types match + if source_tensor.dtype != target_tensor.dtype: + target_tensor = target_tensor.to(source_tensor.dtype) + + # Compute difference + diff = target_tensor - source_tensor + + # Check if there are actual changes + if torch.allclose(diff, torch.zeros_like(diff), atol=1e-8): + # No change + return None + + return diff + + +def _decompose_to_lora(diff: torch.Tensor, key: str, rank: int) -> Dict[str, torch.Tensor]: + """ + Decompose weight difference into LoRA format + + Args: + diff: Weight difference tensor + key: Original weight key name + rank: LoRA rank + + Returns: + LoRA weights dictionary (containing lora_up and lora_down) + """ + # Ensure it's a 2D tensor + if len(diff.shape) != 2: + raise ValueError(f"LoRA decomposition only supports 2D weights, but got {len(diff.shape)}D tensor: {key}") + + a, b = diff.shape + + # Check if rank is reasonable + max_rank = min(a, b) + if rank > max_rank: + rank = max_rank + + # Choose compute device (prefer GPU, fallback to CPU) + device = "cuda" if torch.cuda.is_available() else "cpu" + diff_device = diff.to(device) + + # SVD decomposition + U, S, V = torch.linalg.svd(diff_device, full_matrices=False) + + # Take the first rank components + U = U[:, :rank] # (a, rank) + S = S[:rank] # (rank,) + V = V[:rank, :] # (rank, b) + + # Distribute square root of singular values to both matrices + S_sqrt = S.sqrt() + lora_up = U * S_sqrt.unsqueeze(0) # (a, rank) * (1, rank) = (a, rank) + lora_down = S_sqrt.unsqueeze(1) * V # (rank, 1) * (rank, b) = (rank, b) + + # Move back to CPU and convert to original data type, ensure contiguous + lora_up = lora_up.cpu().to(diff.dtype).contiguous() + lora_down = lora_down.cpu().to(diff.dtype).contiguous() + + # Generate LoRA weight key names + base_key = key.replace(".weight", "") + lora_up_key = "diffusion_model." + f"{base_key}.lora_up.weight" + lora_down_key = "diffusion_model." + f"{base_key}.lora_down.weight" + + # Return the decomposed weights + lora_weights = {lora_up_key: lora_up, lora_down_key: lora_down} + + return lora_weights + + +def extract_lora_from_diff(source_weights: Dict[str, torch.Tensor], target_weights: Dict[str, torch.Tensor], rank: int = 16, diff_only: bool = False) -> Dict[str, torch.Tensor]: + """ + Extract LoRA weights from model difference + + Args: + source_weights: Source model weights + target_weights: Target model weights + rank: LoRA rank + diff_only: If True, save all weights as direct diff without LoRA decomposition + + Returns: + LoRA weights dictionary + """ + print("Starting LoRA weight extraction...") + if diff_only: + print("Mode: Direct diff only (no LoRA decomposition)") + else: + print(f"Mode: Smart extraction - rank: {rank}") + print(f"Source model weight count: {len(source_weights)}") + print(f"Target model weight count: {len(target_weights)}") + + lora_weights = {} + processed_count = 0 + diff_count = 0 + lora_count = 0 + similar_count = 0 + skipped_count = 0 + fail_count = 0 + + # Find common keys between two models + common_keys = set(source_weights.keys()) & set(target_weights.keys()) + source_only_keys = set(source_weights.keys()) - set(target_weights.keys()) + target_only_keys = set(target_weights.keys()) - set(source_weights.keys()) + + if source_only_keys: + print(f"Warning: Source model exclusive weight keys ({len(source_only_keys)} keys): {list(source_only_keys)[:5]}...") + if target_only_keys: + print(f"Warning: Target model exclusive weight keys ({len(target_only_keys)} keys): {list(target_only_keys)[:5]}...") + + print(f"Common weight keys count: {len(common_keys)}") + + # Process common keys, extract LoRA weights + common_keys_sorted = sorted(common_keys) + pbar = tqdm(common_keys_sorted, desc="Extracting LoRA weights", unit="layer") + + for key in pbar: + source_tensor = source_weights[key] + target_tensor = target_weights[key] + + # Update progress bar description + short_key = key.split(".")[-2:] if "." in key else [key] + pbar.set_postfix_str(f"Processing: {'.'.join(short_key)}") + + # Compute weight difference + diff = _compute_weight_diff(source_tensor, target_tensor, key) + + if diff is None: + # No change or shape mismatch + if source_tensor.shape == target_tensor.shape: + similar_count += 1 + else: + skipped_count += 1 + continue + + # Calculate parameter count + param_count = source_tensor.numel() + is_1d = len(source_tensor.shape) == 1 + + # Decide whether to save diff directly or perform LoRA decomposition + if diff_only or is_1d or param_count < 1000000: + # Save diff directly + lora_key = _generate_lora_diff_key(key) + if lora_key == "skip": + skipped_count += 1 + continue + lora_weights[lora_key] = diff + diff_count += 1 + + else: + # Perform LoRA decomposition + if len(diff.shape) == 2 and key.endswith(".weight"): + try: + decomposed_weights = _decompose_to_lora(diff, key, rank) + lora_weights.update(decomposed_weights) + lora_count += 1 + except Exception as e: + print(f"Error: {e}") + fail_count += 1 + + else: + print(f"Error: {key} is not a 2D weight tensor") + fail_count += 1 + + processed_count += 1 + + # Close progress bar + pbar.close() + + print(f"\nExtraction statistics:") + print(f" Processed weights: {processed_count}") + print(f" Direct diff: {diff_count}") + print(f" LoRA decomposition: {lora_count}") + print(f" Skipped weights: {skipped_count}") + print(f" Similar weights: {similar_count}") + print(f" Failed weights: {fail_count}") + print(f" Total extracted LoRA weights: {len(lora_weights)}") + print("LoRA weight extraction completed") + + return lora_weights + + +def _generate_lora_diff_key(original_key: str) -> str: + """ + Generate LoRA weight key based on original weight key + + Args: + original_key: Original weight key name + + Returns: + LoRA weight key name + """ + ret_key = "diffusion_model." + original_key + if original_key.endswith(".weight"): + return ret_key.replace(".weight", ".diff") + elif original_key.endswith(".bias"): + return ret_key.replace(".bias", ".diff_b") + elif original_key.endswith(".modulation"): + return ret_key.replace(".modulation", ".diff_m") + else: + # If no matching suffix, skip + return "skip" + + +def main(): + """Main function""" + args = parse_args() + + print("=" * 50) + print("LoRA Extractor Started") + print("=" * 50) + print(f"Source model: {args.source_model} ({args.source_type})") + print(f"Target model: {args.target_model} ({args.target_type})") + print(f"Output path: {args.output} ({args.output_format})") + print(f"Output data type: {args.output_dtype}") + print(f"LoRA parameters: rank={args.rank}") + print(f"Diff only mode: {args.diff_only}") + print("=" * 50) + + try: + # Load source and target models + source_weights = load_model_weights(args.source_model, args.source_type) + target_weights = load_model_weights(args.target_model, args.target_type) + + # Extract LoRA weights + lora_weights = extract_lora_from_diff(source_weights, target_weights, rank=args.rank, diff_only=args.diff_only) + + # Save LoRA weights + save_lora_weights(lora_weights, args.output, args.output_format, args.output_dtype) + + print("=" * 50) + print("LoRA extraction completed!") + print("=" * 50) + + except Exception as e: + print(f"Error: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/tools/extract/lora_merger.py b/tools/extract/lora_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..1217da5fd79ff02c2b2e7f19977fe9b2d22f9f6a --- /dev/null +++ b/tools/extract/lora_merger.py @@ -0,0 +1,418 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +LoRA Merger Script +Merge a source model with LoRA weights to create a new model +""" + +import argparse +import os +from typing import Dict, Optional + +import torch +from safetensors import safe_open +from safetensors import torch as st +from tqdm import tqdm + + +def _get_torch_dtype(dtype_str: str) -> torch.dtype: + """ + Convert string to torch data type + + Args: + dtype_str: Data type string + + Returns: + Torch data type + """ + dtype_mapping = { + "float32": torch.float32, + "fp32": torch.float32, + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + } + + if dtype_str not in dtype_mapping: + raise ValueError(f"Unsupported data type: {dtype_str}") + + return dtype_mapping[dtype_str] + + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="Merge a source model with LoRA weights to create a new model", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + # Source model parameters + parser.add_argument("--source-model", type=str, required=True, help="Path to source model") + parser.add_argument("--source-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Source model format type") + + # LoRA parameters + parser.add_argument("--lora-model", type=str, required=True, help="Path to LoRA weights") + parser.add_argument("--lora-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="LoRA weights format type") + + # Output parameters + parser.add_argument("--output", type=str, required=True, help="Path to output merged model") + parser.add_argument("--output-format", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Output model format") + + # Merge parameters + parser.add_argument("--alpha", type=float, default=1.0, help="LoRA merge strength (alpha value)") + parser.add_argument("--output-dtype", type=str, choices=["float32", "fp32", "float16", "fp16", "bfloat16", "bf16"], default="bf16", help="Output weight data type") + + return parser.parse_args() + + +def load_model_weights(model_path: str, model_type: str) -> Dict[str, torch.Tensor]: + """ + Load model weights (using fp32 precision) + + Args: + model_path: Model file path or directory path + model_type: Model type ("safetensors" or "pytorch") + + Returns: + Model weights dictionary (fp32 precision) + """ + print(f"Loading model: {model_path} (type: {model_type}, precision: fp32)") + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model path does not exist: {model_path}") + + weights = {} + + if model_type == "safetensors": + if os.path.isdir(model_path): + # If it's a directory, load all .safetensors files in the directory + safetensors_files = [] + for file in os.listdir(model_path): + if file.endswith(".safetensors"): + safetensors_files.append(os.path.join(model_path, file)) + + if not safetensors_files: + raise ValueError(f"No .safetensors files found in directory: {model_path}") + + print(f"Found {len(safetensors_files)} safetensors files") + + # Load all files and merge weights + for file_path in sorted(safetensors_files): + print(f" Loading file: {os.path.basename(file_path)}") + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key in weights: + print(f"Warning: weight key '{key}' is duplicated in multiple files, will be overwritten") + weights[key] = f.get_tensor(key) + + elif os.path.isfile(model_path): + # If it's a single file + if model_path.endswith(".safetensors"): + with safe_open(model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + weights[key] = f.get_tensor(key) + else: + raise ValueError(f"safetensors type file should end with .safetensors: {model_path}") + else: + raise ValueError(f"Invalid path type: {model_path}") + + elif model_type == "pytorch": + # Load pytorch format (.pt, .pth) + if model_path.endswith((".pt", ".pth")): + checkpoint = torch.load(model_path, map_location="cpu") + + # Handle possible nested structure + if isinstance(checkpoint, dict): + if "state_dict" in checkpoint: + weights = checkpoint["state_dict"] + elif "model" in checkpoint: + weights = checkpoint["model"] + else: + weights = checkpoint + else: + weights = checkpoint + else: + raise ValueError(f"pytorch type file should end with .pt or .pth: {model_path}") + else: + raise ValueError(f"Unsupported model type: {model_type}") + + # Convert all floating point weights to fp32 to ensure computational precision + print("Converting weights to fp32 to ensure computational precision...") + + converted_weights = {} + for key, tensor in weights.items(): + # Only convert floating point tensors, keep integer tensors unchanged + if tensor.dtype.is_floating_point: + converted_weights[key] = tensor.to(torch.float32) + else: + converted_weights[key] = tensor + + print(f"Successfully loaded model with {len(converted_weights)} weight tensors") + return converted_weights + + +def save_model_weights(model_weights: Dict[str, torch.Tensor], output_path: str, output_format: str, output_dtype: str = "bf16"): + """ + Save model weights + + Args: + model_weights: Model weights dictionary + output_path: Output path + output_format: Output format + output_dtype: Output data type + """ + print(f"Saving merged model to: {output_path} (format: {output_format}, data type: {output_dtype})") + + # Ensure output directory exists + output_dir = os.path.dirname(output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + # Convert data type + target_dtype = _get_torch_dtype(output_dtype) + print(f"Converting model weights to {output_dtype} type...") + + converted_weights = {} + with tqdm(model_weights.items(), desc="Converting data type", unit="weights") as pbar: + for key, tensor in pbar: + # Only convert floating point tensors, keep integer tensors unchanged + if tensor.dtype.is_floating_point: + converted_weights[key] = tensor.to(target_dtype).contiguous() + else: + converted_weights[key] = tensor.contiguous() + + if output_format == "safetensors": + # Save as safetensors format + if not output_path.endswith(".safetensors"): + output_path += ".safetensors" + st.save_file(converted_weights, output_path) + + elif output_format == "pytorch": + # Save as pytorch format + if not output_path.endswith((".pt", ".pth")): + output_path += ".pt" + torch.save(converted_weights, output_path) + else: + raise ValueError(f"Unsupported output format: {output_format}") + + print(f"Merged model saved to: {output_path}") + + +def merge_lora_weights(source_weights: Dict[str, torch.Tensor], lora_weights: Dict[str, torch.Tensor], alpha: float = 1.0) -> Dict[str, torch.Tensor]: + """ + Merge source model with LoRA weights + + Args: + source_weights: Source model weights + lora_weights: LoRA weights + alpha: LoRA merge strength + + Returns: + Merged model weights + """ + print("Starting LoRA merge...") + print(f"Merge parameters - alpha: {alpha}") + print(f"Source model weight count: {len(source_weights)}") + print(f"LoRA weight count: {len(lora_weights)}") + + merged_weights = source_weights.copy() + processed_count = 0 + lora_merged_count = 0 + diff_merged_count = 0 + skipped_source_count = 0 + skipped_lora_count = 0 + skipped_source_keys = [] + skipped_lora_keys = [] + + # Group LoRA weights by base key + lora_pairs = {} + diff_weights = {} + + for lora_key, lora_tensor in lora_weights.items(): + if lora_key.endswith(".lora_up.weight"): + base_key = lora_key.replace(".lora_up.weight", "") + if base_key not in lora_pairs: + lora_pairs[base_key] = {} + lora_pairs[base_key]["up"] = lora_tensor + elif lora_key.endswith(".lora_down.weight"): + base_key = lora_key.replace(".lora_down.weight", "") + if base_key not in lora_pairs: + lora_pairs[base_key] = {} + lora_pairs[base_key]["down"] = lora_tensor + elif lora_key.endswith((".diff", ".diff_b", ".diff_m")): + diff_weights[lora_key] = lora_tensor + + print(f"Found {len(lora_pairs)} LoRA pairs and {len(diff_weights)} diff weights") + + # Process with progress bar + all_items = list(lora_pairs.items()) + list(diff_weights.items()) + pbar = tqdm(all_items, desc="Merging LoRA weights", unit="weight") + + for item in pbar: + if isinstance(item[1], dict): # LoRA pair + base_key, lora_pair = item + if "up" in lora_pair and "down" in lora_pair: + # Find corresponding source weight + source_key = _find_source_key(base_key, source_weights) + if source_key: + if source_weights[source_key].shape != (lora_pair["up"].shape[0], lora_pair["down"].shape[1]): + skipped_source_count += 1 + skipped_source_keys.append(source_key) + continue + lora_up = lora_pair["up"] + lora_down = lora_pair["down"] + + # Compute LoRA delta: alpha * (lora_up @ lora_down) + lora_delta = alpha * (lora_up @ lora_down) + + # Apply to source weight + merged_weights[source_key] = source_weights[source_key] + lora_delta + lora_merged_count += 1 + pbar.set_postfix_str(f"LoRA: {source_key.split('.')[-1]}") + else: + skipped_source_count += 1 + skipped_source_keys.append(base_key) + else: + print(f"Warning: Incomplete LoRA pair for: {base_key}") + skipped_lora_count += 1 + skipped_lora_keys.append(base_key) + else: # Diff weight + diff_key, diff_tensor = item + # Find corresponding source weight + source_key = _find_source_key_from_diff(diff_key, source_weights) + if source_key: + if source_weights[source_key].shape != diff_tensor.shape: + skipped_source_count += 1 + skipped_source_keys.append(source_key) + continue + # Apply diff: source + alpha * diff + merged_weights[source_key] = source_weights[source_key] + alpha * diff_tensor + diff_merged_count += 1 + pbar.set_postfix_str(f"Diff: {source_key.split('.')[-1]}") + else: + skipped_lora_count += 1 + skipped_lora_keys.append(diff_key) + + processed_count += 1 + + pbar.close() + + print(f"\nMerge statistics:") + print(f" Processed weights: {processed_count}") + print(f" LoRA merged: {lora_merged_count}") + print(f" Diff merged: {diff_merged_count}") + print(f" Skipped source weights: {skipped_source_count}") + if skipped_source_count > 0: + print(f" Skipped source keys:") + for key in skipped_source_keys: + print(f" {key}") + print(f" Skipped LoRA weights: {skipped_lora_count}") + if skipped_lora_count > 0: + print(f" Skipped LoRA keys:") + for key in skipped_lora_keys: + print(f" {key}") + print(f" Total merged model weights: {len(merged_weights)}") + print("LoRA merge completed") + + return merged_weights + + +def _find_source_key(lora_base_key: str, source_weights: Dict[str, torch.Tensor]) -> Optional[str]: + """ + Find corresponding source weight key for LoRA base key + + Args: + lora_base_key: LoRA base key (e.g., "diffusion_model.input_blocks.0.0.weight") + source_weights: Source model weights + + Returns: + Corresponding source key or None + """ + # Remove diffusion_model prefix if present + if lora_base_key.startswith("diffusion_model."): + source_key = lora_base_key[16:] + ".weight" # Remove "diffusion_model." and add ".weight" + else: + source_key = lora_base_key + ".weight" + + if source_key in source_weights: + return source_key + + # Try without adding .weight (in case it's already included) + if lora_base_key.startswith("diffusion_model."): + source_key_alt = lora_base_key[16:] + else: + source_key_alt = lora_base_key + + if source_key_alt in source_weights: + return source_key_alt + + return None + + +def _find_source_key_from_diff(diff_key: str, source_weights: Dict[str, torch.Tensor]) -> Optional[str]: + """ + Find corresponding source weight key for diff key + + Args: + diff_key: Diff key (e.g., "diffusion_model.input_blocks.0.diff") + source_weights: Source model weights + + Returns: + Corresponding source key or None + """ + # Remove diffusion_model prefix and diff suffix + if diff_key.startswith("diffusion_model."): + base_key = diff_key[16:] # Remove "diffusion_model." + else: + base_key = diff_key + + # Remove diff suffixes + if base_key.endswith(".diff"): + source_key = base_key[:-5] + ".weight" # Remove ".diff" with ".weight" + elif base_key.endswith(".diff_b"): + source_key = base_key[:-7] + ".bias" # Replace ".diff_b" with ".bias" + elif base_key.endswith(".diff_m"): + source_key = base_key[:-7] + ".modulation" # Replace ".diff_m" with ".modulation" + else: + source_key = base_key + + if source_key in source_weights: + return source_key + + return None + + +def main(): + """Main function""" + args = parse_args() + + print("=" * 50) + print("LoRA Merger Started") + print("=" * 50) + print(f"Source model: {args.source_model} ({args.source_type})") + print(f"LoRA weights: {args.lora_model} ({args.lora_type})") + print(f"Output path: {args.output} ({args.output_format})") + print(f"Output data type: {args.output_dtype}") + print(f"Merge parameters: alpha={args.alpha}") + print("=" * 50) + + try: + # Load source model and LoRA weights + source_weights = load_model_weights(args.source_model, args.source_type) + lora_weights = load_model_weights(args.lora_model, args.lora_type) + + # Merge LoRA weights with source model + merged_weights = merge_lora_weights(source_weights, lora_weights, alpha=args.alpha) + + # Save merged model + save_model_weights(merged_weights, args.output, args.output_format, args.output_dtype) + + print("=" * 50) + print("LoRA merge completed!") + print("=" * 50) + + except Exception as e: + print(f"Error: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/tools/preprocess/UserGuider.md b/tools/preprocess/UserGuider.md new file mode 100644 index 0000000000000000000000000000000000000000..8a3e455e287293449121d5b171c22277f4802bc3 --- /dev/null +++ b/tools/preprocess/UserGuider.md @@ -0,0 +1,70 @@ +# Wan-animate Preprocessing User Guider + +## 1. Introductions + + +Wan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline. + +### 1.1 Animation Mode + +In this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar. + + - A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality. + + - **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding. + + - Community contributions to improve on this feature are welcome. + +### 1.2 Replacement Mode + + - Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment. + + - **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output. + + - A simplified version for extracting the character's mask is also provided. + - **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool. + +--- + +## 2. Preprocessing Instructions and Recommendations + +### 2.1 Basic Usage + +- The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure: +``` + /path/to/your/ckpt_path/ + ├── det/ + │ └── yolov10m.onnx + ├── pose2d/ + │ └── vitpose_h_wholebody.onnx + ├── sam2/ + │ └── sam2_hiera_large.pt + └── FLUX.1-Kontext-dev/ +``` +- `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results. + +- When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated. + +- The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area. + +- The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness. + +--- + +### 2.2 Animation Mode + +- We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters. + +- Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose. + +- Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own. + +--- + +### 2.3 Replacement Mode + +- Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters. +- `iterations` and `k` can make the mask larger, covering more area. +- `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer. + +- A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data. diff --git a/tools/preprocess/__init__.py b/tools/preprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b76b6241a230f153b27d217b8ce4413e5ab90d9 --- /dev/null +++ b/tools/preprocess/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from .process_pipepline import ProcessPipeline +from .video_predictor import SAM2VideoPredictor diff --git a/tools/preprocess/human_visualization.py b/tools/preprocess/human_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..519fbab343dce80ba2596130b6d92b8735fc8c3e --- /dev/null +++ b/tools/preprocess/human_visualization.py @@ -0,0 +1,1337 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import random +from typing import Dict, List + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from pose2d_utils import AAPoseMeta + + +def draw_handpose(canvas, keypoints, hand_score_th=0.6): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + eps = 0.01 + + H, W, C = canvas.shape + stickwidth = max(int(min(H, W) / 200), 1) + + edges = [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [0, 5], + [5, 6], + [6, 7], + [7, 8], + [0, 9], + [9, 10], + [10, 11], + [11, 12], + [0, 13], + [13, 14], + [14, 15], + [15, 16], + [0, 17], + [17, 18], + [18, 19], + [19, 20], + ] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + if k1[2] < hand_score_th or k2[2] < hand_score_th: + continue + + x1 = int(k1[0]) + y1 = int(k1[1]) + x2 = int(k2[0]) + y2 = int(k2[1]) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line( + canvas, + (x1, y1), + (x2, y2), + matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, + thickness=stickwidth, + ) + + for keypoint in keypoints: + if keypoint is None: + continue + if keypoint[2] < hand_score_th: + continue + + x, y = keypoint[0], keypoint[1] + x = int(x) + y = int(y) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1) + return canvas + + +def draw_handpose_new(canvas, keypoints, stickwidth_type="v2", hand_score_th=0.6): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + eps = 0.01 + + H, W, C = canvas.shape + if stickwidth_type == "v1": + stickwidth = max(int(min(H, W) / 200), 1) + elif stickwidth_type == "v2": + stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1) + + edges = [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [0, 5], + [5, 6], + [6, 7], + [7, 8], + [0, 9], + [9, 10], + [10, 11], + [11, 12], + [0, 13], + [13, 14], + [14, 15], + [15, 16], + [0, 17], + [17, 18], + [18, 19], + [19, 20], + ] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + if k1[2] < hand_score_th or k2[2] < hand_score_th: + continue + + x1 = int(k1[0]) + y1 = int(k1[1]) + x2 = int(k2[0]) + y2 = int(k2[1]) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line( + canvas, + (x1, y1), + (x2, y2), + matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, + thickness=stickwidth, + ) + + for keypoint in keypoints: + if keypoint is None: + continue + if keypoint[2] < hand_score_th: + continue + + x, y = keypoint[0], keypoint[1] + x = int(x) + y = int(y) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1) + return canvas + + +def draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6): + H, W, C = img.shape + stickwidth = max(int(min(H, W) / 200), 1) + + if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + return img + + Y = np.array([keypoint1[0], keypoint2[0]]) + X = np.array([keypoint1[1], keypoint2[1]]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + return img + + +def split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]: + """Convert the 133 keypoints from pose2d to body and hands keypoints. + + Args: + kp2ds (np.ndarray): [133, 2] + + Returns: + List[np.ndarray]: _description_ + """ + kp2ds_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds_lhand = kp2ds[91:112] + kp2ds_rhand = kp2ds[112:133] + return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy() + + +def draw_aapose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) + kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head) + return pose_img + + +def draw_aapose_by_meta_new(img, meta: AAPoseMeta, threshold=0.5, stickwidth_type="v2", draw_hand=True, draw_head=True): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) + kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_aapose_new(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stickwidth_type=stickwidth_type, draw_hand=draw_hand, draw_head=draw_head) + return pose_img + + +def draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None] * 0], axis=1) + kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=True, draw_head=False) + return pose_img + + +def draw_aaface_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=False, draw_head=True): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) + # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_M(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head) + return pose_img + + +def draw_aanose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=100, draw_hand=False): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) + # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_nose(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand) + return pose_img + + +def gen_face_motion_seq(img, metas: List[AAPoseMeta], threshold=0.5, stick_width_norm=200): + return + + +def draw_M(img, kp2ds, threshold=0.6, data_to_json=None, idx=-1, kp2ds_lhand=None, kp2ds_rhand=None, draw_hand=False, stick_width_norm=200, draw_head=True): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + + new_kep_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", + ] + # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ + # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds = kp2ds.copy() + # import ipdb; ipdb.set_trace() + kp2ds[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 18, 19], 2] = 0 + if not draw_head: + kp2ds[[0, 14, 15, 16, 17], 2] = 0 + kp2ds_body = kp2ds + # kp2ds_body = kp2ds_body[:18] + + # kp2ds_lhand = kp2ds.copy()[91:112] + # kp2ds_rhand = kp2ds.copy()[112:133] + + limbSeq = [ + # [2, 3], + # [2, 6], # shoulders + # [3, 4], + # [4, 5], # left arm + # [6, 7], + # [7, 8], # right arm + # [2, 9], + # [9, 10], + # [10, 11], # right leg + # [2, 12], + # [12, 13], + # [13, 14], # left leg + # [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + # [14, 19], + # [11, 20], # foot + ] + + colors = [ + # [255, 0, 0], + # [255, 85, 0], + # [255, 170, 0], + # [255, 255, 0], + # [170, 255, 0], + # [85, 255, 0], + # [0, 255, 0], + # [0, 255, 85], + # [0, 255, 170], + # [0, 255, 255], + # [0, 170, 255], + # [0, 85, 255], + # [0, 0, 255], + # [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + # foot + # [200, 200, 0], + # [100, 100, 0], + ] + + H, W, C = img.shape + stickwidth = max(int(min(H, W) / stick_width_norm), 1) + + for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): + keypoint1 = kp2ds_body[k1_index - 1] + keypoint2 = kp2ds_body[k2_index - 1] + + if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) + X = np.array([keypoint1[1], keypoint2[1]]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + + for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): + if keypoint[-1] < threshold: + continue + x, y = keypoint[0], keypoint[1] + # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) + + if draw_hand: + img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) + img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) + + kp2ds_body[:, 0] /= W + kp2ds_body[:, 1] /= H + + if data_to_json is not None: + if idx == -1: + data_to_json.append( + { + "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + ) + else: + data_to_json[idx] = { + "image_id": "frame_{:05d}.jpg".format(idx + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + return img + + +def draw_nose( + img, + kp2ds, + threshold=0.6, + data_to_json=None, + idx=-1, + kp2ds_lhand=None, + kp2ds_rhand=None, + draw_hand=False, + stick_width_norm=200, +): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + + new_kep_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", + ] + # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ + # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds = kp2ds.copy() + kp2ds[1:, 2] = 0 + # kp2ds[0, 2] = 1 + kp2ds_body = kp2ds + # kp2ds_body = kp2ds_body[:18] + + # kp2ds_lhand = kp2ds.copy()[91:112] + # kp2ds_rhand = kp2ds.copy()[112:133] + + limbSeq = [ + # [2, 3], + # [2, 6], # shoulders + # [3, 4], + # [4, 5], # left arm + # [6, 7], + # [7, 8], # right arm + # [2, 9], + # [9, 10], + # [10, 11], # right leg + # [2, 12], + # [12, 13], + # [13, 14], # left leg + # [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + # [14, 19], + # [11, 20], # foot + ] + + colors = [ + # [255, 0, 0], + # [255, 85, 0], + # [255, 170, 0], + # [255, 255, 0], + # [170, 255, 0], + # [85, 255, 0], + # [0, 255, 0], + # [0, 255, 85], + # [0, 255, 170], + # [0, 255, 255], + # [0, 170, 255], + # [0, 85, 255], + # [0, 0, 255], + # [85, 0, 255], + [170, 0, 255], + # [255, 0, 255], + # [255, 0, 170], + # [255, 0, 85], + # foot + # [200, 200, 0], + # [100, 100, 0], + ] + + H, W, C = img.shape + stickwidth = max(int(min(H, W) / stick_width_norm), 1) + + # for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): + # keypoint1 = kp2ds_body[k1_index - 1] + # keypoint2 = kp2ds_body[k2_index - 1] + + # if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + # continue + + # Y = np.array([keypoint1[0], keypoint2[0]]) + # X = np.array([keypoint1[1], keypoint2[1]]) + # mX = np.mean(X) + # mY = np.mean(Y) + # length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + # angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + # polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + # cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + + for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): + if keypoint[-1] < threshold: + continue + x, y = keypoint[0], keypoint[1] + # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) + + if draw_hand: + img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) + img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) + + kp2ds_body[:, 0] /= W + kp2ds_body[:, 1] /= H + + if data_to_json is not None: + if idx == -1: + data_to_json.append( + { + "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + ) + else: + data_to_json[idx] = { + "image_id": "frame_{:05d}.jpg".format(idx + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + return img + + +def draw_aapose(img, kp2ds, threshold=0.6, data_to_json=None, idx=-1, kp2ds_lhand=None, kp2ds_rhand=None, draw_hand=False, stick_width_norm=200, draw_head=True): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + + new_kep_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", + ] + # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ + # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds = kp2ds.copy() + if not draw_head: + kp2ds[[0, 14, 15, 16, 17], 2] = 0 + kp2ds_body = kp2ds + + # kp2ds_lhand = kp2ds.copy()[91:112] + # kp2ds_rhand = kp2ds.copy()[112:133] + + limbSeq = [ + [2, 3], + [2, 6], # shoulders + [3, 4], + [4, 5], # left arm + [6, 7], + [7, 8], # right arm + [2, 9], + [9, 10], + [10, 11], # right leg + [2, 12], + [12, 13], + [13, 14], # left leg + [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + [14, 19], + [11, 20], # foot + ] + + colors = [ + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], + [0, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 255, 255], + [0, 170, 255], + [0, 85, 255], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + # foot + [200, 200, 0], + [100, 100, 0], + ] + + H, W, C = img.shape + stickwidth = max(int(min(H, W) / stick_width_norm), 1) + + for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): + keypoint1 = kp2ds_body[k1_index - 1] + keypoint2 = kp2ds_body[k2_index - 1] + + if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) + X = np.array([keypoint1[1], keypoint2[1]]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + + for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): + if keypoint[-1] < threshold: + continue + x, y = keypoint[0], keypoint[1] + # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) + + if draw_hand: + img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) + img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) + + kp2ds_body[:, 0] /= W + kp2ds_body[:, 1] /= H + + if data_to_json is not None: + if idx == -1: + data_to_json.append( + { + "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + ) + else: + data_to_json[idx] = { + "image_id": "frame_{:05d}.jpg".format(idx + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + return img + + +def draw_aapose_new(img, kp2ds, threshold=0.6, data_to_json=None, idx=-1, kp2ds_lhand=None, kp2ds_rhand=None, draw_hand=False, stickwidth_type="v2", draw_head=True): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + + new_kep_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", + ] + # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ + # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds = kp2ds.copy() + if not draw_head: + kp2ds[[0, 14, 15, 16, 17], 2] = 0 + kp2ds_body = kp2ds + + # kp2ds_lhand = kp2ds.copy()[91:112] + # kp2ds_rhand = kp2ds.copy()[112:133] + + limbSeq = [ + [2, 3], + [2, 6], # shoulders + [3, 4], + [4, 5], # left arm + [6, 7], + [7, 8], # right arm + [2, 9], + [9, 10], + [10, 11], # right leg + [2, 12], + [12, 13], + [13, 14], # left leg + [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + [14, 19], + [11, 20], # foot + ] + + colors = [ + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], + [0, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 255, 255], + [0, 170, 255], + [0, 85, 255], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + # foot + [200, 200, 0], + [100, 100, 0], + ] + + H, W, C = img.shape + H, W, C = img.shape + + if stickwidth_type == "v1": + stickwidth = max(int(min(H, W) / 200), 1) + elif stickwidth_type == "v2": + stickwidth = max(int(min(H, W) / 200) - 1, 1) + else: + raise + + for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): + keypoint1 = kp2ds_body[k1_index - 1] + keypoint2 = kp2ds_body[k2_index - 1] + + if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) + X = np.array([keypoint1[1], keypoint2[1]]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + + for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): + if keypoint[-1] < threshold: + continue + x, y = keypoint[0], keypoint[1] + # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) + + if draw_hand: + img = draw_handpose_new(img, kp2ds_lhand, stickwidth_type=stickwidth_type, hand_score_th=threshold) + img = draw_handpose_new(img, kp2ds_rhand, stickwidth_type=stickwidth_type, hand_score_th=threshold) + + kp2ds_body[:, 0] /= W + kp2ds_body[:, 1] /= H + + if data_to_json is not None: + if idx == -1: + data_to_json.append( + { + "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + ) + else: + data_to_json[idx] = { + "image_id": "frame_{:05d}.jpg".format(idx + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + return img + + +def draw_bbox(img, bbox, color=(255, 0, 0)): + img = load_image(img) + bbox = [int(bbox_tmp) for bbox_tmp in bbox] + cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2) + return img + + +def draw_kp2ds(img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None, reverse=False): + img = load_image(img, reverse) + + if skeleton is not None: + if skeleton == "coco17": + skeleton_list = [ + [6, 8], + [8, 10], + [5, 7], + [7, 9], + [11, 13], + [13, 15], + [12, 14], + [14, 16], + [5, 6], + [6, 12], + [12, 11], + [11, 5], + ] + color_list = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + ] + elif skeleton == "cocowholebody": + skeleton_list = [ + [6, 8], + [8, 10], + [5, 7], + [7, 9], + [11, 13], + [13, 15], + [12, 14], + [14, 16], + [5, 6], + [6, 12], + [12, 11], + [11, 5], + [15, 17], + [15, 18], + [15, 19], + [16, 20], + [16, 21], + [16, 22], + [91, 92, 93, 94, 95], + [91, 96, 97, 98, 99], + [91, 100, 101, 102, 103], + [91, 104, 105, 106, 107], + [91, 108, 109, 110, 111], + [112, 113, 114, 115, 116], + [112, 117, 118, 119, 120], + [112, 121, 122, 123, 124], + [112, 125, 126, 127, 128], + [112, 129, 130, 131, 132], + ] + color_list = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + ] + else: + color_list = [color] + for _idx, _skeleton in enumerate(skeleton_list): + for i in range(len(_skeleton) - 1): + cv2.line( + img, + (int(kp2ds[_skeleton[i], 0]), int(kp2ds[_skeleton[i], 1])), + (int(kp2ds[_skeleton[i + 1], 0]), int(kp2ds[_skeleton[i + 1], 1])), + color_list[_idx % len(color_list)], + 3, + ) + + for _idx, kp2d in enumerate(kp2ds): + if kp2d[2] > threshold: + cv2.circle(img, (int(kp2d[0]), int(kp2d[1])), 3, color, -1) + # cv2.putText(img, + # str(_idx), + # (int(kp2d[0, i, 0])*1, + # int(kp2d[0, i, 1])*1), + # cv2.FONT_HERSHEY_SIMPLEX, + # 0.75, + # color, + # 2 + # ) + + return img + + +def draw_mask(img, mask, background=0, return_rgba=False): + img = load_image(img) + h, w, _ = img.shape + if type(background) == int: # noqa + background = np.ones((h, w, 3)).astype(np.uint8) * 255 * background + backgournd = cv2.resize(background, (w, h)) + img_rgba = np.concatenate([img, mask], -1) + return alphaMerge(img_rgba, background, 0, 0, return_rgba=True) + + +def draw_pcd(pcd_list, save_path=None): + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + color_list = ["r", "g", "b", "y", "p"] + + for _idx, _pcd in enumerate(pcd_list): + ax.scatter(_pcd[:, 0], _pcd[:, 1], _pcd[:, 2], c=color_list[_idx], marker="o") + + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + + if save_path is not None: + plt.savefig(save_path) + else: + plt.savefig("tmp.png") + + +def load_image(img, reverse=False): + if type(img) == str: # noqa + img = cv2.imread(img) + if reverse: + img = img.astype(np.float32) + img = img[:, :, ::-1] + img = img.astype(np.uint8) + return img + + +def draw_skeleten(meta): + kps = [] + for i, kp in enumerate(meta["keypoints_body"]): + if kp is None: + # if kp is None: + kps.append([0, 0, 0]) + else: + kps.append([*kp, 1]) + kps = np.array(kps) + + kps[:, 0] *= meta["width"] + kps[:, 1] *= meta["height"] + pose_img = np.zeros([meta["height"], meta["width"], 3], dtype=np.uint8) + + pose_img = draw_aapose( + pose_img, + kps, + draw_hand=True, + kp2ds_lhand=meta["keypoints_left_hand"], + kp2ds_rhand=meta["keypoints_right_hand"], + ) + return pose_img + + +def draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray: + """ + Args: + pncc: [H,W,3] + meta: required keys: keypoints_body: [N, 3] keypoints_left_hand, keypoints_right_hand + Return: + np.ndarray [H, W, 3] + """ + # preprocess keypoints + kps = [] + for i, kp in enumerate(meta["keypoints_body"]): + if kp is None: + # if kp is None: + kps.append([0, 0, 0]) + elif i in [14, 15, 16, 17]: + kps.append([0, 0, 0]) + else: + kps.append([*kp]) + kps = np.stack(kps) + + kps[:, 0] *= pncc.shape[1] + kps[:, 1] *= pncc.shape[0] + + # draw neck + canvas = np.zeros_like(pncc) + if kps[0][2] > 0.6 and kps[1][2] > 0.6: + canvas = draw_ellipse_by_2kp(canvas, kps[0], kps[1], [0, 0, 255]) + + # draw pncc + mask = (pncc > 0).max(axis=2) + canvas[mask] = pncc[mask] + pncc = canvas + + # draw other skeleten + kps[0] = 0 + + meta["keypoints_left_hand"][:, 0] *= meta["width"] + meta["keypoints_left_hand"][:, 1] *= meta["height"] + + meta["keypoints_right_hand"][:, 0] *= meta["width"] + meta["keypoints_right_hand"][:, 1] *= meta["height"] + pose_img = draw_aapose( + pncc, + kps, + draw_hand=True, + kp2ds_lhand=meta["keypoints_left_hand"], + kp2ds_rhand=meta["keypoints_right_hand"], + ) + return pose_img + + +FACE_CUSTOM_STYLE = { + "eyeball": {"indexs": [68, 69], "color": [255, 255, 255], "connect": False}, + "left_eyebrow": {"indexs": [17, 18, 19, 20, 21], "color": [0, 255, 0]}, + "right_eyebrow": {"indexs": [22, 23, 24, 25, 26], "color": [0, 0, 255]}, + "left_eye": {"indexs": [36, 37, 38, 39, 40, 41], "color": [255, 255, 0], "close": True}, + "right_eye": {"indexs": [42, 43, 44, 45, 46, 47], "color": [255, 0, 255], "close": True}, + "mouth_outside": {"indexs": list(range(48, 60)), "color": [100, 255, 50], "close": True}, + "mouth_inside": {"indexs": [60, 61, 62, 63, 64, 65, 66, 67], "color": [255, 100, 50], "close": True}, +} + + +def draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE): + """ + Args: + img: [H, W, 3] + kps: [70, 2] + """ + img = img.copy() + for key, item in style.items(): + pts = np.array(kps[item["indexs"]]).astype(np.int32) + connect = item.get("connect", True) + color = item["color"] + close = item.get("close", False) + if connect: + cv2.polylines(img, [pts], close, color, thickness=thickness) + else: + for kp in pts: + kp = np.array(kp).astype(np.int32) + cv2.circle(img, kp, thickness * 2, color=color, thickness=-1) + return img + + +def draw_traj(metas: List[AAPoseMeta], threshold=0.6): + colors = [ + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], + [0, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 255, 255], + [0, 170, 255], + [0, 85, 255], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + [100, 255, 50], + [255, 100, 50], + # foot + [200, 200, 0], + [100, 100, 0], + ] + limbSeq = [ + [1, 2], + [1, 5], # shoulders + [2, 3], + [3, 4], # left arm + [5, 6], + [6, 7], # right arm + [1, 8], + [8, 9], + [9, 10], # right leg + [1, 11], + [11, 12], + [12, 13], # left leg + # face (nose, eyes, ears) + [13, 18], + [10, 19], # foot + ] + + face_seq = [[1, 0], [0, 14], [14, 16], [0, 15], [15, 17]] + kp_body = np.array([meta.kps_body for meta in metas]) + kp_body_p = np.array([meta.kps_body_p for meta in metas]) + + face_seq = random.sample(face_seq, 2) + + kp_lh = np.array([meta.kps_lhand for meta in metas]) + kp_rh = np.array([meta.kps_rhand for meta in metas]) + + kp_lh_p = np.array([meta.kps_lhand_p for meta in metas]) + kp_rh_p = np.array([meta.kps_rhand_p for meta in metas]) + + # kp_lh = np.concatenate([kp_lh, kp_lh_p], axis=-1) + # kp_rh = np.concatenate([kp_rh, kp_rh_p], axis=-1) + + new_limbSeq = [] + key_point_list = [] + for _idx, ((k1_index, k2_index)) in enumerate(limbSeq): + vis = (kp_body_p[:, k1_index] > threshold) * (kp_body_p[:, k2_index] > threshold) * 1 + if vis.sum() * 1.0 / vis.shape[0] > 0.4: + new_limbSeq.append([k1_index, k2_index]) + + for _idx, ((k1_index, k2_index)) in enumerate(limbSeq): + keypoint1 = kp_body[:, k1_index - 1] + keypoint2 = kp_body[:, k2_index - 1] + interleave = random.randint(4, 7) + randind = random.randint(0, interleave - 1) + # randind = random.rand(range(interleave), sampling_num) + + Y = np.array([keypoint1[:, 0], keypoint2[:, 0]]) + X = np.array([keypoint1[:, 1], keypoint2[:, 1]]) + + vis = (keypoint1[:, -1] > threshold) * (keypoint2[:, -1] > threshold) * 1 + + # for randidx in randind: + t = randind / interleave + x = (1 - t) * Y[0, :] + t * Y[1, :] + y = (1 - t) * X[0, :] + t * X[1, :] + + # np.array([1]) + x = x.astype(int) + y = y.astype(int) + + new_array = np.array([x, y, vis]).T + + key_point_list.append(new_array) + + indx_lh = random.randint(0, kp_lh.shape[1] - 1) + lh = kp_lh[:, indx_lh, :] + lh_p = kp_lh_p[:, indx_lh : indx_lh + 1] + lh = np.concatenate([lh, lh_p], axis=-1) + + indx_rh = random.randint(0, kp_rh.shape[1] - 1) + rh = kp_rh[:, random.randint(0, kp_rh.shape[1] - 1), :] + rh_p = kp_rh_p[:, indx_rh : indx_rh + 1] + rh = np.concatenate([rh, rh_p], axis=-1) + + lh[-1, :] = (lh[-1, :] > threshold) * 1 + rh[-1, :] = (rh[-1, :] > threshold) * 1 + + # print(rh.shape, new_array.shape) + # exit() + key_point_list.append(lh.astype(int)) + key_point_list.append(rh.astype(int)) + + key_points_list = np.stack(key_point_list) + num_points = len(key_points_list) + sample_colors = random.sample(colors, num_points) + + stickwidth = max(int(min(metas[0].width, metas[0].height) / 150), 2) + + image_list_ori = [] + for i in range(key_points_list.shape[-2]): + _image_vis = np.zeros((metas[0].width, metas[0].height, 3)) + points = key_points_list[:, i, :] + for idx, point in enumerate(points): + x, y, vis = point + if vis == 1: + cv2.circle(_image_vis, (x, y), stickwidth, sample_colors[idx], thickness=-1) + + image_list_ori.append(_image_vis) + + return image_list_ori + + return [np.zeros([meta.width, meta.height, 3], dtype=np.uint8) for meta in metas] + + +if __name__ == "__main__": + meta = { + "image_id": "00472.jpg", + "height": 540, + "width": 414, + "category_id": 1, + "keypoints_body": [ + [0.5084776947463768, 0.11350188078703703], + [0.504467655495169, 0.20419560185185184], + [0.3982016153381642, 0.198046875], + [0.3841664779589372, 0.34869068287037036], + [0.3901815368357488, 0.4670536747685185], + [0.610733695652174, 0.2103443287037037], + [0.6167487545289855, 0.3517650462962963], + [0.6448190292874396, 0.4762767650462963], + [0.4523371452294686, 0.47320240162037036], + [0.4503321256038647, 0.6776475694444445], + [0.47639738073671495, 0.8544234664351852], + [0.5766483620169082, 0.47320240162037036], + [0.5666232638888888, 0.6761103877314815], + [0.534542949879227, 0.863646556712963], + [0.4864224788647343, 0.09505570023148148], + [0.5285278910024155, 0.09351851851851851], + [0.46236224335748793, 0.10581597222222222], + [0.5586031853864735, 0.10274160879629629], + [0.4994551064311594, 0.9405056423611111], + [0.4152442821557971, 0.9312825520833333], + ], + "keypoints_left_hand": [ + [267.78515625, 263.830078125, 1.2840936183929443], + [265.294921875, 269.640625, 1.2546794414520264], + [263.634765625, 277.111328125, 1.2863062620162964], + [262.8046875, 285.412109375, 1.267038345336914], + [261.14453125, 292.8828125, 1.280144453048706], + [273.595703125, 281.26171875, 1.2592815160751343], + [271.10546875, 291.22265625, 1.3256099224090576], + [265.294921875, 294.54296875, 1.2368024587631226], + [261.14453125, 294.54296875, 0.9771889448165894], + [274.42578125, 282.091796875, 1.250044584274292], + [269.4453125, 291.22265625, 1.2571144104003906], + [264.46484375, 292.8828125, 1.177802324295044], + [260.314453125, 292.052734375, 0.9283463358879089], + [273.595703125, 282.091796875, 1.1834490299224854], + [269.4453125, 290.392578125, 1.188171625137329], + [265.294921875, 290.392578125, 1.192609429359436], + [261.974609375, 289.5625, 0.9366656541824341], + [271.935546875, 281.26171875, 1.0946396589279175], + [268.615234375, 287.072265625, 0.9906131029129028], + [265.294921875, 287.90234375, 1.0219476222991943], + [262.8046875, 287.072265625, 0.9240120053291321], + ], + "keypoints_right_hand": [ + [161.53515625, 258.849609375, 1.2069408893585205], + [168.17578125, 263.0, 1.1846840381622314], + [173.986328125, 269.640625, 1.1435924768447876], + [173.986328125, 277.94140625, 1.1802611351013184], + [173.986328125, 286.2421875, 1.2599592208862305], + [165.685546875, 275.451171875, 1.0633569955825806], + [167.345703125, 286.2421875, 1.1693341732025146], + [169.8359375, 291.22265625, 1.2698509693145752], + [170.666015625, 294.54296875, 1.0619274377822876], + [160.705078125, 276.28125, 1.0995020866394043], + [163.1953125, 287.90234375, 1.2735884189605713], + [166.515625, 291.22265625, 1.339503526687622], + [169.005859375, 294.54296875, 1.0835273265838623], + [157.384765625, 277.111328125, 1.0866981744766235], + [161.53515625, 287.072265625, 1.2468621730804443], + [164.025390625, 289.5625, 1.2817761898040771], + [166.515625, 292.052734375, 1.099466323852539], + [155.724609375, 277.111328125, 1.1065717935562134], + [159.044921875, 285.412109375, 1.1924479007720947], + [160.705078125, 287.072265625, 1.1304771900177002], + [162.365234375, 287.90234375, 1.0040509700775146], + ], + } + demo_meta = AAPoseMeta(meta) + res = draw_traj([demo_meta] * 5) + cv2.imwrite("traj.png", res[0][..., ::-1]) diff --git a/tools/preprocess/pose2d.py b/tools/preprocess/pose2d.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd26e4222d47d8276ee0afc0e7acbf3832f26e8 --- /dev/null +++ b/tools/preprocess/pose2d.py @@ -0,0 +1,414 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +from typing import List, Union + +import cv2 +import numpy as np +import onnxruntime +import torch +from pose2d_utils import bbox_from_detector, box_convert_simple, crop, keypoints_from_heatmaps, load_pose_metas_from_kp2ds_seq, read_img + + +class SimpleOnnxInference(object): + def __init__(self, checkpoint, device="cuda", reverse_input=False, **kwargs): + if isinstance(device, str): + device = torch.device(device) + if device.type == "cuda": + device = "{}:{}".format(device.type, device.index) + providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + self.device = device + if not os.path.exists(checkpoint): + raise RuntimeError("{} is not existed!".format(checkpoint)) + + if os.path.isdir(checkpoint): + checkpoint = os.path.join(checkpoint, "end2end.onnx") + + self.session = onnxruntime.InferenceSession(checkpoint, providers=providers) + self.input_name = self.session.get_inputs()[0].name + self.output_name = self.session.get_outputs()[0].name + self.input_resolution = self.session.get_inputs()[0].shape[2:] if not reverse_input else self.session.get_inputs()[0].shape[2:][::-1] + self.input_resolution = np.array(self.input_resolution) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def get_output_names(self): + output_names = [] + for node in self.session.get_outputs(): + output_names.append(node.name) + return output_names + + def set_device(self, device): + if isinstance(device, str): + device = torch.device(device) + if device.type == "cuda": + device = "{}:{}".format(device.type, device.index) + providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + self.session.set_providers(providers) + self.device = device + + +class Yolo(SimpleOnnxInference): + def __init__( + self, + checkpoint, + device="cuda", + threshold_conf=0.05, + threshold_multi_persons=0.1, + input_resolution=(640, 640), + threshold_iou=0.5, + threshold_bbox_shape_ratio=0.4, + cat_id=[1], + select_type="max", + strict=True, + sorted_func=None, + **kwargs, + ): + super(Yolo, self).__init__(checkpoint, device=device, **kwargs) + + model_inputs = self.session.get_inputs() + input_shape = model_inputs[0].shape + + self.input_width = 640 + self.input_height = 640 + + self.threshold_multi_persons = threshold_multi_persons + self.threshold_conf = threshold_conf + self.threshold_iou = threshold_iou + self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio + self.input_resolution = input_resolution + self.cat_id = cat_id + self.select_type = select_type + self.strict = strict + self.sorted_func = sorted_func + + def preprocess(self, input_image): + """ + Preprocesses the input image before performing inference. + + Returns: + image_data: Preprocessed image data ready for inference. + """ + img = read_img(input_image) + # Get the height and width of the input image + img_height, img_width = img.shape[:2] + # Resize the image to match the input shape + img = cv2.resize(img, (self.input_resolution[1], self.input_resolution[0])) + # Normalize the image data by dividing it by 255.0 + image_data = np.array(img) / 255.0 + # Transpose the image to have the channel dimension as the first dimension + image_data = np.transpose(image_data, (2, 0, 1)) # Channel first + # Expand the dimensions of the image data to match the expected input shape + # image_data = np.expand_dims(image_data, axis=0).astype(np.float32) + image_data = image_data.astype(np.float32) + # Return the preprocessed image data + return image_data, np.array([img_height, img_width]) + + def postprocess(self, output, shape_raw, cat_id=[1]): + """ + Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs. + + Args: + input_image (numpy.ndarray): The input image. + output (numpy.ndarray): The output of the model. + + Returns: + numpy.ndarray: The input image with detections drawn on it. + """ + # Transpose and squeeze the output to match the expected shape + + outputs = np.squeeze(output) + if len(outputs.shape) == 1: + outputs = outputs[None] + if output.shape[-1] != 6 and output.shape[1] == 84: + outputs = np.transpose(outputs) + + # Get the number of rows in the outputs array + rows = outputs.shape[0] + + # Calculate the scaling factors for the bounding box coordinates + x_factor = shape_raw[1] / self.input_width + y_factor = shape_raw[0] / self.input_height + + # Lists to store the bounding boxes, scores, and class IDs of the detections + boxes = [] + scores = [] + class_ids = [] + + if outputs.shape[-1] == 6: + max_scores = outputs[:, 4] + classid = outputs[:, -1] + + threshold_conf_masks = max_scores >= self.threshold_conf + classid_masks = classid[threshold_conf_masks] != 3.14159 + + max_scores = max_scores[threshold_conf_masks][classid_masks] + classid = classid[threshold_conf_masks][classid_masks] + + boxes = outputs[:, :4][threshold_conf_masks][classid_masks] + boxes[:, [0, 2]] *= x_factor + boxes[:, [1, 3]] *= y_factor + boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + boxes = boxes.astype(np.int32) + + else: + classes_scores = outputs[:, 4:] + max_scores = np.amax(classes_scores, -1) + threshold_conf_masks = max_scores >= self.threshold_conf + + classid = np.argmax(classes_scores[threshold_conf_masks], -1) + + classid_masks = classid != 3.14159 + + classes_scores = classes_scores[threshold_conf_masks][classid_masks] + max_scores = max_scores[threshold_conf_masks][classid_masks] + classid = classid[classid_masks] + + xywh = outputs[:, :4][threshold_conf_masks][classid_masks] + + x = xywh[:, 0:1] + y = xywh[:, 1:2] + w = xywh[:, 2:3] + h = xywh[:, 3:4] + + left = (x - w / 2) * x_factor + top = (y - h / 2) * y_factor + width = w * x_factor + height = h * y_factor + boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32) + + boxes = boxes.tolist() + scores = max_scores.tolist() + class_ids = classid.tolist() + + # Apply non-maximum suppression to filter out overlapping bounding boxes + indices = cv2.dnn.NMSBoxes(boxes, scores, self.threshold_conf, self.threshold_iou) + # Iterate over the selected indices after non-maximum suppression + + results = [] + for i in indices: + # Get the box, score, and class ID corresponding to the index + box = box_convert_simple(boxes[i], "xywh2xyxy") + score = scores[i] + class_id = class_ids[i] + results.append(box + [score] + [class_id]) + # # Draw the detection on the input image + + # Return the modified input image + return np.array(results) + + def process_results(self, results, shape_raw, cat_id=[1], single_person=True): + if isinstance(results, tuple): + det_results = results[0] + else: + det_results = results + + person_results = [] + person_count = 0 + if len(results): + max_idx = -1 + max_bbox_size = shape_raw[0] * shape_raw[1] * -10 + max_bbox_shape = -1 + + bboxes = [] + idx_list = [] + for i in range(results.shape[0]): + bbox = results[i] + if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf): + idx_list.append(i) + bbox_shape = max((bbox[2] - bbox[0]), (bbox[3] - bbox[1])) + if bbox_shape > max_bbox_shape: + max_bbox_shape = bbox_shape + + results = results[idx_list] + + for i in range(results.shape[0]): + bbox = results[i] + bboxes.append(bbox) + if self.select_type == "max": + bbox_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + elif self.select_type == "center": + bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2)) * -1 + bbox_shape = max((bbox[2] - bbox[0]), (bbox[3] - bbox[1])) + if bbox_size > max_bbox_size: + if (self.strict or max_idx != -1) and bbox_shape < max_bbox_shape * self.threshold_bbox_shape_ratio: + continue + max_bbox_size = bbox_size + max_bbox_shape = bbox_shape + max_idx = i + + if self.sorted_func is not None and len(bboxes) > 0: + max_idx = self.sorted_func(bboxes, shape_raw) + bbox = bboxes[max_idx] + if self.select_type == "max": + max_bbox_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + elif self.select_type == "center": + max_bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2)) * -1 + + if max_idx != -1: + person_count = 1 + + if max_idx != -1: + person = {} + person["bbox"] = results[max_idx, :5] + person["track_id"] = int(0) + person_results.append(person) + + for i in range(results.shape[0]): + bbox = results[i] + if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf): + if self.select_type == "max": + bbox_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + elif self.select_type == "center": + bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2)) * -1 + if i != max_idx and bbox_size > max_bbox_size * self.threshold_multi_persons and bbox_size < max_bbox_size: + person_count += 1 + if not single_person: + person = {} + person["bbox"] = results[i, :5] + person["track_id"] = int(person_count - 1) + person_results.append(person) + return person_results + else: + return None + + def postprocess_threading(self, outputs, shape_raw, person_results, i, single_person=True, **kwargs): + result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id) + result = self.process_results(result, shape_raw[i], cat_id=self.cat_id, single_person=single_person) + if result is not None and len(result) != 0: + person_results[i] = result + + def forward(self, img, shape_raw, **kwargs): + """ + Performs inference using an ONNX model and returns the output image with drawn detections. + + Returns: + output_img: The output image with drawn detections. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + shape_raw = shape_raw.cpu().numpy() + + outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0] + person_results = [[{"bbox": np.array([0.0, 0.0, 1.0 * shape_raw[i][1], 1.0 * shape_raw[i][0], -1]), "track_id": -1}] for i in range(len(outputs))] + + for i in range(len(outputs)): + self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs) + return person_results + + +class ViTPose(SimpleOnnxInference): + def __init__(self, checkpoint, device="cuda", **kwargs): + super(ViTPose, self).__init__(checkpoint, device=device) + + def forward(self, img, center, scale, **kwargs): + heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0] + points, prob = keypoints_from_heatmaps(heatmaps=heatmaps, center=center, scale=scale * 200, unbiased=True, use_udp=False) + return np.concatenate([points, prob], axis=2) + + @staticmethod + def preprocess(img, bbox=None, input_resolution=(256, 192), rescale=1.25, mask=None, **kwargs): + if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10: + bbox = np.array([0, 0, img.shape[1], img.shape[0]]) + + bbox_xywh = bbox + if mask is not None: + img = np.where(mask > 128, img, mask) + + if isinstance(input_resolution, int): + center, scale = bbox_from_detector(bbox_xywh, (input_resolution, input_resolution), rescale=rescale) + img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution, input_resolution)) + else: + center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale) + img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution[0], input_resolution[1])) + + IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406]) + IMG_NORM_STD = np.array([0.229, 0.224, 0.225]) + img_norm = (img / 255.0 - IMG_NORM_MEAN) / IMG_NORM_STD + img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) + return img_norm, np.array(center), np.array(scale) + + +class Pose2d: + def __init__(self, checkpoint, detector_checkpoint=None, device="cuda", **kwargs): + if detector_checkpoint is not None: + self.detector = Yolo(detector_checkpoint, device) + else: + self.detector = None + + self.model = ViTPose(checkpoint, device) + self.device = device + + def load_images(self, inputs): + """ + Load images from various input types. + + Args: + inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path, + single image array, or list of image arrays + + Returns: + List[np.ndarray]: List of RGB image arrays + + Raises: + ValueError: If file format is unsupported or image cannot be read + """ + if isinstance(inputs, str): + if inputs.lower().endswith((".mp4", ".avi", ".mov", ".mkv")): + cap = cv2.VideoCapture(inputs) + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + images = frames + elif inputs.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")): + img = cv2.cvtColor(cv2.imread(inputs), cv2.COLOR_BGR2RGB) + if img is None: + raise ValueError(f"Cannot read image: {inputs}") + images = [img] + else: + raise ValueError(f"Unsupported file format: {inputs}") + + elif isinstance(inputs, np.ndarray): + images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs] + elif isinstance(inputs, list): + images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs] + return images + + def __call__(self, inputs: Union[str, np.ndarray, List[np.ndarray]], return_image: bool = False, **kwargs): + """ + Process input and estimate 2D keypoints. + + Args: + inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path, + single image array, or list of image arrays + **kwargs: Additional arguments for processing + + Returns: + np.ndarray: Array of detected 2D keypoints for all input images + """ + images = self.load_images(inputs) + H, W = images[0].shape[:2] + if self.detector is not None: + bboxes = [] + for _image in images: + img, shape = self.detector.preprocess(_image) + bboxes.append(self.detector(img[None], shape[None])[0][0]["bbox"]) + else: + bboxes = [None] * len(images) + + kp2ds = [] + for _image, _bbox in zip(images, bboxes): + img, center, scale = self.model.preprocess(_image, _bbox) + kp2ds.append(self.model(img[None], center[None], scale[None])) + kp2ds = np.concatenate(kp2ds, 0) + metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H) + return metas diff --git a/tools/preprocess/pose2d_utils.py b/tools/preprocess/pose2d_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf93e1189d9ead32d20a8a98c73b1885babcea4 --- /dev/null +++ b/tools/preprocess/pose2d_utils.py @@ -0,0 +1,1117 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import warnings +from typing import List + +import cv2 +import numpy as np +from PIL import Image + + +def box_convert_simple(box, convert_type="xyxy2xywh"): + if convert_type == "xyxy2xywh": + return [box[0], box[1], box[2] - box[0], box[3] - box[1]] + elif convert_type == "xywh2xyxy": + return [box[0], box[1], box[2] + box[0], box[3] + box[1]] + elif convert_type == "xyxy2ctwh": + return [(box[0] + box[2]) / 2, (box[1] + box[3]) / 2, box[2] - box[0], box[3] - box[1]] + elif convert_type == "ctwh2xyxy": + return [box[0] - box[2] // 2, box[1] - box[3] // 2, box[0] + (box[2] - box[2] // 2), box[1] + (box[3] - box[3] // 2)] + + +def read_img(image, convert="RGB", check_exist=False): + if isinstance(image, str): + if check_exist and not osp.exists(image): + return None + try: + img = Image.open(image) + if convert: + img = img.convert(convert) + except: # noqa + raise IOError("File error: ", image) + return np.asarray(img) + else: + if isinstance(image, np.ndarray): + if convert: + return image[..., ::-1] + else: + if convert: + img = img.convert(convert) + return np.asarray(img) + + +class AAPoseMeta: + def __init__(self, meta=None, kp2ds=None): + self.image_id = "" + self.height = 0 + self.width = 0 + + self.kps_body: np.ndarray = None + self.kps_lhand: np.ndarray = None + self.kps_rhand: np.ndarray = None + self.kps_face: np.ndarray = None + self.kps_body_p: np.ndarray = None + self.kps_lhand_p: np.ndarray = None + self.kps_rhand_p: np.ndarray = None + self.kps_face_p: np.ndarray = None + + if meta is not None: + self.load_from_meta(meta) + elif kp2ds is not None: + self.load_from_kp2ds(kp2ds) + + def is_valid(self, kp, p, threshold): + x, y = kp + if x < 0 or y < 0 or x > self.width or y > self.height or p < threshold: + return False + else: + return True + + def get_bbox(self, kp, kp_p, threshold=0.5): + kps = kp[kp_p > threshold] + if kps.size == 0: + return 0, 0, 0, 0 + x0, y0 = kps.min(axis=0) + x1, y1 = kps.max(axis=0) + return x0, y0, x1, y1 + + def crop(self, x0, y0, x1, y1): + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + for kps in all_kps: + if kps is not None: + kps[:, 0] -= x0 + kps[:, 1] -= y0 + self.width = x1 - x0 + self.height = y1 - y0 + return self + + def resize(self, width, height): + scale_x = width / self.width + scale_y = height / self.height + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + for kps in all_kps: + if kps is not None: + kps[:, 0] *= scale_x + kps[:, 1] *= scale_y + self.width = width + self.height = height + return self + + def get_kps_body_with_p(self, normalize=False): + kps_body = self.kps_body.copy() + if normalize: + kps_body = kps_body / np.array([self.width, self.height]) + + return np.concatenate([kps_body, self.kps_body_p[:, None]]) + + @staticmethod + def from_kps_face(kps_face: np.ndarray, height: int, width: int): + pose_meta = AAPoseMeta() + pose_meta.kps_face = kps_face[:, :2] + if kps_face.shape[1] == 3: + pose_meta.kps_face_p = kps_face[:, 2] + else: + pose_meta.kps_face_p = kps_face[:, 0] * 0 + 1 + pose_meta.height = height + pose_meta.width = width + return pose_meta + + @staticmethod + def from_kps_body(kps_body: np.ndarray, height: int, width: int): + pose_meta = AAPoseMeta() + pose_meta.kps_body = kps_body[:, :2] + pose_meta.kps_body_p = kps_body[:, 2] + pose_meta.height = height + pose_meta.width = width + return pose_meta + + @staticmethod + def from_humanapi_meta(meta): + pose_meta = AAPoseMeta() + width, height = meta["width"], meta["height"] + pose_meta.width = width + pose_meta.height = height + pose_meta.kps_body = meta["keypoints_body"][:, :2] * (width, height) + pose_meta.kps_body_p = meta["keypoints_body"][:, 2] + pose_meta.kps_lhand = meta["keypoints_left_hand"][:, :2] * (width, height) + pose_meta.kps_lhand_p = meta["keypoints_left_hand"][:, 2] + pose_meta.kps_rhand = meta["keypoints_right_hand"][:, :2] * (width, height) + pose_meta.kps_rhand_p = meta["keypoints_right_hand"][:, 2] + if "keypoints_face" in meta: + pose_meta.kps_face = meta["keypoints_face"][:, :2] * (width, height) + pose_meta.kps_face_p = meta["keypoints_face"][:, 2] + return pose_meta + + def load_from_meta(self, meta, norm_body=True, norm_hand=False): + self.image_id = meta.get("image_id", "00000.png") + self.height = meta["height"] + self.width = meta["width"] + kps_body_p = [] + kps_body = [] + for kp in meta["keypoints_body"]: + if kp is None: + kps_body.append([0, 0]) + kps_body_p.append(0) + else: + kps_body.append(kp) + kps_body_p.append(1) + + self.kps_body = np.array(kps_body) + self.kps_body[:, 0] *= self.width + self.kps_body[:, 1] *= self.height + self.kps_body_p = np.array(kps_body_p) + + self.kps_lhand = np.array(meta["keypoints_left_hand"])[:, :2] + self.kps_lhand_p = np.array(meta["keypoints_left_hand"])[:, 2] + self.kps_rhand = np.array(meta["keypoints_right_hand"])[:, :2] + self.kps_rhand_p = np.array(meta["keypoints_right_hand"])[:, 2] + + @staticmethod + def load_from_kp2ds(kp2ds: List[np.ndarray], width: int, height: int): + """input 133x3 numpy keypoints and output AAPoseMeta + + Args: + kp2ds (List[np.ndarray]): _description_ + width (int): _description_ + height (int): _description_ + + Returns: + _type_: _description_ + """ + pose_meta = AAPoseMeta() + pose_meta.width = width + pose_meta.height = height + kps_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kps_lhand = kp2ds[91:112] + kps_rhand = kp2ds[112:133] + kps_face = np.concatenate([kp2ds[23 : 23 + 68], kp2ds[1:3]], axis=0) + pose_meta.kps_body = kps_body[:, :2] + pose_meta.kps_body_p = kps_body[:, 2] + pose_meta.kps_lhand = kps_lhand[:, :2] + pose_meta.kps_lhand_p = kps_lhand[:, 2] + pose_meta.kps_rhand = kps_rhand[:, :2] + pose_meta.kps_rhand_p = kps_rhand[:, 2] + pose_meta.kps_face = kps_face[:, :2] + pose_meta.kps_face_p = kps_face[:, 2] + return pose_meta + + @staticmethod + def from_dwpose(dwpose_det_res, height, width): + pose_meta = AAPoseMeta() + pose_meta.kps_body = dwpose_det_res["bodies"]["candidate"] + pose_meta.kps_body_p = dwpose_det_res["bodies"]["score"] + pose_meta.kps_body[:, 0] *= width + pose_meta.kps_body[:, 1] *= height + + pose_meta.kps_lhand, pose_meta.kps_rhand = dwpose_det_res["hands"] + pose_meta.kps_lhand[:, 0] *= width + pose_meta.kps_lhand[:, 1] *= height + pose_meta.kps_rhand[:, 0] *= width + pose_meta.kps_rhand[:, 1] *= height + pose_meta.kps_lhand_p, pose_meta.kps_rhand_p = dwpose_det_res["hands_score"] + + pose_meta.kps_face = dwpose_det_res["faces"][0] + pose_meta.kps_face[:, 0] *= width + pose_meta.kps_face[:, 1] *= height + pose_meta.kps_face_p = dwpose_det_res["faces_score"][0] + return pose_meta + + def save_json(self): + pass + + def draw_aapose(self, img, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True): + from .human_visualization import draw_aapose_by_meta + + return draw_aapose_by_meta(img, self, threshold, stick_width_norm, draw_hand, draw_head) + + def translate(self, x0, y0): + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + for kps in all_kps: + if kps is not None: + kps[:, 0] -= x0 + kps[:, 1] -= y0 + + def scale(self, sx, sy): + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + for kps in all_kps: + if kps is not None: + kps[:, 0] *= sx + kps[:, 1] *= sy + + def padding_resize2(self, height=512, width=512): + """kps will be changed inplace""" + + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + + ori_height, ori_width = self.height, self.width + + if (ori_height / ori_width) > (height / width): + new_width = int(height / ori_height * ori_width) + padding = int((width - new_width) / 2) + padding_width = padding + padding_height = 0 + scale = height / ori_height + + for kps in all_kps: + if kps is not None: + kps[:, 0] = kps[:, 0] * scale + padding + kps[:, 1] = kps[:, 1] * scale + + else: + new_height = int(width / ori_width * ori_height) + padding = int((height - new_height) / 2) + padding_width = 0 + padding_height = padding + scale = width / ori_width + for kps in all_kps: + if kps is not None: + kps[:, 1] = kps[:, 1] * scale + padding + kps[:, 0] = kps[:, 0] * scale + + self.width = width + self.height = height + return self + + +def transform_preds(coords, center, scale, output_size, use_udp=False): + """Get final keypoint predictions from heatmaps and apply scaling and + translation to map them back to the image. + + Note: + num_keypoints: K + + Args: + coords (np.ndarray[K, ndims]): + + * If ndims=2, corrds are predicted keypoint location. + * If ndims=4, corrds are composed of (x, y, scores, tags) + * If ndims=5, corrds are composed of (x, y, scores, tags, + flipped_tags) + + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + use_udp (bool): Use unbiased data processing + + Returns: + np.ndarray: Predicted coordinates in the images. + """ + assert coords.shape[1] in (2, 4, 5) + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + + # Recover the scale which is normalized by a factor of 200. + # scale = scale * 200.0 + + if use_udp: + scale_x = scale[0] / (output_size[0] - 1.0) + scale_y = scale[1] / (output_size[1] - 1.0) + else: + scale_x = scale[0] / output_size[0] + scale_y = scale[1] / output_size[1] + + target_coords = np.ones_like(coords) + target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5 + target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5 + + return target_coords + + +def _calc_distances(preds, targets, mask, normalize): + """Calculate the normalized distances between preds and target. + + Note: + batch_size: N + num_keypoints: K + dimension of keypoints: D (normally, D=2 or D=3) + + Args: + preds (np.ndarray[N, K, D]): Predicted keypoint location. + targets (np.ndarray[N, K, D]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize (np.ndarray[N, D]): Typical value is heatmap_size + + Returns: + np.ndarray[K, N]: The normalized distances. \ + If target keypoints are missing, the distance is -1. + """ + N, K, _ = preds.shape + # set mask=0 when normalize==0 + _mask = mask.copy() + _mask[np.where((normalize == 0).sum(1))[0], :] = False + distances = np.full((N, K), -1, dtype=np.float32) + # handle invalid values + normalize[np.where(normalize <= 0)] = 1e6 + distances[_mask] = np.linalg.norm(((preds - targets) / normalize[:, None, :])[_mask], axis=-1) + return distances.T + + +def _distance_acc(distances, thr=0.5): + """Return the percentage below the distance threshold, while ignoring + distances values with -1. + + Note: + batch_size: N + Args: + distances (np.ndarray[N, ]): The normalized distances. + thr (float): Threshold of the distances. + + Returns: + float: Percentage of distances below the threshold. \ + If all target keypoints are missing, return -1. + """ + distance_valid = distances != -1 + num_distance_valid = distance_valid.sum() + if num_distance_valid > 0: + return (distances[distance_valid] < thr).sum() / num_distance_valid + return -1 + + +def _get_max_preds(heatmaps): + """Get keypoint predictions from score maps. + + Note: + batch_size: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + + Returns: + tuple: A tuple containing aggregated results. + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray" + assert heatmaps.ndim == 4, "batch_images should be 4-ndim" + + N, K, _, W = heatmaps.shape + heatmaps_reshaped = heatmaps.reshape((N, K, -1)) + idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1)) + maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + preds[:, :, 0] = preds[:, :, 0] % W + preds[:, :, 1] = preds[:, :, 1] // W + + preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1) + return preds, maxvals + + +def _get_max_preds_3d(heatmaps): + """Get keypoint predictions from 3D score maps. + + Note: + batch size: N + num keypoints: K + heatmap depth size: D + heatmap height: H + heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps. + + Returns: + tuple: A tuple containing aggregated results. + + - preds (np.ndarray[N, K, 3]): Predicted keypoint location. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray" + assert heatmaps.ndim == 5, "heatmaps should be 5-ndim" + + N, K, D, H, W = heatmaps.shape + heatmaps_reshaped = heatmaps.reshape((N, K, -1)) + idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1)) + maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1)) + + preds = np.zeros((N, K, 3), dtype=np.float32) + _idx = idx[..., 0] + preds[..., 2] = _idx // (H * W) + preds[..., 1] = (_idx // W) % H + preds[..., 0] = _idx % W + + preds = np.where(maxvals > 0.0, preds, -1) + return preds, maxvals + + +def pose_pck_accuracy(output, target, mask, thr=0.05, normalize=None): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints from heatmaps. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + output (np.ndarray[N, K, H, W]): Model output heatmaps. + target (np.ndarray[N, K, H, W]): Groundtruth heatmaps. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. Default 0.05. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - np.ndarray[K]: Accuracy of each keypoint. + - float: Averaged accuracy across all keypoints. + - int: Number of valid keypoints. + """ + N, K, H, W = output.shape + if K == 0: + return None, 0, 0 + if normalize is None: + normalize = np.tile(np.array([[H, W]]), (N, 1)) + + pred, _ = _get_max_preds(output) + gt, _ = _get_max_preds(target) + return keypoint_pck_accuracy(pred, gt, mask, thr, normalize) + + +def keypoint_pck_accuracy(pred, gt, mask, thr, normalize): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - acc (np.ndarray[K]): Accuracy of each keypoint. + - avg_acc (float): Averaged accuracy across all keypoints. + - cnt (int): Number of valid keypoints. + """ + distances = _calc_distances(pred, gt, mask, normalize) + + acc = np.array([_distance_acc(d, thr) for d in distances]) + valid_acc = acc[acc >= 0] + cnt = len(valid_acc) + avg_acc = valid_acc.mean() if cnt > 0 else 0 + return acc, avg_acc, cnt + + +def keypoint_auc(pred, gt, mask, normalize, num_step=20): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize (float): Normalization factor. + + Returns: + float: Area under curve. + """ + nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1)) + x = [1.0 * i / num_step for i in range(num_step)] + y = [] + for thr in x: + _, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor) + y.append(avg_acc) + + auc = 0 + for i in range(num_step): + auc += 1.0 / num_step * y[i] + return auc + + +def keypoint_nme(pred, gt, mask, normalize_factor): + """Calculate the normalized mean error (NME). + + Note: + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize_factor (np.ndarray[N, 2]): Normalization factor. + + Returns: + float: normalized mean error + """ + distances = _calc_distances(pred, gt, mask, normalize_factor) + distance_valid = distances[distances != -1] + return distance_valid.sum() / max(1, len(distance_valid)) + + +def keypoint_epe(pred, gt, mask): + """Calculate the end-point error. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + + Returns: + float: Average end-point error. + """ + + distances = _calc_distances(pred, gt, mask, np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32)) + distance_valid = distances[distances != -1] + return distance_valid.sum() / max(1, len(distance_valid)) + + +def _taylor(heatmap, coord): + """Distribution aware coordinate decoding method. + + Note: + - heatmap height: H + - heatmap width: W + + Args: + heatmap (np.ndarray[H, W]): Heatmap of a particular joint type. + coord (np.ndarray[2,]): Coordinates of the predicted keypoints. + + Returns: + np.ndarray[2,]: Updated coordinates. + """ + H, W = heatmap.shape[:2] + px, py = int(coord[0]), int(coord[1]) + if 1 < px < W - 2 and 1 < py < H - 2: + dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1]) + dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px]) + dxx = 0.25 * (heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2]) + dxy = 0.25 * (heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] - heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1]) + dyy = 0.25 * (heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] + heatmap[py - 2 * 1][px]) + derivative = np.array([[dx], [dy]]) + hessian = np.array([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = np.linalg.inv(hessian) + offset = -hessianinv @ derivative + offset = np.squeeze(np.array(offset.T), axis=0) + coord += offset + return coord + + +def post_dark_udp(coords, batch_heatmaps, kernel=3): + """DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The + Devil is in the Details: Delving into Unbiased Data Processing for Human + Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + + Note: + - batch size: B + - num keypoints: K + - num persons: N + - height of heatmaps: H + - width of heatmaps: W + + B=1 for bottom_up paradigm where all persons share the same heatmap. + B=N for top_down paradigm where each person has its own heatmaps. + + Args: + coords (np.ndarray[N, K, 2]): Initial coordinates of human pose. + batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps + kernel (int): Gaussian kernel size (K) for modulation. + + Returns: + np.ndarray([N, K, 2]): Refined coordinates. + """ + if not isinstance(batch_heatmaps, np.ndarray): + batch_heatmaps = batch_heatmaps.cpu().numpy() + B, K, H, W = batch_heatmaps.shape + N = coords.shape[0] + assert B == 1 or B == N + for heatmaps in batch_heatmaps: + for heatmap in heatmaps: + cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap) + np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps) + np.log(batch_heatmaps, batch_heatmaps) + + batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge").flatten() + + index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K) + index = index.astype(int).reshape(-1, 1) + i_ = batch_heatmaps_pad[index] + ix1 = batch_heatmaps_pad[index + 1] + iy1 = batch_heatmaps_pad[index + W + 2] + ix1y1 = batch_heatmaps_pad[index + W + 3] + ix1_y1_ = batch_heatmaps_pad[index - W - 3] + ix1_ = batch_heatmaps_pad[index - 1] + iy1_ = batch_heatmaps_pad[index - 2 - W] + + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1) + derivative = derivative.reshape(N, K, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) + hessian = hessian.reshape(N, K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze() + return coords + + +def _gaussian_blur(heatmaps, kernel=11): + """Modulate heatmap distribution with Gaussian. + sigma = 0.3*((kernel_size-1)*0.5-1)+0.8 + sigma~=3 if k=17 + sigma=2 if k=11; + sigma~=1.5 if k=7; + sigma~=1 if k=3; + + Note: + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + np.ndarray ([N, K, H, W]): Modulated heatmap distribution. + """ + assert kernel % 2 == 1 + + border = (kernel - 1) // 2 + batch_size = heatmaps.shape[0] + num_joints = heatmaps.shape[1] + height = heatmaps.shape[2] + width = heatmaps.shape[3] + for i in range(batch_size): + for j in range(num_joints): + origin_max = np.max(heatmaps[i, j]) + dr = np.zeros((height + 2 * border, width + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = heatmaps[i, j].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmaps[i, j] = dr[border:-border, border:-border].copy() + heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j]) + return heatmaps + + +def keypoints_from_regression(regression_preds, center, scale, img_size): + """Get final keypoint predictions from regression vectors and transform + them back to the image. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + regression_preds (np.ndarray[N, K, 2]): model prediction. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + img_size (list(img_width, img_height)): model input image size. + + Returns: + tuple: + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + N, K, _ = regression_preds.shape + preds, maxvals = regression_preds, np.ones((N, K, 1), dtype=np.float32) + + preds = preds * img_size + + # Transform back to the image + for i in range(N): + preds[i] = transform_preds(preds[i], center[i], scale[i], img_size) + + return preds, maxvals + + +def keypoints_from_heatmaps(heatmaps, center, scale, unbiased=False, post_process="default", kernel=11, valid_radius_factor=0.0546875, use_udp=False, target_type="GaussianHeatmap"): + """Get final keypoint predictions from heatmaps and transform them back to + the image. + + Note: + - batch size: N + - num keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + post_process (str/None): Choice of methods to post-process + heatmaps. Currently supported: None, 'default', 'unbiased', + 'megvii'. + unbiased (bool): Option to use unbiased decoding. Mutually + exclusive with megvii. + Note: this arg is deprecated and unbiased=True can be replaced + by post_process='unbiased' + Paper ref: Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + valid_radius_factor (float): The radius factor of the positive area + in classification heatmap for UDP. + use_udp (bool): Use unbiased data processing. + target_type (str): 'GaussianHeatmap' or 'CombinedTarget'. + GaussianHeatmap: Classification target with gaussian distribution. + CombinedTarget: The combination of classification target + (response map) and regression target (offset map). + Paper ref: Huang et al. The Devil is in the Details: Delving into + Unbiased Data Processing for Human Pose Estimation (CVPR 2020). + + Returns: + tuple: A tuple containing keypoint predictions and scores. + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + # Avoid being affected + heatmaps = heatmaps.copy() + + # detect conflicts + if unbiased: + assert post_process not in [False, None, "megvii"] + if post_process in ["megvii", "unbiased"]: + assert kernel > 0 + if use_udp: + assert not post_process == "megvii" + + # normalize configs + if post_process is False: + warnings.warn("post_process=False is deprecated, please use post_process=None instead", DeprecationWarning) + post_process = None + elif post_process is True: + if unbiased is True: + warnings.warn("post_process=True, unbiased=True is deprecated, please use post_process='unbiased' instead", DeprecationWarning) + post_process = "unbiased" + else: + warnings.warn("post_process=True, unbiased=False is deprecated, please use post_process='default' instead", DeprecationWarning) + post_process = "default" + elif post_process == "default": + if unbiased is True: + warnings.warn("unbiased=True is deprecated, please use post_process='unbiased' instead", DeprecationWarning) + post_process = "unbiased" + + # start processing + if post_process == "megvii": + heatmaps = _gaussian_blur(heatmaps, kernel=kernel) + + N, K, H, W = heatmaps.shape + if use_udp: + if target_type.lower() == "GaussianHeatMap".lower(): + preds, maxvals = _get_max_preds(heatmaps) + preds = post_dark_udp(preds, heatmaps, kernel=kernel) + elif target_type.lower() == "CombinedTarget".lower(): + for person_heatmaps in heatmaps: + for i, heatmap in enumerate(person_heatmaps): + kt = 2 * kernel + 1 if i % 3 == 0 else kernel + cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap) + # valid radius is in direct proportion to the height of heatmap. + valid_radius = valid_radius_factor * H + offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius + offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius + heatmaps = heatmaps[:, ::3, :] + preds, maxvals = _get_max_preds(heatmaps) + index = preds[..., 0] + preds[..., 1] * W + index += W * H * np.arange(0, N * K / 3) + index = index.astype(int).reshape(N, K // 3, 1) + preds += np.concatenate((offset_x[index], offset_y[index]), axis=2) + else: + raise ValueError("target_type should be either 'GaussianHeatmap' or 'CombinedTarget'") + else: + preds, maxvals = _get_max_preds(heatmaps) + if post_process == "unbiased": # alleviate biased coordinate + # apply Gaussian distribution modulation. + heatmaps = np.log(np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10)) + for n in range(N): + for k in range(K): + preds[n][k] = _taylor(heatmaps[n][k], preds[n][k]) + elif post_process is not None: + # add +/-0.25 shift to the predicted locations for higher acc. + for n in range(N): + for k in range(K): + heatmap = heatmaps[n][k] + px = int(preds[n][k][0]) + py = int(preds[n][k][1]) + if 1 < px < W - 1 and 1 < py < H - 1: + diff = np.array([heatmap[py][px + 1] - heatmap[py][px - 1], heatmap[py + 1][px] - heatmap[py - 1][px]]) + preds[n][k] += np.sign(diff) * 0.25 + if post_process == "megvii": + preds[n][k] += 0.5 + + # Transform back to the image + for i in range(N): + preds[i] = transform_preds(preds[i], center[i], scale[i], [W, H], use_udp=use_udp) + + if post_process == "megvii": + maxvals = maxvals / 255.0 + 0.5 + + return preds, maxvals + + +def keypoints_from_heatmaps3d(heatmaps, center, scale): + """Get final keypoint predictions from 3d heatmaps and transform them back + to the image. + + Note: + - batch size: N + - num keypoints: K + - heatmap depth size: D + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + + Returns: + tuple: A tuple containing keypoint predictions and scores. + + - preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \ + in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + N, K, D, H, W = heatmaps.shape + preds, maxvals = _get_max_preds_3d(heatmaps) + # Transform back to the image + for i in range(N): + preds[i, :, :2] = transform_preds(preds[i, :, :2], center[i], scale[i], [W, H]) + return preds, maxvals + + +def multilabel_classification_accuracy(pred, gt, mask, thr=0.5): + """Get multi-label classification accuracy. + + Note: + - batch size: N + - label number: L + + Args: + pred (np.ndarray[N, L, 2]): model predicted labels. + gt (np.ndarray[N, L, 2]): ground-truth labels. + mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of + ground-truth labels. + + Returns: + float: multi-label classification accuracy. + """ + # we only compute accuracy on the samples with ground-truth of all labels. + valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0) + pred, gt = pred[valid], gt[valid] + + if pred.shape[0] == 0: + acc = 0.0 # when no sample is with gt labels, set acc to 0. + else: + # The classification of a sample is regarded as correct + # only if it's correct for all labels. + acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean() + return acc + + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + # res: (height, width), (rows, cols) + crop_aspect_ratio = res[0] / float(res[1]) + h = 200 * scale + w = h / crop_aspect_ratio + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / w + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / w + 0.5) + t[1, 2] = res[0] * (-float(center[1]) / h + 0.5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[1] / 2 + t_mat[1, 2] = -res[0] / 2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) + return t + + +def transform(pt, center, scale, res, invert=0, rot=0): + """Transform pixel location to different reference.""" + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T + new_pt = np.dot(t, new_pt) + return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1 + + +def bbox_from_detector(bbox, input_resolution=(224, 224), rescale=1.25): + """ + Get center and scale of bounding box from bounding box. + The expected format is [min_x, min_y, max_x, max_y]. + """ + CROP_IMG_HEIGHT, CROP_IMG_WIDTH = input_resolution + CROP_ASPECT_RATIO = CROP_IMG_HEIGHT / float(CROP_IMG_WIDTH) + + # center + center_x = (bbox[0] + bbox[2]) / 2.0 + center_y = (bbox[1] + bbox[3]) / 2.0 + center = np.array([center_x, center_y]) + + # scale + bbox_w = bbox[2] - bbox[0] + bbox_h = bbox[3] - bbox[1] + bbox_size = max(bbox_w * CROP_ASPECT_RATIO, bbox_h) + + scale = np.array([bbox_size / CROP_ASPECT_RATIO, bbox_size]) / 200.0 + # scale = bbox_size / 200.0 + # adjust bounding box tightness + scale *= rescale + return center, scale + + +def crop(img, center, scale, res): + """ + Crop image according to the supplied bounding box. + res: [rows, cols] + """ + # Upper left point + ul = np.array(transform([1, 1], center, max(scale), res, invert=1)) - 1 + # Bottom right point + br = np.array(transform([res[1] + 1, res[0] + 1], center, max(scale), res, invert=1)) - 1 + + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape, dtype=np.float32) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + try: + new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[old_y[0] : old_y[1], old_x[0] : old_x[1]] + except Exception as e: + print(e) + + new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows) + return new_img, new_shape, (old_x, old_y), (new_x, new_y) # , ul, br + + +def split_kp2ds_for_aa(kp2ds, ret_face=False): + kp2ds_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds_lhand = kp2ds[91:112] + kp2ds_rhand = kp2ds[112:133] + kp2ds_face = kp2ds[22:91] + if ret_face: + return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy(), kp2ds_face.copy() + return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy() + + +def load_pose_metas_from_kp2ds_seq_list(kp2ds_seq, width, height): + metas = [] + for kps in kp2ds_seq: + if len(kps) != 1: + return None + kps = kps[0].copy() + kps[:, 0] /= width + kps[:, 1] /= height + kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True) + + if kp2ds_body[:, :2].min(axis=1).max() < 0: + kp2ds_body = last_kp2ds_body + last_kp2ds_body = kp2ds_body + + meta = { + "width": width, + "height": height, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + "keypoints_face": kp2ds_face.tolist(), + } + metas.append(meta) + return metas + + +def load_pose_metas_from_kp2ds_seq(kp2ds_seq, width, height): + metas = [] + for kps in kp2ds_seq: + kps = kps.copy() + kps[:, 0] /= width + kps[:, 1] /= height + kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True) + + # 排除全部小于0的情况 + if kp2ds_body[:, :2].min(axis=1).max() < 0: + kp2ds_body = last_kp2ds_body + last_kp2ds_body = kp2ds_body + + meta = { + "width": width, + "height": height, + "keypoints_body": kp2ds_body, + "keypoints_left_hand": kp2ds_lhand, + "keypoints_right_hand": kp2ds_rhand, + "keypoints_face": kp2ds_face, + } + metas.append(meta) + return metas diff --git a/tools/preprocess/preprocess_data.py b/tools/preprocess/preprocess_data.py new file mode 100644 index 0000000000000000000000000000000000000000..8f60e30c2659bf19c418e9667fd8e46b23b98b20 --- /dev/null +++ b/tools/preprocess/preprocess_data.py @@ -0,0 +1,88 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import os + +from process_pipepline import ProcessPipeline + + +def get_preprocess_parser(): + parser = argparse.ArgumentParser(description="The preprocessing pipeline for Wan-animate.") + + parser.add_argument("--ckpt_path", type=str, default=None, help="The path to the preprocessing model's checkpoint directory. ") + + parser.add_argument("--video_path", type=str, default=None, help="The path to the driving video.") + parser.add_argument("--refer_path", type=str, default=None, help="The path to the refererence image.") + parser.add_argument("--save_path", type=str, default=None, help="The path to save the processed results.") + + parser.add_argument( + "--resolution_area", + type=int, + nargs=2, + default=[1280, 720], + help="The target resolution for processing, specified as [width, height]. To handle different aspect ratios, the video is resized to have a total area equivalent to width * height, while preserving the original aspect ratio.", + ) + parser.add_argument("--fps", type=int, default=30, help="The target FPS for processing the driving video. Set to -1 to use the video's original FPS.") + + parser.add_argument("--replace_flag", action="store_true", default=False, help="Whether to use replacement mode.") + parser.add_argument("--retarget_flag", action="store_true", default=False, help="Whether to use pose retargeting. Currently only supported in animation mode") + parser.add_argument( + "--use_flux", + action="store_true", + default=False, + help="Whether to use image editing in pose retargeting. Recommended if the character in the reference image or the first frame of the driving video is not in a standard, front-facing pose", + ) + + # Parameters for the mask strategy in replacement mode. These control the mask's size and shape. Refer to https://arxiv.org/pdf/2502.06145 + parser.add_argument("--iterations", type=int, default=3, help="Number of iterations for mask dilation.") + parser.add_argument("--k", type=int, default=7, help="Number of kernel size for mask dilation.") + parser.add_argument( + "--w_len", + type=int, + default=1, + help="The number of subdivisions for the grid along the 'w' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.", + ) + parser.add_argument( + "--h_len", + type=int, + default=1, + help="The number of subdivisions for the grid along the 'h' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.", + ) + return parser + + +def process_input_video(args): + args_dict = vars(args) + print(args_dict) + + assert len(args.resolution_area) == 2, "resolution_area should be a list of two integers [width, height]" + assert not args.use_flux or args.retarget_flag, "Image editing with FLUX can only be used when pose retargeting is enabled." + + pose2d_checkpoint_path = os.path.join(args.ckpt_path, "pose2d/vitpose_h_wholebody.onnx") + det_checkpoint_path = os.path.join(args.ckpt_path, "det/yolov10m.onnx") + + sam2_checkpoint_path = os.path.join(args.ckpt_path, "sam2/sam2_hiera_large.pt") if args.replace_flag else None + flux_kontext_path = os.path.join(args.ckpt_path, "FLUX.1-Kontext-dev") if args.use_flux else None + process_pipeline = ProcessPipeline( + det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path + ) + os.makedirs(args.save_path, exist_ok=True) + process_pipeline( + video_path=args.video_path, + refer_image_path=args.refer_path, + output_path=args.save_path, + resolution_area=args.resolution_area, + fps=args.fps, + iterations=args.iterations, + k=args.k, + w_len=args.w_len, + h_len=args.h_len, + retarget_flag=args.retarget_flag, + use_flux=args.use_flux, + replace_flag=args.replace_flag, + ) + + +if __name__ == "__main__": + parser = get_preprocess_parser() + args = parser.parse_args() + process_input_video(args) diff --git a/tools/preprocess/process_pipepline.py b/tools/preprocess/process_pipepline.py new file mode 100644 index 0000000000000000000000000000000000000000..5a490242ba675ee2444500a3617d4df8c53ac5ce --- /dev/null +++ b/tools/preprocess/process_pipepline.py @@ -0,0 +1,355 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import shutil + +import cv2 +import numpy as np +import torch +from PIL import Image +from diffusers import FluxKontextPipeline +from loguru import logger + +try: + import moviepy.editor as mpy +except: # noqa + import moviepy as mpy + +import sam2.modeling.sam.transformer as transformer +from decord import VideoReader +from human_visualization import draw_aapose_by_meta_new +from pose2d import Pose2d +from pose2d_utils import AAPoseMeta +from retarget_pose import get_retarget_pose +from utils import get_aug_mask, get_face_bboxes, get_frame_indices, get_mask_body_img, padding_resize, resize_by_area + +transformer.USE_FLASH_ATTN = False +transformer.MATH_KERNEL_ON = True +transformer.OLD_GPU = True +from sam_utils import build_sam2_video_predictor # noqa + + +class ProcessPipeline: + def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path): + self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path) + + model_cfg = "sam2_hiera_l.yaml" + if sam_checkpoint_path is not None: + self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path) + if flux_kontext_path is not None: + self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda") + + def __call__(self, video_path, refer_image_path, output_path, resolution_area=[1280, 720], fps=30, iterations=3, k=7, w_len=1, h_len=1, retarget_flag=False, use_flux=False, replace_flag=False): + if replace_flag: + video_reader = VideoReader(video_path) + frame_num = len(video_reader) + print("frame_num: {}".format(frame_num)) + + video_fps = video_reader.get_avg_fps() + print("video_fps: {}".format(video_fps)) + print("fps: {}".format(fps)) + + # TODO: Maybe we can switch to PyAV later, which can get accurate frame num + duration = video_reader.get_frame_timestamp(-1)[-1] + expected_frame_num = int(duration * video_fps + 0.5) + ratio = abs((frame_num - expected_frame_num) / frame_num) + if ratio > 0.1: + print("Warning: The difference between the actual number of frames and the expected number of frames is two large") + frame_num = expected_frame_num + + if fps == -1: + fps = video_fps + + target_num = int(frame_num / video_fps * fps) + print("target_num: {}".format(target_num)) + idxs = get_frame_indices(frame_num, video_fps, target_num, fps) + frames = video_reader.get_batch(idxs).asnumpy() + + frames = [resize_by_area(frame, resolution_area[0] * resolution_area[1], divisor=16) for frame in frames] + height, width = frames[0].shape[:2] + logger.info(f"Processing pose meta") + + tpl_pose_metas = self.pose2d(frames) + + face_images = [] + for idx, meta in enumerate(tpl_pose_metas): + face_bbox_for_image = get_face_bboxes(meta["keypoints_face"][:, :2], scale=1.3, image_shape=(frames[0].shape[0], frames[0].shape[1])) + + x1, x2, y1, y2 = face_bbox_for_image + face_image = frames[idx][y1:y2, x1:x2] + face_image = cv2.resize(face_image, (512, 512)) + face_images.append(face_image) + + logger.info(f"Processing reference image: {refer_image_path}") + refer_img = cv2.imread(refer_image_path) + src_ref_path = os.path.join(output_path, "src_ref.png") + shutil.copy(refer_image_path, src_ref_path) + refer_img = refer_img[..., ::-1] + + refer_img = padding_resize(refer_img, height, width) + logger.info(f"Processing template video: {video_path}") + tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas] + cond_images = [] + + for idx, meta in enumerate(tpl_retarget_pose_metas): + canvas = np.zeros_like(refer_img) + conditioning_image = draw_aapose_by_meta_new(canvas, meta) + cond_images.append(conditioning_image) + masks = self.get_mask(frames, 400, tpl_pose_metas) + + bg_images = [] + aug_masks = [] + + for frame, mask in zip(frames, masks): + if iterations > 0: + _, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k) + each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len) + else: + each_aug_mask = mask + + each_bg_image = frame * (1 - each_aug_mask[:, :, None]) + bg_images.append(each_bg_image) + aug_masks.append(each_aug_mask) + + src_face_path = os.path.join(output_path, "src_face.mp4") + mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path) + + src_pose_path = os.path.join(output_path, "src_pose.mp4") + mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path) + + src_bg_path = os.path.join(output_path, "src_bg.mp4") + mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path) + + aug_masks_new = [np.stack([mask * 255, mask * 255, mask * 255], axis=2) for mask in aug_masks] + src_mask_path = os.path.join(output_path, "src_mask.mp4") + mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path) + return True + else: + logger.info(f"Processing reference image: {refer_image_path}") + refer_img = cv2.imread(refer_image_path) + src_ref_path = os.path.join(output_path, "src_ref.png") + shutil.copy(refer_image_path, src_ref_path) + refer_img = refer_img[..., ::-1] + + refer_img = resize_by_area(refer_img, resolution_area[0] * resolution_area[1], divisor=16) + + refer_pose_meta = self.pose2d([refer_img])[0] + + logger.info(f"Processing template video: {video_path}") + video_reader = VideoReader(video_path) + frame_num = len(video_reader) + print("frame_num: {}".format(frame_num)) + + video_fps = video_reader.get_avg_fps() + print("video_fps: {}".format(video_fps)) + print("fps: {}".format(fps)) + + # TODO: Maybe we can switch to PyAV later, which can get accurate frame num + duration = video_reader.get_frame_timestamp(-1)[-1] + expected_frame_num = int(duration * video_fps + 0.5) + ratio = abs((frame_num - expected_frame_num) / frame_num) + if ratio > 0.1: + print("Warning: The difference between the actual number of frames and the expected number of frames is two large") + frame_num = expected_frame_num + + if fps == -1: + fps = video_fps + + target_num = int(frame_num / video_fps * fps) + print("target_num: {}".format(target_num)) + idxs = get_frame_indices(frame_num, video_fps, target_num, fps) + frames = video_reader.get_batch(idxs).asnumpy() + + logger.info(f"Processing pose meta") + + tpl_pose_meta0 = self.pose2d(frames[:1])[0] + tpl_pose_metas = self.pose2d(frames) + + face_images = [] + for idx, meta in enumerate(tpl_pose_metas): + face_bbox_for_image = get_face_bboxes(meta["keypoints_face"][:, :2], scale=1.3, image_shape=(frames[0].shape[0], frames[0].shape[1])) + + x1, x2, y1, y2 = face_bbox_for_image + face_image = frames[idx][y1:y2, x1:x2] + face_image = cv2.resize(face_image, (512, 512)) + face_images.append(face_image) + + if retarget_flag: + if use_flux: + tpl_prompt, refer_prompt = self.get_editing_prompts(tpl_pose_metas, refer_pose_meta) + refer_input = Image.fromarray(refer_img) + refer_edit = self.flux_kontext( + image=refer_input, + height=refer_img.shape[0], + width=refer_img.shape[1], + prompt=refer_prompt, + guidance_scale=2.5, + num_inference_steps=28, + ).images[0] + + refer_edit = Image.fromarray(padding_resize(np.array(refer_edit), refer_img.shape[0], refer_img.shape[1])) + refer_edit_path = os.path.join(output_path, "refer_edit.png") + refer_edit.save(refer_edit_path) + refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0] + + tpl_img = frames[1] + tpl_input = Image.fromarray(tpl_img) + + tpl_edit = self.flux_kontext( + image=tpl_input, + height=tpl_img.shape[0], + width=tpl_img.shape[1], + prompt=tpl_prompt, + guidance_scale=2.5, + num_inference_steps=28, + ).images[0] + + tpl_edit = Image.fromarray(padding_resize(np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1])) + tpl_edit_path = os.path.join(output_path, "tpl_edit.png") + tpl_edit.save(tpl_edit_path) + tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0] + tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tpl_edit_pose_meta0, refer_edit_pose_meta) + else: + tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None) + else: + tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas] + + cond_images = [] + for idx, meta in enumerate(tpl_retarget_pose_metas): + if retarget_flag: + canvas = np.zeros_like(refer_img) + conditioning_image = draw_aapose_by_meta_new(canvas, meta) + else: + canvas = np.zeros_like(frames[0]) + conditioning_image = draw_aapose_by_meta_new(canvas, meta) + conditioning_image = padding_resize(conditioning_image, refer_img.shape[0], refer_img.shape[1]) + + cond_images.append(conditioning_image) + + src_face_path = os.path.join(output_path, "src_face.mp4") + mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path) + + src_pose_path = os.path.join(output_path, "src_pose.mp4") + mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path) + return True + + def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta): + arm_visible = False + leg_visible = False + for tpl_pose_meta in tpl_pose_metas: + tpl_keypoints = tpl_pose_meta["keypoints_body"] + if tpl_keypoints[3].all() != 0 or tpl_keypoints[4].all() != 0 or tpl_keypoints[6].all() != 0 or tpl_keypoints[7].all() != 0: + if ( + (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) + or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) + or (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) + or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75) + ): + arm_visible = True + if tpl_keypoints[9].all() != 0 or tpl_keypoints[12].all() != 0 or tpl_keypoints[10].all() != 0 or tpl_keypoints[13].all() != 0: + if ( + (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) + or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) + or (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) + or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75) + ): + leg_visible = True + if arm_visible and leg_visible: + break + + if leg_visible: + if tpl_pose_meta["width"] > tpl_pose_meta["height"]: + tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." + else: + tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." + + if refer_pose_meta["width"] > refer_pose_meta["height"]: + refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." + else: + refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." + elif arm_visible: + if tpl_pose_meta["width"] > tpl_pose_meta["height"]: + tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." + else: + tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." + + if refer_pose_meta["width"] > refer_pose_meta["height"]: + refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." + else: + refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." + else: + tpl_prompt = "Change the person to face forward." + refer_prompt = "Change the person to face forward." + + return tpl_prompt, refer_prompt + + def get_mask(self, frames, th_step, kp2ds_all): + frame_num = len(frames) + if frame_num < th_step: + num_step = 1 + else: + num_step = (frame_num + th_step) // th_step + + all_mask = [] + for index in range(num_step): + each_frames = frames[index * th_step : (index + 1) * th_step] + + kp2ds = kp2ds_all[index * th_step : (index + 1) * th_step] + if len(each_frames) > 4: + key_frame_num = 4 + elif 4 >= len(each_frames) > 0: + key_frame_num = 1 + else: + continue + + key_frame_step = len(kp2ds) // key_frame_num + key_frame_index_list = list(range(0, len(kp2ds), key_frame_step)) + + key_points_index = [0, 1, 2, 5, 8, 11, 10, 13] + key_frame_body_points_list = [] + for key_frame_index in key_frame_index_list: + keypoints_body_list = [] + body_key_points = kp2ds[key_frame_index]["keypoints_body"] + for each_index in key_points_index: + each_keypoint = body_key_points[each_index] + if None is each_keypoint: + continue + keypoints_body_list.append(each_keypoint) + + keypoints_body = np.array(keypoints_body_list)[:, :2] + wh = np.array([[kp2ds[0]["width"], kp2ds[0]["height"]]]) + points = (keypoints_body * wh).astype(np.int32) + key_frame_body_points_list.append(points) + + inference_state = self.predictor.init_state_v2(frames=each_frames) + self.predictor.reset_state(inference_state) + ann_obj_id = 1 + for ann_frame_idx, points in zip(key_frame_index_list, key_frame_body_points_list): + labels = np.array([1] * points.shape[0], np.int32) + _, out_obj_ids, out_mask_logits = self.predictor.add_new_points( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, + ) + + video_segments = {} + for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state): + video_segments[out_frame_idx] = {out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids)} + + for out_frame_idx in range(len(video_segments)): + for out_obj_id, out_mask in video_segments[out_frame_idx].items(): + out_mask = out_mask[0].astype(np.uint8) + all_mask.append(out_mask) + + return all_mask + + def convert_list_to_array(self, metas): + metas_list = [] + for meta in metas: + for key, value in meta.items(): + if type(value) is list: + value = np.array(value) + meta[key] = value + metas_list.append(meta) + return metas_list diff --git a/tools/preprocess/retarget_pose.py b/tools/preprocess/retarget_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..3e24b7ba5c029d8af7c276f9dbf46553934eb8ec --- /dev/null +++ b/tools/preprocess/retarget_pose.py @@ -0,0 +1,850 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import copy +import math +from typing import NamedTuple + +import numpy as np +from pose2d_utils import AAPoseMeta +from tqdm import tqdm + +# load skeleton name and bone lines +keypoint_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", +] + + +limbSeq = [ + [2, 3], + [2, 6], # shoulders + [3, 4], + [4, 5], # left arm + [6, 7], + [7, 8], # right arm + [2, 9], + [9, 10], + [10, 11], # right leg + [2, 12], + [12, 13], + [13, 14], # left leg + [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + [14, 19], # left foot + [11, 20], # right foot +] + +eps = 0.01 + + +class Keypoint(NamedTuple): + x: float + y: float + score: float = 1.0 + id: int = -1 + + +# for each limb, calculate src & dst bone's length +# and calculate their ratios +def get_length(skeleton, limb): + k1_index, k2_index = limb + + H, W = skeleton["height"], skeleton["width"] + keypoints = skeleton["keypoints_body"] + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None: + return None, None, None + + X = np.array([keypoint1[0], keypoint2[0]]) * float(W) + Y = np.array([keypoint1[1], keypoint2[1]]) * float(H) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + + return X, Y, length + + +def get_handpose_meta(keypoints, delta, src_H, src_W): + new_keypoints = [] + + for idx, keypoint in enumerate(keypoints): + if keypoint is None: + new_keypoints.append(None) + continue + if keypoint.score == 0: + new_keypoints.append(None) + continue + + x, y = keypoint.x, keypoint.y + x = int(x * src_W + delta[0]) + y = int(y * src_H + delta[1]) + + new_keypoints.append( + Keypoint( + x=x, + y=y, + score=keypoint.score, + ) + ) + + return new_keypoints + + +def deal_hand_keypoints(hand_res, r_ratio, l_ratio, hand_score_th=0.5): + left_hand = [] + right_hand = [] + + left_delta_x = hand_res["left"][0][0] * (l_ratio - 1) + left_delta_y = hand_res["left"][0][1] * (l_ratio - 1) + + right_delta_x = hand_res["right"][0][0] * (r_ratio - 1) + right_delta_y = hand_res["right"][0][1] * (r_ratio - 1) + + length = len(hand_res["left"]) + + for i in range(length): + # left hand + if hand_res["left"][i][2] < hand_score_th: + left_hand.append( + Keypoint( + x=-1, + y=-1, + score=0, + ) + ) + else: + left_hand.append(Keypoint(x=hand_res["left"][i][0] * l_ratio - left_delta_x, y=hand_res["left"][i][1] * l_ratio - left_delta_y, score=hand_res["left"][i][2])) + + # right hand + if hand_res["right"][i][2] < hand_score_th: + right_hand.append( + Keypoint( + x=-1, + y=-1, + score=0, + ) + ) + else: + right_hand.append(Keypoint(x=hand_res["right"][i][0] * r_ratio - right_delta_x, y=hand_res["right"][i][1] * r_ratio - right_delta_y, score=hand_res["right"][i][2])) + + return right_hand, left_hand + + +def get_scaled_pose(canvas, src_canvas, keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, id, scale_min, threshold=0.4): + H, W = canvas + src_H, src_W = src_canvas + + new_length_list = [] + angle_list = [] + + # keypoints from 0-1 to H/W range + for idx in range(len(keypoints)): + if keypoints[idx] is None or len(keypoints[idx]) == 0: + continue + + keypoints[idx] = [keypoints[idx][0] * src_W, keypoints[idx][1] * src_H, keypoints[idx][2]] + + # first traverse, get new_length_list and angle_list + for idx, (k1_index, k2_index) in enumerate(limbSeq): + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0: + new_length_list.append(None) + angle_list.append(None) + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) # * float(W) + X = np.array([keypoint1[1], keypoint2[1]]) # * float(H) + + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + + new_length = length * bone_ratio_list[idx] + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + new_length_list.append(new_length) + angle_list.append(angle) + + # Keep foot length within 0.5x calf length + foot_lower_leg_ratio = 0.5 + if new_length_list[8] != None and new_length_list[18] != None: # noqa + if new_length_list[18] > new_length_list[8] * foot_lower_leg_ratio: + new_length_list[18] = new_length_list[8] * foot_lower_leg_ratio + + if new_length_list[11] != None and new_length_list[17] != None: # noqa + if new_length_list[17] > new_length_list[11] * foot_lower_leg_ratio: + new_length_list[17] = new_length_list[11] * foot_lower_leg_ratio + + # second traverse, calculate new keypoints + rescale_keypoints = keypoints.copy() + + for idx, (k1_index, k2_index) in enumerate(limbSeq): + # update dst_keypoints + start_keypoint = rescale_keypoints[k1_index - 1] + new_length = new_length_list[idx] + angle = angle_list[idx] + + if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0: + continue + + # calculate end_keypoint + delta_x = new_length * math.cos(math.radians(angle)) + delta_y = new_length * math.sin(math.radians(angle)) + + end_keypoint_x = start_keypoint[0] - delta_x + end_keypoint_y = start_keypoint[1] - delta_y + + # update keypoints + rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y, rescale_keypoints[k2_index - 1][2]] + + if id == 0: + if body_flag == "full_body" and rescale_keypoints[8] != None and rescale_keypoints[11] != None: # noqa + delta_ground_x_offset_first_frame = (rescale_keypoints[8][0] + rescale_keypoints[11][0]) / 2 - rescaled_src_ground_x + delta_ground_x += delta_ground_x_offset_first_frame + elif body_flag == "half_body" and rescale_keypoints[1] != None: # noqa + delta_ground_x_offset_first_frame = rescale_keypoints[1][0] - rescaled_src_ground_x + delta_ground_x += delta_ground_x_offset_first_frame + + # offset all keypoints + for idx in range(len(rescale_keypoints)): + if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0: + continue + rescale_keypoints[idx][0] -= delta_ground_x + rescale_keypoints[idx][1] -= delta_ground_y + + # rescale keypoints to original size + rescale_keypoints[idx][0] /= scale_min + rescale_keypoints[idx][1] /= scale_min + + # Scale hand proportions based on body skeletal ratios + r_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min + l_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min + left_hand, right_hand = deal_hand_keypoints(keypoints_hand, r_ratio, l_ratio, hand_score_th=threshold) + + left_hand_new = left_hand.copy() + right_hand_new = right_hand.copy() + + if rescale_keypoints[4] == None and rescale_keypoints[7] == None: # noqa + pass + + elif rescale_keypoints[4] == None and rescale_keypoints[7] != None: # noqa + right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2]) + right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W) + + elif rescale_keypoints[4] != None and rescale_keypoints[7] == None: # noqa + left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2]) + left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W) + + else: + # get left_hand and right_hand offset + left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2]) + right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2]) + + if keypoints[4][0] != None and left_hand[0].x != -1: # noqa + left_hand_root_offset = np.array((keypoints[4][0] - left_hand[0].x * src_W, keypoints[4][1] - left_hand[0].y * src_H)) + left_hand_delta += left_hand_root_offset + + if keypoints[7][0] != None and right_hand[0].x != -1: # noqa + right_hand_root_offset = np.array((keypoints[7][0] - right_hand[0].x * src_W, keypoints[7][1] - right_hand[0].y * src_H)) + right_hand_delta += right_hand_root_offset + + dis_left_hand = ((keypoints[4][0] - left_hand[0].x * src_W) ** 2 + (keypoints[4][1] - left_hand[0].y * src_H) ** 2) ** 0.5 + dis_right_hand = ((keypoints[7][0] - left_hand[0].x * src_W) ** 2 + (keypoints[7][1] - left_hand[0].y * src_H) ** 2) ** 0.5 + + if dis_left_hand > dis_right_hand: + right_hand_new = get_handpose_meta(left_hand, right_hand_delta, src_H, src_W) + left_hand_new = get_handpose_meta(right_hand, left_hand_delta, src_H, src_W) + else: + left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W) + right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W) + + # get normalized keypoints_body + norm_body_keypoints = [] + for body_keypoint in rescale_keypoints: + if body_keypoint != None: # noqa + norm_body_keypoints.append([body_keypoint[0] / W, body_keypoint[1] / H, body_keypoint[2]]) + else: + norm_body_keypoints.append(None) + + frame_info = { + "height": H, + "width": W, + "keypoints_body": norm_body_keypoints, + "keypoints_left_hand": left_hand_new, + "keypoints_right_hand": right_hand_new, + } + + return frame_info + + +def rescale_skeleton(H, W, keypoints, bone_ratio_list): + rescale_keypoints = keypoints.copy() + + new_length_list = [] + angle_list = [] + + # keypoints from 0-1 to H/W range + for idx in range(len(rescale_keypoints)): + if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0: + continue + + rescale_keypoints[idx] = [rescale_keypoints[idx][0] * W, rescale_keypoints[idx][1] * H] + + # first traverse, get new_length_list and angle_list + for idx, (k1_index, k2_index) in enumerate(limbSeq): + keypoint1 = rescale_keypoints[k1_index - 1] + keypoint2 = rescale_keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0: + new_length_list.append(None) + angle_list.append(None) + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) # * float(W) + X = np.array([keypoint1[1], keypoint2[1]]) # * float(H) + + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + + new_length = length * bone_ratio_list[idx] + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + new_length_list.append(new_length) + angle_list.append(angle) + + # # second traverse, calculate new keypoints + for idx, (k1_index, k2_index) in enumerate(limbSeq): + # update dst_keypoints + start_keypoint = rescale_keypoints[k1_index - 1] + new_length = new_length_list[idx] + angle = angle_list[idx] + + if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0: + continue + + # calculate end_keypoint + delta_x = new_length * math.cos(math.radians(angle)) + delta_y = new_length * math.sin(math.radians(angle)) + + end_keypoint_x = start_keypoint[0] - delta_x + end_keypoint_y = start_keypoint[1] - delta_y + + # update keypoints + rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y] + + return rescale_keypoints + + +def fix_lack_keypoints_use_sym(skeleton): + keypoints = skeleton["keypoints_body"] + H, W = skeleton["height"], skeleton["width"] + + limb_points_list = [ + [3, 4, 5], + [6, 7, 8], + [12, 13, 14, 19], + [9, 10, 11, 20], + ] + + for limb_points in limb_points_list: + miss_flag = False + for point in limb_points: + if keypoints[point - 1] is None: + miss_flag = True + continue + if miss_flag: + skeleton["keypoints_body"][point - 1] = None + + repair_limb_seq_left = [ + [3, 4], + [4, 5], # left arm + [12, 13], + [13, 14], # left leg + [14, 19], # left foot + ] + + repair_limb_seq_right = [ + [6, 7], + [7, 8], # right arm + [9, 10], + [10, 11], # right leg + [11, 20], # right foot + ] + + repair_limb_seq = [repair_limb_seq_left, repair_limb_seq_right] + + for idx_part, part in enumerate(repair_limb_seq): + for idx, limb in enumerate(part): + k1_index, k2_index = limb + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 != None and keypoint2 is None: # noqa + # reference to symmetric limb + sym_limb = repair_limb_seq[1 - idx_part][idx] + k1_index_sym, k2_index_sym = sym_limb + keypoint1_sym = keypoints[k1_index_sym - 1] + keypoint2_sym = keypoints[k2_index_sym - 1] + ref_length = 0 + + if keypoint1_sym != None and keypoint2_sym != None: # noqa + X = np.array([keypoint1_sym[0], keypoint2_sym[0]]) * float(W) + Y = np.array([keypoint1_sym[1], keypoint2_sym[1]]) * float(H) + ref_length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + else: + ref_length_left, ref_length_right = 0, 0 + if keypoints[1] != None and keypoints[8] != None: # noqa + X = np.array([keypoints[1][0], keypoints[8][0]]) * float(W) + Y = np.array([keypoints[1][1], keypoints[8][1]]) * float(H) + ref_length_left = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + if idx <= 1: # arms + ref_length_left /= 2 + + if keypoints[1] != None and keypoints[11] != None: # noqa + X = np.array([keypoints[1][0], keypoints[11][0]]) * float(W) + Y = np.array([keypoints[1][1], keypoints[11][1]]) * float(H) + ref_length_right = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + if idx <= 1: # arms + ref_length_right /= 2 + elif idx == 4: # foot + ref_length_right /= 5 + + ref_length = max(ref_length_left, ref_length_right) + + if ref_length != 0: + skeleton["keypoints_body"][k2_index - 1] = [0, 0] # init + skeleton["keypoints_body"][k2_index - 1][0] = skeleton["keypoints_body"][k1_index - 1][0] + skeleton["keypoints_body"][k2_index - 1][1] = skeleton["keypoints_body"][k1_index - 1][1] + ref_length / H + return skeleton + + +def rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list): + modify_bone_list = [[0, 1], [2, 4], [3, 5], [6, 9], [7, 10], [8, 11], [17, 18]] + + for modify_bone in modify_bone_list: + new_ratio = max(ratio_list[modify_bone[0]], ratio_list[modify_bone[1]]) + ratio_list[modify_bone[0]] = new_ratio + ratio_list[modify_bone[1]] = new_ratio + + if ratio_list[13] != None and ratio_list[15] != None: # noqa + ratio_eye_avg = (ratio_list[13] + ratio_list[15]) / 2 + ratio_list[13] = ratio_eye_avg + ratio_list[15] = ratio_eye_avg + + if ratio_list[14] != None and ratio_list[16] != None: # noqa + ratio_eye_avg = (ratio_list[14] + ratio_list[16]) / 2 + ratio_list[14] = ratio_eye_avg + ratio_list[16] = ratio_eye_avg + + return ratio_list, src_length_list, dst_length_list + + +def check_full_body(keypoints, threshold=0.4): + body_flag = "half_body" + + # 1. If ankle points exist, confidence is greater than the threshold, and points do not exceed the frame, return full_body + if keypoints[10] != None and keypoints[13] != None and keypoints[8] != None and keypoints[11] != None: # noqa + if ( + (keypoints[10][1] <= 1 and keypoints[13][1] <= 1) + and (keypoints[10][2] >= threshold and keypoints[13][2] >= threshold) + and (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) + and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold) + ): + body_flag = "full_body" + return body_flag + + # 2. If hip points exist, return three_quarter_body + if keypoints[8] != None and keypoints[11] != None: # noqa + if (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold): + body_flag = "three_quarter_body" + return body_flag + + return body_flag + + +def check_full_body_both(flag1, flag2): + body_flag_dict = {"full_body": 2, "three_quarter_body": 1, "half_body": 0} + + body_flag_dict_reverse = {2: "full_body", 1: "three_quarter_body", 0: "half_body"} + + flag1_num = body_flag_dict[flag1] + flag2_num = body_flag_dict[flag2] + flag_both_num = min(flag1_num, flag2_num) + return body_flag_dict_reverse[flag_both_num] + + +def write_to_poses(data_to_json, none_idx, dst_shape, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, scale_min): + outputs = [] + length = len(data_to_json) + for id in tqdm(range(length)): + src_height, src_width = data_to_json[id]["height"], data_to_json[id]["width"] + width, height = dst_shape + keypoints = data_to_json[id]["keypoints_body"] + for idx in range(len(keypoints)): + if idx in none_idx: + keypoints[idx] = None + new_keypoints = keypoints.copy() + + # get hand keypoints + keypoints_hand = {"left": data_to_json[id]["keypoints_left_hand"], "right": data_to_json[id]["keypoints_right_hand"]} + # Normalize hand coordinates to 0-1 range + for hand_idx in range(len(data_to_json[id]["keypoints_left_hand"])): + data_to_json[id]["keypoints_left_hand"][hand_idx][0] = data_to_json[id]["keypoints_left_hand"][hand_idx][0] / src_width + data_to_json[id]["keypoints_left_hand"][hand_idx][1] = data_to_json[id]["keypoints_left_hand"][hand_idx][1] / src_height + + for hand_idx in range(len(data_to_json[id]["keypoints_right_hand"])): + data_to_json[id]["keypoints_right_hand"][hand_idx][0] = data_to_json[id]["keypoints_right_hand"][hand_idx][0] / src_width + data_to_json[id]["keypoints_right_hand"][hand_idx][1] = data_to_json[id]["keypoints_right_hand"][hand_idx][1] / src_height + + frame_info = get_scaled_pose( + (height, width), (src_height, src_width), new_keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, id, scale_min + ) + outputs.append(frame_info) + + return outputs + + +def calculate_scale_ratio(skeleton, skeleton_edit, scale_ratio_flag): + if scale_ratio_flag: + headw = max(skeleton["keypoints_body"][0][0], skeleton["keypoints_body"][14][0], skeleton["keypoints_body"][15][0], skeleton["keypoints_body"][16][0], skeleton["keypoints_body"][17][0]) - min( + skeleton["keypoints_body"][0][0], skeleton["keypoints_body"][14][0], skeleton["keypoints_body"][15][0], skeleton["keypoints_body"][16][0], skeleton["keypoints_body"][17][0] + ) + headw_edit = max( + skeleton_edit["keypoints_body"][0][0], + skeleton_edit["keypoints_body"][14][0], + skeleton_edit["keypoints_body"][15][0], + skeleton_edit["keypoints_body"][16][0], + skeleton_edit["keypoints_body"][17][0], + ) - min( + skeleton_edit["keypoints_body"][0][0], + skeleton_edit["keypoints_body"][14][0], + skeleton_edit["keypoints_body"][15][0], + skeleton_edit["keypoints_body"][16][0], + skeleton_edit["keypoints_body"][17][0], + ) + headw_ratio = headw / headw_edit + + _, _, shoulder = get_length(skeleton, [6, 3]) + _, _, shoulder_edit = get_length(skeleton_edit, [6, 3]) + shoulder_ratio = shoulder / shoulder_edit + + return max(headw_ratio, shoulder_ratio) + + else: + return 1 + + +def retarget_pose(src_skeleton, dst_skeleton, all_src_skeleton, src_skeleton_edit, dst_skeleton_edit, threshold=0.4): + if src_skeleton_edit is not None and dst_skeleton_edit is not None: # noqa + use_edit_for_base = True + else: + use_edit_for_base = False + + src_skeleton_ori = copy.deepcopy(src_skeleton) + + dst_skeleton_ori_h, dst_skeleton_ori_w = dst_skeleton["height"], dst_skeleton["width"] + if ( + src_skeleton["keypoints_body"][0] != None # noqa + and src_skeleton["keypoints_body"][10] != None # noqa + and src_skeleton["keypoints_body"][13] != None # noqa + and dst_skeleton["keypoints_body"][0] != None # noqa + and dst_skeleton["keypoints_body"][10] != None # noqa + and dst_skeleton["keypoints_body"][13] != None # noqa + and src_skeleton["keypoints_body"][0][2] > 0.5 + and src_skeleton["keypoints_body"][10][2] > 0.5 + and src_skeleton["keypoints_body"][13][2] > 0.5 + and dst_skeleton["keypoints_body"][0][2] > 0.5 + and dst_skeleton["keypoints_body"][10][2] > 0.5 + and dst_skeleton["keypoints_body"][13][2] > 0.5 + ): + src_height = src_skeleton["height"] * abs((src_skeleton["keypoints_body"][10][1] + src_skeleton["keypoints_body"][13][1]) / 2 - src_skeleton["keypoints_body"][0][1]) + dst_height = dst_skeleton["height"] * abs((dst_skeleton["keypoints_body"][10][1] + dst_skeleton["keypoints_body"][13][1]) / 2 - dst_skeleton["keypoints_body"][0][1]) + scale_min = 1.0 * src_height / dst_height + elif ( + src_skeleton["keypoints_body"][0] != None # noqa + and src_skeleton["keypoints_body"][8] != None # noqa + and src_skeleton["keypoints_body"][11] != None # noqa + and dst_skeleton["keypoints_body"][0] != None # noqa + and dst_skeleton["keypoints_body"][8] != None # noqa + and dst_skeleton["keypoints_body"][11] != None # noqa + and src_skeleton["keypoints_body"][0][2] > 0.5 + and src_skeleton["keypoints_body"][8][2] > 0.5 + and src_skeleton["keypoints_body"][11][2] > 0.5 + and dst_skeleton["keypoints_body"][0][2] > 0.5 + and dst_skeleton["keypoints_body"][8][2] > 0.5 + and dst_skeleton["keypoints_body"][11][2] > 0.5 + ): + src_height = src_skeleton["height"] * abs((src_skeleton["keypoints_body"][8][1] + src_skeleton["keypoints_body"][11][1]) / 2 - src_skeleton["keypoints_body"][0][1]) + dst_height = dst_skeleton["height"] * abs((dst_skeleton["keypoints_body"][8][1] + dst_skeleton["keypoints_body"][11][1]) / 2 - dst_skeleton["keypoints_body"][0][1]) + scale_min = 1.0 * src_height / dst_height + else: + scale_min = np.sqrt(src_skeleton["height"] * src_skeleton["width"]) / np.sqrt(dst_skeleton["height"] * dst_skeleton["width"]) + + if use_edit_for_base: + scale_ratio_flag = False + if ( + src_skeleton_edit["keypoints_body"][0] != None # noqa + and src_skeleton_edit["keypoints_body"][10] != None # noqa + and src_skeleton_edit["keypoints_body"][13] != None # noqa + and dst_skeleton_edit["keypoints_body"][0] != None # noqa + and dst_skeleton_edit["keypoints_body"][10] != None # noqa + and dst_skeleton_edit["keypoints_body"][13] != None # noqa + and src_skeleton_edit["keypoints_body"][0][2] > 0.5 + and src_skeleton_edit["keypoints_body"][10][2] > 0.5 + and src_skeleton_edit["keypoints_body"][13][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][0][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][10][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][13][2] > 0.5 + ): + src_height_edit = src_skeleton_edit["height"] * abs( + (src_skeleton_edit["keypoints_body"][10][1] + src_skeleton_edit["keypoints_body"][13][1]) / 2 - src_skeleton_edit["keypoints_body"][0][1] + ) + dst_height_edit = dst_skeleton_edit["height"] * abs( + (dst_skeleton_edit["keypoints_body"][10][1] + dst_skeleton_edit["keypoints_body"][13][1]) / 2 - dst_skeleton_edit["keypoints_body"][0][1] + ) + scale_min_edit = 1.0 * src_height_edit / dst_height_edit + elif ( + src_skeleton_edit["keypoints_body"][0] != None # noqa + and src_skeleton_edit["keypoints_body"][8] != None # noqa + and src_skeleton_edit["keypoints_body"][11] != None # noqa + and dst_skeleton_edit["keypoints_body"][0] != None # noqa + and dst_skeleton_edit["keypoints_body"][8] != None # noqa + and dst_skeleton_edit["keypoints_body"][11] != None # noqa + and src_skeleton_edit["keypoints_body"][0][2] > 0.5 + and src_skeleton_edit["keypoints_body"][8][2] > 0.5 + and src_skeleton_edit["keypoints_body"][11][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][0][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][8][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][11][2] > 0.5 + ): + src_height_edit = src_skeleton_edit["height"] * abs( + (src_skeleton_edit["keypoints_body"][8][1] + src_skeleton_edit["keypoints_body"][11][1]) / 2 - src_skeleton_edit["keypoints_body"][0][1] + ) + dst_height_edit = dst_skeleton_edit["height"] * abs( + (dst_skeleton_edit["keypoints_body"][8][1] + dst_skeleton_edit["keypoints_body"][11][1]) / 2 - dst_skeleton_edit["keypoints_body"][0][1] + ) + scale_min_edit = 1.0 * src_height_edit / dst_height_edit + else: + scale_min_edit = np.sqrt(src_skeleton_edit["height"] * src_skeleton_edit["width"]) / np.sqrt(dst_skeleton_edit["height"] * dst_skeleton_edit["width"]) + scale_ratio_flag = True + + # Flux may change the scale, compensate for it here + ratio_src = calculate_scale_ratio(src_skeleton, src_skeleton_edit, scale_ratio_flag) + ratio_dst = calculate_scale_ratio(dst_skeleton, dst_skeleton_edit, scale_ratio_flag) + + dst_skeleton_edit["height"] = int(dst_skeleton_edit["height"] * scale_min_edit) + dst_skeleton_edit["width"] = int(dst_skeleton_edit["width"] * scale_min_edit) + for idx in range(len(dst_skeleton_edit["keypoints_left_hand"])): + dst_skeleton_edit["keypoints_left_hand"][idx][0] *= scale_min_edit + dst_skeleton_edit["keypoints_left_hand"][idx][1] *= scale_min_edit + for idx in range(len(dst_skeleton_edit["keypoints_right_hand"])): + dst_skeleton_edit["keypoints_right_hand"][idx][0] *= scale_min_edit + dst_skeleton_edit["keypoints_right_hand"][idx][1] *= scale_min_edit + + dst_skeleton["height"] = int(dst_skeleton["height"] * scale_min) + dst_skeleton["width"] = int(dst_skeleton["width"] * scale_min) + for idx in range(len(dst_skeleton["keypoints_left_hand"])): + dst_skeleton["keypoints_left_hand"][idx][0] *= scale_min + dst_skeleton["keypoints_left_hand"][idx][1] *= scale_min + for idx in range(len(dst_skeleton["keypoints_right_hand"])): + dst_skeleton["keypoints_right_hand"][idx][0] *= scale_min + dst_skeleton["keypoints_right_hand"][idx][1] *= scale_min + + dst_body_flag = check_full_body(dst_skeleton["keypoints_body"], threshold) + src_body_flag = check_full_body(src_skeleton_ori["keypoints_body"], threshold) + body_flag = check_full_body_both(dst_body_flag, src_body_flag) + # print('body_flag: ', body_flag) + + if use_edit_for_base: + src_skeleton_edit = fix_lack_keypoints_use_sym(src_skeleton_edit) + dst_skeleton_edit = fix_lack_keypoints_use_sym(dst_skeleton_edit) + else: + src_skeleton = fix_lack_keypoints_use_sym(src_skeleton) + dst_skeleton = fix_lack_keypoints_use_sym(dst_skeleton) + + none_idx = [] + for idx in range(len(dst_skeleton["keypoints_body"])): + if dst_skeleton["keypoints_body"][idx] == None or src_skeleton["keypoints_body"][idx] == None: # noqa + src_skeleton["keypoints_body"][idx] = None + dst_skeleton["keypoints_body"][idx] = None + none_idx.append(idx) + + # get bone ratio list + ratio_list, src_length_list, dst_length_list = [], [], [] + for idx, limb in enumerate(limbSeq): + if use_edit_for_base: + src_X, src_Y, src_length = get_length(src_skeleton_edit, limb) + dst_X, dst_Y, dst_length = get_length(dst_skeleton_edit, limb) + + if src_X is None or src_Y is None or dst_X is None or dst_Y is None: # noqa + ratio = -1 + else: + ratio = 1.0 * dst_length * ratio_dst / src_length / ratio_src + + else: + src_X, src_Y, src_length = get_length(src_skeleton, limb) + dst_X, dst_Y, dst_length = get_length(dst_skeleton, limb) + + if src_X is None or src_Y is None or dst_X is None or dst_Y is None: # noqa + ratio = -1 + else: + ratio = 1.0 * dst_length / src_length + + ratio_list.append(ratio) + src_length_list.append(src_length) + dst_length_list.append(dst_length) + + for idx, ratio in enumerate(ratio_list): + if ratio == -1: + if ratio_list[0] != -1 and ratio_list[1] != -1: + ratio_list[idx] = (ratio_list[0] + ratio_list[1]) / 2 + + # Consider adding constraints when Flux fails to correct head pose, causing neck issues. + # if ratio_list[12] > (ratio_list[0]+ratio_list[1])/2*1.25: + # ratio_list[12] = (ratio_list[0]+ratio_list[1])/2*1.25 + + ratio_list, src_length_list, dst_length_list = rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list) + + rescaled_src_skeleton_ori = rescale_skeleton(src_skeleton_ori["height"], src_skeleton_ori["width"], src_skeleton_ori["keypoints_body"], ratio_list) + + # get global translation offset_x and offset_y + if body_flag == "full_body": + # print('use foot mark.') + dst_ground_y = max(dst_skeleton["keypoints_body"][10][1], dst_skeleton["keypoints_body"][13][1]) * dst_skeleton["height"] + # The midpoint between toe and ankle + if dst_skeleton["keypoints_body"][18] != None and dst_skeleton["keypoints_body"][19] != None: # noqa + right_foot_mid = (dst_skeleton["keypoints_body"][10][1] + dst_skeleton["keypoints_body"][19][1]) / 2 + left_foot_mid = (dst_skeleton["keypoints_body"][13][1] + dst_skeleton["keypoints_body"][18][1]) / 2 + dst_ground_y = max(left_foot_mid, right_foot_mid) * dst_skeleton["height"] + + rescaled_src_ground_y = max(rescaled_src_skeleton_ori[10][1], rescaled_src_skeleton_ori[13][1]) + delta_ground_y = rescaled_src_ground_y - dst_ground_y + + dst_ground_x = (dst_skeleton["keypoints_body"][8][0] + dst_skeleton["keypoints_body"][11][0]) * dst_skeleton["width"] / 2 + rescaled_src_ground_x = (rescaled_src_skeleton_ori[8][0] + rescaled_src_skeleton_ori[11][0]) / 2 + delta_ground_x = rescaled_src_ground_x - dst_ground_x + delta_x, delta_y = delta_ground_x, delta_ground_y + + else: + # print('use neck mark.') + # use neck keypoint as mark + src_neck_y = rescaled_src_skeleton_ori[1][1] + dst_neck_y = dst_skeleton["keypoints_body"][1][1] + delta_neck_y = src_neck_y - dst_neck_y * dst_skeleton["height"] + + src_neck_x = rescaled_src_skeleton_ori[1][0] + dst_neck_x = dst_skeleton["keypoints_body"][1][0] + delta_neck_x = src_neck_x - dst_neck_x * dst_skeleton["width"] + delta_x, delta_y = delta_neck_x, delta_neck_y + rescaled_src_ground_x = src_neck_x + + dst_shape = (dst_skeleton_ori_w, dst_skeleton_ori_h) + output = write_to_poses(all_src_skeleton, none_idx, dst_shape, ratio_list, delta_x, delta_y, rescaled_src_ground_x, body_flag, scale_min) + return output + + +def get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tql_edit_pose_meta0, refer_edit_pose_meta): + for key, value in tpl_pose_meta0.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array([[tpl_pose_meta0["width"], tpl_pose_meta0["height"], 1.0]]) + if not isinstance(value, list): + value = value.tolist() + tpl_pose_meta0[key] = value + + for key, value in refer_pose_meta.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array([[refer_pose_meta["width"], refer_pose_meta["height"], 1.0]]) + if not isinstance(value, list): + value = value.tolist() + refer_pose_meta[key] = value + + tpl_pose_metas_new = [] + for meta in tpl_pose_metas: + for key, value in meta.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array([[meta["width"], meta["height"], 1.0]]) + if not isinstance(value, list): + value = value.tolist() + meta[key] = value + tpl_pose_metas_new.append(meta) + + if tql_edit_pose_meta0 is not None: + for key, value in tql_edit_pose_meta0.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array([[tql_edit_pose_meta0["width"], tql_edit_pose_meta0["height"], 1.0]]) + if not isinstance(value, list): + value = value.tolist() + tql_edit_pose_meta0[key] = value + + if refer_edit_pose_meta is not None: + for key, value in refer_edit_pose_meta.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array([[refer_edit_pose_meta["width"], refer_edit_pose_meta["height"], 1.0]]) + if not isinstance(value, list): + value = value.tolist() + refer_edit_pose_meta[key] = value + + retarget_tpl_pose_metas = retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas_new, tql_edit_pose_meta0, refer_edit_pose_meta) + + pose_metas = [] + for meta in retarget_tpl_pose_metas: + pose_meta = AAPoseMeta() + width, height = meta["width"], meta["height"] + pose_meta.width = width + pose_meta.height = height + pose_meta.kps_body = np.array(meta["keypoints_body"])[:, :2] * (width, height) + pose_meta.kps_body_p = np.array(meta["keypoints_body"])[:, 2] + + kps_lhand = [] + kps_lhand_p = [] + for each_kps_lhand in meta["keypoints_left_hand"]: + if each_kps_lhand is not None: + kps_lhand.append([each_kps_lhand.x, each_kps_lhand.y]) + kps_lhand_p.append(each_kps_lhand.score) + else: + kps_lhand.append([None, None]) + kps_lhand_p.append(0.0) + + pose_meta.kps_lhand = np.array(kps_lhand) + pose_meta.kps_lhand_p = np.array(kps_lhand_p) + + kps_rhand = [] + kps_rhand_p = [] + for each_kps_rhand in meta["keypoints_right_hand"]: + if each_kps_rhand is not None: + kps_rhand.append([each_kps_rhand.x, each_kps_rhand.y]) + kps_rhand_p.append(each_kps_rhand.score) + else: + kps_rhand.append([None, None]) + kps_rhand_p.append(0.0) + + pose_meta.kps_rhand = np.array(kps_rhand) + pose_meta.kps_rhand_p = np.array(kps_rhand_p) + + pose_metas.append(pose_meta) + + return pose_metas diff --git a/tools/preprocess/sam_utils.py b/tools/preprocess/sam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb2149061d53e32b24785a235e7c02cef0b6f7f --- /dev/null +++ b/tools/preprocess/sam_utils.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025. Your modifications here. +# This file wraps and extends sam2.utils.misc for custom modifications. + +import os + +import numpy as np +import torch +from PIL import Image +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf +from sam2.build_sam import _load_checkpoint +from sam2.utils.misc import * +from sam2.utils.misc import AsyncVideoFrameLoader, _load_img_as_tensor +from tqdm import tqdm + + +def _load_img_v2_as_tensor(img, image_size): + img_pil = Image.fromarray(img.astype(np.uint8)) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + frame_names=None, +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError("Only JPEG frames are supported at this moment") + if frame_names is None: + frame_names = [p for p in os.listdir(jpg_folder) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader(img_paths, image_size, offload_video_to_cpu, img_mean, img_std) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_v2( + frames, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + frame_names=None, +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + num_frames = len(frames) + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, frame in enumerate(tqdm(frames, desc="video frame")): + images[n], video_height, video_width = _load_img_v2_as_tensor(frame, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, +): + hydra_overrides = [ + "++model._target_=video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + + hydra_overrides.extend(hydra_overrides_extra) + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model diff --git a/tools/preprocess/utils.py b/tools/preprocess/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a18c244045dc15b87e388843d0a180588d18c529 --- /dev/null +++ b/tools/preprocess/utils.py @@ -0,0 +1,219 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import random + +import cv2 +import numpy as np + + +def get_mask_boxes(mask): + """ + + Args: + mask: [h, w] + Returns: + + """ + y_coords, x_coords = np.nonzero(mask) + x_min = x_coords.min() + x_max = x_coords.max() + y_min = y_coords.min() + y_max = y_coords.max() + bbox = np.array([x_min, y_min, x_max, y_max]).astype(np.int32) + return bbox + + +def get_aug_mask(body_mask, w_len=10, h_len=20): + body_bbox = get_mask_boxes(body_mask) + + bbox_wh = body_bbox[2:4] - body_bbox[0:2] + w_slice = np.int32(bbox_wh[0] / w_len) + h_slice = np.int32(bbox_wh[1] / h_len) + + for each_w in range(body_bbox[0], body_bbox[2], w_slice): + w_start = min(each_w, body_bbox[2]) + w_end = min((each_w + w_slice), body_bbox[2]) + # print(w_start, w_end) + for each_h in range(body_bbox[1], body_bbox[3], h_slice): + h_start = min(each_h, body_bbox[3]) + h_end = min((each_h + h_slice), body_bbox[3]) + if body_mask[h_start:h_end, w_start:w_end].sum() > 0: + body_mask[h_start:h_end, w_start:w_end] = 1 + + return body_mask + + +def get_mask_body_img(img_copy, hand_mask, k=7, iterations=1): + kernel = np.ones((k, k), np.uint8) + dilation = cv2.dilate(hand_mask, kernel, iterations=iterations) + mask_hand_img = img_copy * (1 - dilation[:, :, None]) + + return mask_hand_img, dilation + + +def get_face_bboxes(kp2ds, scale, image_shape, ratio_aug): + h, w = image_shape + kp2ds_face = kp2ds.copy()[23:91, :2] + + min_x, min_y = np.min(kp2ds_face, axis=0) + max_x, max_y = np.max(kp2ds_face, axis=0) + + initial_width = max_x - min_x + initial_height = max_y - min_y + + initial_area = initial_width * initial_height + + expanded_area = initial_area * scale + + new_width = np.sqrt(expanded_area * (initial_width / initial_height)) + new_height = np.sqrt(expanded_area * (initial_height / initial_width)) + + delta_width = (new_width - initial_width) / 2 + delta_height = (new_height - initial_height) / 4 + + if ratio_aug: + if random.random() > 0.5: + delta_width += random.uniform(0, initial_width // 10) + else: + delta_height += random.uniform(0, initial_height // 10) + + expanded_min_x = max(min_x - delta_width, 0) + expanded_max_x = min(max_x + delta_width, w) + expanded_min_y = max(min_y - 3 * delta_height, 0) + expanded_max_y = min(max_y + delta_height, h) + + return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)] + + +def calculate_new_size(orig_w, orig_h, target_area, divisor=64): + target_ratio = orig_w / orig_h + + def check_valid(w, h): + if w <= 0 or h <= 0: + return False + return w * h <= target_area and w % divisor == 0 and h % divisor == 0 + + def get_ratio_diff(w, h): + return abs(w / h - target_ratio) + + def round_to_64(value, round_up=False, divisor=64): + if round_up: + return divisor * ((value + (divisor - 1)) // divisor) + return divisor * (value // divisor) + + possible_sizes = [] + + max_area_h = int(np.sqrt(target_area / target_ratio)) + max_area_w = int(max_area_h * target_ratio) + + max_h = round_to_64(max_area_h, round_up=True, divisor=divisor) + max_w = round_to_64(max_area_w, round_up=True, divisor=divisor) + + for h in range(divisor, max_h + divisor, divisor): + ideal_w = h * target_ratio + + w_down = round_to_64(ideal_w) + w_up = round_to_64(ideal_w, round_up=True) + + for w in [w_down, w_up]: + if check_valid(w, h, divisor): + possible_sizes.append((w, h, get_ratio_diff(w, h))) + + if not possible_sizes: + raise ValueError("Can not find suitable size") + + possible_sizes.sort(key=lambda x: (-x[0] * x[1], x[2])) + + best_w, best_h, _ = possible_sizes[0] + return int(best_w), int(best_h) + + +def resize_by_area(image, target_area, keep_aspect_ratio=True, divisor=64, padding_color=(0, 0, 0)): + h, w = image.shape[:2] + try: + new_w, new_h = calculate_new_size(w, h, target_area, divisor) + except: # noqa + aspect_ratio = w / h + + if keep_aspect_ratio: + new_h = math.sqrt(target_area / aspect_ratio) + new_w = target_area / new_h + else: + new_w = new_h = math.sqrt(target_area) + + new_w, new_h = int((new_w // divisor) * divisor), int((new_h // divisor) * divisor) + + interpolation = cv2.INTER_AREA if (new_w * new_h < w * h) else cv2.INTER_LINEAR + + resized_image = padding_resize(image, height=new_h, width=new_w, padding_color=padding_color, interpolation=interpolation) + return resized_image + + +def padding_resize(img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR): + ori_height = img_ori.shape[0] + ori_width = img_ori.shape[1] + channel = img_ori.shape[2] + + img_pad = np.zeros((height, width, channel)) + if channel == 1: + img_pad[:, :, 0] = padding_color[0] + else: + img_pad[:, :, 0] = padding_color[0] + img_pad[:, :, 1] = padding_color[1] + img_pad[:, :, 2] = padding_color[2] + + if (ori_height / ori_width) > (height / width): + new_width = int(height / ori_height * ori_width) + img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation) + padding = int((width - new_width) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[:, padding : padding + new_width, :] = img + else: + new_height = int(width / ori_width * ori_height) + img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation) + padding = int((height - new_height) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[padding : padding + new_height, :, :] = img + + img_pad = np.uint8(img_pad) + + return img_pad + + +def get_frame_indices(frame_num, video_fps, clip_length, train_fps): + start_frame = 0 + times = np.arange(0, clip_length) / train_fps + frame_indices = start_frame + np.round(times * video_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, frame_num - 1) + + return frame_indices.tolist() + + +def get_face_bboxes(kp2ds, scale, image_shape): + h, w = image_shape + kp2ds_face = kp2ds.copy()[1:] * (w, h) + + min_x, min_y = np.min(kp2ds_face, axis=0) + max_x, max_y = np.max(kp2ds_face, axis=0) + + initial_width = max_x - min_x + initial_height = max_y - min_y + + initial_area = initial_width * initial_height + + expanded_area = initial_area * scale + + new_width = np.sqrt(expanded_area * (initial_width / initial_height)) + new_height = np.sqrt(expanded_area * (initial_height / initial_width)) + + delta_width = (new_width - initial_width) / 2 + delta_height = (new_height - initial_height) / 4 + + expanded_min_x = max(min_x - delta_width, 0) + expanded_max_x = min(max_x + delta_width, w) + expanded_min_y = max(min_y - 3 * delta_height, 0) + expanded_max_y = min(max_y + delta_height, h) + + return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)] diff --git a/tools/preprocess/video_predictor.py b/tools/preprocess/video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c4d819a1816f7ce0ce30c869373408c5d18e85 --- /dev/null +++ b/tools/preprocess/video_predictor.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025. Your modifications here. +# A wrapper for sam2 functions +from collections import OrderedDict + +import torch +from sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor +from sam_utils import load_video_frames, load_video_frames_v2 + + +class SAM2VideoPredictor(_SAM2VideoPredictor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.inference_mode() + def init_state(self, video_path, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, frame_names=None): + """Initialize a inference state.""" + images, video_height, video_width = load_video_frames( + video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, frame_names=frame_names + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @torch.inference_mode() + def init_state_v2(self, frames, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, frame_names=None): + """Initialize a inference state.""" + images, video_height, video_width = load_video_frames_v2( + frames=frames, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, frame_names=frame_names + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + inference_state["frames_tracked_per_obj"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state