Commit 6cb0d1ce authored by dengjb's avatar dengjb
Browse files

update

parents
Pipeline #3058 canceled with stages
*.mp4
*.DS_Store
weights/
__pycache__/
\ No newline at end of file
MIT License
Copyright (c) 2025 Meituan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
# LongCat-Video
## 论文
[LongCat-Video Technical Report](https://arxiv.org/abs/2510.22200)
## 模型简介
LongCat-Vedio,是一个拥有136B参数的基础适配生成模型,在多种适配生成任务重表现出色。擅长高效、高质量的生成长视频
**核心特点**
- 🌟 **多任务统一架构:** LongCat-Video(长猫视频模型)在单一视频生成框架内,实现了文本生成视频(Text-to-Video)、图像生成视频(Image-to-Video)与视频续接(Video-Continuation)三大任务的统一。该模型仅通过单个模型即可原生支持所有上述任务,且在每个独立任务中均能稳定输出优异性能。
- 🌟 **长视频生成能力:** LongCat-Video 原生基于视频续接任务进行预训练,能够生成时长可达数分钟的视频,且过程中不会出现色彩偏移(color drifting)或画质下降(quality degradation)问题。
- 🌟 **高效推理性能:** LongCat-Video 采用时空双轴 “由粗到精”(coarse-to-fine)的生成策略,可在数分钟内生成 720p、30fps 规格的视频。此外,块稀疏注意力(Block Sparse Attention)技术进一步提升了推理效率,在高分辨率生成场景下效果尤为显著。
- 🌟 **多奖励 RLHF 驱动的卓越性能:** 借助多奖励分组相对策略优化(Group Relative Policy Optimization, GRPO)技术,通过内部基准与公开基准的全面评估显示,LongCat-Video 的性能可媲美主流开源视频生成模型及最新商业解决方案。
![alt text](image.png)
## 环境依赖
| 软件 | 版本 |
| :------: | :------: |
| DTK | 25.04.2 |
| python | 3.10.12 |
| transformers | 4.57.1 |
| vllm | 0.11.0+das.opt1.alpha.8e22ded.dtk25042 |
| torch | 2.5.1+das.opt1.dtk25042 |
| triton | 3.1+das.opt1.3c5d12d.dtk25041 |
| flash_attn | 2.6.1+das.opt1.dtk2504 |
| flash_mla | 1.0.0+das.opt1.dtk25042 |
当前仅支持镜像:
- 挂载地址`-v`根据实际模型情况修改
```bash
docker run -it --shm-size 60g --network=host --name minimax_m2 --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /opt/hyhal/:/opt/hyhal/:ro -v /path/your_code_path/:/path/your_code_path/ docker pull image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.2-py3.10 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
```bash
pip install -r requirements.txt
```
## 数据集
暂无
## 训练
暂无
## 推理
### pytorch
#### 单机推理
可参考 run.sh文件
##### 文生视频
```shell
# 单GPU推理
torchrun run_demo_text_to_video.py --checkpoint_dir=./weights/LongCat-Video
# 多GPU推理
torchrun --nproc_per_node=8 run_demo_text_to_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video
```
##### 图生视频
```shell
# 单GPU推理
torchrun run_demo_image_to_video.py --checkpoint_dir=./weights/LongCat-Video
# 多GPU推理
torchrun --nproc_per_node=8 run_demo_image_to_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video
```
#### 长视频生成
```shell
# 单GPU推理
torchrun run_demo_long_video.py --checkpoint_dir=./weights/LongCat-Video
# 多GPU推理
torchrun --nproc_per_node=8 run_demo_long_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video
```
## 效果展示
<div align="center">
<video src="https://github.com/user-attachments/assets/00fa63f0-9c4e-461a-a79e-c662ad596d7d" width="2264" height="384"> </video>
</div>
### 精度
DCU与GPU精度一致,推理框架:transformers。
## 预训练权重
| 模型名称 | 权重大小 | DCU型号 | 最低卡数需求 |下载地址|
|:-----:|:----------:|:----------:|:---------------------:|:----------:|
| LongCat-Video | 13.6 B | K100AI | 8 | 🤗 [Huggingface](https://huggingface.co/meituan-longcat/LongCat-Video) |
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/longcat-video-pytorch
## 参考资料
- https://github.com/meituan-longcat/LongCat-Video
# LongCat-Video
<div align="center">
<img src="assets/longcat-video_logo.svg" width="45%" alt="LongCat-Video" />
</div>
<hr>
<div align="center" style="line-height: 1;">
<a href='https://meituan-longcat.github.io/LongCat-Video/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
<a href='https://github.com/meituan-longcat/LongCat-Video/blob/main/longcatvideo_tech_report.pdf'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
<a href='https://huggingface.co/meituan-longcat/LongCat-Video'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
</div>
<div align="center" style="line-height: 1;">
<a href='https://github.com/meituan-longcat/LongCat-Flash-Chat/blob/main/figures/wechat_official_accounts.png'><img src='https://img.shields.io/badge/WeChat-LongCat-brightgreen?logo=wechat&logoColor=white'></a>
<a href='https://x.com/Meituan_LongCat'><img src='https://img.shields.io/badge/Twitter-LongCat-white?logo=x&logoColor=white'></a>
</div>
<div align="center" style="line-height: 1;">
<a href='LICENSE'><img src='https://img.shields.io/badge/License-MIT-f5de53?&color=f5de53'></a>
</div>
## Model Introduction
We introduce LongCat-Video, a foundational video generation model with 13.6B parameters, delivering strong performance across *Text-to-Video*, *Image-to-Video*, and *Video-Continuation* generation tasks. It particularly excels in efficient and high-quality long video generation, representing our first step toward world models.
### Key Features
- 🌟 **Unified architecture for multiple tasks**: LongCat-Video unifies *Text-to-Video*, *Image-to-Video*, and *Video-Continuation* tasks within a single video generation framework. It natively supports all these tasks with a single model and consistently delivers strong performance across each individual task.
- 🌟 **Long video generation**: LongCat-Video is natively pretrained on *Video-Continuation* tasks, enabling it to produce minutes-long videos without color drifting or quality degradation.
- 🌟 **Efficient inference**: LongCat-Video generates $720p$, $30fps$ videos within minutes by employing a coarse-to-fine generation strategy along both the temporal and spatial axes. Block Sparse Attention further enhances efficiency, particularly at high resolutions
- 🌟 **Strong performance with multi-reward RLHF**: Powered by multi-reward Group Relative Policy Optimization (GRPO), comprehensive evaluations on both internal and public benchmarks demonstrate that LongCat-Video achieves performance comparable to leading open-source video generation models as well as the latest commercial solutions.
For more detail, please refer to the comprehensive [***LongCat-Video Technical Report***](https://github.com/meituan-longcat/LongCat-Video/blob/main/longcatvideo_tech_report.pdf).
## 🎥 Teaser Video
<div align="center">
<video src="https://github.com/user-attachments/assets/00fa63f0-9c4e-461a-a79e-c662ad596d7d" width="2264" height="384"> </video>
</div>
## Quick Start
### Installation
Clone the repo:
```shell
git clone --single-branch --branch main https://github.com/meituan-longcat/LongCat-Video
cd LongCat-Video
```
Install dependencies:
```shell
# create conda environment
conda create -n longcat-video python=3.10
conda activate longcat-video
# install torch (configure according to your CUDA version)
pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
# install flash-attn-2
pip install ninja
pip install psutil
pip install packaging
pip install flash_attn==2.7.4.post1
# install other requirements
pip install -r requirements.txt
```
FlashAttention-2 is enabled in the model config by default; you can also change the model config ("./weights/LongCat-Video/dit/config.json") to use FlashAttention-3 or xformers once installed.
### Model Download
| Models | Download Link |
| --- | --- |
| LongCat-Video | 🤗 [Huggingface](https://huggingface.co/meituan-longcat/LongCat-Video) |
Download models using huggingface-cli:
```shell
pip install "huggingface_hub[cli]"
huggingface-cli download meituan-longcat/LongCat-Video --local-dir ./weights/LongCat-Video
```
### Run Text-to-Video
```shell
# Single-GPU inference
torchrun run_demo_text_to_video.py --checkpoint_dir=./weights/LongCat-Video --enable_compile
# Multi-GPU inference
torchrun --nproc_per_node=2 run_demo_text_to_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video --enable_compile
```
### Run Image-to-Video
```shell
# Single-GPU inference
torchrun run_demo_image_to_video.py --checkpoint_dir=./weights/LongCat-Video --enable_compile
# Multi-GPU inference
torchrun --nproc_per_node=2 run_demo_image_to_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video --enable_compile
```
### Run Video-Continuation
```shell
# Single-GPU inference
torchrun run_demo_video_continuation.py --checkpoint_dir=./weights/LongCat-Video --enable_compile
# Multi-GPU inference
torchrun --nproc_per_node=2 run_demo_video_continuation.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video --enable_compile
```
### Run Long-Video Generation
```shell
# Single-GPU inference
torchrun run_demo_long_video.py --checkpoint_dir=./weights/LongCat-Video --enable_compile
# Multi-GPU inference
torchrun --nproc_per_node=2 run_demo_long_video.py --context_parallel_size=2 --checkpoint_dir=./weights/LongCat-Video --enable_compile
```
### Run Streamlit
```shell
# Single-GPU inference
streamlit run ./run_streamlit.py --server.fileWatcherType none --server.headless=false
```
## Evaluation Results
### Text-to-Video
The *Text-to-Video* MOS evaluation results on our internal benchmark.
| **MOS score** | **Veo3** | **PixVerse-V5** | **Wan 2.2-T2V-A14B** | **LongCat-Video** |
|---------------|-------------------|--------------------|-------------|-------------|
| **Accessibility** | Proprietary | Proprietary | Open Source | Open Source |
| **Architecture** | - | - | MoE | Dense |
| **# Total Params** | - | - | 28B | 13.6B |
| **# Activated Params** | - | - | 14B | 13.6B |
| Text-Alignment↑ | 3.99 | 3.81 | 3.70 | 3.76 |
| Visual Quality↑ | 3.23 | 3.13 | 3.26 | 3.25 |
| Motion Quality↑ | 3.86 | 3.81 | 3.78 | 3.74 |
| Overall Quality↑ | 3.48 | 3.36 | 3.35 | 3.38 |
### Image-to-Video
The *Image-to-Video* MOS evaluation results on our internal benchmark.
| **MOS score** | **Seedance 1.0** | **Hailuo-02** | **Wan 2.2-I2V-A14B** | **LongCat-Video** |
|---------------|-------------------|--------------------|-------------|-------------|
| **Accessibility** | Proprietary | Proprietary | Open Source | Open Source |
| **Architecture** | - | - | MoE | Dense |
| **# Total Params** | - | - | 28B | 13.6B |
| **# Activated Params** | - | - | 14B | 13.6B |
| Image-Alignment↑ | 4.12 | 4.18 | 4.18 | 4.04 |
| Text-Alignment↑ | 3.70 | 3.85 | 3.33 | 3.49 |
| Visual Quality↑ | 3.22 | 3.18 | 3.23 | 3.27 |
| Motion Quality↑ | 3.77 | 3.80 | 3.79 | 3.59 |
| Overall Quality↑ | 3.35 | 3.27 | 3.26 | 3.17 |
## License Agreement
The **model weights** are released under the **MIT License**.
Any contributions to this repository are licensed under the MIT License, unless otherwise stated. This license does not grant any rights to use Meituan trademarks or patents.
See the [LICENSE](LICENSE) file for the full license text.
## Usage Considerations
This model has not been specifically designed or comprehensively evaluated for every possible downstream application.
Developers should take into account the known limitations of large language models, including performance variations across different languages, and carefully assess accuracy, safety, and fairness before deploying the model in sensitive or high-risk scenarios.
It is the responsibility of developers and downstream users to understand and comply with all applicable laws and regulations relevant to their use case, including but not limited to data protection, privacy, and content safety requirements.
Nothing in this Model Card should be interpreted as altering or restricting the terms of the MIT License under which the model is released.
## Citation
We kindly encourage citation of our work if you find it useful.
```
@misc{meituan2025longcatvideotechnicalreport,
title={LongCat-Video Technical Report},
author={Meituan LongCat Team},
year={2025},
eprint={xxx},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/xxx},
}
```
## Acknowledgements
We would like to thank the contributors to the [Wan](https://huggingface.co/Wan-AI), [UMT5-XXL](https://huggingface.co/google/umt5-xxl), [Diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research.
## Contact
Please contact us at <a href="mailto:longcat-team@meituan.com">longcat-team@meituan.com</a> or join our WeChat Group if you have any questions.
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
fill="none"
version="1.1"
width="594.24902"
height="100"
viewBox="0 0 594.24902 100"
id="svg7"
sodipodi:docname="longcat-video_logo.svg"
inkscape:version="1.4.2 (ebf0e940, 2025-05-08)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs7" />
<sodipodi:namedview
id="namedview7"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="1.858108"
inkscape:cx="276.08729"
inkscape:cy="6.996364"
inkscape:window-width="2304"
inkscape:window-height="1243"
inkscape:window-x="0"
inkscape:window-y="25"
inkscape:window-maximized="1"
inkscape:current-layer="g7"
inkscape:export-bgcolor="#ffffffff" />
<g
id="g7">
<g
id="g6">
<g
id="g3" />
<g
id="g4">
<path
d="m 4.62614,81.999556 c -1.321058,0 -2.279197,-1.2581 -1.9280502,-2.5317 L 19.566,18.291386 c 0.721,-2.615181 3.781,-3.769321 6.0495,-2.281758 l 22.1949,14.554028 c 1.33,0.8721 3.0502,0.8735 4.3816,0.0034 L 74.481,16.001644 c 2.2706,-1.483766 5.3282,-0.32536 6.0457,2.290412 l 16.7797,61.1784 c 0.3491,1.273 -0.6088,2.5291 -1.9288,2.5291 H 73.9544 c 3.8999,-4.5155 6.0457,-10.2829 6.0457,-16.2494 v -0.6963 c 0,-5.8361 -2.123,-11.473 -5.9728,-15.8593 l -2.7494,-13.8003 c -0.1615,-0.8108 -0.8732,-1.3947 -1.7,-1.3947 -0.3751,0 -0.74,0.1216 -1.0401,0.3466 l -10.3659,7.7745 c -0.7382,0.5536 -1.693,0.73 -2.5803,0.4764 -3.6546,-1.0441 -7.5285,-1.0441 -11.1831,0 -0.8873,0.2536 -1.8421,0.0772 -2.5803,-0.4764 l -10.3705,-7.7779 c -0.297,-0.2228 -0.6584,-0.3432 -1.0297,-0.3432 -0.8275,0 -1.5372,0.5904 -1.6877,1.404 l -2.6674,14.4205 c -3.8919,3.9576 -6.0728,9.286 -6.0728,14.8366 v 1.3191 c 0,5.8204 2.0821,11.4489 5.8699,15.8681 l 0.1301,0.1517 z"
fill-rule="evenodd"
fill="#29e154"
fill-opacity="1"
id="path3" />
</g>
<g
id="g5">
<path
d="m 39,70 h 6 V 56 h -5.090909 z m 22,0 H 55 V 56 h 5.0909 z"
fill="#000000"
fill-opacity="1"
id="path4" />
<path
d="M 37.93296,71 H 46 V 55 H 38.97192 Z M 54,71 h 8.067 L 61.0281,55 H 54 Z M 44,70 H 39 L 39.909091,56 H 45 V 70 Z M 60.9351,69 61,70 H 59.9979 56 55 V 56 h 5.0909 z"
fill-rule="evenodd"
fill="#ffffff"
fill-opacity="1"
id="path5" />
</g>
</g>
<text
xml:space="preserve"
style="font-size:58.6667px;text-align:start;writing-mode:lr-tb;direction:ltr;text-anchor:start;fill:#000000"
x="112.31583"
y="72.60601"
id="text7"><tspan
sodipodi:role="line"
id="tspan7"
x="112.31583"
y="72.60601"
style="font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;font-size:58.6667px;font-family:meituan;-inkscape-font-specification:'meituan Bold';stroke:#ffffff;stroke-opacity:1">LongCat-Video</tspan></text>
</g>
</svg>
image.png

102 KB

This diff is collapsed.
import triton
import triton.language as tl
import os
if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1':
autotune = triton.autotune
else:
def autotune(*args, **kwargs):
def decorator(func):
return func
return decorator
configs_gating_preset = {
'default': {
'BLOCK_M': 64,
'BLOCK_N': 64,
'num_stages': 3,
'num_warps': 8,
}
}
configs_gating = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BM in [64, 128] \
for BN in [32, 64] \
for s in [2, 3, 4, 5] \
for w in [4, 8] \
]
gating_reevaluate_keys = ["M", "N"] if os.environ.get('TRITON_REEVALUATE_KEY', '0') == '1' else []
@autotune(configs_gating, key=gating_reevaluate_keys)
@triton.jit
def _attn_fwd_gating(
Q, K, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_oz, stride_oh, stride_om, stride_on,
H, M, N,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(M, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(HEAD_DIM, N),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(M, N),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr, boundary_check=(0,))
for start_n in range(0, N, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr, boundary_check=(1,))
qk = tl.dot(q, k)
tl.store(O_block_ptr, qk.to(Out.type.element_ty), boundary_check=(0, 1))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
O_block_ptr = tl.advance(O_block_ptr, (0, BLOCK_N))
@triton.jit
def _attn_bwd_preprocess(
O, DO,
Delta, # output
N_CTX,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, HEAD_DIM)
# load
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_hz * N_CTX + off_m, delta)
import torch
def p2p_communicate(
rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
"""Point-to-point communications of KV and dKV in Attention with context parallelism"""
send_recv_ops = []
if batch_p2p_comm: # int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) 为啥呢
if rank % 2 == 0:
send_op = torch.distributed.P2POp(
torch.distributed.isend, send_tensor, send_dst, cp_group
)
recv_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_tensor, recv_src, cp_group
)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_tensor, recv_src, cp_group
)
send_op = torch.distributed.P2POp(
torch.distributed.isend, send_tensor, send_dst, cp_group
)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
else:
if rank % 2 == 0:
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = send_recv_ops
return send_recv_reqs
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from einops import rearrange
dp_size = dp_group = cp_group = cp_stream = dp_ranks = cp_ranks = dp_rank = None
cp_size: int = 1
cp_rank: int = 0
def init_context_parallel(context_parallel_size: int = 1,
global_rank: int = 0,
world_size: int = 1,):
global dp_size, cp_size, dp_group, cp_group, dp_ranks, cp_ranks, dp_rank, cp_rank
if world_size % context_parallel_size != 0:
raise RuntimeError(f'world_size {world_size} must be multiple of context_parallel_size {context_parallel_size}')
cp_size = context_parallel_size
dp_size = world_size//context_parallel_size
print(f'[rank {global_rank}] init_device_mesh [dp_size x cp_size]: [{dp_size} x {cp_size}]')
mesh_2d = init_device_mesh("cuda", (dp_size, cp_size), mesh_dim_names=("dp", "cp"))
print(f'[rank {global_rank}] mesh_2d: {mesh_2d}')
dp_group = mesh_2d.get_group(mesh_dim="dp")
cp_group = mesh_2d.get_group(mesh_dim="cp")
dp_ranks = torch.distributed.get_process_group_ranks(dp_group)
cp_ranks = torch.distributed.get_process_group_ranks(cp_group)
dp_rank = dist.get_rank(group=dp_group)
cp_rank = dist.get_rank(group=cp_group)
curr_global_rank = torch.distributed.get_rank()
print(f'[rank {curr_global_rank}] [dp_rank, cp_rank]: [{dp_rank}, {cp_rank}], dp_ranks: {dp_ranks}, cp_ranks: {cp_ranks}')
def get_cp_size():
global cp_size
return cp_size
def get_dp_size():
global dp_size
return dp_size
def get_cp_stream():
global cp_stream
if cp_stream == None:
cp_stream = torch.cuda.Stream()
return cp_stream
def get_dp_group():
global dp_group
return dp_group
def get_cp_group():
global cp_group
return cp_group
def get_dp_rank():
global dp_rank
return dp_rank
def get_cp_rank():
global cp_rank
return cp_rank
def get_cp_rank_list():
global cp_ranks
if cp_ranks == None:
cp_ranks = torch.distributed.get_process_group_ranks(cp_group)
return cp_ranks
def cp_broadcast(tensor, cp_index=0):
global dp_group
global cp_group
cp_ranks = get_cp_rank_list()
torch.distributed.broadcast(tensor, cp_ranks[cp_index], group=cp_group)
def split_tensor_in_cp_2d(input, dim_hw, split_hw):
global cp_size
dim_h, dim_w = dim_hw
split_h, split_w = split_hw
assert cp_size == split_h * split_w
seq_size_h = input.shape[dim_h]
seq_size_w = input.shape[dim_w]
if seq_size_h % split_h != 0:
raise RuntimeError(f'seq_size_h {seq_size_h} in dim_h {dim_h} must be multiple of split_h {split_h}!!!')
if seq_size_w % split_w != 0:
raise RuntimeError(f'seq_size_w {seq_size_w} in dim_w {dim_w} must be multiple of split_w {split_w}!!!')
split_seq_size_h = seq_size_h // split_h
split_seq_size_w = seq_size_w // split_w
tensor_splits_h = input.split(split_seq_size_h, dim=dim_h)
tensor_splits = []
for tensor_split_h in tensor_splits_h:
tensor_splits_hw = tensor_split_h.split(split_seq_size_w, dim=dim_w)
tensor_splits.extend(tensor_splits_hw)
cp_rank = get_cp_rank()
split_tensor = tensor_splits[cp_rank]
return split_tensor
class GatherFunction2D(torch.autograd.Function):
@staticmethod
def forward(ctx, input, process_group, seq_dim_hw, shape, split_hw):
ctx.cp_group = process_group
ctx.seq_dim_hw = seq_dim_hw
ctx.split_hw = split_hw
ctx.shape = shape
ctx.cp_size = get_cp_size()
T, H, W = shape
dim_h, dim_w = seq_dim_hw
split_h, split_w = split_hw
assert H % split_h == 0, W % split_w == 0
assert T * (H // split_h) * (W // split_w) == input.shape[1]
input = rearrange(input, "B (T H W) C -> B T H W C", T=T, H=H // split_h, W=W // split_w)
with torch.no_grad():
input = input.contiguous()
output_tensors = [torch.zeros_like(input) for _ in range(ctx.cp_size)]
dist.all_gather(output_tensors, input, group=ctx.cp_group)
output_tensors_hs = []
assert ctx.cp_size % split_w == 0
for i in range(0, ctx.cp_size // split_w):
output_tensors_hs.append(
torch.cat(output_tensors[i * split_w : (i + 1) * split_w], dim=dim_w)
)
output_tensor = torch.cat(output_tensors_hs, dim=dim_h)
output_tensor = rearrange(output_tensor, "B T H W C -> B (T H W) C")
return output_tensor
@staticmethod
def backward(ctx, grad_output):
T, H, W = ctx.shape
with torch.no_grad():
grad_output = grad_output * ctx.cp_size
grad_output = rearrange(grad_output, "B (T H W) C -> B T H W C", T=T, H=H, W=W)
grad_input = split_tensor_in_cp_2d(grad_output, ctx.seq_dim_hw, ctx.split_hw)
grad_input = rearrange(grad_input, "B T H W C -> B (T H W) C")
return grad_input, None, None, None, None
class SplitFunction2D(torch.autograd.Function):
@staticmethod
def forward(ctx, input, process_group, seq_dim_hw, split_hw):
ctx.cp_group = process_group
ctx.seq_dim_hw = seq_dim_hw
ctx.split_hw = split_hw
ctx.cp_size = get_cp_size()
output_tensor = split_tensor_in_cp_2d(input, ctx.seq_dim_hw, split_hw)
return output_tensor
@staticmethod
def backward(ctx, grad_output):
with torch.no_grad():
grad_output = grad_output / ctx.cp_size
output_tensors = [torch.zeros_like(grad_output) for _ in range(ctx.cp_size)]
dist.all_gather(output_tensors, grad_output, group=ctx.cp_group)
split_h, split_w = ctx.split_hw
dim_h, dim_w = ctx.seq_dim_hw
output_tensors_hs = []
assert ctx.cp_size % split_w == 0
for i in range(0, ctx.cp_size // split_w):
output_tensors_hs.append(
torch.cat(output_tensors[i * split_w : (i + 1) * split_w], dim=dim_w)
)
grad_input = torch.cat(output_tensors_hs, dim=dim_h)
return grad_input, None, None, None
def gather_cp_2d(input, shape, split_hw):
cp_process_group = get_cp_group()
output_tensor = GatherFunction2D.apply(input, cp_process_group, (2, 3), shape, split_hw)
return output_tensor
def split_cp_2d(input, seq_dim_hw, split_hw):
cp_process_group = get_cp_group()
output_tensor = SplitFunction2D.apply(input, cp_process_group, seq_dim_hw, split_hw)
return output_tensor
def get_optimal_split(size):
factors = []
for i in range(1, int(size**0.5) + 1):
if size % i == 0:
factors.append([i, size // i])
return min(factors, key=lambda x: abs(x[0] - x[1]))
\ No newline at end of file
import torch
import torch.distributed as dist
from ..context_parallel import context_parallel_util
def all_to_all(tensor, scatter_idx, gather_idx, group=None, gather=True):
"""Perform all-to-all communication on a tensor.
Args:
tensor (torch.Tensor): Input tensor for all-to-all communication
scatter_idx (int): Dimension to scatter, will split along this dimension and then scatter to all processes
gather_idx (int): Dimension to gather, will gather from all processes and then concatenate along this dimension
group (ProcessGroup, optional): Process group to use for communication
Returns:
torch.Tensor
"""
if not dist.is_initialized():
return tensor
world_size = dist.get_world_size(group)
ulysses_rank = context_parallel_util.get_cp_rank()
if world_size == 1:
return tensor
if scatter_idx == gather_idx:
raise ValueError("scatter_idx and gather_idx must be different")
def chunk_tensor(tensor, scatter_idx):
t_shape = list(tensor.shape)
if t_shape[scatter_idx] % world_size != 0:
raise ValueError(f"Dimension {scatter_idx} must be divisible by world size {world_size}")
chunk_size = t_shape[scatter_idx] // world_size
new_shape = list()
for i in range(len(t_shape)):
if i != scatter_idx:
new_shape.append(t_shape[i])
else:
new_shape.extend([world_size, chunk_size])
tensor = tensor.reshape(*new_shape)
# move scatter_idx to front
tensor = tensor.permute(scatter_idx, *[i for i in range(len(new_shape)) if i != scatter_idx]).contiguous()
return tensor
# chunk tensor for all_to_all
tensor = chunk_tensor(tensor, scatter_idx)
# Perform all2all
output = torch.empty_like(tensor)
dist.all_to_all_single(output, tensor, group=group)
# output: e.g., [world_size, B, chunked_H, chunked_S, D] if scatter_idx == 1, gather_idx == 2 -> [B, chunked_H, S, D]
def reorder_tensor(tensor, gather_idx):
t_shape = list(tensor.shape)
world_size = t_shape[0]
# insert front to gather_idx + 1
permute_idx = list()
for i in range(1, len(t_shape)):
if i != gather_idx + 1:
permute_idx.append(i)
else:
permute_idx.extend([0, i])
tensor = tensor.permute(*permute_idx).contiguous() # permute(1,2,0,3) W B CH CS D -> B CH W CS D
# reshape tensor
new_shape = list()
if gather:
for i in range(1, len(t_shape)): # B CH CS D
if i != gather_idx + 1:
new_shape.append(t_shape[i])
else:
new_shape.append(world_size * t_shape[i]) # B CH W*CS D
tensor = tensor.reshape(*new_shape)
else:
tensor = tensor[:,ulysses_rank] # W B CS CH D -> B CS W CH D
return tensor
output = reorder_tensor(output, gather_idx)
return output
@torch.compiler.disable
def ulysses_a2a_in(query, key, value):
if context_parallel_util.get_cp_size() == 1:
return query, key, value
# [B, H, S/N, D] -> [B, H/N, S, D]
query = all_to_all(query, scatter_idx=1, gather_idx=2, group=context_parallel_util.get_cp_group())
key = all_to_all(key, scatter_idx=1, gather_idx=2, group=context_parallel_util.get_cp_group())
value = all_to_all(value, scatter_idx=1, gather_idx=2, group=context_parallel_util.get_cp_group())
return query, key, value
@torch.compiler.disable
def ulysses_a2a_out(output):
if context_parallel_util.get_cp_size() == 1:
return output
# [B, H/N, S, D] -> [B, H, S/N, D]
output = all_to_all(output, scatter_idx=2, gather_idx=1, group=context_parallel_util.get_cp_group())
return output
def ulysses_wrapper(func):
def wrapper(self, query, key, value, shape):
# Apply ulysses_a2a_in before the function call, gather sequence and split head
query, key, value = ulysses_a2a_in(query, key, value)
output = func(self, query, key, value, shape)
output = ulysses_a2a_out(output)
return output
return wrapper
from typing import List, Optional
import torch
import torch.nn as nn
from einops import rearrange
from .rope_3d import RotaryPositionalEmbedding
from .blocks import RMSNorm_FP32
from ..block_sparse_attention.bsa_interface import flash_attn_bsa_3d
from ..context_parallel.ulysses_wrapper import ulysses_wrapper
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
self.enable_bsa = enable_bsa
self.bsa_params = bsa_params
self.cp_split_hw = cp_split_hw
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.proj = nn.Linear(dim, dim)
self.rope_3d = RotaryPositionalEmbedding(
self.head_dim,
cp_split_hw=cp_split_hw
)
@ulysses_wrapper
def _process_attn(self, q, k, v, shape):
"""
function wrapper to do attention with q, k, v
"""
B, H, SQ, D = q.shape
_, _, SKV, _ = k.shape
if self.enable_bsa and shape[0] > 1: # bsa will not be used in image training / sampling
assert self.bsa_params is not None
_, H, W = shape
assert H % self.cp_split_hw[0] == 0, W % self.cp_split_hw[1] == 0
H, W = H // self.cp_split_hw[0], W // self.cp_split_hw[1]
Tq = SQ // (H * W)
Tk = SKV // (H * W)
latent_shape_q = (Tq, H, W)
latent_shape_k = (Tk, H, W)
x = flash_attn_bsa_3d(q, k, v, latent_shape_q, latent_shape_k, **self.bsa_params)
elif self.enable_flashattn3:
from flash_attn_interface import flash_attn_func
q = rearrange(q, "B H S D -> B S H D").contiguous()
k = rearrange(k, "B H S D -> B S H D").contiguous()
v = rearrange(v, "B H S D -> B S H D").contiguous()
x, *_ = flash_attn_func(
q,
k,
v,
softmax_scale=self.scale,
)
x = rearrange(x, "B S H D -> B H S D")
elif self.enable_flashattn2:
from flash_attn import flash_attn_func
q = rearrange(q, "B H S D -> B S H D")
k = rearrange(k, "B H S D -> B S H D")
v = rearrange(v, "B H S D -> B S H D")
x = flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=self.scale,
)
x = rearrange(x, "B S H D -> B H S D")
elif self.enable_xformers:
import xformers.ops
# Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
# the sequence length, H the number of heads, and K the embeding size per head
q = rearrange(q, "B H M K -> B M H K")
k = rearrange(k, "B H M K -> B M H K")
v = rearrange(v, "B H M K -> B M H K")
x = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None,)
x = rearrange(x, "B M H K -> B H M K")
else:
raise RuntimeError("Unsupported attention operations.")
return x
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if return_kv:
k_cache, v_cache = k.clone(), v.clone()
q, k = self.rope_3d(q, k, shape)
# cond mode
if num_cond_latents is not None and num_cond_latents > 0:
num_cond_latents_thw = num_cond_latents * (N // shape[0])
# process the condition tokens
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
# process the noise tokens
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
x_noise = self._process_attn(q_noise, k, v, shape)
# merge x_cond and x_noise
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
else:
x = self._process_attn(q, k, v, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
if return_kv:
return x, (k_cache, v_cache)
else:
return x
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
T, H, W = shape
k_cache, v_cache = kv_cache
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
if k_cache.shape[0] == 1:
k_cache = k_cache.repeat(B, 1, 1, 1)
v_cache = v_cache.repeat(B, 1, 1, 1)
if num_cond_latents is not None and num_cond_latents > 0:
k_full = torch.cat([k_cache, k], dim=2).contiguous()
v_full = torch.cat([v_cache, v], dim=2).contiguous()
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
q = q_padding[:, :, -N:].contiguous()
x = self._process_attn(q, k_full, v_full, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads,
enable_flashattn3=False,
enable_flashattn2=False,
enable_xformers=False,
):
super(MultiHeadCrossAttention, self).__init__()
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = nn.Linear(dim, dim)
self.kv_linear = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
def _process_cross_attn(self, x, cond, kv_seqlen):
B, N, C = x.shape
assert C == self.dim and cond.shape[2] == self.dim
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
if self.enable_flashattn3:
from flash_attn_interface import flash_attn_varlen_func
x = flash_attn_varlen_func(
q=q[0],
k=k[0],
v=v[0],
cu_seqlens_q=torch.tensor([0] + [N] * B, device=q.device).cumsum(0).to(torch.int32),
cu_seqlens_k=torch.tensor([0] + kv_seqlen, device=q.device).cumsum(0).to(torch.int32),
max_seqlen_q=N,
max_seqlen_k=max(kv_seqlen),
)[0]
elif self.enable_flashattn2:
from flash_attn import flash_attn_varlen_func
x = flash_attn_varlen_func(
q=q[0],
k=k[0],
v=v[0],
cu_seqlens_q=torch.tensor([0] + [N] * B, device=q.device).cumsum(0).to(torch.int32),
cu_seqlens_k=torch.tensor([0] + kv_seqlen, device=q.device).cumsum(0).to(torch.int32),
max_seqlen_q=N,
max_seqlen_k=max(kv_seqlen),
)
elif self.enable_xformers:
import xformers.ops
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens([N] * B, kv_seqlen)
x = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
else:
raise RuntimeError("Unsupported attention operations.")
x = x.view(B, -1, C)
x = self.proj(x)
return x
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
"""
x: [B, N, C]
cond: [B, M, C]
"""
if num_cond_latents is None or num_cond_latents == 0:
return self._process_cross_attn(x, cond, kv_seqlen)
else:
B, N, C = x.shape
if num_cond_latents is not None and num_cond_latents > 0:
assert shape is not None, "SHOULD pass in the shape"
num_cond_latents_thw = num_cond_latents * (N // shape[0])
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
output = torch.cat([
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
output_noise
], dim=1).contiguous()
else:
raise NotImplementedError
return output
\ No newline at end of file
This diff is collapsed.
# References:
# https://github.com/hpcaitech/Open-Sora
# https://github.com/facebookresearch/DiT/blob/main/models.py
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
# https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py#L14
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp as amp
from typing import Optional
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class RMSNorm_FP32(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class LayerNorm_FP32(nn.LayerNorm):
def __init__(self, dim, eps, elementwise_affine):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
out = F.layer_norm(
inputs.float(),
self.normalized_shape,
None if self.weight is None else self.weight.float(),
None if self.bias is None else self.bias.float() ,
self.eps
).to(origin_dtype)
return out
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
B, C, T, H, W = x.shape
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
def modulate_fp32(norm_func, x, shift, scale):
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
# ensure the modulation params be fp32
assert shift.dtype == torch.float32, scale.dtype == torch.float32
dtype = x.dtype
x = norm_func(x.to(torch.float32))
x = x * (scale + 1) + shift
x = x.to(dtype)
return x
class FinalLayer_FP32(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
super().__init__()
self.hidden_size = hidden_size
self.num_patch = num_patch
self.out_channels = out_channels
self.adaln_tembed_dim = adaln_tembed_dim
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
def forward(self, x, t, latent_shape):
# timestep shape: [B, T, C]
assert t.dtype == torch.float32
B, N, C = x.shape
T, _, _ = latent_shape
with amp.autocast('cuda', dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, t_embed_dim, frequency_embedding_size=256):
super().__init__()
self.t_embed_dim = t_embed_dim
self.frequency_embedding_size = frequency_embedding_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
nn.SiLU(),
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
)
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations.
"""
def __init__(self, in_channels, hidden_size):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.y_proj = nn.Sequential(
nn.Linear(in_channels, hidden_size, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, caption):
B, _, N, C = caption.shape
caption = self.y_proj(caption)
return caption
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.amp as amp
import numpy as np
from einops import rearrange
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from safetensors.torch import load_file
from .lora_utils import create_lora_network
from ..context_parallel import context_parallel_util
from .attention import Attention, MultiHeadCrossAttention
from .blocks import TimestepEmbedder, CaptionEmbedder, PatchEmbed3D, FeedForwardSwiGLU, FinalLayer_FP32, LayerNorm_FP32, modulate_fp32
class LongCatSingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: int,
adaln_tembed_dim: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params=None,
cp_split_hw=None
):
super().__init__()
self.hidden_size = hidden_size
# scale and gate modulation
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
)
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
self.attn = Attention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
self.cross_attn = MultiHeadCrossAttention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
)
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
"""
x: [B, N, C]
y: [1, N_valid_tokens, C]
t: [B, T, C_t]
y_seqlen: [B]; type of a list
latent_shape: latent shape of a single item
"""
x_dtype = x.dtype
B, N, C = x.shape
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
# self attn with modulation
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
if kv_cache is not None:
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
else:
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
if return_kv:
x_s, kv_cache = attn_outputs
else:
x_s = attn_outputs
with amp.autocast(device_type='cuda', dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
# cross attn
if not skip_crs_attn:
if kv_cache is not None:
num_cond_latents = None
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
if return_kv:
return x, kv_cache
else:
return x
class LongCatVideoTransformer3DModel(
ModelMixin, ConfigMixin
):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
depth: int = 48,
num_heads: int = 32,
caption_channels: int = 4096,
mlp_ratio: int = 4,
adaln_tembed_dim: int = 512,
frequency_embedding_size: int = 256,
# default params
patch_size: Tuple[int] = (1, 2, 2),
# attention config
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None,
text_tokens_zero_pad: bool = False,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.cp_split_hw = cp_split_hw
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
)
self.blocks = nn.ModuleList(
[
LongCatSingleStreamBlock(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
adaln_tembed_dim=adaln_tembed_dim,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
for i in range(depth)
]
)
self.final_layer = FinalLayer_FP32(
hidden_size,
np.prod(self.patch_size),
out_channels,
adaln_tembed_dim,
)
self.gradient_checkpointing = False
self.text_tokens_zero_pad = text_tokens_zero_pad
self.lora_dict = {}
self.active_loras = []
def load_lora(self, lora_path, lora_key, multiplier=1.0, lora_network_dim=128, lora_network_alpha=64):
lora_network_state_dict_loaded = load_file(lora_path, device="cpu")
lora_network = create_lora_network(
transformer=self,
lora_network_state_dict_loaded=lora_network_state_dict_loaded,
multiplier=multiplier,
network_dim=lora_network_dim,
network_alpha=lora_network_alpha,
)
lora_network.load_state_dict(lora_network_state_dict_loaded, strict=True)
self.lora_dict[lora_key] = lora_network
def enable_loras(self, lora_key_list=[]):
self.disable_all_loras()
module_loras = {} # {module_name: [lora1, lora2, ...]}
model_device = next(self.parameters()).device
model_dtype = next(self.parameters()).dtype
for lora_key in lora_key_list:
if lora_key in self.lora_dict:
for lora in self.lora_dict[lora_key].loras:
lora.to(model_device, dtype=model_dtype, non_blocking=True)
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
if module_name not in module_loras:
module_loras[module_name] = []
module_loras[module_name].append(lora)
self.active_loras.append(lora_key)
for module_name, loras in module_loras.items():
module = self._get_module_by_name(module_name)
if not hasattr(module, 'org_forward'):
module.org_forward = module.forward
module.forward = self._create_multi_lora_forward(module, loras)
def _create_multi_lora_forward(self, module, loras):
def multi_lora_forward(x, *args, **kwargs):
weight_dtype = x.dtype
org_output = module.org_forward(x, *args, **kwargs)
total_lora_output = 0
for lora in loras:
if lora.use_lora:
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
lx = lora.lora_up(lx)
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
total_lora_output += lora_output
return org_output + total_lora_output
return multi_lora_forward
def _get_module_by_name(self, module_name):
try:
module = self
for part in module_name.split('.'):
module = getattr(module, part)
return module
except AttributeError as e:
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
def disable_all_loras(self):
for name, module in self.named_modules():
if hasattr(module, 'org_forward'):
module.forward = module.org_forward
delattr(module, 'org_forward')
for lora_key, lora_network in self.lora_dict.items():
for lora in lora_network.loras:
lora.to("cpu")
self.active_loras.clear()
def enable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = True
def disable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = False
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states,
encoder_attention_mask=None,
num_cond_latents=0,
return_kv=False,
kv_cache_dict={},
skip_crs_attn=False,
offload_kv_cache=False
):
B, _, T, H, W = hidden_states.shape
N_t = T // self.patch_size[0]
N_h = H // self.patch_size[1]
N_w = W // self.patch_size[2]
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
# expand the shape of timestep from [B] to [B, T]
if len(timestep.shape) == 1:
timestep = timestep.unsqueeze(1).expand(-1, N_t) # [B, T]
dtype = self.x_embedder.proj.weight.dtype
hidden_states = hidden_states.to(dtype)
timestep = timestep.to(dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype)
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type='cuda', dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
else:
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
# blocks
kv_cache_dict_ret = {}
for i, block in enumerate(self.blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
block_outputs = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, t, y_seqlens,
(N_t, N_h, N_w), num_cond_latents, return_kv, kv_cache_dict.get(i, None), skip_crs_attn
)
else:
block_outputs = block(
hidden_states, encoder_hidden_states, t, y_seqlens,
(N_t, N_h, N_w), num_cond_latents, return_kv, kv_cache_dict.get(i, None), skip_crs_attn
)
if return_kv:
hidden_states, kv_cache = block_outputs
if offload_kv_cache:
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
else:
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
else:
hidden_states = block_outputs
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
# cast to float32 for better accuracy
hidden_states = hidden_states.to(torch.float32)
if return_kv:
return hidden_states, kv_cache_dict_ret
else:
return hidden_states
def unpatchify(self, x, N_t, N_h, N_w):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
\ No newline at end of file
This diff is collapsed.
# References:
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/rotary_positional_embedding.py
import numpy as np
from functools import lru_cache
import torch
import torch.nn as nn
from einops import rearrange, repeat
from ..context_parallel import context_parallel_util
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class RotaryPositionalEmbedding(nn.Module):
def __init__(self,
head_dim,
cp_split_hw=None
):
"""Rotary positional embedding for 3D
Reference : https://blog.eleuther.ai/rotary-embeddings/
Paper: https://arxiv.org/pdf/2104.09864.pdf
Args:
dim: Dimension of embedding
base: Base value for exponential
"""
super().__init__()
self.head_dim = head_dim
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
self.cp_split_hw = cp_split_hw
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
self.base = 10000
self.freqs_dict = {}
def register_grid_size(self, grid_size):
if grid_size not in self.freqs_dict:
self.freqs_dict.update({
grid_size: self.precompute_freqs_cis_3d(grid_size)
})
def precompute_freqs_cis_3d(self, grid_size):
num_frames, height, width = grid_size
dim_t = self.head_dim - 4 * (self.head_dim // 6)
dim_h = 2 * (self.head_dim // 6)
dim_w = 2 * (self.head_dim // 6)
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
grid_t = torch.from_numpy(grid_t).float()
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
# (T H W D)
freqs = rearrange(freqs, "T H W D -> (T H W) D")
if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
with torch.no_grad():
freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
freqs = rearrange(freqs, "T H W D -> (T H W) D")
return freqs
def forward(self, q, k, grid_size):
"""3D RoPE.
Args:
query: [B, head, seq, head_dim]
key: [B, head, seq, head_dim]
Returns:
query and key with the same shape as input.
"""
if grid_size not in self.freqs_dict:
self.register_grid_size(grid_size)
freqs_cis = self.freqs_dict[grid_size].to(q.device)
q_, k_ = q.float(), k.float()
freqs_cis = freqs_cis.float().to(q.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
q_ = (q_ * cos) + (rotate_half(q_) * sin)
k_ = (k_ * cos) + (rotate_half(k_) * sin)
return q_.type_as(q), k_.type_as(k)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment