Commit 9ff47a7e authored by mashun1's avatar mashun1
Browse files

latte

parents
Pipeline #792 canceled with stages
.vscode
preprocess
train_datasets
results
pretrained_models
test
*pyc*
share_ckpts
\ No newline at end of file
FROM image.sourcefind.cn:5000/dcu/admin/base/dtk:23.10-ubuntu20.04-py310
\ No newline at end of file
MIT License
Copyright (c) 2023 Xin Ma
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Latte
## 论文
**Latte: Latent Diffusion Transformer for Video Generation**
* https://arxiv.org/abs/2401.03048v1
## 模型结构
该模型主要使用Transformer结构作为去噪模型。大概流程是这样的,输视频的Embedding,获取相应的Token,然后使用`Transformer Blocks`获取相应的时空信息,接着使用`Layer Norm``Linear and Reshape`得到`Noise``Variance`。下图的四种变体为不同的时空信息提取方式。
![alt text](readme_imgs/image-1.png)
## 算法原理
该算法的主要思想是将较为常见的`Unet`结构替换为`Transformer`结构作为去噪模型。相较于使用`Unet`,使用`Transformer`可以提升模型的速度,同时`Transformer`可以较好的对时空信息进行建模。
![alt text](readme_imgs/image-2.png)
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/dtk:23.10-ubuntu20.04-py310
docker run --shm-size 10g --network=host --name=latte --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /opt/hyhal:/opt/hyhal -v 项目地址(绝对路径):/home/ -it <your IMAGE ID> bash
pip install torch-2.1.0a0%2Bgit793d2b5.abi0.dtk2310-cp310-cp310-manylinux2014_x86_64.whl (开发者社区下载)
pip install -r requirements.txt
pip install torchvision==0.16.0 --no-deps
pip install timm --no-deps
### Dockerfile(方法二)
# 需要在对应的目录下
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 10g --network=host --name=latte --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /opt/hyhal:/opt/hyhal -v 项目地址(绝对路径):/home/ -it <your IMAGE ID> bash
pip install torch-2.1.0a0%2Bgit793d2b5.abi0.dtk2310-cp310-cp310-manylinux2014_x86_64.whl (开发者社区下载)
pip install -r requirements.txt
pip install torchvision==0.16.0 --no-deps
pip install timm --no-deps
## 数据集
|名称|URL|条件|
|:---|:---|:---|
|UCF101|https://www.crcv.ucf.edu/research/data-sets/ucf101/|无|
|FaceForensics|https://github.com/ondyari/FaceForensics/tree/original|填写表格|
|Tachi|https://github.com/AliaksandrSiarohin/first-order-model/blob/master/data/taichi-loading/README.md|无|
|SkyTimelapse|https://drive.google.com/file/d/1xWLiU-MBGN7MrsFHQm4_yXmfHBsMbJQo/view|无|
数据结构,这里为示例数据(仅展示UCF-101),完整数据请按如下结构准备。
UCF-101_tiny
├── ApplyEyeMakeup
│ └── v_ApplyEyeMakeup_g01_c01.avi
├── ApplyLipstick
│ └── v_ApplyLipstick_g01_c01.avi
├── Archery
│ └── v_Archery_g01_c01.avi
├── BabyCrawling
│ └── v_BabyCrawling_g01_c01.avi
├── BalanceBeam
├── .....
## 训练
# 训练UCF-101
torchrun --nnodes=1 --nproc_per_node=N train.py --config ./configs/ucf101/ucf101_train.yaml
# 使用集群训练
sbatch slurm_scripts/ucf101.slurm
# 视频-图像联合训练
torchrun --nnodes=1 --nproc_per_node=N train_with_img.py --config ./configs/ucf101/ucf101_img_train.yaml
注意:训练前需要准备相应的预训练模型,具体参考`推理-模型下载`
## 推理
### 模型下载
https://hf-mirror.com/maxin-cn/Latte/tree/main
https://hf-mirror.com/PixArt-alpha/PixArt-XL-2-512x512/tree/main/transformer
share_ckpts/
├── ffs.pt
├── skytimelapse.pt
├── t2v.pt
└── ...
pretrained_models/
├── sd-vae-ft-ema
│ ├── config.json
│ └── diffusion_pytorch_model.bin
├── ....
├── t2v_required_models
│ ├── model_index.json
│ ├── scheduler
│ │ └── scheduler_config.json
│ ├── text_encoder
│ │ ├── config.json
│ │ ├── model-00001-of-00004.safetensors
│ │ ├── model-00002-of-00004.safetensors
│ │ ├── model-00003-of-00004.safetensors
│ │ ├── model-00004-of-00004.safetensors
│ │ └── model.safetensors.index.json
│ ├── tokenizer
│ │ ├── added_tokens.json
│ │ ├── special_tokens_map.json
│ │ ├── spiece.model
│ │ └── tokenizer_config.json
│ ├── transformer
│ │ ├── config.json
│ │ └── diffusion_pytorch_model.safetensors
│ └── vae
│ ├── config.json
│ └── diffusion_pytorch_model.safetensors
└── vae
├── config.json
└── diffusion_pytorch_model.bin
### 命令
# FaceForensics(面部视频)
# 获取一个视频
bash sample/ffs.sh
# 获取多个视频
bash sample/ffs_ddp.sh
# sky(天空视频)
bash sample/sky.sh
bash sample/sky_ddp.sh
# taichi(打太极视频)
bash sample/taichi.sh
bash sample/taichi_ddp.sh
#ucf101(动作视频)
bash sample/ucf101.sh
bash sample/ucf101_ddp.sh
# 文本->视频
bash sample/t2v.sh
## result
![alt text](readme_imgs/test.gif)
### 精度
metric: FVD
| |UCF-101| SkyTimelapse |
|:---|:---|:---|
|DCU |xxx|xxx|
|GPU |xxx|xxx|
## 应用场景
### 算法类别
`AIGC`
### 热点应用行业
`媒体,科研,教育`
## 源码仓库及问题反馈
* https://developer.hpccube.com/codes/modelzoo/latte_pytorch
## 参考资料
* https://github.com/Vchitect/Latte
## Latte: Latent Diffusion Transformer for Video Generation<br><sub>Official PyTorch Implementation</sub>
### [Paper](https://arxiv.org/abs/2401.03048v1) | [Project Page](https://maxin-cn.github.io/latte_project/)
This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring
latent diffusion models with transformers (Latte). You can find more visualizations on our [project page](https://maxin-cn.github.io/latte_project/).
> [**Latte: Latent Diffusion Transformer for Video Generation**](https://maxin-cn.github.io/latte_project/)<br>
> [Xin Ma](https://maxin-cn.github.io/), [Yaohui Wang](https://wyhsirius.github.io/), [Gengyun Jia](https://scholar.google.com/citations?user=_04pkGgAAAAJ&hl=zh-CN) [Xinyuan Chen](https://scholar.google.com/citations?user=3fWSC8YAAAAJ), [Ziwei Liu](https://liuziwei7.github.io/), [Yuan-Fang Li](https://users.monash.edu/~yli/), [Cunjian Chen](https://cunjian.github.io/), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ&hl=zh-CN)
> <br>Department of Data Science \& AI, Faculty of Information Technology, Monash University <br> Shanghai Artificial Intelligence Laboratory, S-Lab, Nanyang Technological University<br> Nanjing University of Posts and Telecommunications
We propose a novel Latent Diffusion Transformer, namely Latte, for video generation. Latte first extracts spatio-temporal tokens from input videos and then adopts a series of Transformer blocks to model video distribution in the latent space. In order to model a substantial number of tokens extracted from videos, four efficient variants are introduced from the perspective of decomposing the spatial and temporal dimensions of input videos. To improve the quality of generated videos, we determine the best practices of Latte through rigorous experimental analysis, including video clip patch embedding, model variants, timestep-class information injection, temporal positional embedding, and learning strategies. Our comprehensive evaluation demonstrates that Latte achieves state-of-the-art performance across four standard video generation datasets, i.e., FaceForensics, SkyTimelapse, UCF101, and Taichi-HD. In addition, we extend Latte to text-to-video generation (T2V) task, where Latte achieves comparable results compared to recent T2V models. We strongly believe that Latte provides valuable insights for future research on incorporating Transformers into diffusion models for video generation.
![The architecure of Latte](visuals/architecture.svg)
This repository contains:
* 🪐 A simple PyTorch [implementation](models/latte.py) of Latte
* ⚡️ Pre-trained Latte models trained on FaceForensics, SkyTimelapse, Taichi-HD and UCF101 (256x256)
* 🛸 A Latte [training script](train.py) using PyTorch DDP
## Setup
First, download and set up the repo:
```bash
git clone https://github.com/maxin-cn/Latte.git
cd Latte
```
We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
```bash
conda env create -f environment.yml
conda activate latte
```
## Sampling
**Pre-trained Latte checkpoints.** You can sample from our pre-trained Latte models with [`sample.py`](sample/sample.py). Weights for our pre-trained Latte model can be found [here](https://huggingface.co/maxin-cn/Latte). If you want to try generating videos from text, please download [`t2v_required_models`](https://huggingface.co/maxin-cn/Latte/tree/main/t2v_required_models). The script has various arguments to adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our model on FaceForensics, you can use:
```bash
bash sample/ffs.sh
```
or if you want to sample hundreds of videos, you can use the following script with Pytorch DDP:
```bash
bash sample/ffs_ddp.sh
```
## Training
We provide a training script for Latte in [`train.py`](train.py). This script can be used to train class-conditional and unconditional
Latte models. To launch Latte (256x256) training with `N` GPUs on the FaceForensics dataset
:
```bash
torchrun --nnodes=1 --nproc_per_node=N train.py --config ./configs/ffs/ffs_train.yaml
```
or If you have a cluster that uses slurm, you can also train Latte's model using the following scripts:
```bash
sbatch slurm_scripts/ffs.slurm
```
We also provide the video-image joint training scripts [`train_with_img.py`](train_with_img.py). Similar to [`train.py`](train.py) scripts, this scripts can be also used to train class-conditional and unconditional
Latte models. For example, if you wan to train Latte model on the FaceForensics dataset, you can use:
```bash
torchrun --nnodes=1 --nproc_per_node=N train_with_img.py --config ./configs/ffs/ffs_img_train.yaml
```
## Citation
If you find this work useful for your research, please consider citing it.
```bibtex
@article{ma2024latte,
title={Latte: Latent Diffusion Transformer for Video Generation},
author={Ma, Xin and Wang, Yaohui and Jia, Gengyun and Chen, Xinyuan and Liu, Ziwei and Li, Yuan-Fang and Chen, Cunjian and Qiao, Yu},
journal={arXiv preprint arXiv:2401.03048},
year={2024}
}
```
## Acknowledgments
Latte has been greatly inspired by the following amazing works and teams: [DiT](https://github.com/facebookresearch/DiT) and [U-ViT](https://github.com/baofff/U-ViT), we thank all the contributors for open-sourcing.
## License
The code and model weights are licensed under [LICENSE](LICENSE).
# dataset
dataset: "ffs_img"
data_path: "/path/to/datasets/preprocessed_ffs/train/videos/"
frame_data_path: "/path/to/datasets/preprocessed_ffs/train/images/"
frame_data_txt: "/path/to/datasets/preprocessed_ffs/train_list.txt"
pretrained_model_path: "/path/to/pretrained/Latte/"
# save and load
results_dir: "./results_img"
pretrained:
# model config:
model: LatteIMG-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True # important
extras: 1 # [1, 2, 78]
# train config:
save_ceph: True # important
use_image_num: 8
learning_rate: 1e-4
ckpt_every: 10000
clip_max_norm: 0.1
start_clip_iter: 500000
local_batch_size: 4 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 100
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes:
# low VRAM and speed up training
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
\ No newline at end of file
# path:
ckpt: # will be overwrite
save_img_path: "./sample_videos" # will be overwrite
pretrained_model_path: "pretrained_models"
# model config:
model: Latte-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
frame_interval: 2
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 1 # [1, 2, 78]
num_classes:
# model speedup
use_compile: False
use_fp16: True
# sample config:
seed:
sample_method: 'ddpm'
num_sampling_steps: 250
cfg_scale: 1.0
negative_name:
# ddp sample config
per_proc_batch_size: 2
num_fvd_samples: 2048
\ No newline at end of file
# dataset
dataset: "ffs"
data_path: "/path/to/datasets/preprocess_ffs/train/videos/" # s
pretrained_model_path: "/path/to/pretrained/Latte/"
# save and load
results_dir: "./results"
pretrained:
# model config:
model: Latte-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True # important
extras: 1 # [1, 2, 78]
# train config:
save_ceph: True # important
learning_rate: 1e-4
ckpt_every: 10000
clip_max_norm: 0.1
start_clip_iter: 20000
local_batch_size: 5 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 100
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes:
# low VRAM and speed up training
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
\ No newline at end of file
# dataset
dataset: "sky_img"
data_path: "/path/to/datasets/sky_timelapse/sky_train/" # s/p
pretrained_model_path: "/path/to/pretrained/Latte/"
# save and load
results_dir: "./results_img"
pretrained:
# model config:
model: LatteIMG-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 1 # [1, 2, 78]
# train config:
save_ceph: True # important
use_image_num: 8 # important
learning_rate: 1e-4
ckpt_every: 10000
clip_max_norm: 0.1
start_clip_iter: 20000
local_batch_size: 4 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 50
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes:
# low VRAM and speed up training
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
\ No newline at end of file
# path:
ckpt: # will be overwrite
save_img_path: "./sample_videos/" # will be overwrite
pretrained_model_path: "/path/to/pretrained/Latte/"
# model config:
model: Latte-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
frame_interval: 2
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 1 # [1, 2, 78]
num_classes:
# model speedup
use_compile: False
use_fp16: True
# sample config:
seed:
sample_method: 'ddpm'
num_sampling_steps: 250
cfg_scale: 1.0
run_time: 12
num_sample: 1
negative_name:
# ddp sample config
per_proc_batch_size: 1
num_fvd_samples: 2
\ No newline at end of file
# dataset
dataset: "sky"
data_path: "/path/to/datasets/sky_timelapse/sky_train/"
pretrained_model_path: "/path/to/pretrained/Latte/"
# save and load
results_dir: "./results"
pretrained:
# model config:
model: Latte-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 1 # [1, 2, 78]
# train config:
save_ceph: True # important
learning_rate: 1e-4
ckpt_every: 10000
clip_max_norm: 0.1
start_clip_iter: 20000
local_batch_size: 5 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 50
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes:
# low VRAM and speed up training
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
\ No newline at end of file
# path:
ckpt: share_ckpts/t2v.pt
save_img_path: "./sample_videos/t2v"
pretrained_model_path: "pretrained_models/t2v_required_models"
# model config:
model: LatteT2V
video_length: 16
image_size: [512, 512]
# # beta schedule
beta_start: 0.0001
beta_end: 0.02
beta_schedule: "linear"
variance_type: "learned_range"
# model speedup
use_compile: False
use_fp16: True
# sample config:
seed:
run_time: 0
guidance_scale: 7.5
sample_method: 'PNDM'
num_sampling_steps: 50
enable_temporal_attentions: True
text_prompt: [
'Yellow and black tropical fish dart through the sea.',
'An epic tornado attacking above aglowing city at night.',
'Slow pan upward of blazing oak fire in an indoor fireplace.',
'a cat wearing sunglasses and working as a lifeguard at pool.',
'Sunset over the sea.',
'A dog in astronaut suit and sunglasses floating in space.',
]
\ No newline at end of file
# dataset
dataset: "taichi_img"
data_path: "/path/to/datasets/taichi"
pretrained_model_path: "/path/to/pretrained/Latte/"
# save and load
results_dir: "./results_img"
pretrained:
# model config:
model: LatteIMG-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 1 # [1, 2, 78]
# train config:
load_from_ceph: False # important
use_image_num: 8
learning_rate: 1e-4
ckpt_every: 10000
clip_max_norm: 0.1
start_clip_iter: 500000
local_batch_size: 4 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 50
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes:
# low VRAM and speed up training TODO
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
\ No newline at end of file
# path:
ckpt: # will be overwrite
save_img_path: "./sample_videos/" # will be overwrite
pretrained_model_path: "/path/to/pretrained/Latte/"
# model config:
model: Latte-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
frame_interval: 2
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 1 # [1, 2, 78]
num_classes:
# model speedup
use_compile: False
use_fp16: True
# sample config:
seed:
sample_method: 'ddpm'
num_sampling_steps: 250
cfg_scale: 1.0
negative_name:
# ddp sample config
per_proc_batch_size: 1
num_fvd_samples: 2
\ No newline at end of file
# dataset
dataset: "taichi"
data_path: "/path/to/datasets/taichi"
pretrained_model_path: "/path/to/pretrained/Latte/"
# save and load
results_dir: "./results"
pretrained:
# model config:
model: Latte-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 1 # [1, 2, 78]
# train config:
load_from_ceph: False # important
learning_rate: 1e-4
ckpt_every: 10000
clip_max_norm: 0.1
start_clip_iter: 500000
local_batch_size: 5 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 50
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes:
# low VRAM and speed up training TODO
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
\ No newline at end of file
# dataset
dataset: "ucf101_img"
data_path: "train_datasets/UCF-101_tiny"
frame_data_txt: "/path/to/datasets/UCF101/train_256_list.txt"
pretrained_model_path: "pretrained_models"
# save and load
results_dir: "./results_img"
pretrained:
# model config:
model: LatteIMG-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 2 # [1, 2, 78] # important
# train config:
save_ceph: True # important
use_image_num: 8 # important
learning_rate: 1e-4
ckpt_every: 10000
clip_max_norm: 0.1
start_clip_iter: 100000
local_batch_size: 4 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 50
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes: 101
# low VRAM and speed up training
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
\ No newline at end of file
# path:
ckpt:
save_img_path: "./sample_videos/"
pretrained_model_path: "/path/to/pretrained/Latte/"
# model config:
model: Latte-XL/2
num_frames: 16
image_size: 256 # choices=[256, 512]
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 2 # [1, 2, 78]
num_classes: 101
# model speedup
use_compile: False
use_fp16: True
# sample config:
seed:
sample_method: 'ddpm'
num_sampling_steps: 250
cfg_scale: 7.0
run_time: 12
num_sample: 1
sample_names: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
negative_name: 101
# ddp sample config
per_proc_batch_size: 2
num_fvd_samples: 2
\ No newline at end of file
# dataset
dataset: "ucf101"
data_path: "train_datasets/UCF-101_tiny"
pretrained_model_path: "pretrained_models"
# save and load
results_dir: "./results"
pretrained:
# model config:
model: Latte-XL/2
num_frames: 16
image_size: 128 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True
extras: 2 # [1, 2, 78] # important
# train config:
save_ceph: True # important
learning_rate: 1e-4
ckpt_every: 100
clip_max_norm: 0.1
start_clip_iter: 100000
local_batch_size: 5 # important
max_train_steps: 1000
global_seed: 3407
num_workers: 8
log_every: 50
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes: 101
# low VRAM and speed up training
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
\ No newline at end of file
from .sky_datasets import Sky
from torchvision import transforms
from .taichi_datasets import Taichi
from datasets import video_transforms
from .ucf101_datasets import UCF101
from .ffs_datasets import FaceForensics
from .ffs_image_datasets import FaceForensicsImages
from .sky_image_datasets import SkyImages
from .ucf101_image_datasets import UCF101Images
from .taichi_image_datasets import TaichiImages
def get_dataset(args):
temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) # 16 1
if args.dataset == 'ffs':
transform_ffs = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(args.image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
return FaceForensics(args, transform=transform_ffs, temporal_sample=temporal_sample)
elif args.dataset == 'ffs_img':
transform_ffs = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(args.image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
return FaceForensicsImages(args, transform=transform_ffs, temporal_sample=temporal_sample)
elif args.dataset == 'ucf101':
transform_ucf101 = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(args.image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
return UCF101(args, transform=transform_ucf101, temporal_sample=temporal_sample)
elif args.dataset == 'ucf101_img':
transform_ucf101 = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(args.image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
return UCF101Images(args, transform=transform_ucf101, temporal_sample=temporal_sample)
elif args.dataset == 'taichi':
transform_taichi = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
return Taichi(args, transform=transform_taichi, temporal_sample=temporal_sample)
elif args.dataset == 'taichi_img':
transform_taichi = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
return TaichiImages(args, transform=transform_taichi, temporal_sample=temporal_sample)
elif args.dataset == 'sky':
transform_sky = transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.CenterCropResizeVideo(args.image_size),
# video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
return Sky(args, transform=transform_sky, temporal_sample=temporal_sample)
elif args.dataset == 'sky_img':
transform_sky = transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.CenterCropResizeVideo(args.image_size),
# video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
return SkyImages(args, transform=transform_sky, temporal_sample=temporal_sample)
else:
raise NotImplementedError(args.dataset)
\ No newline at end of file
import os
import json
import torch
import decord
import torchvision
import numpy as np
from PIL import Image
from einops import rearrange
from typing import Dict, List, Tuple
class_labels_map = None
cls_sample_cnt = None
def temporal_sampling(frames, start_idx, end_idx, num_samples):
"""
Given the start and end frame index, sample num_samples frames between
the start and end with equal interval.
Args:
frames (tensor): a tensor of video frames, dimension is
`num video frames` x `channel` x `height` x `width`.
start_idx (int): the index of the start frame.
end_idx (int): the index of the end frame.
num_samples (int): number of frames to sample.
Returns:
frames (tersor): a tensor of temporal sampled video frames, dimension is
`num clip frames` x `channel` x `height` x `width`.
"""
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
return frames
def numpy2tensor(x):
return torch.from_numpy(x)
def get_filelist(file_path):
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
Filelist.append(os.path.join(home, filename))
# Filelist.append( filename)
return Filelist
def load_annotation_data(data_file_path):
with open(data_file_path, 'r') as data_file:
return json.load(data_file)
def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
global class_labels_map, cls_sample_cnt
if class_labels_map is not None:
return class_labels_map, cls_sample_cnt
else:
cls_sample_cnt = {}
class_labels_map = load_annotation_data(anno_pth)
for cls in class_labels_map:
cls_sample_cnt[cls] = 0
return class_labels_map, cls_sample_cnt
def load_annotations(ann_file, num_class, num_samples_per_cls):
dataset = []
class_to_idx, cls_sample_cnt = get_class_labels(num_class)
with open(ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split('\t')
sample = {}
idx = 0
# idx for frame_dir
frame_dir = line_split[idx]
sample['video'] = frame_dir
idx += 1
# idx for label[s]
label = [x for x in line_split[idx:]]
assert label, f'missing label in line: {line}'
assert len(label) == 1
class_name = label[0]
class_index = int(class_to_idx[class_name])
# choose a class subset of whole dataset
if class_index < num_class:
sample['label'] = class_index
if cls_sample_cnt[class_name] < num_samples_per_cls:
dataset.append(sample)
cls_sample_cnt[class_name]+=1
return dataset
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1, **kwargs):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
self.kwargs = kwargs
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
class FaceForensics(torch.utils.data.Dataset):
"""Load the FaceForensics video files
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(self,
configs,
transform=None,
temporal_sample=None):
self.configs = configs
self.data_path = configs.data_path
self.video_lists = get_filelist(configs.data_path)
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.v_decoder = DecordInit()
def __getitem__(self, index):
path = self.video_lists[index]
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
video = vframes[frame_indice]
# videotransformer data proprecess
video = self.transform(video) # T C H W
return {'video': video, 'video_name': 1}
def __len__(self):
return len(self.video_lists)
if __name__ == '__main__':
pass
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment