Commit 6cb0d1ce authored by dengjb's avatar dengjb
Browse files

update

parents
Pipeline #3058 canceled with stages
*.mp4
*.DS_Store
weights/
__pycache__/
\ No newline at end of file
MIT License
Copyright (c) 2025 Meituan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
# LongCat-Video
## 论文
[LongCat-Video Technical Report](https://arxiv.org/abs/2510.22200)
## 模型简介
LongCat-Vedio,是一个拥有136B参数的基础适配生成模型,在多种适配生成任务重表现出色。擅长高效、高质量的生成长视频
**核心特点**
- 🌟 **多任务统一架构:** LongCat-Video(长猫视频模型)在单一视频生成框架内,实现了文本生成视频(Text-to-Video)、图像生成视频(Image-to-Video)与视频续接(Video-Continuation)三大任务的统一。该模型仅通过单个模型即可原生支持所有上述任务,且在每个独立任务中均能稳定输出优异性能。
- 🌟 **长视频生成能力:** LongCat-Video 原生基于视频续接任务进行预训练,能够生成时长可达数分钟的视频,且过程中不会出现色彩偏移(color drifting)或画质下降(quality degradation)问题。
- 🌟 **高效推理性能:** LongCat-Video 采用时空双轴 “由粗到精”(coarse-to-fine)的生成策略,可在数分钟内生成 720p、30fps 规格的视频。此外,块稀疏注意力(Block Sparse Attention)技术进一步提升了推理效率,在高分辨率生成场景下效果尤为显著。
- 🌟 **多奖励 RLHF 驱动的卓越性能:** 借助多奖励分组相对策略优化(Group Relative Policy Optimization, GRPO)技术,通过内部基准与公开基准的全面评估显示,LongCat-Video 的性能可媲美主流开源视频生成模型及最新商业解决方案。
![alt text](image.png)
## 环境依赖
| 软件 | 版本 |
| :------: | :------: |
| DTK | 25.04.2 |
| python | 3.10.12 |
| transformers | 4.57.1 |
| vllm | 0.11.0+das.opt1.alpha.8e22ded.dtk25042 |
| torch | 2.5.1+das.opt1.dtk25042 |
| triton | 3.1+das.opt1.3c5d12d.dtk25041 |
| flash_attn | 2.6.1+das.opt1.dtk2504 |
| flash_mla | 1.0.0+das.opt1.dtk25042 |
当前仅支持镜像:
- 挂载地址`-v`根据实际模型情况修改
```bash
docker run -it --shm-size 60g --network=host --name minimax_m2 --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /opt/hyhal/:/opt/hyhal/:ro -v /path/your_code_path/:/path/your_code_path/ docker pull image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.2-py3.10 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
```bash
pip install -r requirements.txt
```
## 数据集
暂无
## 训练
暂无
## 推理
### pytorch
#### 单机推理
可参考 run.sh文件
##### 文生视频
```shell
# 单GPU推理
torchrun run_demo_text_to_video.py --checkpoint_dir=./weights/LongCat-Video
# 多GPU推理
torchrun --nproc_per_node=8 run_demo_text_to_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video
```
##### 图生视频
```shell
# 单GPU推理
torchrun run_demo_image_to_video.py --checkpoint_dir=./weights/LongCat-Video
# 多GPU推理
torchrun --nproc_per_node=8 run_demo_image_to_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video
```
#### 长视频生成
```shell
# 单GPU推理
torchrun run_demo_long_video.py --checkpoint_dir=./weights/LongCat-Video
# 多GPU推理
torchrun --nproc_per_node=8 run_demo_long_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video
```
## 效果展示
<div align="center">
<video src="https://github.com/user-attachments/assets/00fa63f0-9c4e-461a-a79e-c662ad596d7d" width="2264" height="384"> </video>
</div>
### 精度
DCU与GPU精度一致,推理框架:transformers。
## 预训练权重
| 模型名称 | 权重大小 | DCU型号 | 最低卡数需求 |下载地址|
|:-----:|:----------:|:----------:|:---------------------:|:----------:|
| LongCat-Video | 13.6 B | K100AI | 8 | 🤗 [Huggingface](https://huggingface.co/meituan-longcat/LongCat-Video) |
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/longcat-video-pytorch
## 参考资料
- https://github.com/meituan-longcat/LongCat-Video
# LongCat-Video
<div align="center">
<img src="assets/longcat-video_logo.svg" width="45%" alt="LongCat-Video" />
</div>
<hr>
<div align="center" style="line-height: 1;">
<a href='https://meituan-longcat.github.io/LongCat-Video/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
<a href='https://github.com/meituan-longcat/LongCat-Video/blob/main/longcatvideo_tech_report.pdf'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
<a href='https://huggingface.co/meituan-longcat/LongCat-Video'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
</div>
<div align="center" style="line-height: 1;">
<a href='https://github.com/meituan-longcat/LongCat-Flash-Chat/blob/main/figures/wechat_official_accounts.png'><img src='https://img.shields.io/badge/WeChat-LongCat-brightgreen?logo=wechat&logoColor=white'></a>
<a href='https://x.com/Meituan_LongCat'><img src='https://img.shields.io/badge/Twitter-LongCat-white?logo=x&logoColor=white'></a>
</div>
<div align="center" style="line-height: 1;">
<a href='LICENSE'><img src='https://img.shields.io/badge/License-MIT-f5de53?&color=f5de53'></a>
</div>
## Model Introduction
We introduce LongCat-Video, a foundational video generation model with 13.6B parameters, delivering strong performance across *Text-to-Video*, *Image-to-Video*, and *Video-Continuation* generation tasks. It particularly excels in efficient and high-quality long video generation, representing our first step toward world models.
### Key Features
- 🌟 **Unified architecture for multiple tasks**: LongCat-Video unifies *Text-to-Video*, *Image-to-Video*, and *Video-Continuation* tasks within a single video generation framework. It natively supports all these tasks with a single model and consistently delivers strong performance across each individual task.
- 🌟 **Long video generation**: LongCat-Video is natively pretrained on *Video-Continuation* tasks, enabling it to produce minutes-long videos without color drifting or quality degradation.
- 🌟 **Efficient inference**: LongCat-Video generates $720p$, $30fps$ videos within minutes by employing a coarse-to-fine generation strategy along both the temporal and spatial axes. Block Sparse Attention further enhances efficiency, particularly at high resolutions
- 🌟 **Strong performance with multi-reward RLHF**: Powered by multi-reward Group Relative Policy Optimization (GRPO), comprehensive evaluations on both internal and public benchmarks demonstrate that LongCat-Video achieves performance comparable to leading open-source video generation models as well as the latest commercial solutions.
For more detail, please refer to the comprehensive [***LongCat-Video Technical Report***](https://github.com/meituan-longcat/LongCat-Video/blob/main/longcatvideo_tech_report.pdf).
## 🎥 Teaser Video
<div align="center">
<video src="https://github.com/user-attachments/assets/00fa63f0-9c4e-461a-a79e-c662ad596d7d" width="2264" height="384"> </video>
</div>
## Quick Start
### Installation
Clone the repo:
```shell
git clone --single-branch --branch main https://github.com/meituan-longcat/LongCat-Video
cd LongCat-Video
```
Install dependencies:
```shell
# create conda environment
conda create -n longcat-video python=3.10
conda activate longcat-video
# install torch (configure according to your CUDA version)
pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
# install flash-attn-2
pip install ninja
pip install psutil
pip install packaging
pip install flash_attn==2.7.4.post1
# install other requirements
pip install -r requirements.txt
```
FlashAttention-2 is enabled in the model config by default; you can also change the model config ("./weights/LongCat-Video/dit/config.json") to use FlashAttention-3 or xformers once installed.
### Model Download
| Models | Download Link |
| --- | --- |
| LongCat-Video | 🤗 [Huggingface](https://huggingface.co/meituan-longcat/LongCat-Video) |
Download models using huggingface-cli:
```shell
pip install "huggingface_hub[cli]"
huggingface-cli download meituan-longcat/LongCat-Video --local-dir ./weights/LongCat-Video
```
### Run Text-to-Video
```shell
# Single-GPU inference
torchrun run_demo_text_to_video.py --checkpoint_dir=./weights/LongCat-Video --enable_compile
# Multi-GPU inference
torchrun --nproc_per_node=2 run_demo_text_to_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video --enable_compile
```
### Run Image-to-Video
```shell
# Single-GPU inference
torchrun run_demo_image_to_video.py --checkpoint_dir=./weights/LongCat-Video --enable_compile
# Multi-GPU inference
torchrun --nproc_per_node=2 run_demo_image_to_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video --enable_compile
```
### Run Video-Continuation
```shell
# Single-GPU inference
torchrun run_demo_video_continuation.py --checkpoint_dir=./weights/LongCat-Video --enable_compile
# Multi-GPU inference
torchrun --nproc_per_node=2 run_demo_video_continuation.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video --enable_compile
```
### Run Long-Video Generation
```shell
# Single-GPU inference
torchrun run_demo_long_video.py --checkpoint_dir=./weights/LongCat-Video --enable_compile
# Multi-GPU inference
torchrun --nproc_per_node=2 run_demo_long_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video --enable_compile
```
### Run Streamlit
```shell
# Single-GPU inference
streamlit run ./run_streamlit.py --server.fileWatcherType none --server.headless=false
```
## Evaluation Results
### Text-to-Video
The *Text-to-Video* MOS evaluation results on our internal benchmark.
| **MOS score** | **Veo3** | **PixVerse-V5** | **Wan 2.2-T2V-A14B** | **LongCat-Video** |
|---------------|-------------------|--------------------|-------------|-------------|
| **Accessibility** | Proprietary | Proprietary | Open Source | Open Source |
| **Architecture** | - | - | MoE | Dense |
| **# Total Params** | - | - | 28B | 13.6B |
| **# Activated Params** | - | - | 14B | 13.6B |
| Text-Alignment↑ | 3.99 | 3.81 | 3.70 | 3.76 |
| Visual Quality↑ | 3.23 | 3.13 | 3.26 | 3.25 |
| Motion Quality↑ | 3.86 | 3.81 | 3.78 | 3.74 |
| Overall Quality↑ | 3.48 | 3.36 | 3.35 | 3.38 |
### Image-to-Video
The *Image-to-Video* MOS evaluation results on our internal benchmark.
| **MOS score** | **Seedance 1.0** | **Hailuo-02** | **Wan 2.2-I2V-A14B** | **LongCat-Video** |
|---------------|-------------------|--------------------|-------------|-------------|
| **Accessibility** | Proprietary | Proprietary | Open Source | Open Source |
| **Architecture** | - | - | MoE | Dense |
| **# Total Params** | - | - | 28B | 13.6B |
| **# Activated Params** | - | - | 14B | 13.6B |
| Image-Alignment↑ | 4.12 | 4.18 | 4.18 | 4.04 |
| Text-Alignment↑ | 3.70 | 3.85 | 3.33 | 3.49 |
| Visual Quality↑ | 3.22 | 3.18 | 3.23 | 3.27 |
| Motion Quality↑ | 3.77 | 3.80 | 3.79 | 3.59 |
| Overall Quality↑ | 3.35 | 3.27 | 3.26 | 3.17 |
## License Agreement
The **model weights** are released under the **MIT License**.
Any contributions to this repository are licensed under the MIT License, unless otherwise stated. This license does not grant any rights to use Meituan trademarks or patents.
See the [LICENSE](LICENSE) file for the full license text.
## Usage Considerations
This model has not been specifically designed or comprehensively evaluated for every possible downstream application.
Developers should take into account the known limitations of large language models, including performance variations across different languages, and carefully assess accuracy, safety, and fairness before deploying the model in sensitive or high-risk scenarios.
It is the responsibility of developers and downstream users to understand and comply with all applicable laws and regulations relevant to their use case, including but not limited to data protection, privacy, and content safety requirements.
Nothing in this Model Card should be interpreted as altering or restricting the terms of the MIT License under which the model is released.
## Citation
We kindly encourage citation of our work if you find it useful.
```
@misc{meituan2025longcatvideotechnicalreport,
title={LongCat-Video Technical Report},
author={Meituan LongCat Team},
year={2025},
eprint={xxx},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/xxx},
}
```
## Acknowledgements
We would like to thank the contributors to the [Wan](https://huggingface.co/Wan-AI), [UMT5-XXL](https://huggingface.co/google/umt5-xxl), [Diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research.
## Contact
Please contact us at <a href="mailto:longcat-team@meituan.com">longcat-team@meituan.com</a> or join our WeChat Group if you have any questions.
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
fill="none"
version="1.1"
width="594.24902"
height="100"
viewBox="0 0 594.24902 100"
id="svg7"
sodipodi:docname="longcat-video_logo.svg"
inkscape:version="1.4.2 (ebf0e940, 2025-05-08)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs7" />
<sodipodi:namedview
id="namedview7"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="1.858108"
inkscape:cx="276.08729"
inkscape:cy="6.996364"
inkscape:window-width="2304"
inkscape:window-height="1243"
inkscape:window-x="0"
inkscape:window-y="25"
inkscape:window-maximized="1"
inkscape:current-layer="g7"
inkscape:export-bgcolor="#ffffffff" />
<g
id="g7">
<g
id="g6">
<g
id="g3" />
<g
id="g4">
<path
d="m 4.62614,81.999556 c -1.321058,0 -2.279197,-1.2581 -1.9280502,-2.5317 L 19.566,18.291386 c 0.721,-2.615181 3.781,-3.769321 6.0495,-2.281758 l 22.1949,14.554028 c 1.33,0.8721 3.0502,0.8735 4.3816,0.0034 L 74.481,16.001644 c 2.2706,-1.483766 5.3282,-0.32536 6.0457,2.290412 l 16.7797,61.1784 c 0.3491,1.273 -0.6088,2.5291 -1.9288,2.5291 H 73.9544 c 3.8999,-4.5155 6.0457,-10.2829 6.0457,-16.2494 v -0.6963 c 0,-5.8361 -2.123,-11.473 -5.9728,-15.8593 l -2.7494,-13.8003 c -0.1615,-0.8108 -0.8732,-1.3947 -1.7,-1.3947 -0.3751,0 -0.74,0.1216 -1.0401,0.3466 l -10.3659,7.7745 c -0.7382,0.5536 -1.693,0.73 -2.5803,0.4764 -3.6546,-1.0441 -7.5285,-1.0441 -11.1831,0 -0.8873,0.2536 -1.8421,0.0772 -2.5803,-0.4764 l -10.3705,-7.7779 c -0.297,-0.2228 -0.6584,-0.3432 -1.0297,-0.3432 -0.8275,0 -1.5372,0.5904 -1.6877,1.404 l -2.6674,14.4205 c -3.8919,3.9576 -6.0728,9.286 -6.0728,14.8366 v 1.3191 c 0,5.8204 2.0821,11.4489 5.8699,15.8681 l 0.1301,0.1517 z"
fill-rule="evenodd"
fill="#29e154"
fill-opacity="1"
id="path3" />
</g>
<g
id="g5">
<path
d="m 39,70 h 6 V 56 h -5.090909 z m 22,0 H 55 V 56 h 5.0909 z"
fill="#000000"
fill-opacity="1"
id="path4" />
<path
d="M 37.93296,71 H 46 V 55 H 38.97192 Z M 54,71 h 8.067 L 61.0281,55 H 54 Z M 44,70 H 39 L 39.909091,56 H 45 V 70 Z M 60.9351,69 61,70 H 59.9979 56 55 V 56 h 5.0909 z"
fill-rule="evenodd"
fill="#ffffff"
fill-opacity="1"
id="path5" />
</g>
</g>
<text
xml:space="preserve"
style="font-size:58.6667px;text-align:start;writing-mode:lr-tb;direction:ltr;text-anchor:start;fill:#000000"
x="112.31583"
y="72.60601"
id="text7"><tspan
sodipodi:role="line"
id="tspan7"
x="112.31583"
y="72.60601"
style="font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;font-size:58.6667px;font-family:meituan;-inkscape-font-specification:'meituan Bold';stroke:#ffffff;stroke-opacity:1">LongCat-Video</tspan></text>
</g>
</svg>
image.png

102 KB

import os
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import math
from .common import _attn_fwd_gating, _attn_bwd_preprocess, configs_gating_preset
from .flash_attn_bsa_varlen_mask import (
_attn_fwd_bsa_varlen, _attn_fwd_bsa_varlen_align, _attn_bwd_dkdv_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_wrapper, _attn_bwd_dq_bsa_varlen_align_wrapper,
configs_fwd_bsa_varlen_preset, configs_fwd_bsa_varlen_align_preset, configs_bwd_dkdv_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_preset, configs_bwd_dq_bsa_varlen_align_preset
)
from .communicate import p2p_communicate
from ..context_parallel import context_parallel_util
torch._dynamo.config.cache_size_limit = 32
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
if HAS_TMA_DESC:
print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
else:
print("TMA benchmarks will be running without grid constant TMA descriptor.", )
# TmaAutoTuneHelper used in htyu's PR #5622
class TmaAutoTuneHelper:
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc
def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()
TMA_SIZE = 128
def __init__(self):
self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor)
self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}
# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
else:
self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)
# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)
def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]
@triton.jit
def create_mask_from_indices_kernel(
block_indices,
block_mask,
stride_bz, stride_bh, stride_bm, stride_bs,
stride_mz, stride_mh, stride_mm, stride_mn,
H,
):
i_zh, i_m, i_s = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_z, i_h = i_zh // H, i_zh % H
off_b = i_z.to(tl.int64) * stride_bz + i_h.to(tl.int64) * stride_bh + i_m.to(tl.int64) * stride_bm + i_s.to(tl.int64) * stride_bs
b_i = tl.load(block_indices + off_b)
off_m = i_z.to(tl.int64) * stride_mz + i_h.to(tl.int64) * stride_mh + i_m.to(tl.int64) * stride_mm + b_i.to(tl.int64) * stride_mn
b_m = 1
tl.store(block_mask + off_m, b_m.to(block_mask.dtype.element_ty))
def create_mask_from_indices_triton(
block_indices,
N_cols
):
B, H, N_rows, S = block_indices.shape
block_mask = torch.zeros((B, H, N_rows, N_cols), dtype=torch.bool, device=block_indices.device)
create_mask_from_indices_kernel[(B * H, N_rows, S)](
block_indices,
block_mask,
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3),
H,
)
return block_mask
@torch.compile
def create_mask_from_indices_varlen(block_indices, N_cols_mask):
B, H, M, _ = block_indices.shape
device = block_indices.device
mask = torch.zeros((B, H, M, N_cols_mask), dtype=torch.bool, device=device)
valid = block_indices < N_cols_mask
b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(block_indices)
h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(block_indices)
m_idx = torch.arange(M, device=device)[None, None, :, None].expand_as(block_indices)
valid_coords = (b_idx[valid], h_idx[valid], m_idx[valid], block_indices[valid])
mask[valid_coords] = True
return mask
@torch.compile
def create_indices_k_from_indices_q_varlen(
block_indices,
N_cols_mask # indicate the number of the last dimension of the bool mask, since this information cannot be determined by block_indices, which may contain invalid elements
):
block_mask_qk = create_mask_from_indices_varlen(block_indices, N_cols_mask)
B, H, M, N = block_mask_qk.shape
block_mask_kq = block_mask_qk.permute(0, 1, 3, 2)
indices = torch.arange(M, device=block_indices.device).view(1, 1, 1, -1).expand_as(block_mask_kq)
block_indices_k = torch.where(block_mask_kq, indices, M)
block_indices_k, _ = torch.sort(block_indices_k, dim=-1)
block_indices_k_lens = (block_indices_k < M).sum(dim=-1)
return block_indices_k, block_indices_k_lens
@torch.compile
def mean_pooling_compression(
x: torch.Tensor,
block_size: int
) -> torch.Tensor:
B, H, S = x.shape[:3]
num_block = math.ceil(S / block_size)
if S % block_size != 0:
x = F.pad(x, (0, 0, 0, num_block * block_size - S))
x_cmp = x.view(B, H, num_block, block_size, -1).mean(dim=3)
return x_cmp
@torch.compile
def cal_score(q, k):
k_transposed = k.transpose(-1, -2) # [b, h, d, s_k]
score = torch.matmul(q, k_transposed) # [b, h, s_q, s_k]
return score
def cal_score_triton(q, k):
B, H, s_q, D = q.shape
s_k = k.shape[2]
score = torch.empty(B, H, s_q, s_k, device=q.device, dtype=q.dtype)
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_gating_preset['default']
grid = lambda args: (triton.cdiv(s_q, args["BLOCK_M"]), B * H, 1)
_attn_fwd_gating[grid](
q, k, score,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
score.stride(0), score.stride(1), score.stride(2), score.stride(3),
H, s_q, s_k,
HEAD_DIM=D,
**kernel_config
)
return score
@torch.compile
def get_select_indices_topk(q, k, sparsity):
score = cal_score(q, k)
block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity)
return block_indices, block_indices_lens
@torch.compile
def get_select_indices_topk_from_score(score, sparsity):
num_selected = int((1 - sparsity) * score.shape[-1])
block_indices = torch.topk(score, num_selected)[1]
block_indices_lens = torch.full(
(block_indices.shape[0], block_indices.shape[1], block_indices.shape[2]),
num_selected,
dtype=torch.int32,
device=block_indices.device
)
return block_indices, block_indices_lens
@torch.compile
def get_select_indices_cdf(q, k, cdf_threshold):
score = cal_score(q, k)
head_dim = q.shape[-1]
block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold, 1 / head_dim**0.5)
return block_indices, block_indices_lens
@torch.compile
def get_select_indices_cdf_from_score(score, cdf_threshold, sm_scale):
weights = torch.softmax(score * sm_scale, dim=-1)
B, H, Sq, Sk = weights.shape
cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1)
weights_sorted = torch.sort(weights, dim=-1, descending=True)
cdf = torch.cumsum(weights_sorted.values, dim=-1)
num_selected = torch.searchsorted(cdf, cdf_threshold, right=True)
return weights_sorted.indices, num_selected.squeeze(-1)
@torch.compile
def get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold):
score = cal_score(q, k)
head_dim = q.shape[-1]
block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, 1 / head_dim**0.5)
return block_indices, block_indices_lens
@torch.compile
def get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold, sm_scale):
weights = torch.softmax(score * sm_scale, dim=-1)
B, H, Sq, Sk = weights.shape
cdf_threshold = torch.full((H,), cdf_threshold, device=weights.device).view(1, H, 1, 1).expand(B, -1, Sq, -1)
weights_sorted = torch.sort(weights, dim=-1, descending=True)
cdf = torch.cumsum(weights_sorted.values, dim=-1)
num_selected = torch.searchsorted(cdf, cdf_threshold, right=True)
# max(cdf, topk)
num_selected_topk = int((1 - sparsity) * score.shape[-1])
num_selected[num_selected < num_selected_topk] = num_selected_topk
return weights_sorted.indices, num_selected.squeeze(-1)
def get_select_indices(q, k, sparsity, cdf_threshold):
if sparsity is not None and cdf_threshold is None:
block_indices, block_indices_lens = get_select_indices_topk(q, k, sparsity)
elif sparsity is None and cdf_threshold is not None:
block_indices, block_indices_lens = get_select_indices_cdf(q, k, cdf_threshold)
elif sparsity is not None and cdf_threshold is not None:
block_indices, block_indices_lens = get_select_indices_cdf_topk(q, k, sparsity, cdf_threshold)
else:
raise ValueError
return block_indices, block_indices_lens
def get_select_indices_from_score(score, sparsity, cdf_threshold):
if sparsity is not None and cdf_threshold is None:
block_indices, block_indices_lens = get_select_indices_topk_from_score(score, sparsity)
elif sparsity is None and cdf_threshold is not None:
block_indices, block_indices_lens = get_select_indices_cdf_from_score(score, cdf_threshold)
elif sparsity is not None and cdf_threshold is not None:
block_indices, block_indices_lens = get_select_indices_cdf_topk_from_score(score, sparsity, cdf_threshold)
else:
raise ValueError
return block_indices, block_indices_lens
def attn_fwd_bsa_varlen_triton(
q,
k,
v,
sm_scale,
block_indices,
block_indices_lens,
chunk_size_q,
chunk_size_k,
sparsity
):
B, H, Seq, D = q.shape
o = torch.empty_like(q)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
config_key = 'BLOCK_N_LG=64' if chunk_size_k == 64 else 'default'
if chunk_size_k > 128:
fwd_func = _attn_fwd_bsa_varlen
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_preset[config_key]
else:
fwd_func = _attn_fwd_bsa_varlen_align
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_fwd_bsa_varlen_align_preset[config_key]
block_indices = block_indices.contiguous()
block_indices_lens = block_indices_lens.contiguous()
fwd_func[grid](
q, k, v, sm_scale, M, o,
block_indices, # [B, H, M_COMPRESS, S]
block_indices_lens, # [B, H, M_COMPRESS, S_MAX]
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2),
H, Seq,
D,
BLOCK_M=chunk_size_q,
BLOCK_N_LG=chunk_size_k,
SPARSITY=sparsity,
**kernel_config
)
LN2 = 0.6931471824645996
lse = M * LN2 # convert back to natural units (M is of base 2)
return o, lse
def attn_bwd_bsa_varlen_triton(
do,
q,
k,
v,
o,
dq,
dk,
dv,
sm_scale,
M,
block_indices,
block_indices_lens,
chunk_size_q,
chunk_size_k,
sparsity
):
RCP_LN2 = 1.4426950408889634
M = M * RCP_LN2 # ln -> log2
do = do.contiguous()
# assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
BATCH, N_HEAD, N_CTX, HEAD_DIM = q.shape
N_CTX_KV = k.shape[-2]
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) # reciprocal
arg_k = k
arg_k = arg_k * (sm_scale * RCP_LN2)
if min(chunk_size_q, chunk_size_k) >= 128:
PRE_BLOCK = 128
else:
PRE_BLOCK = min(chunk_size_q, chunk_size_k)
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do,
delta,
N_CTX,
BLOCK_M=PRE_BLOCK,
HEAD_DIM=HEAD_DIM
)
block_indices_k, block_indices_k_lens = create_indices_k_from_indices_q_varlen(
block_indices=block_indices,
N_cols_mask=N_CTX_KV // chunk_size_k
)
block_indices = block_indices.contiguous()
block_indices_lens = block_indices_lens.contiguous()
block_indices_k = block_indices_k.contiguous()
block_indices_k_lens = block_indices_k_lens.contiguous()
config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default'
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dkdv_bsa_varlen_preset[config_key]
grid_dkdv = lambda args: (triton.cdiv(arg_k.shape[2], args["BLOCK_N"]), 1, arg_k.shape[0] * arg_k.shape[1])
_attn_bwd_dkdv_bsa_varlen_wrapper[grid_dkdv](
q, arg_k, v, sm_scale, # softmax scale
do,
dk, dv,
M, # lse (log2)
delta,
block_indices_k,
block_indices_k_lens,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
M.stride(0), M.stride(1), M.stride(2),
delta.stride(0), delta.stride(1), delta.stride(2),
block_indices_k.stride(0), block_indices_k.stride(1), block_indices_k.stride(2), block_indices_k.stride(3),
block_indices_k_lens.stride(0), block_indices_k_lens.stride(1), block_indices_k_lens.stride(2),
N_HEAD, N_CTX,
BLOCK_M=chunk_size_q,
BLOCK_N_DQ_LG=chunk_size_k,
HEAD_DIM=HEAD_DIM,
SPARSITY=sparsity,
**kernel_config
)
config_key = 'BLOCK_N_DQ_LG=64' if chunk_size_k == 64 else 'default'
if chunk_size_k > 128:
bwd_dq_func = _attn_bwd_dq_bsa_varlen_wrapper
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_preset[config_key]
else:
bwd_dq_func = _attn_bwd_dq_bsa_varlen_align_wrapper
kernel_config = {} if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1' else configs_bwd_dq_bsa_varlen_align_preset[config_key]
grid_dq = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), 1, q.shape[0] * q.shape[1])
bwd_dq_func[grid_dq](
q, arg_k, v,
do,
dq,
M, # lse (log2)
delta,
block_indices,
block_indices_lens,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
M.stride(0), M.stride(1), M.stride(2),
delta.stride(0), delta.stride(1), delta.stride(2),
block_indices.stride(0), block_indices.stride(1), block_indices.stride(2), block_indices.stride(3),
block_indices_lens.stride(0), block_indices_lens.stride(1), block_indices_lens.stride(2),
N_HEAD, N_CTX,
BLOCK_M=chunk_size_q,
BLOCK_N_DQ_LG=chunk_size_k,
HEAD_DIM=HEAD_DIM,
SPARSITY=sparsity,
**kernel_config
)
@torch.compile
def make_block_indices_varlen_cp_list(block_indices, cp_size, num_blocks_k_full):
"""
Args:
block_indices: [B, H, num_blocks_q_per_cp_rank, num_blocks_k_full]
Return:
a list of [block_indices, block_indices_lens] for k from each cp_rank
- each block_indices starts from zero
- block_indices_lens indicates the valid number of elements in the last dimension of block_indices
"""
res = []
num_blocks_per_rank = num_blocks_k_full // cp_size
for i in range(cp_size):
block_indices_tmp = block_indices.clone()
min_block_idx = i * num_blocks_per_rank
block_indices_tmp -= min_block_idx
block_indices_tmp[block_indices_tmp < 0] = num_blocks_per_rank # block_indices_tmp < 0 indicate invalid indices, set them to num_blocks_per_rank in order to sort them to the tail, so that the first N elements of the block_indices indicated by block_indices_lens are valid
block_indices_tmp, _ = torch.sort(block_indices_tmp, dim=-1)
block_indices_tmp_lens = (block_indices_tmp < num_blocks_per_rank).sum(dim=-1)
res.append([block_indices_tmp, block_indices_tmp_lens])
return res
@torch.compile
def flash_attn_fwd_softmax_lse_correction(
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge softmax stats of each step in Attention with context parallelism"""
max_scale = torch.max(softmax_lse, softmax_lse_per_step)
min_scale = torch.min(softmax_lse, softmax_lse_per_step)
lse_diff = min_scale - max_scale
lse_diff = lse_diff.nan_to_num(nan=0.) # handle cases: tensor(-inf) - tensor(-inf) = tensor(nan); In the current cp implementation, it is possible that lses of 2 cp ranks are both -inf, if no block is selected from both cp ranks. In such cases, the finally corrected lse should remain -inf.
new_scale = max_scale + torch.log1p(torch.exp(lse_diff)) # a + ln(1 + e^(b - a)) = ln(e^a) + ln(1 + e^(b - a)) = ln(e^a + e^b)
softmax_lse.copy_(new_scale)
@torch.compile
def flash_attn_fwd_out_correction_init(
out_init_step: torch.Tensor, # b h s d
softmax_lse: torch.Tensor, # b h s
softmax_lse_init_step: torch.Tensor,
):
"""Merge partial outputs of the first step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_init_step * softmax_lse_corrected_exp
return out_corrected.to(out_init_step.dtype)
@torch.compile
def flash_attn_fwd_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected)
@torch.compile
def topk_sort(score, num_chunks_selected):
block_indices = torch.topk(score, num_chunks_selected)[1]
block_indices, _ = torch.sort(block_indices, dim=-1)
return block_indices
class _attention_bsa(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, sm_scale, use_tma=False):
# shape constraints
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
# ---------------------- gating ----------------------
q_cmp = mean_pooling_compression(q, chunk_size_q)
k_cmp = mean_pooling_compression(k, chunk_size_k)
block_indices, block_indices_lens = get_select_indices(q_cmp, k_cmp, sparsity, cdf_threshold)
# ---------------------- bsa ----------------------
o, lse = attn_fwd_bsa_varlen_triton(
q, k, v,
sm_scale, block_indices, block_indices_lens,
chunk_size_q, chunk_size_k,
sparsity
)
ctx.save_for_backward(q, k, v, o, lse, block_indices, block_indices_lens)
ctx.sm_scale = sm_scale
ctx.HEAD_DIM = HEAD_DIM_K
ctx.chunk_size_q = chunk_size_q
ctx.chunk_size_k = chunk_size_k
ctx.use_tma = use_tma
ctx.sparsity = sparsity
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse, block_indices, block_indices_lens = ctx.saved_tensors
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
attn_bwd_bsa_varlen_triton(
do,
q,
k,
v,
o,
dq,
dk,
dv,
ctx.sm_scale,
lse,
block_indices,
block_indices_lens,
ctx.chunk_size_q,
ctx.chunk_size_k,
ctx.sparsity
)
return dq, dk, dv, None, None, None, None, None, None
flash_attn_bsa = _attention_bsa.apply
def rearrange_THW_to_3d_block(x, Nt, Nh, Nw, t, h, w, D):
B, H, _, D = x.shape
x = x.view(B, H, Nt, t, Nh, h, Nw, w, D)
x = x.permute(0, 1, 2, 4, 6, 3, 5, 7, 8) # B H Nt Nh Nw t h w D
return x.contiguous().view(B, H, Nt * Nh * Nw * t * h * w, D)
def rearrange_3d_block_to_THW(x, Nt, Nh, Nw, t, h, w, D):
B, H, _, D = x.shape
x = x.view(B, H, Nt, Nh, Nw, t, h, w, D)
x = x.permute(0, 1, 2, 5, 3, 6, 4, 7, 8) # B H Nt t Nh h Nw w D
return x.contiguous().view(B, H, Nt * t * Nh * h * Nw * w, D)
def flash_attn_bsa_3d(
q: torch.Tensor, # [B, H, Sq, D]
k: torch.Tensor, # [B, H, Skv, D]
v: torch.Tensor, # [B, H, Skv, D]
latent_shape_q,
latent_shape_k,
# bsa_params
sparsity=0.875,
cdf_threshold=None,
chunk_3d_shape_q=[4, 4, 8],
chunk_3d_shape_k=[4, 4, 8],
) -> torch.Tensor:
_, _, Sq, head_dim_q = q.shape
_, _, Sk, head_dim_k = k.shape
assert head_dim_q == head_dim_k
head_dim = head_dim_q
Tq, Hq, Wq = latent_shape_q
Tk, Hk, Wk = latent_shape_k
assert Tq * Hq * Wq == Sq
assert Tk * Hk * Wk == Sk
tq, hq, wq = chunk_3d_shape_q
tk, hk, wk = chunk_3d_shape_k
assert Tq % tq == 0 and Hq % hq == 0 and Wq % wq == 0
assert Tk % tk == 0 and Hk % hk == 0 and Wk % wk == 0
Ntq = Tq // tq
Nhq = Hq // hq
Nwq = Wq // wq
Ntk = Tk // tk
Nhk = Hk // hk
Nwk = Wk // wk
q = rearrange_THW_to_3d_block(q, Ntq, Nhq, Nwq, tq, hq, wq, q.shape[-1])
k = rearrange_THW_to_3d_block(k, Ntk, Nhk, Nwk, tk, hk, wk, k.shape[-1])
v = rearrange_THW_to_3d_block(v, Ntk, Nhk, Nwk, tk, hk, wk, v.shape[-1])
chunk_size_q = tq * hq * wq
chunk_size_k = tk * hk * wk
output = flash_attn_bsa(q, k, v, chunk_size_q, chunk_size_k, sparsity, cdf_threshold, 1 / head_dim**0.5)
output = rearrange_3d_block_to_THW(output, Ntq, Nhq, Nwq, tq, hq, wq, output.shape[-1])
return output
\ No newline at end of file
import triton
import triton.language as tl
import os
if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1':
autotune = triton.autotune
else:
def autotune(*args, **kwargs):
def decorator(func):
return func
return decorator
configs_gating_preset = {
'default': {
'BLOCK_M': 64,
'BLOCK_N': 64,
'num_stages': 3,
'num_warps': 8,
}
}
configs_gating = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BM in [64, 128] \
for BN in [32, 64] \
for s in [2, 3, 4, 5] \
for w in [4, 8] \
]
gating_reevaluate_keys = ["M", "N"] if os.environ.get('TRITON_REEVALUATE_KEY', '0') == '1' else []
@autotune(configs_gating, key=gating_reevaluate_keys)
@triton.jit
def _attn_fwd_gating(
Q, K, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_oz, stride_oh, stride_om, stride_on,
H, M, N,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(M, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(HEAD_DIM, N),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(M, N),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr, boundary_check=(0,))
for start_n in range(0, N, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr, boundary_check=(1,))
qk = tl.dot(q, k)
tl.store(O_block_ptr, qk.to(Out.type.element_ty), boundary_check=(0, 1))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
O_block_ptr = tl.advance(O_block_ptr, (0, BLOCK_N))
@triton.jit
def _attn_bwd_preprocess(
O, DO,
Delta, # output
N_CTX,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, HEAD_DIM)
# load
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_hz * N_CTX + off_m, delta)
import torch
def p2p_communicate(
rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
"""Point-to-point communications of KV and dKV in Attention with context parallelism"""
send_recv_ops = []
if batch_p2p_comm: # int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) 为啥呢
if rank % 2 == 0:
send_op = torch.distributed.P2POp(
torch.distributed.isend, send_tensor, send_dst, cp_group
)
recv_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_tensor, recv_src, cp_group
)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_tensor, recv_src, cp_group
)
send_op = torch.distributed.P2POp(
torch.distributed.isend, send_tensor, send_dst, cp_group
)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
else:
if rank % 2 == 0:
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = send_recv_ops
return send_recv_reqs
import triton
import triton.language as tl
import os
from .common import autotune
"""
TRITON_REEVALUATE_KEY=1
- autotune whenever params in reevaluate keys change
- use in benchmark script to fine the best config
TRITON_AUTOTUNE_ENBALE=1
- if set to 0, autotune will not work, and the related params must be passed to the function call.
"""
configs_fwd_bsa_varlen_preset = {
'default': {
'BLOCK_N': 64,
'num_stages': 3,
'num_warps': 8,
},
'BLOCK_N_LG=64': {
'BLOCK_N': 64,
'num_stages': 3,
'num_warps': 4,
},
}
configs_fwd_bsa_varlen = [
triton.Config({'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BN in [32, 64, 128] \
for s in [2, 3, 4, 5] \
for w in [4, 8] \
]
fwd_bsa_reevaluate_varlen_keys = ['N_CTX', 'BLOCK_M', 'BLOCK_N_LG', 'SPARSITY'] if os.environ.get('TRITON_REEVALUATE_KEY', '0') == '1' else []
@autotune(list(configs_fwd_bsa_varlen), key=fwd_bsa_reevaluate_varlen_keys)
@triton.jit
def _attn_fwd_bsa_varlen(
Q, K, V, sm_scale, M, Out,
block_indices, # [B, H, M_COMPRESS, S_MAX]
block_indices_lens, # [B, H, M_COMPRESS]
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_ok,
stride_bz, stride_bh, stride_bm, stride_bs,
stride_lz, stride_lh, stride_lm,
H, N_CTX,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N_LG: tl.constexpr,
BLOCK_N: tl.constexpr,
SPARSITY: tl.constexpr, # not used; just for trigger reevaluate for benchmarking
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
v_offset = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
b_offset = off_z.to(tl.int64) * stride_bz + off_h.to(tl.int64) * stride_bh
l_offset = off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
KT_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_ok),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
block_indices += b_offset + start_m * stride_bm
block_indices_lens += l_offset + start_m * stride_lm
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/ln2; exp2(x/ln2) == exp2(ln(e^x) / ln2) == exp2(log2(e^x)) == exp(x)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
S = tl.load(block_indices_lens)
for i in range(S):
block_id = tl.load(block_indices + i * stride_bs).to(tl.int32)
lo, hi = block_id * BLOCK_N_LG, (block_id + 1) * BLOCK_N_LG
lo = tl.multiple_of(lo, BLOCK_N)
KT_block_ptr_i = tl.advance(KT_block_ptr, (0, lo))
V_block_ptr_i = tl.advance(V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
kT = tl.load(KT_block_ptr_i)
qkT = tl.dot(q, kT)
m_ij = tl.maximum(m_i, tl.max(qkT, 1) * qk_scale)
qkT = qkT * qk_scale - m_ij[:, None]
p = tl.math.exp2(qkT)
# -- update m_i and l_i
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
# -- update output accumulator --
acc = acc * alpha[:, None]
# update acc
v = tl.load(V_block_ptr_i)
acc = tl.dot(p.to(v.dtype), v, acc)
# update m_i and l_i
# place this at the end of the loop to reduce register pressure: https://github.com/triton-lang/triton/commit/ee6abd9
l_i = l_i * alpha + l_ij
m_i = m_ij
V_block_ptr_i = tl.advance(V_block_ptr_i, (BLOCK_N, 0))
KT_block_ptr_i = tl.advance(KT_block_ptr_i, (0, BLOCK_N))
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
configs_fwd_bsa_varlen_align_preset = {
'default': {
'num_stages': 3,
'num_warps': 8,
},
'BLOCK_N_LG=64': {
'num_stages': 3,
'num_warps': 4,
},
}
configs_fwd_bsa_varlen_align = [
triton.Config({}, num_stages=s, num_warps=w) \
for s in [2, 3, 4, 5] \
for w in [4, 8] \
]
fwd_bsa_reevaluate_varlen_align_keys = ['N_CTX', 'BLOCK_M', 'BLOCK_N_LG', 'SPARSITY'] if os.environ.get('TRITON_REEVALUATE_KEY', '0') == '1' else []
@autotune(list(configs_fwd_bsa_varlen_align), key=fwd_bsa_reevaluate_varlen_align_keys)
@triton.jit
def _attn_fwd_bsa_varlen_align(
Q, K, V, sm_scale, M, Out,
block_indices, # [B, H, M_COMPRESS, S_MAX]
block_indices_lens, # [B, H, M_COMPRESS]
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_on,
stride_bz, stride_bh, stride_bm, stride_bs,
stride_lz, stride_lh, stride_lm,
H, N_CTX,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N_LG: tl.constexpr,
SPARSITY: tl.constexpr, # not used; just for trigger reevaluate for benchmarking
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
v_offset = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
b_offset = off_z.to(tl.int64) * stride_bz + off_h.to(tl.int64) * stride_bh
l_offset = off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_N_LG, HEAD_DIM),
order=(1, 0),
)
KT_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N_LG),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
block_indices += b_offset + start_m * stride_bm
block_indices_lens += l_offset + start_m * stride_lm
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/ln2; exp2(x/ln2) == exp2(ln(e^x) / ln2) == exp2(log2(e^x)) == exp(x);乘1/ln2后,exp2(x/ln2) == exp(x),exp2速度更快
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
S = tl.load(block_indices_lens)
for i in range(S):
block_id = tl.load(block_indices + i * stride_bs).to(tl.int32)
lo = block_id * BLOCK_N_LG
lo = tl.multiple_of(lo, BLOCK_N_LG)
KT_block_ptr_i = tl.advance(KT_block_ptr, (0, lo))
V_block_ptr_i = tl.advance(V_block_ptr, (lo, 0))
# -- compute qk ----
kT = tl.load(KT_block_ptr_i)
qkT = tl.dot(q, kT)
m_ij = tl.maximum(m_i, tl.max(qkT, 1) * qk_scale)
qkT = qkT * qk_scale - m_ij[:, None]
p = tl.math.exp2(qkT)
# -- update m_i and l_i
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
# -- update output accumulator --
acc = acc * alpha[:, None]
# update acc
v = tl.load(V_block_ptr_i)
acc = tl.dot(p.to(v.dtype), v, acc) # 没除se,fa2引入的优化
# update m_i and l_i
# place this at the end of the loop to reduce register pressure: https://github.com/triton-lang/triton/commit/ee6abd9
l_i = l_i * alpha + l_ij # 当前总se
m_i = m_ij
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
# The main inner-loop logic for computing dK and dV.
@triton.jit
def _attn_bwd_dkdv_bsa_varlen(
dk, dv,
k, v,
Q, DO,
M, D,
block_indices,
block_indices_lens,
# shared by Q/K/V/DO.
# stride_tok, stride_d,
stride_qm, stride_qk,
stride_dom, stride_dok,
stride_mm,
stride_dm,
stride_bm,
N_CTX,
BLOCK_M1: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
QT_block_ptr = tl.make_block_ptr(
base=Q,
shape=(HEAD_DIM, N_CTX),
strides=(stride_qk, stride_qm),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_M1),
order=(0, 1),
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(N_CTX, HEAD_DIM),
strides=(stride_dom, stride_dok),
offsets=(0, 0),
block_shape=(BLOCK_M1, HEAD_DIM),
order=(1, 0),
)
S = tl.load(block_indices_lens)
for i in range(S):
block_id = tl.load(block_indices + i * stride_bm).to(tl.int32)
start_m = block_id * BLOCK_M1
start_m = tl.multiple_of(start_m, BLOCK_M1)
QT_block_ptr_i = tl.advance(QT_block_ptr, (0, start_m))
DO_block_ptr_i = tl.advance(DO_block_ptr, (start_m, 0))
qT = tl.load(QT_block_ptr_i)
# Load m before computing qk to reduce pipeline stall.
offs_m = start_m + tl.arange(0, BLOCK_M1) * stride_mm
m = tl.load(M + offs_m)
kqT = tl.dot(k, qT)
pT = tl.math.exp2(kqT - m[None, :])
do = tl.load(DO_block_ptr_i)
# Compute dV.
ppT = pT
ppT = ppT.to(v.dtype)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
offs_d = start_m + tl.arange(0, BLOCK_M1) * stride_dm
Di = tl.load(D + offs_d)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(v.dtype)
dk += tl.dot(dsT, tl.trans(qT))
return dk, dv
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq_bsa_varlen(
dq,
q, do,
m, d,
K, V,
N_CTX,
BLOCK_N2: tl.constexpr,
BLOCK_N_LG: tl.constexpr,
HEAD_DIM: tl.constexpr,
block_indices,
block_indices_lens,
stride_bn,
# stride_tok, stride_d,
stride_kn, stride_kk,
stride_vn, stride_vk,
):
VT_block_ptr = tl.make_block_ptr(
base=V,
shape=(HEAD_DIM, N_CTX),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N2),
order=(0, 1),
)
KT_block_ptr = tl.make_block_ptr(
base=K,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N2),
order=(0, 1),
)
S = tl.load(block_indices_lens)
for i in range(S):
block_id = tl.load(block_indices + i * stride_bn).to(tl.int32)
lo, hi = block_id * BLOCK_N_LG, (block_id + 1) * BLOCK_N_LG
lo = tl.multiple_of(lo, BLOCK_N2)
KT_block_ptr_i = tl.advance(KT_block_ptr, (0, lo))
VT_block_ptr_i = tl.advance(VT_block_ptr, (0, lo))
for start_n in range(lo, hi, BLOCK_N2):
start_n = tl.multiple_of(start_n, BLOCK_N2)
kT = tl.load(KT_block_ptr_i)
vT = tl.load(VT_block_ptr_i)
qkT = tl.dot(q, kT)
p = tl.math.exp2(qkT - m)
# Compute dP and dS.
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - d)
ds = ds.to(kT.dtype) # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py: Converting ds to q.dtype here reduces register pressure and makes it much faster for BLOCK_HEADDIM=128
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
KT_block_ptr_i = tl.advance(KT_block_ptr_i, (0, BLOCK_N2))
VT_block_ptr_i = tl.advance(VT_block_ptr_i, (0, BLOCK_N2))
return dq
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq_bsa_varlen_align(
dq,
q, do,
m, d,
K, V,
N_CTX,
BLOCK_N_LG: tl.constexpr,
HEAD_DIM: tl.constexpr,
block_indices,
block_indices_lens,
stride_bn,
stride_kn, stride_kk,
stride_vn, stride_vk,
):
VT_block_ptr = tl.make_block_ptr(
base=V,
shape=(HEAD_DIM, N_CTX),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N_LG),
order=(0, 1),
)
KT_block_ptr = tl.make_block_ptr(
base=K,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N_LG),
order=(0, 1),
)
S = tl.load(block_indices_lens)
for i in range(S):
block_id = tl.load(block_indices + i * stride_bn).to(tl.int32)
lo = block_id * BLOCK_N_LG
lo = tl.multiple_of(lo, BLOCK_N_LG)
KT_block_ptr_i = tl.advance(KT_block_ptr, (0, lo))
VT_block_ptr_i = tl.advance(VT_block_ptr, (0, lo))
kT = tl.load(KT_block_ptr_i)
vT = tl.load(VT_block_ptr_i)
qkT = tl.dot(q, kT)
p = tl.math.exp2(qkT - m)
# Compute dP and dS.
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - d)
ds = ds.to(kT.dtype) # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py: Converting ds to q.dtype here reduces register pressure and makes it much faster for BLOCK_HEADDIM=128
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
return dq
configs_bwd_dkdv_bsa_varlen_preset = {
'default': {
'BLOCK_N': 128,
'num_stages': 2,
'num_warps': 8,
},
'BLOCK_N_DQ_LG=64': {
'BLOCK_N': 64,
'num_stages': 2,
'num_warps': 4,
}
}
configs_bwd_dkdv_bsa_varlen = [
triton.Config({'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BN in [32, 64, 128] \
for s in [2, 3, 4, 5] \
for w in [4, 8] \
]
bwd_dkdv_bsa_varlen_reevaluate_keys = ['N_CTX', 'BLOCK_M', 'BLOCK_N_DQ_LG', 'SPARSITY'] if os.environ.get('TRITON_REEVALUATE_KEY', '0') == '1' else []
@autotune(list(configs_bwd_dkdv_bsa_varlen), key=bwd_dkdv_bsa_varlen_reevaluate_keys)
@triton.jit
def _attn_bwd_dkdv_bsa_varlen_wrapper(
Q, K, V, sm_scale, # softmax scale
DO,
DK, DV,
M, # lse (log2)
D,
block_indices,
block_indices_lens,
# stride_z, stride_h, stride_tok, stride_d, # shared by Q/K/V/DO.
# qkv
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
# dk dv do
stride_dkz, stride_dkh, stride_dkn, stride_dkk,
stride_dvz, stride_dvh, stride_dvn, stride_dvk,
stride_doz, stride_doh, stride_dom, stride_dok,
# m, d
stride_mz, stride_mh, stride_mm,
stride_dz, stride_dh, stride_dm,
#
stride_bz, stride_bh, stride_bn, stride_bm, # block_indices
stride_lz, stride_lh, stride_ln, # block_indices_lens
#
H, N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_N_DQ_LG: tl.constexpr, # logical block size
HEAD_DIM: tl.constexpr,
SPARSITY: tl.constexpr, # not used; just for trigger reevaluate for benchmarking
):
tl.static_assert(BLOCK_N_DQ_LG % BLOCK_N == 0)
start_n = tl.program_id(0)
off_hz = tl.program_id(2)
off_z = off_hz // H
off_h = off_hz % H
off_q = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
off_k = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
off_v = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh
off_dk = off_z.to(tl.int64) * stride_dkz + off_h.to(tl.int64) * stride_dkh
off_dv = off_z.to(tl.int64) * stride_dvz + off_h.to(tl.int64) * stride_dvh
off_do = off_z.to(tl.int64) * stride_doz + off_h.to(tl.int64) * stride_doh
off_m = off_z.to(tl.int64) * stride_mz + off_h.to(tl.int64) * stride_mh
off_d = off_z.to(tl.int64) * stride_dz + off_h.to(tl.int64) * stride_dh
off_block_incides = off_z.to(tl.int64) * stride_bz + off_h.to(tl.int64) * stride_bh
off_block_incides_lens = off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
# offset pointers for batch/head
Q += off_q
K += off_k
V += off_v
DO += off_do
DK += off_dk
DV += off_dv
M += off_m
D += off_d
block_indices += off_block_incides
block_indices_lens += off_block_incides_lens
# ---------------------------------------- [DKDV] ----------------------------------------
dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
# load K and V: they stay in SRAM throughout the inner loop.
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(N_CTX, HEAD_DIM),
strides=(stride_kn, stride_kk),
offsets=(start_n * BLOCK_N, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(start_n * BLOCK_N, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
DK_block_ptr = tl.make_block_ptr(
base=DK,
shape=(N_CTX, HEAD_DIM),
strides=(stride_dkn, stride_dkk),
offsets=(start_n * BLOCK_N, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
DV_block_ptr = tl.make_block_ptr(
base=DV,
shape=(N_CTX, HEAD_DIM),
strides=(stride_dvn, stride_dvk),
offsets=(start_n * BLOCK_N, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
k_compress_idx = start_n * BLOCK_N // BLOCK_N_DQ_LG
block_indices_i = block_indices + k_compress_idx * stride_bn
block_indices_lens_i = block_indices_lens + k_compress_idx * stride_ln
dk, dv = _attn_bwd_dkdv_bsa_varlen(
dk, dv,
k, v,
Q, DO,
M, D,
block_indices_i,
block_indices_lens_i,
# shared by Q/K/V/DO.
stride_qm, stride_qk,
stride_dom, stride_dok,
stride_mm,
stride_dm,
#
stride_bm,
N_CTX,
BLOCK_M,
HEAD_DIM,
)
# Write back dk
dk *= sm_scale # S = scale * QKT; dK = scale * QdST
tl.store(DK_block_ptr, dk.to(k.dtype))
# Write back dv
tl.store(DV_block_ptr, dv.to(v.dtype))
configs_bwd_dq_bsa_varlen_preset = {
'default': {
'BLOCK_N_DQ': 64,
'num_stages': 2,
'num_warps': 8,
},
'BLOCK_N_DQ_LG=64': {
'BLOCK_N_DQ': 64,
'num_stages': 2,
'num_warps': 4,
},
}
configs_bwd_dq_bsa_varlen = [
triton.Config({'BLOCK_N_DQ': BN}, num_stages=s, num_warps=w) \
for BN in [32, 64, 128] \
for s in [2, 3, 4, 5] \
for w in [4, 8] \
]
bwd_dq_bsa_varlen_reevaluate_keys = ['N_CTX', 'BLOCK_M', 'BLOCK_N_DQ_LG', 'SPARSITY'] if os.environ.get('TRITON_REEVALUATE_KEY', '0') == '1' else []
@autotune(list(configs_bwd_dq_bsa_varlen), key=bwd_dq_bsa_varlen_reevaluate_keys)
@triton.jit
def _attn_bwd_dq_bsa_varlen_wrapper(
Q, K, V, # softmax scale
DO,
DQ,
M, # lse (log2)
D,
block_indices,
block_indices_lens,
# stride_z, stride_h, stride_tok, stride_d, # shared by Q/K/V/DO.
# qkv
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
# dq do
stride_dqz, stride_dqh, stride_dqm, stride_dqk,
stride_doz, stride_doh, stride_dom, stride_dok,
# m, d
stride_mz, stride_mh, stride_mm,
stride_dz, stride_dh, stride_dm,
#
stride_bz, stride_bh, stride_bm, stride_bn, # block_indices
stride_lz, stride_lh, stride_lm, # block_indices_lens
#
H, N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_N_DQ_LG: tl.constexpr, # logical block size
BLOCK_N_DQ: tl.constexpr,
HEAD_DIM: tl.constexpr,
SPARSITY: tl.constexpr, # not used; just for trigger reevaluate for benchmarking
):
tl.static_assert(BLOCK_N_DQ_LG % BLOCK_N_DQ == 0)
tl.static_assert(BLOCK_N_DQ_LG % BLOCK_M == 0)
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
start_m = tl.program_id(0)
off_hz = tl.program_id(2)
off_z = off_hz // H
off_h = off_hz % H
off_q = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
off_k = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
off_v = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh
off_dq = off_z.to(tl.int64) * stride_dqz + off_h.to(tl.int64) * stride_dqh
off_do = off_z.to(tl.int64) * stride_doz + off_h.to(tl.int64) * stride_doh
off_m = off_z.to(tl.int64) * stride_mz + off_h.to(tl.int64) * stride_mh
off_d = off_z.to(tl.int64) * stride_dz + off_h.to(tl.int64) * stride_dh
off_block_incides = off_z.to(tl.int64) * stride_bz + off_h.to(tl.int64) * stride_bh
off_block_incides_lens = off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
# offset pointers for batch/head
Q += off_q
K += off_k
V += off_v
DO += off_do
DQ += off_dq
M += off_m
D += off_d
block_indices += off_block_incides
block_indices_lens += off_block_incides_lens
# ---------------------------------------- [DQ] ----------------------------------------
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(N_CTX, HEAD_DIM),
strides=(stride_dom, stride_dok),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(N_CTX, HEAD_DIM),
strides=(stride_dqm, stride_dqk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
q = tl.load(Q_block_ptr)
do = tl.load(DO_block_ptr)
dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
start_m = start_m * BLOCK_M
offs_m = start_m + tl.arange(0, BLOCK_M) * stride_mm
offs_d = start_m + tl.arange(0, BLOCK_M) * stride_dm
m = tl.load(M + offs_m)
m = m[:, None]
d = tl.load(D + offs_d)
d = d[:, None]
block_indices_m = block_indices + (start_m // BLOCK_M) * stride_bm
block_indices_lens_m = block_indices_lens + (start_m // BLOCK_M) * stride_lm
dq = _attn_bwd_dq_bsa_varlen(
dq,
q, do,
m, d,
K, V,
N_CTX,
BLOCK_N_DQ,
BLOCK_N_DQ_LG,
HEAD_DIM,
block_indices_m,
block_indices_lens_m,
stride_bn,
stride_kn, stride_kk,
stride_vn, stride_vk,
)
# Write back dQ.
dq *= LN2
tl.store(DQ_block_ptr, dq.to(q.dtype))
configs_bwd_dq_bsa_varlen_align_preset = {
'default': {
'num_stages': 2,
'num_warps': 8,
},
'BLOCK_N_DQ_LG=64': {
'num_stages': 2,
'num_warps': 4,
},
}
configs_bwd_dq_bsa_varlen_align = [
triton.Config({}, num_stages=s, num_warps=w) \
for s in [2, 3, 4, 5] \
for w in [4, 8] \
]
bwd_dq_bsa_varlen_align_reevaluate_keys = ['N_CTX', 'BLOCK_M', 'BLOCK_N_DQ_LG', 'SPARSITY'] if os.environ.get('TRITON_REEVALUATE_KEY', '0') == '1' else []
@autotune(list(configs_bwd_dq_bsa_varlen_align), key=bwd_dq_bsa_varlen_align_reevaluate_keys)
@triton.jit
def _attn_bwd_dq_bsa_varlen_align_wrapper(
Q, K, V, # softmax scale
DO,
DQ,
M, # lse (log2)
D,
block_indices,
block_indices_lens,
# qkv
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
# dq do
stride_dqz, stride_dqh, stride_dqm, stride_dqk,
stride_doz, stride_doh, stride_dom, stride_dok,
# m, d
stride_mz, stride_mh, stride_mm,
stride_dz, stride_dh, stride_dm,
#
stride_bz, stride_bh, stride_bm, stride_bn, # block_indices
stride_lz, stride_lh, stride_lm, # block_indices_lens
#
H, N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_N_DQ_LG: tl.constexpr, # logical block size
HEAD_DIM: tl.constexpr,
SPARSITY: tl.constexpr, # not used; just for trigger reevaluate for benchmarking
):
tl.static_assert(BLOCK_N_DQ_LG % BLOCK_N_DQ_LG == 0)
tl.static_assert(BLOCK_N_DQ_LG % BLOCK_M == 0)
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
start_m = tl.program_id(0)
off_hz = tl.program_id(2)
off_z = off_hz // H
off_h = off_hz % H
off_q = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
off_k = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
off_v = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh
off_dq = off_z.to(tl.int64) * stride_dqz + off_h.to(tl.int64) * stride_dqh
off_do = off_z.to(tl.int64) * stride_doz + off_h.to(tl.int64) * stride_doh
off_m = off_z.to(tl.int64) * stride_mz + off_h.to(tl.int64) * stride_mh
off_d = off_z.to(tl.int64) * stride_dz + off_h.to(tl.int64) * stride_dh
off_block_incides = off_z.to(tl.int64) * stride_bz + off_h.to(tl.int64) * stride_bh
off_block_incides_lens = off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
# offset pointers for batch/head
Q += off_q
K += off_k
V += off_v
DO += off_do
DQ += off_dq
M += off_m
D += off_d
block_indices += off_block_incides
block_indices_lens += off_block_incides_lens
# ---------------------------------------- [DQ] ----------------------------------------
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(N_CTX, HEAD_DIM),
strides=(stride_dom, stride_dok),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(N_CTX, HEAD_DIM),
strides=(stride_dqm, stride_dqk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
q = tl.load(Q_block_ptr)
do = tl.load(DO_block_ptr)
dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
start_m = start_m * BLOCK_M
offs_m = start_m + tl.arange(0, BLOCK_M) * stride_mm
offs_d = start_m + tl.arange(0, BLOCK_M) * stride_dm
m = tl.load(M + offs_m)
m = m[:, None]
# D (= delta) is pre-divided by ds_scale.
d = tl.load(D + offs_d)
d = d[:, None]
block_indices_m = block_indices + (start_m // BLOCK_M) * stride_bm
block_indices_lens_m = block_indices_lens + (start_m // BLOCK_M) * stride_lm
dq = _attn_bwd_dq_bsa_varlen_align(
dq,
q, do,
m, d,
K, V,
N_CTX,
BLOCK_N_DQ_LG,
HEAD_DIM,
block_indices_m,
block_indices_lens_m,
stride_bn,
stride_kn, stride_kk,
stride_vn, stride_vk,
)
# Write back dQ.
dq *= LN2
tl.store(DQ_block_ptr, dq.to(q.dtype))
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from einops import rearrange
dp_size = dp_group = cp_group = cp_stream = dp_ranks = cp_ranks = dp_rank = None
cp_size: int = 1
cp_rank: int = 0
def init_context_parallel(context_parallel_size: int = 1,
global_rank: int = 0,
world_size: int = 1,):
global dp_size, cp_size, dp_group, cp_group, dp_ranks, cp_ranks, dp_rank, cp_rank
if world_size % context_parallel_size != 0:
raise RuntimeError(f'world_size {world_size} must be multiple of context_parallel_size {context_parallel_size}')
cp_size = context_parallel_size
dp_size = world_size//context_parallel_size
print(f'[rank {global_rank}] init_device_mesh [dp_size x cp_size]: [{dp_size} x {cp_size}]')
mesh_2d = init_device_mesh("cuda", (dp_size, cp_size), mesh_dim_names=("dp", "cp"))
print(f'[rank {global_rank}] mesh_2d: {mesh_2d}')
dp_group = mesh_2d.get_group(mesh_dim="dp")
cp_group = mesh_2d.get_group(mesh_dim="cp")
dp_ranks = torch.distributed.get_process_group_ranks(dp_group)
cp_ranks = torch.distributed.get_process_group_ranks(cp_group)
dp_rank = dist.get_rank(group=dp_group)
cp_rank = dist.get_rank(group=cp_group)
curr_global_rank = torch.distributed.get_rank()
print(f'[rank {curr_global_rank}] [dp_rank, cp_rank]: [{dp_rank}, {cp_rank}], dp_ranks: {dp_ranks}, cp_ranks: {cp_ranks}')
def get_cp_size():
global cp_size
return cp_size
def get_dp_size():
global dp_size
return dp_size
def get_cp_stream():
global cp_stream
if cp_stream == None:
cp_stream = torch.cuda.Stream()
return cp_stream
def get_dp_group():
global dp_group
return dp_group
def get_cp_group():
global cp_group
return cp_group
def get_dp_rank():
global dp_rank
return dp_rank
def get_cp_rank():
global cp_rank
return cp_rank
def get_cp_rank_list():
global cp_ranks
if cp_ranks == None:
cp_ranks = torch.distributed.get_process_group_ranks(cp_group)
return cp_ranks
def cp_broadcast(tensor, cp_index=0):
global dp_group
global cp_group
cp_ranks = get_cp_rank_list()
torch.distributed.broadcast(tensor, cp_ranks[cp_index], group=cp_group)
def split_tensor_in_cp_2d(input, dim_hw, split_hw):
global cp_size
dim_h, dim_w = dim_hw
split_h, split_w = split_hw
assert cp_size == split_h * split_w
seq_size_h = input.shape[dim_h]
seq_size_w = input.shape[dim_w]
if seq_size_h % split_h != 0:
raise RuntimeError(f'seq_size_h {seq_size_h} in dim_h {dim_h} must be multiple of split_h {split_h}!!!')
if seq_size_w % split_w != 0:
raise RuntimeError(f'seq_size_w {seq_size_w} in dim_w {dim_w} must be multiple of split_w {split_w}!!!')
split_seq_size_h = seq_size_h // split_h
split_seq_size_w = seq_size_w // split_w
tensor_splits_h = input.split(split_seq_size_h, dim=dim_h)
tensor_splits = []
for tensor_split_h in tensor_splits_h:
tensor_splits_hw = tensor_split_h.split(split_seq_size_w, dim=dim_w)
tensor_splits.extend(tensor_splits_hw)
cp_rank = get_cp_rank()
split_tensor = tensor_splits[cp_rank]
return split_tensor
class GatherFunction2D(torch.autograd.Function):
@staticmethod
def forward(ctx, input, process_group, seq_dim_hw, shape, split_hw):
ctx.cp_group = process_group
ctx.seq_dim_hw = seq_dim_hw
ctx.split_hw = split_hw
ctx.shape = shape
ctx.cp_size = get_cp_size()
T, H, W = shape
dim_h, dim_w = seq_dim_hw
split_h, split_w = split_hw
assert H % split_h == 0, W % split_w == 0
assert T * (H // split_h) * (W // split_w) == input.shape[1]
input = rearrange(input, "B (T H W) C -> B T H W C", T=T, H=H // split_h, W=W // split_w)
with torch.no_grad():
input = input.contiguous()
output_tensors = [torch.zeros_like(input) for _ in range(ctx.cp_size)]
dist.all_gather(output_tensors, input, group=ctx.cp_group)
output_tensors_hs = []
assert ctx.cp_size % split_w == 0
for i in range(0, ctx.cp_size // split_w):
output_tensors_hs.append(
torch.cat(output_tensors[i * split_w : (i + 1) * split_w], dim=dim_w)
)
output_tensor = torch.cat(output_tensors_hs, dim=dim_h)
output_tensor = rearrange(output_tensor, "B T H W C -> B (T H W) C")
return output_tensor
@staticmethod
def backward(ctx, grad_output):
T, H, W = ctx.shape
with torch.no_grad():
grad_output = grad_output * ctx.cp_size
grad_output = rearrange(grad_output, "B (T H W) C -> B T H W C", T=T, H=H, W=W)
grad_input = split_tensor_in_cp_2d(grad_output, ctx.seq_dim_hw, ctx.split_hw)
grad_input = rearrange(grad_input, "B T H W C -> B (T H W) C")
return grad_input, None, None, None, None
class SplitFunction2D(torch.autograd.Function):
@staticmethod
def forward(ctx, input, process_group, seq_dim_hw, split_hw):
ctx.cp_group = process_group
ctx.seq_dim_hw = seq_dim_hw
ctx.split_hw = split_hw
ctx.cp_size = get_cp_size()
output_tensor = split_tensor_in_cp_2d(input, ctx.seq_dim_hw, split_hw)
return output_tensor
@staticmethod
def backward(ctx, grad_output):
with torch.no_grad():
grad_output = grad_output / ctx.cp_size
output_tensors = [torch.zeros_like(grad_output) for _ in range(ctx.cp_size)]
dist.all_gather(output_tensors, grad_output, group=ctx.cp_group)
split_h, split_w = ctx.split_hw
dim_h, dim_w = ctx.seq_dim_hw
output_tensors_hs = []
assert ctx.cp_size % split_w == 0
for i in range(0, ctx.cp_size // split_w):
output_tensors_hs.append(
torch.cat(output_tensors[i * split_w : (i + 1) * split_w], dim=dim_w)
)
grad_input = torch.cat(output_tensors_hs, dim=dim_h)
return grad_input, None, None, None
def gather_cp_2d(input, shape, split_hw):
cp_process_group = get_cp_group()
output_tensor = GatherFunction2D.apply(input, cp_process_group, (2, 3), shape, split_hw)
return output_tensor
def split_cp_2d(input, seq_dim_hw, split_hw):
cp_process_group = get_cp_group()
output_tensor = SplitFunction2D.apply(input, cp_process_group, seq_dim_hw, split_hw)
return output_tensor
def get_optimal_split(size):
factors = []
for i in range(1, int(size**0.5) + 1):
if size % i == 0:
factors.append([i, size // i])
return min(factors, key=lambda x: abs(x[0] - x[1]))
\ No newline at end of file
import torch
import torch.distributed as dist
from ..context_parallel import context_parallel_util
def all_to_all(tensor, scatter_idx, gather_idx, group=None, gather=True):
"""Perform all-to-all communication on a tensor.
Args:
tensor (torch.Tensor): Input tensor for all-to-all communication
scatter_idx (int): Dimension to scatter, will split along this dimension and then scatter to all processes
gather_idx (int): Dimension to gather, will gather from all processes and then concatenate along this dimension
group (ProcessGroup, optional): Process group to use for communication
Returns:
torch.Tensor
"""
if not dist.is_initialized():
return tensor
world_size = dist.get_world_size(group)
ulysses_rank = context_parallel_util.get_cp_rank()
if world_size == 1:
return tensor
if scatter_idx == gather_idx:
raise ValueError("scatter_idx and gather_idx must be different")
def chunk_tensor(tensor, scatter_idx):
t_shape = list(tensor.shape)
if t_shape[scatter_idx] % world_size != 0:
raise ValueError(f"Dimension {scatter_idx} must be divisible by world size {world_size}")
chunk_size = t_shape[scatter_idx] // world_size
new_shape = list()
for i in range(len(t_shape)):
if i != scatter_idx:
new_shape.append(t_shape[i])
else:
new_shape.extend([world_size, chunk_size])
tensor = tensor.reshape(*new_shape)
# move scatter_idx to front
tensor = tensor.permute(scatter_idx, *[i for i in range(len(new_shape)) if i != scatter_idx]).contiguous()
return tensor
# chunk tensor for all_to_all
tensor = chunk_tensor(tensor, scatter_idx)
# Perform all2all
output = torch.empty_like(tensor)
dist.all_to_all_single(output, tensor, group=group)
# output: e.g., [world_size, B, chunked_H, chunked_S, D] if scatter_idx == 1, gather_idx == 2 -> [B, chunked_H, S, D]
def reorder_tensor(tensor, gather_idx):
t_shape = list(tensor.shape)
world_size = t_shape[0]
# insert front to gather_idx + 1
permute_idx = list()
for i in range(1, len(t_shape)):
if i != gather_idx + 1:
permute_idx.append(i)
else:
permute_idx.extend([0, i])
tensor = tensor.permute(*permute_idx).contiguous() # permute(1,2,0,3) W B CH CS D -> B CH W CS D
# reshape tensor
new_shape = list()
if gather:
for i in range(1, len(t_shape)): # B CH CS D
if i != gather_idx + 1:
new_shape.append(t_shape[i])
else:
new_shape.append(world_size * t_shape[i]) # B CH W*CS D
tensor = tensor.reshape(*new_shape)
else:
tensor = tensor[:,ulysses_rank] # W B CS CH D -> B CS W CH D
return tensor
output = reorder_tensor(output, gather_idx)
return output
@torch.compiler.disable
def ulysses_a2a_in(query, key, value):
if context_parallel_util.get_cp_size() == 1:
return query, key, value
# [B, H, S/N, D] -> [B, H/N, S, D]
query = all_to_all(query, scatter_idx=1, gather_idx=2, group=context_parallel_util.get_cp_group())
key = all_to_all(key, scatter_idx=1, gather_idx=2, group=context_parallel_util.get_cp_group())
value = all_to_all(value, scatter_idx=1, gather_idx=2, group=context_parallel_util.get_cp_group())
return query, key, value
@torch.compiler.disable
def ulysses_a2a_out(output):
if context_parallel_util.get_cp_size() == 1:
return output
# [B, H/N, S, D] -> [B, H, S/N, D]
output = all_to_all(output, scatter_idx=2, gather_idx=1, group=context_parallel_util.get_cp_group())
return output
def ulysses_wrapper(func):
def wrapper(self, query, key, value, shape):
# Apply ulysses_a2a_in before the function call, gather sequence and split head
query, key, value = ulysses_a2a_in(query, key, value)
output = func(self, query, key, value, shape)
output = ulysses_a2a_out(output)
return output
return wrapper
from typing import List, Optional
import torch
import torch.nn as nn
from einops import rearrange
from .rope_3d import RotaryPositionalEmbedding
from .blocks import RMSNorm_FP32
from ..block_sparse_attention.bsa_interface import flash_attn_bsa_3d
from ..context_parallel.ulysses_wrapper import ulysses_wrapper
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
self.enable_bsa = enable_bsa
self.bsa_params = bsa_params
self.cp_split_hw = cp_split_hw
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.proj = nn.Linear(dim, dim)
self.rope_3d = RotaryPositionalEmbedding(
self.head_dim,
cp_split_hw=cp_split_hw
)
@ulysses_wrapper
def _process_attn(self, q, k, v, shape):
"""
function wrapper to do attention with q, k, v
"""
B, H, SQ, D = q.shape
_, _, SKV, _ = k.shape
if self.enable_bsa and shape[0] > 1: # bsa will not be used in image training / sampling
assert self.bsa_params is not None
_, H, W = shape
assert H % self.cp_split_hw[0] == 0, W % self.cp_split_hw[1] == 0
H, W = H // self.cp_split_hw[0], W // self.cp_split_hw[1]
Tq = SQ // (H * W)
Tk = SKV // (H * W)
latent_shape_q = (Tq, H, W)
latent_shape_k = (Tk, H, W)
x = flash_attn_bsa_3d(q, k, v, latent_shape_q, latent_shape_k, **self.bsa_params)
elif self.enable_flashattn3:
from flash_attn_interface import flash_attn_func
q = rearrange(q, "B H S D -> B S H D").contiguous()
k = rearrange(k, "B H S D -> B S H D").contiguous()
v = rearrange(v, "B H S D -> B S H D").contiguous()
x, *_ = flash_attn_func(
q,
k,
v,
softmax_scale=self.scale,
)
x = rearrange(x, "B S H D -> B H S D")
elif self.enable_flashattn2:
from flash_attn import flash_attn_func
q = rearrange(q, "B H S D -> B S H D")
k = rearrange(k, "B H S D -> B S H D")
v = rearrange(v, "B H S D -> B S H D")
x = flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=self.scale,
)
x = rearrange(x, "B S H D -> B H S D")
elif self.enable_xformers:
import xformers.ops
# Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
# the sequence length, H the number of heads, and K the embeding size per head
q = rearrange(q, "B H M K -> B M H K")
k = rearrange(k, "B H M K -> B M H K")
v = rearrange(v, "B H M K -> B M H K")
x = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None,)
x = rearrange(x, "B M H K -> B H M K")
else:
raise RuntimeError("Unsupported attention operations.")
return x
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if return_kv:
k_cache, v_cache = k.clone(), v.clone()
q, k = self.rope_3d(q, k, shape)
# cond mode
if num_cond_latents is not None and num_cond_latents > 0:
num_cond_latents_thw = num_cond_latents * (N // shape[0])
# process the condition tokens
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
# process the noise tokens
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
x_noise = self._process_attn(q_noise, k, v, shape)
# merge x_cond and x_noise
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
else:
x = self._process_attn(q, k, v, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
if return_kv:
return x, (k_cache, v_cache)
else:
return x
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
T, H, W = shape
k_cache, v_cache = kv_cache
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
if k_cache.shape[0] == 1:
k_cache = k_cache.repeat(B, 1, 1, 1)
v_cache = v_cache.repeat(B, 1, 1, 1)
if num_cond_latents is not None and num_cond_latents > 0:
k_full = torch.cat([k_cache, k], dim=2).contiguous()
v_full = torch.cat([v_cache, v], dim=2).contiguous()
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
q = q_padding[:, :, -N:].contiguous()
x = self._process_attn(q, k_full, v_full, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads,
enable_flashattn3=False,
enable_flashattn2=False,
enable_xformers=False,
):
super(MultiHeadCrossAttention, self).__init__()
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = nn.Linear(dim, dim)
self.kv_linear = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
def _process_cross_attn(self, x, cond, kv_seqlen):
B, N, C = x.shape
assert C == self.dim and cond.shape[2] == self.dim
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
if self.enable_flashattn3:
from flash_attn_interface import flash_attn_varlen_func
x = flash_attn_varlen_func(
q=q[0],
k=k[0],
v=v[0],
cu_seqlens_q=torch.tensor([0] + [N] * B, device=q.device).cumsum(0).to(torch.int32),
cu_seqlens_k=torch.tensor([0] + kv_seqlen, device=q.device).cumsum(0).to(torch.int32),
max_seqlen_q=N,
max_seqlen_k=max(kv_seqlen),
)[0]
elif self.enable_flashattn2:
from flash_attn import flash_attn_varlen_func
x = flash_attn_varlen_func(
q=q[0],
k=k[0],
v=v[0],
cu_seqlens_q=torch.tensor([0] + [N] * B, device=q.device).cumsum(0).to(torch.int32),
cu_seqlens_k=torch.tensor([0] + kv_seqlen, device=q.device).cumsum(0).to(torch.int32),
max_seqlen_q=N,
max_seqlen_k=max(kv_seqlen),
)
elif self.enable_xformers:
import xformers.ops
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens([N] * B, kv_seqlen)
x = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
else:
raise RuntimeError("Unsupported attention operations.")
x = x.view(B, -1, C)
x = self.proj(x)
return x
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
"""
x: [B, N, C]
cond: [B, M, C]
"""
if num_cond_latents is None or num_cond_latents == 0:
return self._process_cross_attn(x, cond, kv_seqlen)
else:
B, N, C = x.shape
if num_cond_latents is not None and num_cond_latents > 0:
assert shape is not None, "SHOULD pass in the shape"
num_cond_latents_thw = num_cond_latents * (N // shape[0])
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
output = torch.cat([
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
output_noise
], dim=1).contiguous()
else:
raise NotImplementedError
return output
\ No newline at end of file
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin
from diffusers.utils import logging
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.activations import get_activation
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CACHE_T = 2
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1 :, :, :]
return x
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
# Set up causal padding
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class WanRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class WanUpsample(nn.Upsample):
r"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Args:
x (torch.Tensor): Input tensor to be upsampled.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def forward(self, x):
return super().forward(x.float()).type_as(x)
class WanResample(nn.Module):
r"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# default to dim //2
if upsample_out_dim is None:
upsample_out_dim = dim // 2
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class WanResidualBlock(nn.Module):
r"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
in_dim: int,
out_dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = WanRMS_norm(in_dim, images=False)
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = WanRMS_norm(out_dim, images=False)
self.dropout = nn.Dropout(dropout)
self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
# Add residual connection
return x + h
class WanAttentionBlock(nn.Module):
r"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = WanRMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def forward(self, x):
identity = x
batch_size, channels, time, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x)
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
qkv = qkv.permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(q, k, v)
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
# output projection
x = self.proj(x)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x = x.view(batch_size, time, channels, height, width)
x = x.permute(0, 2, 1, 3, 4)
return x + identity
class WanMidBlock(nn.Module):
"""
Middle block for WanVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
super().__init__()
self.dim = dim
# Create the components
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
attentions = []
for _ in range(num_layers):
attentions.append(WanAttentionBlock(dim))
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
return x
class WanResidualDownBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
resnets = []
for _ in range(num_res_blocks):
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
self.downsampler = WanResample(out_dim, mode=mode)
else:
self.downsampler = None
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache, feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
in_channels: int = 3,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
is_residual: bool = False, # wan 2.2 vae use a residual downblock
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.nonlinearity = get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if is_residual:
self.down_blocks.append(
WanResidualDownBlock(
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
down_flag=i != len(dim_mult) - 1,
)
)
else:
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class WanResidualUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
temperal_upsample (bool): Whether to upsample on temporal dimension
up_flag (bool): Whether to upsample or not
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
temperal_upsample: bool = False,
up_flag: bool = False,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2,
)
else:
self.avg_shortcut = None
# create residual blocks
resnets = []
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
if up_flag:
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
else:
self.upsampler = None
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
x_copy = x.clone()
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache, feat_idx)
else:
x = self.upsampler(x)
if self.avg_shortcut is not None:
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
return x
class WanUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
upsample_mode: Optional[str] = None,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Create layers list
resnets = []
# Add residual blocks and attention if needed
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
self.upsamplers = None
if upsample_mode is not None:
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
class WanDecoder3d(nn.Module):
r"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
out_channels: int = 3,
is_residual: bool = False,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.nonlinearity = get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0 and not is_residual:
# wan vae 2.1
in_dim = in_dim // 2
# determine if we need upsampling
up_flag = i != len(dim_mult) - 1
# determine upsampling mode, if not upsampling, set to None
upsample_mode = None
if up_flag and temperal_upsample[i]:
upsample_mode = "upsample3d"
elif up_flag:
upsample_mode = "upsample2d"
# Create and add the upsampling block
if is_residual:
up_block = WanResidualUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
temperal_upsample=temperal_upsample[i] if up_flag else False,
up_flag=up_flag,
non_linearity=non_linearity,
)
else:
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() != 5:
raise ValueError(f"Invalid input shape: {x.shape}")
# x shape: [batch_size, channels, frames, height, width]
batch_size, channels, frames, height, width = x.shape
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() != 5:
raise ValueError(f"Invalid input shape: {x.shape}")
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
batch_size, c_patches, frames, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
# Reshape to [b, c, patch_size, patch_size, f, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
latents_mean: List[float] = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
],
latents_std: List[float] = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
],
is_residual: bool = False,
in_channels: int = 3,
out_channels: int = 3,
patch_size: Optional[int] = None,
scale_factor_temporal: Optional[int] = 4,
scale_factor_spatial: Optional[int] = 8,
) -> None:
super().__init__()
self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
if decoder_base_dim is None:
decoder_base_dim = base_dim
self.encoder = WanEncoder3d(
in_channels=in_channels,
dim=base_dim,
z_dim=z_dim * 2,
dim_mult=dim_mult,
num_res_blocks=num_res_blocks,
attn_scales=attn_scales,
temperal_downsample=temperal_downsample,
dropout=dropout,
is_residual=is_residual,
)
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
self.decoder = WanDecoder3d(
dim=decoder_base_dim,
z_dim=z_dim,
dim_mult=dim_mult,
num_res_blocks=num_res_blocks,
attn_scales=attn_scales,
temperal_upsample=self.temperal_upsample,
dropout=dropout,
out_channels=out_channels,
is_residual=is_residual,
)
self.spatial_compression_ratio = scale_factor_spatial
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
if self.decoder is not None
else 0,
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
if self.encoder is not None
else 0,
}
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = self._cached_conv_counts["encoder"]
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
self.clear_cache()
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
self.clear_cache()
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True):
_, _, num_frame, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
self.clear_cache()
x = self.post_quant_conv(z)
for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
if self.config.patch_size is not None:
out = unpatchify(out, patch_size=self.config.patch_size)
out = torch.clamp(out, min=-1.0, max=1.0)
self.clear_cache()
if not return_dict:
return (out,)
return DecoderOutput(sample=out)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
self.clear_cache()
time = []
frame_range = 1 + (num_frames - 1) // 4
for k in range(frame_range):
self._enc_conv_idx = [0]
if k == 0:
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
else:
tile = x[
:,
:,
1 + 4 * (k - 1) : 1 + 4 * k,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
tile = self.quant_conv(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
self.clear_cache()
time = []
for k in range(num_frames):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
# References:
# https://github.com/hpcaitech/Open-Sora
# https://github.com/facebookresearch/DiT/blob/main/models.py
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
# https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py#L14
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp as amp
from typing import Optional
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class RMSNorm_FP32(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class LayerNorm_FP32(nn.LayerNorm):
def __init__(self, dim, eps, elementwise_affine):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
out = F.layer_norm(
inputs.float(),
self.normalized_shape,
None if self.weight is None else self.weight.float(),
None if self.bias is None else self.bias.float() ,
self.eps
).to(origin_dtype)
return out
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
B, C, T, H, W = x.shape
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
def modulate_fp32(norm_func, x, shift, scale):
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
# ensure the modulation params be fp32
assert shift.dtype == torch.float32, scale.dtype == torch.float32
dtype = x.dtype
x = norm_func(x.to(torch.float32))
x = x * (scale + 1) + shift
x = x.to(dtype)
return x
class FinalLayer_FP32(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
super().__init__()
self.hidden_size = hidden_size
self.num_patch = num_patch
self.out_channels = out_channels
self.adaln_tembed_dim = adaln_tembed_dim
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
def forward(self, x, t, latent_shape):
# timestep shape: [B, T, C]
assert t.dtype == torch.float32
B, N, C = x.shape
T, _, _ = latent_shape
with amp.autocast('cuda', dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, t_embed_dim, frequency_embedding_size=256):
super().__init__()
self.t_embed_dim = t_embed_dim
self.frequency_embedding_size = frequency_embedding_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
nn.SiLU(),
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
)
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations.
"""
def __init__(self, in_channels, hidden_size):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.y_proj = nn.Sequential(
nn.Linear(in_channels, hidden_size, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, caption):
B, _, N, C = caption.shape
caption = self.y_proj(caption)
return caption
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.amp as amp
import numpy as np
from einops import rearrange
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from safetensors.torch import load_file
from .lora_utils import create_lora_network
from ..context_parallel import context_parallel_util
from .attention import Attention, MultiHeadCrossAttention
from .blocks import TimestepEmbedder, CaptionEmbedder, PatchEmbed3D, FeedForwardSwiGLU, FinalLayer_FP32, LayerNorm_FP32, modulate_fp32
class LongCatSingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: int,
adaln_tembed_dim: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params=None,
cp_split_hw=None
):
super().__init__()
self.hidden_size = hidden_size
# scale and gate modulation
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
)
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
self.attn = Attention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
self.cross_attn = MultiHeadCrossAttention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
)
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
"""
x: [B, N, C]
y: [1, N_valid_tokens, C]
t: [B, T, C_t]
y_seqlen: [B]; type of a list
latent_shape: latent shape of a single item
"""
x_dtype = x.dtype
B, N, C = x.shape
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
# self attn with modulation
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
if kv_cache is not None:
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
else:
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
if return_kv:
x_s, kv_cache = attn_outputs
else:
x_s = attn_outputs
with amp.autocast(device_type='cuda', dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
# cross attn
if not skip_crs_attn:
if kv_cache is not None:
num_cond_latents = None
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
if return_kv:
return x, kv_cache
else:
return x
class LongCatVideoTransformer3DModel(
ModelMixin, ConfigMixin
):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
depth: int = 48,
num_heads: int = 32,
caption_channels: int = 4096,
mlp_ratio: int = 4,
adaln_tembed_dim: int = 512,
frequency_embedding_size: int = 256,
# default params
patch_size: Tuple[int] = (1, 2, 2),
# attention config
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None,
text_tokens_zero_pad: bool = False,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.cp_split_hw = cp_split_hw
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
)
self.blocks = nn.ModuleList(
[
LongCatSingleStreamBlock(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
adaln_tembed_dim=adaln_tembed_dim,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
for i in range(depth)
]
)
self.final_layer = FinalLayer_FP32(
hidden_size,
np.prod(self.patch_size),
out_channels,
adaln_tembed_dim,
)
self.gradient_checkpointing = False
self.text_tokens_zero_pad = text_tokens_zero_pad
self.lora_dict = {}
self.active_loras = []
def load_lora(self, lora_path, lora_key, multiplier=1.0, lora_network_dim=128, lora_network_alpha=64):
lora_network_state_dict_loaded = load_file(lora_path, device="cpu")
lora_network = create_lora_network(
transformer=self,
lora_network_state_dict_loaded=lora_network_state_dict_loaded,
multiplier=multiplier,
network_dim=lora_network_dim,
network_alpha=lora_network_alpha,
)
lora_network.load_state_dict(lora_network_state_dict_loaded, strict=True)
self.lora_dict[lora_key] = lora_network
def enable_loras(self, lora_key_list=[]):
self.disable_all_loras()
module_loras = {} # {module_name: [lora1, lora2, ...]}
model_device = next(self.parameters()).device
model_dtype = next(self.parameters()).dtype
for lora_key in lora_key_list:
if lora_key in self.lora_dict:
for lora in self.lora_dict[lora_key].loras:
lora.to(model_device, dtype=model_dtype, non_blocking=True)
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
if module_name not in module_loras:
module_loras[module_name] = []
module_loras[module_name].append(lora)
self.active_loras.append(lora_key)
for module_name, loras in module_loras.items():
module = self._get_module_by_name(module_name)
if not hasattr(module, 'org_forward'):
module.org_forward = module.forward
module.forward = self._create_multi_lora_forward(module, loras)
def _create_multi_lora_forward(self, module, loras):
def multi_lora_forward(x, *args, **kwargs):
weight_dtype = x.dtype
org_output = module.org_forward(x, *args, **kwargs)
total_lora_output = 0
for lora in loras:
if lora.use_lora:
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
lx = lora.lora_up(lx)
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
total_lora_output += lora_output
return org_output + total_lora_output
return multi_lora_forward
def _get_module_by_name(self, module_name):
try:
module = self
for part in module_name.split('.'):
module = getattr(module, part)
return module
except AttributeError as e:
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
def disable_all_loras(self):
for name, module in self.named_modules():
if hasattr(module, 'org_forward'):
module.forward = module.org_forward
delattr(module, 'org_forward')
for lora_key, lora_network in self.lora_dict.items():
for lora in lora_network.loras:
lora.to("cpu")
self.active_loras.clear()
def enable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = True
def disable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = False
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states,
encoder_attention_mask=None,
num_cond_latents=0,
return_kv=False,
kv_cache_dict={},
skip_crs_attn=False,
offload_kv_cache=False
):
B, _, T, H, W = hidden_states.shape
N_t = T // self.patch_size[0]
N_h = H // self.patch_size[1]
N_w = W // self.patch_size[2]
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
# expand the shape of timestep from [B] to [B, T]
if len(timestep.shape) == 1:
timestep = timestep.unsqueeze(1).expand(-1, N_t) # [B, T]
dtype = self.x_embedder.proj.weight.dtype
hidden_states = hidden_states.to(dtype)
timestep = timestep.to(dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype)
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type='cuda', dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
else:
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
# blocks
kv_cache_dict_ret = {}
for i, block in enumerate(self.blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
block_outputs = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, t, y_seqlens,
(N_t, N_h, N_w), num_cond_latents, return_kv, kv_cache_dict.get(i, None), skip_crs_attn
)
else:
block_outputs = block(
hidden_states, encoder_hidden_states, t, y_seqlens,
(N_t, N_h, N_w), num_cond_latents, return_kv, kv_cache_dict.get(i, None), skip_crs_attn
)
if return_kv:
hidden_states, kv_cache = block_outputs
if offload_kv_cache:
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
else:
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
else:
hidden_states = block_outputs
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
# cast to float32 for better accuracy
hidden_states = hidden_states.to(torch.float32)
if return_kv:
return hidden_states, kv_cache_dict_ret
else:
return hidden_states
def unpatchify(self, x, N_t, N_h, N_w):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
\ No newline at end of file
# References:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# https://github.com/bmaltais/kohya_ss
import math
import functools
from collections import defaultdict
from typing import Optional
import torch
class LoRAUPParallel(torch.nn.Module):
def __init__(self, blocks):
super().__init__()
self.blocks = torch.nn.ModuleList(blocks)
def forward(self, x):
assert x.shape[-1] % len(self.blocks) == 0
xs = torch.chunk(x, len(self.blocks), dim=-1)
out = torch.cat([self.blocks[i](xs[i]) for i in range(len(self.blocks))], dim=-1)
return out
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
n_seperate=1
):
super().__init__()
self.lora_name = lora_name
assert org_module.__class__.__name__ == "Linear"
in_dim = org_module.in_features
out_dim = org_module.out_features
if n_seperate > 1:
assert out_dim % n_seperate == 0
self.lora_dim = lora_dim
if n_seperate > 1:
self.lora_down = torch.nn.Linear(in_dim, n_seperate * self.lora_dim, bias=False)
self.lora_up = LoRAUPParallel([torch.nn.Linear(self.lora_dim, out_dim // n_seperate, bias=False) for _ in range(n_seperate)])
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
alpha_scale = alpha / self.lora_dim
self.register_buffer("alpha_scale", torch.tensor(alpha_scale))
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
if n_seperate > 1:
for block in self.lora_up.blocks:
torch.nn.init.zeros_(block.weight)
else:
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.use_lora = True
def set_use_lora(self, use_lora):
self.use_lora = use_lora
class LoRANetwork(torch.nn.Module):
LORA_PREFIX = "lora"
LORA_HYPHEN = "___lorahyphen___"
def __init__(
self,
model,
lora_network_state_dict_loaded,
multiplier: float = 1.0,
lora_dim: int = 128,
alpha: float = 64,
) -> None:
super().__init__()
self.multiplier = multiplier
self.use_lora = True
self.lora_dim = lora_dim
self.alpha = alpha
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
lora_module_names = set()
for key in lora_network_state_dict_loaded.keys():
if key.endswith("lora_down.weight"):
lora_name = key.split(".lora_down.weight")[0]
lora_module_names.add(lora_name)
loras = []
for lora_name in lora_module_names:
# 还原为模型中的真实模块名
module_name = lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
# 查找模块
try:
module = model
for part in module_name.split('.'):
module = getattr(module, part)
except Exception as e:
print(f"Cannot find module: {module_name}, error: {e}")
continue
if module.__class__.__name__ != "Linear":
continue
# 推断 n_seperate
n_seperate = 1
prefix = lora_name + ".lora_up.blocks"
n_blocks = sum(1 for k in lora_network_state_dict_loaded if k.startswith(prefix))
if n_blocks > 0:
n_seperate = n_blocks
dim = self.lora_dim
alpha = self.alpha
lora = LoRAModule(
lora_name,
module,
self.multiplier,
dim,
alpha,
n_seperate=n_seperate
)
loras.append(lora)
self.loras = loras
for lora in self.loras:
self.add_module(lora.lora_name, lora)
print(f"create LoRA for model: {len(self.loras)} modules.")
# assertion
names = set()
for lora in self.loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def disapply_to(self):
for lora in self.loras:
lora.disapply_to()
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.loras:
lora.multiplier = self.multiplier
def set_use_lora(self, use_lora):
self.use_lora = use_lora
for lora in self.loras:
lora.set_use_lora(use_lora)
def prepare_optimizer_params(self, lr):
self.requires_grad_(True)
all_params = []
params = []
for lora in self.loras:
params.extend(lora.parameters())
param_data = {"params": params}
param_data["lr"] = lr
all_params.append(param_data)
return all_params
def create_lora_network(
transformer,
lora_network_state_dict_loaded,
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
):
network = LoRANetwork(
transformer,
lora_network_state_dict_loaded,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
)
return network
# References:
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/rotary_positional_embedding.py
import numpy as np
from functools import lru_cache
import torch
import torch.nn as nn
from einops import rearrange, repeat
from ..context_parallel import context_parallel_util
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class RotaryPositionalEmbedding(nn.Module):
def __init__(self,
head_dim,
cp_split_hw=None
):
"""Rotary positional embedding for 3D
Reference : https://blog.eleuther.ai/rotary-embeddings/
Paper: https://arxiv.org/pdf/2104.09864.pdf
Args:
dim: Dimension of embedding
base: Base value for exponential
"""
super().__init__()
self.head_dim = head_dim
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
self.cp_split_hw = cp_split_hw
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
self.base = 10000
self.freqs_dict = {}
def register_grid_size(self, grid_size):
if grid_size not in self.freqs_dict:
self.freqs_dict.update({
grid_size: self.precompute_freqs_cis_3d(grid_size)
})
def precompute_freqs_cis_3d(self, grid_size):
num_frames, height, width = grid_size
dim_t = self.head_dim - 4 * (self.head_dim // 6)
dim_h = 2 * (self.head_dim // 6)
dim_w = 2 * (self.head_dim // 6)
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
grid_t = torch.from_numpy(grid_t).float()
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
# (T H W D)
freqs = rearrange(freqs, "T H W D -> (T H W) D")
if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
with torch.no_grad():
freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
freqs = rearrange(freqs, "T H W D -> (T H W) D")
return freqs
def forward(self, q, k, grid_size):
"""3D RoPE.
Args:
query: [B, head, seq, head_dim]
key: [B, head, seq, head_dim]
Returns:
query and key with the same shape as input.
"""
if grid_size not in self.freqs_dict:
self.register_grid_size(grid_size)
freqs_cis = self.freqs_dict[grid_size].to(q.device)
q_, k_ = q.float(), k.float()
freqs_cis = freqs_cis.float().to(q.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
q_ = (q_ * cos) + (rotate_half(q_) * sin)
k_ = (k_ * cos) + (rotate_half(k_) * sin)
return q_.type_as(q), k_.type_as(k)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment