Commit 0e56f303 authored by mashun's avatar mashun
Browse files

pyramid-flow

parents
Pipeline #2007 canceled with stages
# Xcode
.DS_Store
.idea
# tyte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
onnx_model/*.onnx
onnx_model/antelope/*.onnx
logs/
prompts/
# Distribution / packaging
.Python
build/
develop-eggs/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.pt2/
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
.bak
pyramid_flow_model
temp_dit
temp_dit_no_ar
temp_vae
*.mp4
datasets
\ No newline at end of file
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-py3.10-dtk24.04.3-ubuntu20.04
MIT License
Copyright (c) 2024 Yang Jin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Pyramid-Flow
## 论文
`PYRAMIDAL FLOW MATCHING FOR EFFICIENT VIDEO GENERATIVE MODELING`
* https://arxiv.org/pdf/2410.05954
## 模型结构
该项目采用flux.1的模型结构,增加了人物结构的稳定性。
![alt text](readme_imgs/arch.png)
## 算法原理
该算法主要关注时间-空间问题的解决,如下:
1、将生成过程分解为多个金字塔阶段,每个阶段在更低分辨率的压缩空间中进行操作,最终阶段才在原始分辨率下进行,从而减少冗余计算。
2、使用逐渐增加分辨率的压缩历史作为条件,进一步减少训练所需的视频标记数量,提高训练效率。
<img src="readme_imgs/alg.png" style="zoom:100%">
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-py3.10-dtk24.04.3-ubuntu20.04
docker run --shm-size 50g --network=host --name=pyramid-flow --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 50g --network=host --name=pyramid-flow --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
### Anaconda (方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/
DTK驱动:dtk24.04.3
python:python3.10
torch: 2.1.0
torchvision: 0.16.0
Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应
2、其它非特殊库参照requirements.txt安装
pip install -r requirements.txt
## 数据集
除本项目提供的数据集外,也可自行准备其他可用数据集。
video: [VIDGEN-1M](http://113.200.138.88:18080/aidatasets/fudan-fuxi/VIDGEN-1M)
image: [text-to-image-2M](https://hf-mirror.com/datasets/jackyhate/text-to-image-2M/tree/) | [SCNet高速下载通道](http://113.200.138.88:18080/aidatasets/jackyhate/text-to-image-2M)
### VAE数据格式
```
# For Video
{"video": video_path}
# For Image
{"image": image_path}
```
数据处理脚本
```bash
cd extra_utils
python generate_vae_annotation.py \
--data_root="/path/to/[image|video]" \
--data_type="[image|video]" \
--save_path="/path/to/save/xxx.jsonl"
```
### DiT数据格式
```
{"video": video_path, "text": text prompt, "latent": extracted video vae latent, "text_fea": extracted text feature}
```
数据处理脚本
```bash
cd extra_utils
python get_video_text.py \
--video_root="/path/to/video_root" \
--caption_json_path="/path/to/caption_file" \
--save_root="/path/to/save_file_root" \
--video_latent_root="/path/to/save/video_latent_root" \
--text_fea_root="/path/to/save/text_feature_root"
```
注意:该脚本仅适用于给定数据集,更多数据处理脚本见`extra_utils`
## 训练
### VAE
```bash
bash scripts/train_causal_video_vae.sh
```
注意:需要在该文件中修改相应参数。
### DiT
#### 数据准备
1、提取视频vae-latent
```bash
bash scripts/extract_vae_latent.sh
```
2、提取T5文本特征(可选)
```bash
bash scripts/extract_text_feature.sh
```
注意:在运行前需确保相应参数正确。
### run
```bash
# 使用时间金字塔的自回归视频生成训练
bash scripts/train_pyramid_flow.sh
```
```bash
# 使用 pyramid-flow 进行全序列扩散训练
bash scripts/train_pyramid_flow_without_ar.sh
```
注意:在运行前需确保相应参数正确,详情见`参考资料`
## 推理
```bash
# 多卡推理 - 增加推理速度
HIP_VISIBLE_DEVICES=0,1 bash scripts/inference_multigpu.sh
```
```bash
# 运行video_generation_demo.ipynb代码时需启动jupyter服务
jupyter notebook --no-browser --ip=0.0.0.0 --allow-root
```
### webui
```bash
python app.py
```
## result
|类别|输入|结果|
|:---:|:---:|:---:|
|t2v|a cat on the moon, salt desert, cinematic style, shot on 35mm film, vivid colors| ![video](readme_imgs/result.gif)|
|i2v|<img src="assets/the_great_wall.jpg" zoom="10%" >|![](readme_imgs/result2.gif)
### 精度
与Nvidia GPU精度一致。
## 应用场景
### 算法类别
`AIGC`
### 热点应用行业
`电商,教育,广媒`
## 预训练权重
[huggingface](https://huggingface.co/rain1011/pyramid-flow-miniflux/tree/main) | [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/rain1011/pyramid-flow-miniflux)
[vgg_lpips](https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.0/vgg.pth)
### 权重文件结构
```
pyramid_flow_model/
└── pyramid-flow-miniflux
├── vgg.pth
├── causal_video_vae
│   ├── config.json
│   └── diffusion_pytorch_model.bin
├── diffusion_transformer_384p
│   ├── config.json
│   └── diffusion_pytorch_model.safetensors
├── diffusion_transformer_image
│   ├── config.json
│   └── diffusion_pytorch_model.safetensors
├── README.md
├── text_encoder
│   ├── config.json
│   └── model.safetensors
├── text_encoder_2
│   ├── config.json
│   ├── model-00001-of-00002.safetensors
│   ├── model-00002-of-00002.safetensors
│   └── model.safetensors.index.json
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
└── tokenizer_2
├── special_tokens_map.json
├── spiece.model
├── tokenizer_config.json
└── tokenizer.json
```
## 源码仓库及问题反馈
* https://developer.sourcefind.cn/codes/modelzoo/pyramid-flow_pytorch
## 参考资料
* https://blog.csdn.net/2401_84760322/article/details/141558082
<div align="center">
# ⚡️Pyramid Flow⚡️
[[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page ✨]](https://pyramid-flow.github.io) [[miniFLUX Model 🚀]](https://huggingface.co/rain1011/pyramid-flow-miniflux) [[SD3 Model ⚡️]](https://huggingface.co/rain1011/pyramid-flow-sd3) [[demo 🤗](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow)]
</div>
This is the official repository for Pyramid Flow, a training-efficient **Autoregressive Video Generation** method based on **Flow Matching**. By training only on **open-source datasets**, it can generate high-quality 10-second videos at 768p resolution and 24 FPS, and naturally supports image-to-video generation.
<table class="center" border="0" style="width: 100%; text-align: left;">
<tr>
<th>10s, 768p, 24fps</th>
<th>5s, 768p, 24fps</th>
<th>Image-to-video</th>
</tr>
<tr>
<td><video src="https://github.com/user-attachments/assets/9935da83-ae56-4672-8747-0f46e90f7b2b" autoplay muted loop playsinline></video></td>
<td><video src="https://github.com/user-attachments/assets/3412848b-64db-4d9e-8dbf-11403f6d02c5" autoplay muted loop playsinline></video></td>
<td><video src="https://github.com/user-attachments/assets/3bd7251f-7b2c-4bee-951d-656fdb45f427" autoplay muted loop playsinline></video></td>
</tr>
</table>
## News
* `2024.10.29` ⚡️⚡️⚡️ We release [training code for VAE](#1-training-vae), [finetuning code for DiT](#2-finetuning-dit) and [new model checkpoints](https://huggingface.co/rain1011/pyramid-flow-miniflux) with FLUX structure trained from scratch.
> We have switched the model structure from SD3 to a mini FLUX to fix human structure issues, please try our 1024p image checkpoint and 384p video checkpoint (up to 5s). The new miniflux model shows great improvement on human structure and motion stability. We will release 768p video checkpoint in a few days.
* `2024.10.13` ✨✨✨ [Multi-GPU inference](#3-multi-gpu-inference) and [CPU offloading](#cpu-offloading) are supported. Use it with **less than 8GB** of GPU memory, with great speedup on multiple GPUs.
* `2024.10.11` 🤗🤗🤗 [Hugging Face demo](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow) is available. Thanks [@multimodalart](https://huggingface.co/multimodalart) for the commit!
* `2024.10.10` 🚀🚀🚀 We release the [technical report](https://arxiv.org/abs/2410.05954), [project page](https://pyramid-flow.github.io) and [model checkpoint](https://huggingface.co/rain1011/pyramid-flow-sd3) of Pyramid Flow.
## Table of Contents
* [Introduction](#introduction)
* [Installation](#installation)
* [Inference](#inference)
1. [Quick Start with Gradio](#1-quick-start-with-gradio)
2. [Inference Code](#2-inference-code)
3. [Multi-GPU Inference](#3-multi-gpu-inference)
4. [Usage Tips](#4-usage-tips)
* [Training](#Training)
1. [Training VAE](#training-vae)
2. [Finetuning DiT](#finetuning-dit)
* [Gallery](#gallery)
* [Comparison](#comparison)
* [Acknowledgement](#acknowledgement)
* [Citation](#citation)
## Introduction
![motivation](assets/motivation.jpg)
Existing video diffusion models operate at full resolution, spending a lot of computation on very noisy latents. By contrast, our method harnesses the flexibility of flow matching ([Lipman et al., 2023](https://openreview.net/forum?id=PqvMRDCJT9t); [Liu et al., 2023](https://openreview.net/forum?id=XVjTT1nw5z); [Albergo & Vanden-Eijnden, 2023](https://openreview.net/forum?id=li7qeBbCR1t)) to interpolate between latents of different resolutions and noise levels, allowing for simultaneous generation and decompression of visual content with better computational efficiency. The entire framework is end-to-end optimized with a single DiT ([Peebles & Xie, 2023](http://openaccess.thecvf.com/content/ICCV2023/html/Peebles_Scalable_Diffusion_Models_with_Transformers_ICCV_2023_paper.html)), generating high-quality 10-second videos at 768p resolution and 24 FPS within 20.7k A100 GPU training hours.
## Installation
We recommend setting up the environment with conda. The codebase currently uses Python 3.8.10 and PyTorch 2.1.2 ([guide](https://pytorch.org/get-started/previous-versions/#v212)), and we are actively working to support a wider range of versions.
```bash
git clone https://github.com/jy0205/Pyramid-Flow
cd Pyramid-Flow
# create env using conda
conda create -n pyramid python==3.8.10
conda activate pyramid
pip install -r requirements.txt
```
Then, download the model from [Huggingface](https://huggingface.co/rain1011) (there are two variants: [miniFLUX](https://huggingface.co/rain1011/pyramid-flow-miniflux) or [SD3](https://huggingface.co/rain1011/pyramid-flow-sd3)). The miniFLUX models support 1024p image and 384p video generation, and the SD3-based models support 768p and 384p video generation. The 384p checkpoint generates 5-second video at 24FPS, while the 768p checkpoint generates up to 10-second video at 24FPS.
```python
from huggingface_hub import snapshot_download
model_path = 'PATH' # The local directory to save downloaded checkpoint
snapshot_download("rain1011/pyramid-flow-miniflux", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
```
## Inference
### 1. Quick start with Gradio
To get started, first install [Gradio](https://www.gradio.app/guides/quickstart), set your model path at [#L36](https://github.com/jy0205/Pyramid-Flow/blob/3777f8b84bddfa2aa2b497ca919b3f40567712e6/app.py#L36), and then run on your local machine:
```bash
python app.py
```
The Gradio demo will be opened in a browser. Thanks to [@tpc2233](https://github.com/tpc2233) the commit, see [#48](https://github.com/jy0205/Pyramid-Flow/pull/48) for details.
Or, try it out effortlessly on [Hugging Face Space 🤗](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow) created by [@multimodalart](https://huggingface.co/multimodalart). Due to GPU limits, this online demo can only generate 25 frames (export at 8FPS or 24FPS). Duplicate the space to generate longer videos.
### 2. Inference Code
To use our model, please follow the inference code in `video_generation_demo.ipynb` at [this link](https://github.com/jy0205/Pyramid-Flow/blob/main/video_generation_demo.ipynb). We strongly recommend you to try the latest published pyramid-miniflux, which shows great improvement on human structure and motion stability. Set the param `model_name` to `pyramid_flux` to use. We further simplify it into the following two-step procedure. First, load the downloaded model:
```python
import torch
from PIL import Image
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import load_image, export_to_video
torch.cuda.set_device(0)
model_dtype, torch_dtype = 'bf16', torch.bfloat16 # Use bf16 (not support fp16 yet)
model = PyramidDiTForVideoGeneration(
'PATH', # The downloaded checkpoint dir
model_name="pyramid_flux",
model_dtype,
model_variant='diffusion_transformer_384p', # SD3 supports 'diffusion_transformer_768p'
)
model.vae.enable_tiling()
# model.vae.to("cuda")
# model.dit.to("cuda")
# model.text_encoder.to("cuda")
# if you're not using sequential offloading bellow uncomment the lines above ^
model.enable_sequential_cpu_offload()
```
Then, you can try text-to-video generation on your own prompts. Noting that the 384p version only support 5s now (set temp up to 16)!
```python
prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=384,
width=640,
temp=16, # temp=16: 5s, temp=31: 10s
guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
video_guidance_scale=5.0, # The guidance for the other video latent
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
)
export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
```
As an autoregressive model, our model also supports (text conditioned) image-to-video generation:
```python
image = Image.open('assets/the_great_wall.jpg').convert("RGB").resize((640, 384))
prompt = "FPV flying over the Great Wall"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=16,
video_guidance_scale=4.0,
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
)
export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
```
#### CPU offloading
We also support two types of CPU offloading to reduce GPU memory requirements. Note that they may sacrifice efficiency.
* Adding a `cpu_offloading=True` parameter to the generate function allows inference with **less than 12GB** of GPU memory. This feature was contributed by [@Ednaordinary](https://github.com/Ednaordinary), see [#23](https://github.com/jy0205/Pyramid-Flow/pull/23) for details.
* Calling `model.enable_sequential_cpu_offload()` before the above procedure allows inference with **less than 8GB** of GPU memory. This feature was contributed by [@rodjjo](https://github.com/rodjjo), see [#75](https://github.com/jy0205/Pyramid-Flow/pull/75) for details.
#### MPS backend
Thanks to [@niw](https://github.com/niw), Apple Silicon users (e.g. MacBook Pro with M2 24GB) can also try our model using the MPS backend! Please see [#113](https://github.com/jy0205/Pyramid-Flow/pull/113) for the details.
### 3. Multi-GPU Inference
For users with multiple GPUs, we provide an [inference script](https://github.com/jy0205/Pyramid-Flow/blob/main/scripts/inference_multigpu.sh) that uses sequence parallelism to save memory on each GPU. This also brings a big speedup, taking only 2.5 minutes to generate a 5s, 768p, 24fps video on 4 A100 GPUs (vs. 5.5 minutes on a single A100 GPU). Run it on 2 GPUs with the following command:
```bash
CUDA_VISIBLE_DEVICES=0,1 sh scripts/inference_multigpu.sh
```
It currently supports 2 or 4 GPUs, with more configurations available in the original script. You can also launch a [multi-GPU Gradio demo](https://github.com/jy0205/Pyramid-Flow/blob/main/scripts/app_multigpu_engine.sh) created by [@tpc2233](https://github.com/tpc2233), see [#59](https://github.com/jy0205/Pyramid-Flow/pull/59) for details.
> Spoiler: We didn't even use sequence parallelism in training, thanks to our efficient pyramid flow designs.
### 4. Usage tips
* The `guidance_scale` parameter controls the visual quality. We suggest using a guidance within [7, 9] for the 768p checkpoint during text-to-video generation, and 7 for the 384p checkpoint.
* The `video_guidance_scale` parameter controls the motion. A larger value increases the dynamic degree and mitigates the autoregressive generation degradation, while a smaller value stabilizes the video.
* For 10-second video generation, we recommend using a guidance scale of 7 and a video guidance scale of 5.
## Training
### 1. Training VAE
The hardware requirements for training VAE are at least 8 A100 GPUs. Please refer to [this document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/VAE.md). This is a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code.
### 2. Finetuning DiT
The hardware requirements for finetuning DiT are at least 8 A100 GPUs. Please refer to [this document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT.md). We provide instructions for both autoregressive and non-autoregressive versions of Pyramid Flow. The former is more research oriented and the latter is more stable (but less efficient without temporal pyramid).
## Gallery
The following video examples are generated at 5s, 768p, 24fps. For more results, please visit our [project page](https://pyramid-flow.github.io).
<table class="center" border="0" style="width: 100%; text-align: left;">
<tr>
<td><video src="https://github.com/user-attachments/assets/5b44a57e-fa08-4554-84a2-2c7a99f2b343" autoplay muted loop playsinline></video></td>
<td><video src="https://github.com/user-attachments/assets/5afd5970-de72-40e2-900d-a20d18308e8e" autoplay muted loop playsinline></video></td>
</tr>
<tr>
<td><video src="https://github.com/user-attachments/assets/1d44daf8-017f-40e9-bf18-1e19c0a8983b" autoplay muted loop playsinline></video></td>
<td><video src="https://github.com/user-attachments/assets/7f5dd901-b7d7-48cc-b67a-3c5f9e1546d2" autoplay muted loop playsinline></video></td>
</tr>
</table>
## Comparison
On VBench ([Huang et al., 2024](https://huggingface.co/spaces/Vchitect/VBench_Leaderboard)), our method surpasses all the compared open-source baselines. Even with only public video data, it achieves comparable performance to commercial models like Kling ([Kuaishou, 2024](https://kling.kuaishou.com/en)) and Gen-3 Alpha ([Runway, 2024](https://runwayml.com/research/introducing-gen-3-alpha)), especially in the quality score (84.74 vs. 84.11 of Gen-3) and motion smoothness.
![vbench](assets/vbench.jpg)
We conduct an additional user study with 20+ participants. As can be seen, our method is preferred over open-source models such as [Open-Sora](https://github.com/hpcaitech/Open-Sora) and [CogVideoX-2B](https://github.com/THUDM/CogVideo) especially in terms of motion smoothness.
![user_study](assets/user_study.jpg)
## Acknowledgement
We are grateful for the following awesome projects when implementing Pyramid Flow:
* [SD3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) and [Flux 1.0](https://huggingface.co/black-forest-labs/FLUX.1-dev): State-of-the-art image generation models based on flow matching.
* [Diffusion Forcing](https://boyuan.space/diffusion-forcing) and [GameNGen](https://gamengen.github.io): Next-token prediction meets full-sequence diffusion.
* [WebVid-10M](https://github.com/m-bain/webvid), [OpenVid-1M](https://github.com/NJU-PCALab/OpenVid-1M) and [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan): Large-scale datasets for text-to-video generation.
* [CogVideoX](https://github.com/THUDM/CogVideo): An open-source text-to-video generation model that shares many training details.
* [Video-LLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2): An open-source video LLM for our video recaptioning.
## Citation
Consider giving this repository a star and cite Pyramid Flow in your publications if it helps your research.
```
@article{jin2024pyramidal,
title={Pyramidal Flow Matching for Efficient Video Generative Modeling},
author={Jin, Yang and Sun, Zhicheng and Li, Ningyuan and Xu, Kun and Xu, Kun and Jiang, Hao and Zhuang, Nan and Huang, Quzhe and Song, Yang and Mu, Yadong and Lin, Zhouchen},
jounal={arXiv preprint arXiv:2410.05954},
year={2024}
}
```
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
{"image": "SAM_filter/000424/sa_4749867.jpg", "text": "a cityscape with a large body of water, such as a lake or a river, in the foreground"}
{"image": "SAM_filter/000311/sa_3490721.jpg", "text": "a large, stately building with a white and blue color scheme, which gives it a grand and elegant appearance"}
{"image": "SAM_filter/000273/sa_3059407.jpg", "text": "a close-up of a green bag containing a package of Japanese soybeans, along with a bottle of sake, a traditional Japanese alcoholic beverage"}
{"image": "SAM_filter/000745/sa_8344729.jpg", "text": "a large, old-fashioned building with a red and white color scheme"}
{"image": "SAM_filter/000832/sa_9310794.jpg", "text": "a cityscape with a large tower, likely the Eiffel Tower, as the main focal point"}
{"image": "SAM_filter/000427/sa_4779422.jpg", "text": "a large cruise ship, specifically a Royal Caribbean cruise ship, docked at a pier in a harbor"}
{"image": "SAM_filter/000105/sa_1178255.jpg", "text": "a close-up view of a computer screen with a magnifying glass placed over it"}
{"image": "SAM_filter/000765/sa_8560467.jpg", "text": "a tree with a sign attached to it, which is located in a lush green field"}
{"image": "SAM_filter/000216/sa_2417372.jpg", "text": "a large airport terminal with a long blue and white rope-style security line"}
{"image": "SAM_filter/000385/sa_4308806.jpg", "text": "a close-up of a cell phone screen displaying a blue and white logo, which appears to be a bank logo"}
{"image": "SAM_filter/000931/sa_10425835.jpg", "text": "a large body of water, possibly a lake, with a lush green landscape surrounding it"}
{"image": "SAM_filter/000364/sa_4079002.jpg", "text": "a large, empty airport terminal with a long row of gray metal chairs arranged in a straight line"}
{"image": "SAM_filter/000474/sa_5306222.jpg", "text": "a large, modern building with a tall, glass structure, which is likely a museum"}
{"image": "SAM_filter/000584/sa_6536849.jpg", "text": "a city street scene with a black car parked in a parking lot, a building with a balcony, and a city skyline in the background"}
{"image": "SAM_filter/000188/sa_2104485.jpg", "text": "a large jet fighter airplane flying through the sky, captured in a high-quality photograph"}
{"image": "SAM_filter/000219/sa_2458908.jpg", "text": "a stone structure with a tall tower, which is situated in a lush green garden"}
{"image": "SAM_filter/000440/sa_4929413.jpg", "text": "a large city street with a mix of architectural styles, including a Gothic-style building and a modern building"}
{"image": "SAM_filter/000739/sa_8279296.jpg", "text": "a vintage blue and white bus parked on the side of a dirt road, with a building in the background"}
{"image": "SAM_filter/000809/sa_9052304.jpg", "text": "a large, old stone building with a clock tower, which is situated in a small town"}
{"image": "SAM_filter/000294/sa_3300200.jpg", "text": "a table with various utensils, including a bowl, spoon, and fork, placed on a wooden surface"}
\ No newline at end of file
{"video": "webvid10m/train/010451_010500/23388121.mp4", "text": "the serene beauty of a valley with a river, mountains, and clouds", "latent": "webvid10m/train/010451_010500/23388121-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/010451_010500/23388121-text.pt"}
{"video": "pexels/8440980-uhd_3840_2160_25fps.mp4", "text": "A group of people, including two men and two women, are seen sitting at a table, smiling and waving at the camera, and appear to be in a good mood", "latent": "pexels/8440980-uhd_3840_2160_25fps-latent-384-2.pt", "text_fea": "text_feature/pexels/8440980-uhd_3840_2160_25fps-text.pt"}
{"video": "webvid10m/train/176251_176300/1011015221.mp4", "text": "an aerial view of a large wheat field with a road running through it, and a car driving on the road", "latent": "webvid10m/train/176251_176300/1011015221-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/176251_176300/1011015221-text.pt"}
{"video": "webvid10m/train/005801_005850/22143805.mp4", "text": "a close-up of paint mixing in water, creating swirling patterns", "latent": "webvid10m/train/005801_005850/22143805-latent-384-8.pt", "text_fea": "text_feature/webvid10m/train/005801_005850/22143805-text.pt"}
{"video": "OpenVid-1M/videos/qsXY7FkNFwE_2_0to743.mp4", "text": "A baby girl in a pink shirt and striped pants sits in a high chair, eats a piece of bread, and looks at the camera", "latent": "OpenVid-1M/videos/qsXY7FkNFwE_2_0to743-latent-384-0.pt", "text_fea": "text_feature/OpenVid-1M/videos/qsXY7FkNFwE_2_0to743-text.pt"}
{"video": "webvid10m/train/134901_134950/1037990273.mp4", "text": "a field of green wheat waving in the wind", "latent": "webvid10m/train/134901_134950/1037990273-latent-384-6.pt", "text_fea": "text_feature/webvid10m/train/134901_134950/1037990273-text.pt"}
{"video": "pexels/5263258-uhd_2160_4096_30fps.mp4", "text": "A dog sits patiently in front of its bowl, waiting for it to be filled with food", "latent": "pexels/5263258-uhd_2160_4096_30fps-latent-384-6.pt", "text_fea": "text_feature/pexels/5263258-uhd_2160_4096_30fps-text.pt"}
{"video": "webvid10m/train/117851_117900/6461432.mp4", "text": "A ladybug crawls along a blade of grass in a serene natural setting", "latent": "webvid10m/train/117851_117900/6461432-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/117851_117900/6461432-text.pt"}
{"video": "webvid10m/train/053051_053100/1058396656.mp4", "text": "a group of construction workers working on a rooftop, with a supervisor overseeing the work", "latent": "webvid10m/train/053051_053100/1058396656-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/053051_053100/1058396656-text.pt"}
{"video": "webvid10m/train/073651_073700/1021916425.mp4", "text": "an aerial view of a beautiful coastline with rocky islands, blue water, and a white cloud in the sky", "latent": "webvid10m/train/073651_073700/1021916425-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/073651_073700/1021916425-text.pt"}
{"video": "webvid10m/train/027051_027100/1032549941.mp4", "text": "a young woman waking up in bed, smiling at the camera, and then lying back down on the bed", "latent": "webvid10m/train/027051_027100/1032549941-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/027051_027100/1032549941-text.pt"}
{"video": "pexels/5564564-uhd_3840_2160_24fps.mp4", "text": "a person rolling out dough on a table using a rolling pin", "latent": "pexels/5564564-uhd_3840_2160_24fps-latent-384-8.pt", "text_fea": "text_feature/pexels/5564564-uhd_3840_2160_24fps-text.pt"}
{"video": "webvid10m/train/073701_073750/24008116.mp4", "text": "a cityscape with a moon in the sky, and the camera pans across the city", "latent": "webvid10m/train/073701_073750/24008116-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/073701_073750/24008116-text.pt"}
{"video": "webvid10m/train/118351_118400/23370991.mp4", "text": "a group of dolphins swimming in the ocean, with a person on a boat nearby", "latent": "webvid10m/train/118351_118400/23370991-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/118351_118400/23370991-text.pt"}
{"video": "webvid10m/train/022001_022050/1023013066.mp4", "text": "a bird's eye view of a beachfront city, highlighting the hotels, pools, and proximity to the ocean", "latent": "webvid10m/train/022001_022050/1023013066-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/022001_022050/1023013066-text.pt"}
{"video": "webvid10m/train/004601_004650/1015979020.mp4", "text": "a bridge over a body of water, with a boat passing under it", "latent": "webvid10m/train/004601_004650/1015979020-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/004601_004650/1015979020-text.pt"}
{"video": "webvid10m/train/149701_149750/1034525579.mp4", "text": "a group of owls and a moon, with the moon appearing to grow larger as the video progresses", "latent": "webvid10m/train/149701_149750/1034525579-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/149701_149750/1034525579-text.pt"}
\ No newline at end of file
import os
import uuid
import gradio as gr
import torch
import PIL
from PIL import Image
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download
import threading
# Global model cache
model_cache = {}
# Lock to ensure thread-safe access to the model cache
model_cache_lock = threading.Lock()
# Configuration
model_name = "pyramid_flux" # or pyramid_mmdit
model_repo = "rain1011/pyramid-flow-sd3" if model_name == "pyramid_mmdit" else "rain1011/pyramid-flow-miniflux"
model_dtype = "bf16" # Support bf16 and fp32
variants = {
'high': 'diffusion_transformer_768p', # For high-resolution version
'low': 'diffusion_transformer_384p' # For low-resolution version
}
required_file = 'config.json' # Ensure config.json is present
width_high = 1280
height_high = 768
width_low = 640
height_low = 384
cpu_offloading = True # enable cpu_offloading by default
# Get the current working directory and create a folder to store the model
current_directory = os.getcwd()
model_path = os.path.join(current_directory, "pyramid_flow_model") # Directory to store the model
# Download the model if not already present
def download_model_from_hf(model_repo, model_dir, variants, required_file):
need_download = False
if not os.path.exists(model_dir):
print(f"[INFO] Model directory '{model_dir}' does not exist. Initiating download...")
need_download = True
else:
# Check if all required files exist for each variant
for variant_key, variant_dir in variants.items():
variant_path = os.path.join(model_dir, variant_dir)
file_path = os.path.join(variant_path, required_file)
if not os.path.exists(file_path):
print(f"[WARNING] Required file '{required_file}' missing in '{variant_path}'.")
need_download = True
break
if need_download:
print(f"[INFO] Downloading model from '{model_repo}' to '{model_dir}'...")
try:
snapshot_download(
repo_id=model_repo,
local_dir=model_dir,
local_dir_use_symlinks=False,
repo_type='model'
)
print("[INFO] Model download complete.")
except Exception as e:
print(f"[ERROR] Failed to download the model: {e}")
raise
else:
print(f"[INFO] All required model files are present in '{model_dir}'. Skipping download.")
# Download model from Hugging Face if not present
download_model_from_hf(model_repo, model_path, variants, required_file)
# Function to initialize the model based on user options
def initialize_model(variant):
print(f"[INFO] Initializing model with variant='{variant}', using bf16 precision...")
# Determine the correct variant directory
variant_dir = variants['high'] if variant == '768p' else variants['low']
base_path = model_path # Pass the base model path
print(f"[DEBUG] Model base path: {base_path}")
# Verify that config.json exists in the variant directory
config_path = os.path.join(model_path, variant_dir, 'config.json')
if not os.path.exists(config_path):
print(f"[ERROR] config.json not found in '{os.path.join(model_path, variant_dir)}'.")
raise FileNotFoundError(f"config.json not found in '{os.path.join(model_path, variant_dir)}'.")
if model_dtype == "bf16":
torch_dtype_selected = torch.bfloat16
else:
torch_dtype_selected = torch.float32
# Initialize the model
try:
# TODO: remove this check code after miniflux 768 version is released
if model_name == "pyramid_flux":
if variant_dir == "diffusion_transformer_768p":
raise NotImplementedError("The pyramid_flux does not support high resolution now, we will release it after finishing training. \
You can modify the model_name to pyramid_mmdit to support 768p version generation")
model = PyramidDiTForVideoGeneration(
base_path, # Pass the base model path
model_name=model_name, # set to pyramid_flux or pyramid_mmdit
model_dtype=model_dtype, # Use bf16
model_variant=variant_dir, # Pass the variant directory name
cpu_offloading=cpu_offloading, # Pass the CPU offloading flag
)
# Always enable tiling for the VAE
model.vae.enable_tiling()
# Remove manual device placement when using CPU offloading
# The components will be moved to the appropriate devices automatically
if torch.cuda.is_available():
torch.cuda.set_device(0)
# Manual device replacement when not using CPU offloading
if not cpu_offloading:
model.vae.to("cuda")
model.dit.to("cuda")
model.text_encoder.to("cuda")
else:
print("[WARNING] CUDA is not available. Proceeding without GPU.")
print("[INFO] Model initialized successfully.")
return model, torch_dtype_selected
except Exception as e:
print(f"[ERROR] Error initializing model: {e}")
raise
# Function to get the model from cache or initialize it
def initialize_model_cached(variant):
key = variant
# Check if the model is already in the cache
if key not in model_cache:
with model_cache_lock:
# Double-checked locking to prevent race conditions
if key not in model_cache:
model, dtype = initialize_model(variant)
model_cache[key] = (model, dtype)
return model_cache[key]
def resize_crop_image(img: PIL.Image.Image, tgt_width, tgt_height):
ori_width, ori_height = img.width, img.height
scale = max(tgt_width / ori_width, tgt_height / ori_height)
resized_width = round(ori_width * scale)
resized_height = round(ori_height * scale)
img = img.resize((resized_width, resized_height), resample=PIL.Image.LANCZOS)
left = (resized_width - tgt_width) / 2
top = (resized_height - tgt_height) / 2
right = (resized_width + tgt_width) / 2
bottom = (resized_height + tgt_height) / 2
# Crop the center of the image
img = img.crop((left, top, right, bottom))
return img
# Function to generate text-to-video
def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, progress=gr.Progress()):
progress(0, desc="Loading model")
print("[DEBUG] generate_text_to_video called.")
variant = '768p' if resolution == "768p" else '384p'
height = height_high if resolution == "768p" else height_low
width = width_high if resolution == "768p" else width_low
def progress_callback(i, m):
progress(i/m)
# Initialize model based on user options using cached function
try:
model, torch_dtype_selected = initialize_model_cached(variant)
except Exception as e:
print(f"[ERROR] Model initialization failed: {e}")
return f"Model initialization failed: {e}"
try:
print("[INFO] Starting text-to-video generation...")
with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=height,
width=width,
temp=temp,
guidance_scale=guidance_scale,
video_guidance_scale=video_guidance_scale,
output_type="pil",
cpu_offloading=cpu_offloading,
save_memory=True,
callback=progress_callback,
)
print("[INFO] Text-to-video generation completed.")
except Exception as e:
print(f"[ERROR] Error during text-to-video generation: {e}")
return f"Error during video generation: {e}"
video_path = f"{str(uuid.uuid4())}_text_to_video_sample.mp4"
try:
export_to_video(frames, video_path, fps=24)
print(f"[INFO] Video exported to {video_path}.")
except Exception as e:
print(f"[ERROR] Error exporting video: {e}")
return f"Error exporting video: {e}"
return video_path
# Function to generate image-to-video
def generate_image_to_video(image, prompt, temp, video_guidance_scale, resolution, progress=gr.Progress()):
progress(0, desc="Loading model")
print("[DEBUG] generate_image_to_video called.")
variant = '768p' if resolution == "768p" else '384p'
height = height_high if resolution == "768p" else height_low
width = width_high if resolution == "768p" else width_low
try:
image = resize_crop_image(image, width, height)
print("[INFO] Image resized and cropped successfully.")
except Exception as e:
print(f"[ERROR] Error processing image: {e}")
return f"Error processing image: {e}"
def progress_callback(i, m):
progress(i/m)
# Initialize model based on user options using cached function
try:
model, torch_dtype_selected = initialize_model_cached(variant)
except Exception as e:
print(f"[ERROR] Model initialization failed: {e}")
return f"Model initialization failed: {e}"
try:
print("[INFO] Starting image-to-video generation...")
with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=temp,
video_guidance_scale=video_guidance_scale,
output_type="pil",
cpu_offloading=cpu_offloading,
save_memory=True,
callback=progress_callback,
)
print("[INFO] Image-to-video generation completed.")
except Exception as e:
print(f"[ERROR] Error during image-to-video generation: {e}")
return f"Error during video generation: {e}"
video_path = f"{str(uuid.uuid4())}_image_to_video_sample.mp4"
try:
export_to_video(frames, video_path, fps=24)
print(f"[INFO] Video exported to {video_path}.")
except Exception as e:
print(f"[ERROR] Error exporting video: {e}")
return f"Error exporting video: {e}"
return video_path
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Pyramid Flow Video Generation Demo
Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours.
[[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page]](https://pyramid-flow.github.io) [[Code]](https://github.com/jy0205/Pyramid-Flow) [[Model]](https://huggingface.co/rain1011/pyramid-flow-sd3)
"""
)
# Shared settings
with gr.Row():
resolution_dropdown = gr.Dropdown(
choices=["768p", "384p"],
value="384p",
label="Model Resolution"
)
with gr.Tab("Text-to-Video"):
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(label="Prompt (Less than 128 words)", placeholder="Enter a text prompt for the video", lines=2)
temp_slider = gr.Slider(1, 16, value=16, step=1, label="Duration")
guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
txt_generate = gr.Button("Generate Video")
with gr.Column():
txt_output = gr.Video(label="Generated Video")
gr.Examples(
examples=[
["A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors", 16, 7.0, 5.0, "384p"],
["Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes", 16, 7.0, 5.0, "384p"],
# ["Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours", 31, 9.0, 5.0, "768p"],
],
inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown],
outputs=[txt_output],
fn=generate_text_to_video,
cache_examples='lazy',
)
with gr.Tab("Image-to-Video"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
image_prompt = gr.Textbox(label="Prompt (Less than 128 words)", placeholder="Enter a text prompt for the video", lines=2)
image_temp_slider = gr.Slider(2, 16, value=16, step=1, label="Duration")
image_video_guidance_scale_slider = gr.Slider(1.0, 7.0, value=4.0, step=0.1, label="Video Guidance Scale")
img_generate = gr.Button("Generate Video")
with gr.Column():
img_output = gr.Video(label="Generated Video")
gr.Examples(
examples=[
['assets/the_great_wall.jpg', 'FPV flying over the Great Wall', 16, 4.0, "384p"]
],
inputs=[image_input, image_prompt, image_temp_slider, image_video_guidance_scale_slider, resolution_dropdown],
outputs=[img_output],
fn=generate_image_to_video,
cache_examples='lazy',
)
# Update generate functions to include resolution options
txt_generate.click(
generate_text_to_video,
inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown],
outputs=txt_output
)
img_generate.click(
generate_image_to_video,
inputs=[image_input, image_prompt, image_temp_slider, image_video_guidance_scale_slider, resolution_dropdown],
outputs=img_output
)
# Launch Gradio app
demo.launch(share=True)
import os
import uuid
import gradio as gr
import subprocess
import tempfile
import shutil
def run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt):
"""
Runs the external multi-GPU inference script and returns the path to the generated video.
"""
# Create a temporary directory to store inputs and outputs
with tempfile.TemporaryDirectory() as tmpdir:
output_video = os.path.join(tmpdir, f"{uuid.uuid4()}_output.mp4")
# Path to the external shell script
script_path = "./scripts/app_multigpu_engine.sh" # Updated script path
# Prepare the command
cmd = [
script_path,
str(gpus),
variant,
model_path,
't2v', # Task is always 't2v' since 'i2v' is removed
str(temp),
str(guidance_scale),
str(video_guidance_scale),
resolution,
output_video,
prompt # Pass the prompt directly as an argument
]
try:
# Run the external script
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Error during video generation: {e}")
# After generation, move the video to a permanent location
final_output = os.path.join("generated_videos", f"{uuid.uuid4()}_output.mp4")
os.makedirs("generated_videos", exist_ok=True)
shutil.move(output_video, final_output)
return final_output
def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, gpus):
model_path = "./pyramid_flow_model" # Use the model path as specified
# Determine variant based on resolution
if resolution == "768p":
variant = "diffusion_transformer_768p"
else:
variant = "diffusion_transformer_384p"
return run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Pyramid Flow Video Generation Demo
Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours.
[[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page]](https://pyramid-flow.github.io) [[Code]](https://github.com/jy0205/Pyramid-Flow) [[Model]](https://huggingface.co/rain1011/pyramid-flow-sd3)
"""
)
# Shared settings
with gr.Row():
gpus_dropdown = gr.Dropdown(
choices=[2, 4],
value=4,
label="Number of GPUs"
)
resolution_dropdown = gr.Dropdown(
choices=["768p", "384p"],
value="768p",
label="Model Resolution"
)
with gr.Tab("Text-to-Video"):
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(
label="Prompt (Less than 128 words)",
placeholder="Enter a text prompt for the video",
lines=2
)
temp_slider = gr.Slider(1, 31, value=16, step=1, label="Duration")
guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
txt_generate = gr.Button("Generate Video")
with gr.Column():
txt_output = gr.Video(label="Generated Video")
gr.Examples(
examples=[
[
"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors",
16,
9.0,
5.0,
"768p",
4
],
[
"Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes",
16,
9.0,
5.0,
"768p",
4
],
[
"Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours",
31,
9.0,
5.0,
"768p",
4
],
],
inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, gpus_dropdown],
outputs=[txt_output],
fn=generate_text_to_video,
cache_examples='lazy',
)
# Update generate function for Text-to-Video
txt_generate.click(
generate_text_to_video,
inputs=[
text_prompt,
temp_slider,
guidance_scale_slider,
video_guidance_scale_slider,
resolution_dropdown,
gpus_dropdown
],
outputs=txt_output
)
# Launch Gradio app
demo.launch(share=True)
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import cv2\n",
"import torch\n",
"import numpy as np\n",
"import PIL\n",
"from PIL import Image\n",
"from einops import rearrange\n",
"from video_vae import CausalVideoVAELossWrapper\n",
"from torchvision import transforms as pth_transforms\n",
"from torchvision.transforms.functional import InterpolationMode\n",
"from IPython.display import Image as ipython_image\n",
"from diffusers.utils import load_image, export_to_video, export_to_gif\n",
"from IPython.display import HTML"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model_path = \"pyramid_flow_model/pyramid-flow-miniflux/causal_video_vae\" # The video-vae checkpoint dir\n",
"model_path = \"temp_vae\"\n",
"model_dtype = 'bf16'\n",
"\n",
"device_id = 3\n",
"torch.cuda.set_device(device_id)\n",
"\n",
"model = CausalVideoVAELossWrapper(\n",
" model_path,\n",
" model_dtype,\n",
" interpolate=False, \n",
" add_discriminator=False,\n",
")\n",
"model = model.to(\"cuda\")\n",
"\n",
"if model_dtype == \"bf16\":\n",
" torch_dtype = torch.bfloat16 \n",
"elif model_dtype == \"fp16\":\n",
" torch_dtype = torch.float16\n",
"else:\n",
" torch_dtype = torch.float32\n",
"\n",
"def image_transform(images, resize_width, resize_height):\n",
" transform_list = pth_transforms.Compose([\n",
" pth_transforms.Resize((resize_height, resize_width), InterpolationMode.BICUBIC, antialias=True),\n",
" pth_transforms.ToTensor(),\n",
" pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
" ])\n",
" return torch.stack([transform_list(image) for image in images])\n",
"\n",
"\n",
"def get_transform(width, height, new_width=None, new_height=None, resize=False,):\n",
" transform_list = []\n",
"\n",
" if resize:\n",
" if new_width is None:\n",
" new_width = width // 8 * 8\n",
" if new_height is None:\n",
" new_height = height // 8 * 8\n",
" transform_list.append(pth_transforms.Resize((new_height, new_width), InterpolationMode.BICUBIC, antialias=True))\n",
" \n",
" transform_list.extend([\n",
" pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
" ])\n",
" transform_list = pth_transforms.Compose(transform_list)\n",
"\n",
" return transform_list\n",
"\n",
"\n",
"def load_video_and_transform(video_path, frame_number, new_width=None, new_height=None, max_frames=600, sample_fps=24, resize=False):\n",
" try:\n",
" video_capture = cv2.VideoCapture(video_path)\n",
" fps = video_capture.get(cv2.CAP_PROP_FPS)\n",
" frames = []\n",
" pil_frames = []\n",
" while True:\n",
" flag, frame = video_capture.read()\n",
" if not flag:\n",
" break\n",
" \n",
" pil_frames.append(np.ascontiguousarray(frame[:, :, ::-1]))\n",
" frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
" frame = torch.from_numpy(frame)\n",
" frame = frame.permute(2, 0, 1)\n",
" frames.append(frame)\n",
" if len(frames) >= max_frames:\n",
" break\n",
"\n",
" video_capture.release()\n",
" interval = max(int(fps / sample_fps), 1)\n",
" pil_frames = pil_frames[::interval][:frame_number]\n",
" frames = frames[::interval][:frame_number]\n",
" frames = torch.stack(frames).float() / 255\n",
" width = frames.shape[-1]\n",
" height = frames.shape[-2]\n",
" video_transform = get_transform(width, height, new_width, new_height, resize=resize)\n",
" frames = video_transform(frames)\n",
" pil_frames = [Image.fromarray(frame).convert(\"RGB\") for frame in pil_frames]\n",
"\n",
" if resize:\n",
" if new_width is None:\n",
" new_width = width // 32 * 32\n",
" if new_height is None:\n",
" new_height = height // 32 * 32\n",
" pil_frames = [frame.resize((new_width or width, new_height or height), PIL.Image.BICUBIC) for frame in pil_frames]\n",
" return frames, pil_frames\n",
" except Exception:\n",
" return None\n",
"\n",
"\n",
"def show_video(ori_path, rec_path, width=\"100%\"):\n",
" html = ''\n",
" if ori_path is not None:\n",
" html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
" <source src=\"{ori_path}\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\"\n",
" \n",
" html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
" <source src=\"{rec_path}\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\"\n",
" return HTML(html)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Image Reconstruction"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_path = 'datasets/text_to_image/flux_1024_10k_00000000.jpg'\n",
"\n",
"image = Image.open(image_path).convert(\"RGB\")\n",
"resize_width = image.width // 8 * 8\n",
"resize_height = image.height // 8 * 8\n",
"input_image_tensor = image_transform([image], resize_width, resize_height)\n",
"input_image_tensor = input_image_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
" latent = model.encode_latent(input_image_tensor.to(\"cuda\"), sample=True)\n",
" rec_images = model.decode_latent(latent)\n",
"\n",
"display(image)\n",
"display(rec_images[0])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Video Reconstruction"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"video_path = 'datasets/vidgen/_-iHMLLwX0o-Scene-0037.mp4'\n",
"\n",
"frame_number = 57 # x*8 + 1\n",
"width = 640\n",
"height = 384\n",
"\n",
"video_frames_tensor, pil_video_frames = load_video_and_transform(video_path, frame_number, new_width=width, new_height=height, resize=True)\n",
"video_frames_tensor = video_frames_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
"print(video_frames_tensor.shape)\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
" latent = model.encode_latent(video_frames_tensor.to(\"cuda\"), sample=False, window_size=8, temporal_chunk=True)\n",
" rec_frames = model.decode_latent(latent.float(), window_size=2, temporal_chunk=True)\n",
"\n",
"export_to_video(pil_video_frames, './ori_video.mp4', fps=24)\n",
"export_to_video(rec_frames, \"./rec_video.mp4\", fps=24)\n",
"show_video('./ori_video.mp4', \"./rec_video.mp4\", \"60%\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
from .dataset_cls import (
ImageTextDataset,
LengthGroupedVideoTextDataset,
ImageDataset,
VideoDataset,
)
from .dataloaders import (
create_image_text_dataloaders,
create_length_grouped_video_text_dataloader,
create_mixed_dataloaders,
)
\ No newline at end of file
import torch
import torchvision
import numpy as np
import math
import random
import time
class Bucketeer:
def __init__(
self, dataloader,
sizes=[(256, 256), (192, 384), (192, 320), (384, 192), (320, 192)],
is_infinite=True, epoch=0,
):
# Ratios and Sizes : (w h)
self.sizes = sizes
self.batch_size = dataloader.batch_size
self._dataloader = dataloader
self.iterator = iter(dataloader)
self.sampler = dataloader.sampler
self.buckets = {s: [] for s in self.sizes}
self.is_infinite = is_infinite
self._epoch = epoch
def get_available_batch(self):
available_size = []
for b in self.buckets:
if len(self.buckets[b]) >= self.batch_size:
available_size.append(b)
if len(available_size) == 0:
return None
else:
b = random.choice(available_size)
batch = self.buckets[b][:self.batch_size]
self.buckets[b] = self.buckets[b][self.batch_size:]
return batch
def __next__(self):
batch = self.get_available_batch()
while batch is None:
try:
elements = next(self.iterator)
except StopIteration:
# To make it infinity
if self.is_infinite:
self._epoch += 1
if hasattr(self._dataloader.sampler, "set_epoch"):
self._dataloader.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self.iterator = iter(self._dataloader)
elements = next(self.iterator)
else:
raise StopIteration
for dct in elements:
try:
img = dct['video']
size = (img.shape[-1], img.shape[-2])
self.buckets[size].append({**{'video': img}, **{k:dct[k] for k in dct if k != 'video'}})
except Exception as e:
continue
batch = self.get_available_batch()
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
def __iter__(self):
return self
def __len__(self):
return len(self.iterator)
class TemporalLengthBucketeer:
def __init__(
self, dataloader, max_frames=16, epoch=0,
):
self.batch_size = dataloader.batch_size
self._dataloader = dataloader
self.iterator = iter(dataloader)
self.buckets = {temp: [] for temp in range(1, max_frames + 1)}
self._epoch = epoch
def get_available_batch(self):
available_size = []
for b in self.buckets:
if len(self.buckets[b]) >= self.batch_size:
available_size.append(b)
if len(available_size) == 0:
return None
else:
b = random.choice(available_size)
batch = self.buckets[b][:self.batch_size]
self.buckets[b] = self.buckets[b][self.batch_size:]
return batch
def __next__(self):
batch = self.get_available_batch()
while batch is None:
try:
elements = next(self.iterator)
except StopIteration:
# To make it infinity
self._epoch += 1
if hasattr(self._dataloader.sampler, "set_epoch"):
self._dataloader.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self.iterator = iter(self._dataloader)
elements = next(self.iterator)
for dct in elements:
try:
video_latent = dct['video']
temp = video_latent.shape[2]
self.buckets[temp].append({**{'video': video_latent}, **{k:dct[k] for k in dct if k != 'video'}})
except Exception as e:
continue
batch = self.get_available_batch()
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
out = {k: torch.cat(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
if 'prompt_embed' in out:
# Loading the pre-extrcted textual features
prompt_embeds = out['prompt_embed'].clone()
del out['prompt_embed']
prompt_attention_mask = out['prompt_attention_mask'].clone()
del out['prompt_attention_mask']
pooled_prompt_embeds = out['pooled_prompt_embed'].clone()
del out['pooled_prompt_embed']
out['text'] = {
'prompt_embeds' : prompt_embeds,
'prompt_attention_mask': prompt_attention_mask,
'pooled_prompt_embeds': pooled_prompt_embeds,
}
return out
def __iter__(self):
return self
def __len__(self):
return len(self.iterator)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment