Commit c07946d8 authored by hepj's avatar hepj
Browse files

dit & video

parents
(wanvideo)=
# WanVideo
## Inference T2V with WanVideo
First, download the model:
```bash
python scripts/huggingface/download_hf.py --repo_id=Wan-AI/Wan2.1-T2V-1.3B-Diffusers --local_dir=YOUR_LOCAL_DIR --repo_type=model
```
or
```bash
python scripts/huggingface/download_hf.py --repo_id=Wan-AI/Wan2.1-T2V-14B-Diffusers --local_dir=YOUR_LOCAL_DIR --repo_type=model
```
Then run the inference using:
```bash
sh scripts/inference/v1_inference_wan.sh
```
Remember to set `MODEL_BASE` and `num_gpus` accordingly.
## Inference I2V with WanVideo
First, download the model:
```bash
python scripts/huggingface/download_hf.py --repo_id=Wan-AI/Wan2.1-I2V-14B-480P-Diffusers --local_dir=YOUR_LOCAL_DIR --repo_type=model
```
or
```bash
python scripts/huggingface/download_hf.py --repo_id=Wan-AI/Wan2.1-I2V-14B-720P-Diffusers --local_dir=YOUR_LOCAL_DIR --repo_type=model
```
Then run the inference using:
```bash
sh scripts/inference/v1_inference_wan_i2v.sh
```
Remember to set `MODEL_BASE` and `num_gpus` accordingly.
(sta-demo)=
# 🔍 Demo
There is a demo for 2D STA with window size (6,6) operating on a (10, 10) image.
<div style="text-align: center;">
<video controls width="800">
<source src="https://github.com/user-attachments/assets/f3b6dd79-7b43-4b60-a0fa-3d6495ec5747" type="video/mp4">
Your browser does not support the video tag.
</video>
</div>
(sta-installation)=
# 🔧 Installation
You can install the Sliding Tile Attention package using
```
pip install st_attn==0.0.4
```
# Building from Source
We test our code on Pytorch 2.5.0 and CUDA>=12.4. Currently we only have implementation on H100.
First, install C++20 for ThunderKittens:
```bash
cd csrc/sliding_tile_attention/
sudo apt update
sudo apt install gcc-11 g++-11
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 --slave /usr/bin/g++ g++ /usr/bin/g++-11
sudo apt update
sudo apt install clang-11
```
Install STA:
```bash
export CUDA_HOME=/usr/local/cuda-12.4
export PATH=${CUDA_HOME}/bin:${PATH}
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH
git submodule update --init --recursive
python setup.py install
```
# 🧪 Test
```bash
python test/test_sta.py
```
# 📋 Usage
```python
from st_attn import sliding_tile_attention
# assuming video size (T, H, W) = (30, 48, 80), text tokens = 256 with padding.
# q, k, v: [batch_size, num_heads, seq_length, head_dim], seq_length = T*H*W + 256
# a tile is a cube of size (6, 8, 8)
# window_size in tiles: [(window_t, window_h, window_w), (..)...]. For example, window size (3, 3, 3) means a query can attend to (3x6, 3x8, 3x8) = (18, 24, 24) tokens out of the total 30x48x80 video.
# text_length: int ranging from 0 to 256
# If your attention contains text token (Hunyuan)
out = sliding_tile_attention(q, k, v, window_size, text_length)
# If your attention does not contain text token (StepVideo)
out = sliding_tile_attention(q, k, v, window_size, 0, False)
```
(v0-data-preprocess)=
# 🧱 Data Preprocess
To save GPU memory, we precompute text embeddings and VAE latents to eliminate the need to load the text encoder and VAE during training.
We provide a sample dataset to help you get started. Download the source media using the following command:
```bash
python scripts/huggingface/download_hf.py --repo_id=FastVideo/Image-Vid-Finetune-Src --local_dir=data/Image-Vid-Finetune-Src --repo_type=dataset
```
To preprocess the dataset for fine-tuning or distillation, run:
```
bash scripts/preprocess/preprocess_mochi_data.sh # for mochi
bash scripts/preprocess/preprocess_hunyuan_data.sh # for hunyuan
```
The preprocessed dataset will be stored in `Image-Vid-Finetune-Mochi` or `Image-Vid-Finetune-HunYuan` correspondingly.
## Process your own dataset
If you wish to create your own dataset for finetuning or distillation, please structure you video dataset in the following format:
```
path_to_dataset_folder/
├── media/
│ ├── 0.jpg
│ ├── 1.mp4
│ ├── 2.jpg
├── video2caption.json
└── merge.txt
```
Format the JSON file as a list, where each item represents a media source:
For image media,
```
{
"path": "0.jpg",
"cap": ["captions"]
}
```
For video media,
```
{
"path": "1.mp4",
"resolution": {
"width": 848,
"height": 480
},
"fps": 30.0,
"duration": 6.033333333333333,
"cap": [
"caption"
]
}
```
Use a txt file (merge.txt) to contain the source folder for media and the JSON file for meta information:
```
path_to_media_source_foder,path_to_json_file
```
Adjust the `DATA_MERGE_PATH` and `OUTPUT_DIR` in `scripts/preprocess/preprocess_****_data.sh` accordingly and run:
```
bash scripts/preprocess/preprocess_****_data.sh
```
The preprocessed data will be put into the `OUTPUT_DIR` and the `videos2caption.json` can be used in finetune and distill scripts.
(v0-distill)=
# 🎯 Distill
Our distillation recipe is based on [Phased Consistency Model](https://github.com/G-U-N/Phased-Consistency-Model). We did not find significant improvement using multi-phase distillation, so we keep the one phase setup similar to the original latent consistency model's recipe.
We use the [MixKit](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/all_mixkit) dataset for distillation. To avoid running the text encoder and VAE during training, we prprocess all data to generate text embeddings and VAE latents.
Preprocessing instructions can be found [data_preprocess.md](#v0-data-preprocess). For convenience, we also provide preprocessed data that can be downloaded directly using the following command:
```bash
python scripts/huggingface/download_hf.py --repo_id=FastVideo/HD-Mixkit-Finetune-Hunyuan --local_dir=data/HD-Mixkit-Finetune-Hunyuan --repo_type=dataset
```
Next, download the original model weights with:
```bash
python scripts/huggingface/download_hf.py --repo_id=FastVideo/hunyuan --local_dir=data/hunyuan --repo_type=model # original hunyuan
python scripts/huggingface/download_hf.py --repo_id=genmo/mochi-1-preview --local_dir=data/mochi --repo_type=model # original mochi
```
To launch the distillation process, use the following commands:
```
bash scripts/distill/distill_hunyuan.sh # for hunyuan
bash scripts/distill/distill_mochi.sh # for mochi
```
We also provide an optional script for distillation with adversarial loss, located at `fastvideo/distill_adv.py`. Although we tried adversarial loss, we did not observe significant improvements.
(v0-finetune)=
# 🧠 Finetune
## ⚡ Full Finetune
Ensure your data is prepared and preprocessed in the format specified in [data_preprocess.md](#v0-data-preprocess). For convenience, we also provide a mochi preprocessed Black Myth Wukong data that can be downloaded directly:
```bash
python scripts/huggingface/download_hf.py --repo_id=FastVideo/Mochi-Black-Myth --local_dir=data/Mochi-Black-Myth --repo_type=dataset
```
Download the original model weights as specified in [Distill Section](#v0-distill):
Then you can run the finetune with:
```
bash scripts/finetune/finetune_mochi.sh # for mochi
```
**Note that for finetuning, we did not tune the hyperparameters in the provided script.**
## ⚡ Lora Finetune
Hunyuan supports Lora fine-tuning of videos up to 720p. Demos and prompts of Black-Myth-Wukong can be found in [here](https://huggingface.co/FastVideo/Hunyuan-Black-Myth-Wukong-lora-weight). You can download the Lora weight through:
```bash
python scripts/huggingface/download_hf.py --repo_id=FastVideo/Hunyuan-Black-Myth-Wukong-lora-weight --local_dir=data/Hunyuan-Black-Myth-Wukong-lora-weight --repo_type=model
```
### Minimum Hardware Requirement
- 40 GB GPU memory each for 2 GPUs with lora.
- 30 GB GPU memory each for 2 GPUs with CPU offload and lora.
Currently, both Mochi and Hunyuan models support Lora finetuning through diffusers. To generate personalized videos from your own dataset, you'll need to follow three main steps: dataset preparation, finetuning, and inference.
### Dataset Preparation
We provide scripts to better help you get started to train on your own characters!
You can run this to organize your dataset to get the videos2caption.json before preprocess. Specify your video folder and corresponding caption folder (caption files should be .txt files and have the same name with its video):
```
python scripts/dataset_preparation/prepare_json_file.py --video_dir data/input_videos/ --prompt_dir data/captions/ --output_path data/output_folder/videos2caption.json --verbose
```
Also, we provide script to resize your videos:
```
python scripts/data_preprocess/resize_videos.py
```
### Finetuning
After basic dataset preparation and preprocess, you can start to finetune your model using Lora:
```
bash scripts/finetune/finetune_hunyuan_hf_lora.sh
```
### Inference
For inference with Lora checkpoint, you can run the following scripts with additional parameter `--lora_checkpoint_dir`:
```
bash scripts/inference/inference_hunyuan_hf.sh
```
**We also provide scripts for Mochi in the same directory.**
### Finetune with Both Image and Video
Our codebase support finetuning with both image and video.
```bash
bash scripts/finetune/finetune_hunyuan.sh
bash scripts/finetune/finetune_mochi_lora_mix.sh
```
For Image-Video Mixture Fine-tuning, make sure to enable the `--group_frame` option in your script.
# Basic
The class provides the main python interface for using FastVideo's inference pipeline.
print('Hello, world!')
\ No newline at end of file
from fastvideo.v1.configs.pipelines import PipelineConfig
from fastvideo.v1.configs.sample import SamplingParam
from fastvideo.v1.entrypoints.video_generator import VideoGenerator
__all__ = ["VideoGenerator", "PipelineConfig", "SamplingParam"]
import argparse
import json
import os
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from diffusers.utils import export_to_video
from diffusers.video_processor import VideoProcessor
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from fastvideo.utils.load import load_text_encoder, load_vae
logger = get_logger(__name__)
class T5dataset(Dataset):
def __init__(
self,
json_path,
vae_debug,
):
self.json_path = json_path
self.vae_debug = vae_debug
with open(self.json_path, "r") as f:
train_dataset = json.load(f)
self.train_dataset = sorted(train_dataset, key=lambda x: x["latent_path"])
def __getitem__(self, idx):
caption = self.train_dataset[idx]["caption"]
filename = self.train_dataset[idx]["latent_path"].split(".")[0]
length = self.train_dataset[idx]["length"]
if self.vae_debug:
latents = torch.load(
os.path.join(args.output_dir, "latent", self.train_dataset[idx]["latent_path"]),
map_location="cpu",
)
else:
latents = []
return dict(caption=caption, latents=latents, filename=filename, length=length)
def __len__(self):
return len(self.train_dataset)
def main(args):
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size, "local rank", local_rank)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
videoprocessor = VideoProcessor(vae_scale_factor=8)
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True)
latents_json_path = os.path.join(args.output_dir, "videos2caption_temp.json")
train_dataset = T5dataset(latents_json_path, args.vae_debug)
text_encoder = load_text_encoder(args.model_type, args.model_path, device=device)
vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
vae.enable_tiling()
sampler = DistributedSampler(train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True)
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
json_data = []
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
with torch.inference_mode():
with torch.autocast("cuda", dtype=autocast_type):
prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(prompt=data["caption"], )
if args.vae_debug:
latents = data["latents"]
video = vae.decode(latents.to(device), return_dict=False)[0]
video = videoprocessor.postprocess_video(video)
for idx, video_name in enumerate(data["filename"]):
prompt_embed_path = os.path.join(args.output_dir, "prompt_embed", video_name + ".pt")
video_path = os.path.join(args.output_dir, "video", video_name + ".mp4")
prompt_attention_mask_path = os.path.join(args.output_dir, "prompt_attention_mask",
video_name + ".pt")
# save latent
torch.save(prompt_embeds[idx], prompt_embed_path)
torch.save(prompt_attention_mask[idx], prompt_attention_mask_path)
print(f"sample {video_name} saved")
if args.vae_debug:
export_to_video(video[idx], video_path, fps=fps)
item = {}
item["length"] = int(data["length"][idx])
item["latent_path"] = video_name + ".pt"
item["prompt_embed_path"] = video_name + ".pt"
item["prompt_attention_mask"] = video_name + ".pt"
item["caption"] = data["caption"][idx]
json_data.append(item)
dist.barrier()
local_data = json_data
gathered_data = [None] * world_size
dist.all_gather_object(gathered_data, local_data)
if local_rank == 0:
# os.remove(latents_json_path)
all_json_data = [item for sublist in gathered_data for item in sublist]
with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
json.dump(all_json_data, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset & dataloader
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--model_type", type=str, default="mochi")
# text encoder & vae & diffusion model
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=1,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--vae_debug", action="store_true")
args = parser.parse_args()
main(args)
import argparse
import json
import os
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from fastvideo.dataset import getdataset
from fastvideo.utils.load import load_vae
logger = get_logger(__name__)
def main(args):
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size, "local rank", local_rank)
train_dataset = getdataset(args)
sampler = DistributedSampler(train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True)
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
encoder_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
vae.enable_tiling()
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
json_data = []
for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0):
with torch.inference_mode():
with torch.autocast("cuda", dtype=autocast_type):
latents = vae.encode(data["pixel_values"].to(encoder_device))["latent_dist"].sample()
for idx, video_path in enumerate(data["path"]):
video_name = os.path.basename(video_path).split(".")[0]
latent_path = os.path.join(args.output_dir, "latent", video_name + ".pt")
torch.save(latents[idx].to(torch.bfloat16), latent_path)
item = {}
item["length"] = latents[idx].shape[1]
item["latent_path"] = video_name + ".pt"
item["caption"] = data["text"][idx]
json_data.append(item)
print(f"{video_name} processed")
dist.barrier()
local_data = json_data
gathered_data = [None] * world_size
dist.all_gather_object(gathered_data, local_data)
if local_rank == 0:
all_json_data = [item for sublist in gathered_data for item in sublist]
with open(os.path.join(args.output_dir, "videos2caption_temp.json"), "w") as f:
json.dump(all_json_data, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset & dataloader
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--model_type", type=str, default="mochi")
parser.add_argument("--data_merge_path", type=str, required=True)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--max_height", type=int, default=480)
parser.add_argument("--max_width", type=int, default=848)
parser.add_argument("--video_length_tolerance_range", type=int, default=2.0)
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
parser.add_argument("--dataset", default="t2v")
parser.add_argument("--train_fps", type=int, default=30)
parser.add_argument("--use_image_num", type=int, default=0)
parser.add_argument("--text_max_length", type=int, default=256)
parser.add_argument("--speed_factor", type=float, default=1.0)
parser.add_argument("--drop_short_ratio", type=float, default=1.0)
# text encoder & vae & diffusion model
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
parser.add_argument("--cfg", type=float, default=0.0)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
args = parser.parse_args()
main(args)
import argparse
import os
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from fastvideo.utils.load import load_text_encoder
logger = get_logger(__name__)
def main(args):
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size, "local rank", local_rank)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
text_encoder = load_text_encoder(args.model_type, args.model_path, device=device)
autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16
# output_dir/validation/prompt_attention_mask
# output_dir/validation/prompt_embed
os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True)
os.makedirs(
os.path.join(args.output_dir, "validation", "prompt_attention_mask"),
exist_ok=True,
)
os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"), exist_ok=True)
with open(args.validation_prompt_txt, "r", encoding="utf-8") as file:
lines = file.readlines()
prompts = [line.strip() for line in lines]
for prompt in prompts:
with torch.inference_mode():
with torch.autocast("cuda", dtype=autocast_type):
prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(prompt)
file_name = prompt.split(".")[0]
prompt_embed_path = os.path.join(args.output_dir, "validation", "prompt_embed", f"{file_name}.pt")
prompt_attention_mask_path = os.path.join(
args.output_dir,
"validation",
"prompt_attention_mask",
f"{file_name}.pt",
)
torch.save(prompt_embeds[0], prompt_embed_path)
torch.save(prompt_attention_mask[0], prompt_attention_mask_path)
print(f"sample {file_name} saved")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset & dataloader
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--model_type", type=str, default="mochi")
parser.add_argument("--validation_prompt_txt", type=str)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
args = parser.parse_args()
main(args)
from torchvision import transforms
from torchvision.transforms import Lambda
from transformers import AutoTokenizer
from fastvideo.dataset.t2v_datasets import T2V_dataset
from fastvideo.dataset.transform import CenterCropResizeVideo, Normalize255, TemporalRandomCrop
def getdataset(args):
temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x
norm_fun = Lambda(lambda x: 2.0 * x - 1.0)
resize_topcrop = [
CenterCropResizeVideo((args.max_height, args.max_width), top_crop=True),
]
resize = [
CenterCropResizeVideo((args.max_height, args.max_width)),
]
transform = transforms.Compose([
# Normalize255(),
*resize,
])
transform_topcrop = transforms.Compose([
Normalize255(),
*resize_topcrop,
norm_fun,
])
# tokenizer = AutoTokenizer.from_pretrained("/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl", cache_dir=args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir)
if args.dataset == "t2v":
return T2V_dataset(
args,
transform=transform,
temporal_sample=temporal_sample,
tokenizer=tokenizer,
transform_topcrop=transform_topcrop,
)
raise NotImplementedError(args.dataset)
if __name__ == "__main__":
import random
from accelerate import Accelerator
from tqdm import tqdm
from fastvideo.dataset.t2v_datasets import dataset_prog
args = type(
"args",
(),
{
"ae": "CausalVAEModel_4x8x8",
"dataset": "t2v",
"attention_mode": "xformers",
"use_rope": True,
"text_max_length": 300,
"max_height": 320,
"max_width": 240,
"num_frames": 1,
"use_image_num": 0,
"interpolation_scale_t": 1,
"interpolation_scale_h": 1,
"interpolation_scale_w": 1,
"cache_dir": "../cache_dir",
"image_data": "/storage/ongoing/new/Open-Sora-Plan-bak/7.14bak/scripts/train_data/image_data.txt",
"video_data": "1",
"train_fps": 24,
"drop_short_ratio": 1.0,
"use_img_from_vid": False,
"speed_factor": 1.0,
"cfg": 0.1,
"text_encoder_name": "google/mt5-xxl",
"dataloader_num_workers": 10,
},
)
accelerator = Accelerator()
dataset = getdataset(args)
num = len(dataset_prog.img_cap_list)
zero = 0
for idx in tqdm(range(num)):
image_data = dataset_prog.img_cap_list[idx]
caps = [i["cap"] if isinstance(i["cap"], list) else [i["cap"]] for i in image_data]
try:
caps = [[random.choice(i)] for i in caps]
except Exception as e:
print(e)
# import ipdb;ipdb.set_trace()
print(image_data)
zero += 1
continue
assert caps[0] is not None and len(caps[0]) > 0
print(num, zero)
import ipdb
ipdb.set_trace()
print("end")
import json
import os
import random
import torch
from torch.utils.data import Dataset
class LatentDataset(Dataset):
def __init__(
self,
json_path,
num_latent_t,
cfg_rate,
):
# data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
self.json_path = json_path
self.cfg_rate = cfg_rate
self.datase_dir_path = os.path.dirname(json_path)
self.video_dir = os.path.join(self.datase_dir_path, "video")
self.latent_dir = os.path.join(self.datase_dir_path, "latent")
self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed")
self.prompt_attention_mask_dir = os.path.join(self.datase_dir_path, "prompt_attention_mask")
with open(self.json_path, "r") as f:
self.data_anno = json.load(f)
# json.load(f) already keeps the order
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
self.num_latent_t = num_latent_t
# just zero embeddings [256, 4096]
self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
# 256 zeros
self.uncond_prompt_mask = torch.zeros(256).bool()
self.lengths = [data_item["length"] if "length" in data_item else 1 for data_item in self.data_anno]
def __getitem__(self, idx):
latent_file = self.data_anno[idx]["latent_path"]
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"]
# load
latent = torch.load(
os.path.join(self.latent_dir, latent_file),
map_location="cpu",
weights_only=True,
)
latent = latent.squeeze(0)[:, -self.num_latent_t:]
if random.random() < self.cfg_rate:
prompt_embed = self.uncond_prompt_embed
prompt_attention_mask = self.uncond_prompt_mask
else:
prompt_embed = torch.load(
os.path.join(self.prompt_embed_dir, prompt_embed_file),
map_location="cpu",
weights_only=True,
)
prompt_attention_mask = torch.load(
os.path.join(self.prompt_attention_mask_dir, prompt_attention_mask_file),
map_location="cpu",
weights_only=True,
)
return latent, prompt_embed, prompt_attention_mask
def __len__(self):
return len(self.data_anno)
def latent_collate_function(batch):
# return latent, prompt, latent_attn_mask, text_attn_mask
# latent_attn_mask: # b t h w
# text_attn_mask: b 1 l
# needs to check if the latent/prompt' size and apply padding & attn mask
latents, prompt_embeds, prompt_attention_masks = zip(*batch)
# calculate max shape
max_t = max([latent.shape[1] for latent in latents])
max_h = max([latent.shape[2] for latent in latents])
max_w = max([latent.shape[3] for latent in latents])
# padding
latents = [
torch.nn.functional.pad(
latent,
(
0,
max_t - latent.shape[1],
0,
max_h - latent.shape[2],
0,
max_w - latent.shape[3],
),
) for latent in latents
]
# attn mask
latent_attn_mask = torch.ones(len(latents), max_t, max_h, max_w)
# set to 0 if padding
for i, latent in enumerate(latents):
latent_attn_mask[i, latent.shape[1]:, :, :] = 0
latent_attn_mask[i, :, latent.shape[2]:, :] = 0
latent_attn_mask[i, :, :, latent.shape[3]:] = 0
prompt_embeds = torch.stack(prompt_embeds, dim=0)
prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
latents = torch.stack(latents, dim=0)
return latents, prompt_embeds, latent_attn_mask, prompt_attention_masks
if __name__ == "__main__":
dataset = LatentDataset("data/Mochi-Synthetic-Data/merge.txt", num_latent_t=28)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function)
for latent, prompt_embed, latent_attn_mask, prompt_attention_mask in dataloader:
print(
latent.shape,
prompt_embed.shape,
latent_attn_mask.shape,
prompt_attention_mask.shape,
)
import pdb
pdb.set_trace()
import json
import math
import os
import random
from collections import Counter
from os.path import join as opj
import numpy as np
import torch
import torchvision
from einops import rearrange
from PIL import Image
from torch.utils.data import Dataset
from fastvideo.utils.dataset_utils import DecordInit
from fastvideo.utils.logging_ import main_print
class SingletonMeta(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
class DataSetProg(metaclass=SingletonMeta):
def __init__(self):
self.cap_list = []
self.elements = []
self.num_workers = 1
self.n_elements = 0
self.worker_elements = dict()
self.n_used_elements = dict()
def set_cap_list(self, num_workers, cap_list, n_elements):
self.num_workers = num_workers
self.cap_list = cap_list
self.n_elements = n_elements
self.elements = list(range(n_elements))
random.shuffle(self.elements)
print(f"n_elements: {len(self.elements)}", flush=True)
for i in range(self.num_workers):
self.n_used_elements[i] = 0
per_worker = int(math.ceil(len(self.elements) / float(self.num_workers)))
start = i * per_worker
end = min(start + per_worker, len(self.elements))
self.worker_elements[i] = self.elements[start:end]
def get_item(self, work_info):
if work_info is None:
worker_id = 0
else:
worker_id = work_info.id
idx = self.worker_elements[worker_id][self.n_used_elements[worker_id] % len(self.worker_elements[worker_id])]
self.n_used_elements[worker_id] += 1
return idx
dataset_prog = DataSetProg()
def filter_resolution(h, w, max_h_div_w_ratio=17 / 16, min_h_div_w_ratio=8 / 16):
if h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio:
return True
return False
class T2V_dataset(Dataset):
def __init__(self, args, transform, temporal_sample, tokenizer, transform_topcrop):
self.data = args.data_merge_path
self.num_frames = args.num_frames
self.train_fps = args.train_fps
self.use_image_num = args.use_image_num
self.transform = transform
self.transform_topcrop = transform_topcrop
self.temporal_sample = temporal_sample
self.tokenizer = tokenizer
self.text_max_length = args.text_max_length
self.cfg = args.cfg
self.speed_factor = args.speed_factor
self.max_height = args.max_height
self.max_width = args.max_width
self.drop_short_ratio = args.drop_short_ratio
assert self.speed_factor >= 1
self.v_decoder = DecordInit()
self.video_length_tolerance_range = args.video_length_tolerance_range
self.support_Chinese = True
if "mt5" not in args.text_encoder_name:
self.support_Chinese = False
cap_list = self.get_cap_list()
assert len(cap_list) > 0
cap_list, self.sample_num_frames = self.define_frame_index(cap_list)
self.lengths = self.sample_num_frames
n_elements = len(cap_list)
dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, n_elements)
print(f"video length: {len(dataset_prog.cap_list)}", flush=True)
def set_checkpoint(self, n_used_elements):
for i in range(len(dataset_prog.n_used_elements)):
dataset_prog.n_used_elements[i] = n_used_elements
def __len__(self):
return dataset_prog.n_elements
def __getitem__(self, idx):
data = self.get_data(idx)
return data
def get_data(self, idx):
path = dataset_prog.cap_list[idx]["path"]
if path.endswith(".mp4"):
return self.get_video(idx)
else:
return self.get_image(idx)
def get_video(self, idx):
video_path = dataset_prog.cap_list[idx]["path"]
assert os.path.exists(video_path), f"file {video_path} do not exist!"
frame_indices = dataset_prog.cap_list[idx]["sample_frame_index"]
torchvision_video, _, metadata = torchvision.io.read_video(video_path, output_format="TCHW")
video = torchvision_video[frame_indices]
video = self.transform(video)
video = rearrange(video, "t c h w -> c t h w")
video = video.to(torch.uint8)
assert video.dtype == torch.uint8
h, w = video.shape[-2:]
assert (
h / w <= 17 / 16 and h / w >= 8 / 16
), f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) found ratio is {round(h / w, 2)} with the shape of {video.shape}"
video = video.float() / 127.5 - 1.0
text = dataset_prog.cap_list[idx]["cap"]
if not isinstance(text, list):
text = [text]
text = [random.choice(text)]
text = text[0] if random.random() > self.cfg else ""
text_tokens_and_mask = self.tokenizer(
text,
max_length=self.text_max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = text_tokens_and_mask["input_ids"]
cond_mask = text_tokens_and_mask["attention_mask"]
return dict(
pixel_values=video,
text=text,
input_ids=input_ids,
cond_mask=cond_mask,
path=video_path,
)
def get_image(self, idx):
image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...]
image = Image.open(image_data["path"]).convert("RGB") # [h, w, c]
image = torch.from_numpy(np.array(image)) # [h, w, c]
image = rearrange(image, "h w c -> c h w").unsqueeze(0) # [1 c h w]
# for i in image:
# h, w = i.shape[-2:]
# assert h / w <= 17 / 16 and h / w >= 8 / 16, f'Only image with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But found ratio is {round(h / w, 2)} with the shape of {i.shape}'
image = (self.transform_topcrop(image) if "human_images" in image_data["path"] else self.transform(image)
) # [1 C H W] -> num_img [1 C H W]
image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W]
image = image.float() / 127.5 - 1.0
caps = (image_data["cap"] if isinstance(image_data["cap"], list) else [image_data["cap"]])
caps = [random.choice(caps)]
text = caps
input_ids, cond_mask = [], []
text = text[0] if random.random() > self.cfg else ""
text_tokens_and_mask = self.tokenizer(
text,
max_length=self.text_max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = text_tokens_and_mask["input_ids"] # 1, l
cond_mask = text_tokens_and_mask["attention_mask"] # 1, l
return dict(
pixel_values=image,
text=text,
input_ids=input_ids,
cond_mask=cond_mask,
path=image_data["path"],
)
def define_frame_index(self, cap_list):
new_cap_list = []
sample_num_frames = []
cnt_too_long = 0
cnt_too_short = 0
cnt_no_cap = 0
cnt_no_resolution = 0
cnt_resolution_mismatch = 0
cnt_movie = 0
cnt_img = 0
for i in cap_list:
path = i["path"]
cap = i.get("cap", None)
# ======no caption=====
if cap is None:
cnt_no_cap += 1
continue
if path.endswith(".mp4"):
# ======no fps and duration=====
duration = i.get("duration", None)
fps = i.get("fps", None)
if fps is None or duration is None:
continue
# ======resolution mismatch=====
resolution = i.get("resolution", None)
if resolution is None:
cnt_no_resolution += 1
continue
else:
if (resolution.get("height", None) is None or resolution.get("width", None) is None):
cnt_no_resolution += 1
continue
height, width = i["resolution"]["height"], i["resolution"]["width"]
aspect = self.max_height / self.max_width
hw_aspect_thr = 1.5
is_pick = filter_resolution(
height,
width,
max_h_div_w_ratio=hw_aspect_thr * aspect,
min_h_div_w_ratio=1 / hw_aspect_thr * aspect,
)
if not is_pick:
print("resolution mismatch")
cnt_resolution_mismatch += 1
continue
# import ipdb;ipdb.set_trace()
i["num_frames"] = math.ceil(fps * duration)
# max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration.
if i["num_frames"] / fps > self.video_length_tolerance_range * (
self.num_frames / self.train_fps *
self.speed_factor): # too long video is not suitable for this training stage (self.num_frames)
cnt_too_long += 1
continue
# resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24)
frame_interval = fps / self.train_fps
start_frame_idx = 0
frame_indices = np.arange(start_frame_idx, i["num_frames"], frame_interval).astype(int)
# comment out it to enable dynamic frames training
if (len(frame_indices) < self.num_frames and random.random() < self.drop_short_ratio):
cnt_too_short += 1
continue
# too long video will be temporal-crop randomly
if len(frame_indices) > self.num_frames:
begin_index, end_index = self.temporal_sample(len(frame_indices))
frame_indices = frame_indices[begin_index:end_index]
# frame_indices = frame_indices[:self.num_frames] # head crop
i["sample_frame_index"] = frame_indices.tolist()
new_cap_list.append(i)
i["sample_num_frames"] = len(i["sample_frame_index"]) # will use in dataloader(group sampler)
sample_num_frames.append(i["sample_num_frames"])
elif path.endswith(".jpg"): # image
cnt_img += 1
new_cap_list.append(i)
i["sample_num_frames"] = 1
sample_num_frames.append(i["sample_num_frames"])
else:
raise NameError(
f"Unknown file extension {path.split('.')[-1]}, only support .mp4 for video and .jpg for image")
# import ipdb;ipdb.set_trace()
main_print(
f"no_cap: {cnt_no_cap}, too_long: {cnt_too_long}, too_short: {cnt_too_short}, "
f"no_resolution: {cnt_no_resolution}, resolution_mismatch: {cnt_resolution_mismatch}, "
f"Counter(sample_num_frames): {Counter(sample_num_frames)}, cnt_movie: {cnt_movie}, cnt_img: {cnt_img}, "
f"before filter: {len(cap_list)}, after filter: {len(new_cap_list)}")
return new_cap_list, sample_num_frames
def decord_read(self, path, frame_indices):
decord_vr = self.v_decoder(path)
video_data = decord_vr.get_batch(frame_indices).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
return video_data
def read_jsons(self, data):
cap_lists = []
with open(data, "r") as f:
folder_anno = [i.strip().split(",") for i in f.readlines() if len(i.strip()) > 0]
print(folder_anno)
for folder, anno in folder_anno:
with open(anno, "r") as f:
sub_list = json.load(f)
for i in range(len(sub_list)):
sub_list[i]["path"] = opj(folder, sub_list[i]["path"])
cap_lists += sub_list
return cap_lists
def get_cap_list(self):
cap_lists = self.read_jsons(self.data)
return cap_lists
import numbers
import random
import torch
from PIL import Image
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y:crop_y + image_size, crop_x:crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i:i + h, j:j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(
clip,
size=target_size,
mode=interpolation_mode,
align_corners=True,
antialias=True,
)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(
clip,
scale_factor=scale_,
mode=interpolation_mode,
align_corners=True,
antialias=True,
)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def center_crop_th_tw(clip, th, tw, top_crop):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
# import ipdb;ipdb.set_trace()
h, w = clip.size(-2), clip.size(-1)
tr = th / tw
if h / w > tr:
new_h = int(w * tr)
new_w = w
else:
new_h = h
new_w = int(h / tr)
i = 0 if top_crop else int(round((h - new_h) / 2.0))
j = int(round((w - new_w) / 2.0))
return crop(clip, i, j, new_h, new_w)
def random_shift_crop(clip):
"""
Slide along the long edge, with the short edge as crop size
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
short_edge = h
else:
short_edge = w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1, )).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item()
return crop(clip, i, j, th, tw)
def normalize_video(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1, )).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class SpatialStrideCropVideo:
def __init__(self, stride):
self.stride = stride
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: cropped video clip by stride.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = h // self.stride * self.stride, w // self.stride * self.stride
return 0, 0, th, tw # from top-left
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class LongSideResizeVideo:
"""
First use the long side,
then resize to the specified size
"""
def __init__(
self,
size,
skip_low_resolution=False,
interpolation_mode="bilinear",
):
self.size = size
self.skip_low_resolution = skip_low_resolution
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized video clip.
size is (T, C, 512, *) or (T, C, *, 512)
"""
_, _, h, w = clip.shape
if self.skip_low_resolution and max(h, w) <= self.size:
return clip
if h > w:
w = int(w * self.size / h)
h = self.size
else:
h = int(h * self.size / w)
w = self.size
resize_clip = resize(clip, target_size=(h, w), interpolation_mode=self.interpolation_mode)
return resize_clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class CenterCropResizeVideo:
"""
First use the short side for cropping length,
center crop video, then resize to the specified size
"""
def __init__(
self,
size,
top_crop=False,
interpolation_mode="bilinear",
):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
self.top_crop = top_crop
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# clip_center_crop = center_crop_using_short_edge(clip)
clip_center_crop = center_crop_th_tw(clip, self.size[0], self.size[1], top_crop=self.top_crop)
# import ipdb;ipdb.set_trace()
clip_center_crop_resize = resize(
clip_center_crop,
target_size=self.size,
interpolation_mode=self.interpolation_mode,
)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class UCFCenterCropVideo:
"""
First scale to the specified size in equal proportion to the short edge,
then center cropping
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
"""
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class Normalize:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class Normalize255:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return normalize_video(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
class DynamicSampleDuration(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, t_stride, extra_1):
self.t_stride = t_stride
self.extra_1 = extra_1
def __call__(self, t, h, w):
if self.extra_1:
t = t - 1
truncate_t_list = list(range(t + 1))[t // 2:][::self.t_stride] # need half at least
truncate_t = random.choice(truncate_t_list)
if self.extra_1:
truncate_t = truncate_t + 1
return 0, truncate_t
if __name__ == "__main__":
import os
import numpy as np
import torchvision.io as io
from torchvision import transforms
from torchvision.utils import save_image
vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")
trans = transforms.Compose([
Normalize255(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(
select_vframes_trans[i],
os.path.join("./test000", "%04d.png" % i),
normalize=True,
value_range=(-1, 1),
)
# !/bin/python3
# isort: skip_file
import argparse
import math
import os
import time
from collections import deque
from copy import deepcopy
import torch
import torch.distributed as dist
import wandb
from accelerate.utils import set_seed
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from peft import LoraConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
from fastvideo.distill.solver import EulerSolver, extract_into_tensor
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, save_lora_checkpoint)
from fastvideo.utils.communications import (broadcast, sp_parallel_dataloader_wrapper)
from fastvideo.utils.dataset_utils import LengthGroupedSampler
from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, get_dit_fsdp_kwargs)
from fastvideo.utils.load import load_transformer
from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, get_sequence_parallel_state,
initialize_sequence_parallel_state)
from fastvideo.utils.validation import log_validation
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
def main_print(content):
if int(os.environ["LOCAL_RANK"]) <= 0:
print(content)
def reshard_fsdp(model):
for m in FSDP.fsdp_modules(model):
if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD:
torch.distributed.fsdp._runtime_utils._reshard(m, m._handle, True)
def get_norm(model_pred, norms, gradient_accumulation_steps):
fro_norm = (
torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore
gradient_accumulation_steps)
largest_singular_value = (torch.linalg.matrix_norm(model_pred, ord=2) / gradient_accumulation_steps)
absolute_mean = torch.mean(torch.abs(model_pred)) / gradient_accumulation_steps
absolute_max = torch.max(torch.abs(model_pred)) / gradient_accumulation_steps
dist.all_reduce(fro_norm, op=dist.ReduceOp.AVG)
dist.all_reduce(largest_singular_value, op=dist.ReduceOp.AVG)
dist.all_reduce(absolute_mean, op=dist.ReduceOp.AVG)
norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore
norms["largest singular value"] += torch.mean(largest_singular_value).item()
norms["absolute mean"] += absolute_mean.item()
norms["absolute max"] += absolute_max.item()
def distill_one_step(
transformer,
model_type,
teacher_transformer,
ema_transformer,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
solver,
noise_random_generator,
gradient_accumulation_steps,
sp_size,
max_grad_norm,
uncond_prompt_embed,
uncond_prompt_mask,
num_euler_timesteps,
multiphase,
not_apply_cfg_solver,
distill_cfg,
ema_decay,
pred_decay_weight,
pred_decay_type,
hunyuan_teacher_disable_cfg,
):
total_loss = 0.0
optimizer.zero_grad()
model_pred_norm = {
"fro": 0.0, # codespell:ignore
"largest singular value": 0.0,
"absolute mean": 0.0,
"absolute max": 0.0,
}
for _ in range(gradient_accumulation_steps):
(
latents,
encoder_hidden_states,
latents_attention_mask,
encoder_attention_mask,
) = next(loader)
model_input = normalize_dit_input(model_type, latents)
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
index = torch.randint(0, num_euler_timesteps, (bsz, ), device=model_input.device).long()
if sp_size > 1:
broadcast(index)
# Add noise according to flow matching.
# sigmas = get_sigmas(start_timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
sigmas = extract_into_tensor(solver.sigmas, index, model_input.shape)
sigmas_prev = extract_into_tensor(solver.sigmas_prev, index, model_input.shape)
timesteps = (sigmas * noise_scheduler.config.num_train_timesteps).view(-1)
# if squeeze to [], unsqueeze to [1]
timesteps_prev = (sigmas_prev * noise_scheduler.config.num_train_timesteps).view(-1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
# Predict the noise residual
with torch.autocast("cuda", dtype=torch.bfloat16):
teacher_kwargs = {
"hidden_states": noisy_model_input,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timesteps,
"encoder_attention_mask": encoder_attention_mask, # B, L
"return_dict": False,
}
if hunyuan_teacher_disable_cfg:
teacher_kwargs["guidance"] = torch.tensor([1000.0],
device=noisy_model_input.device,
dtype=torch.bfloat16)
model_pred = transformer(**teacher_kwargs)[0]
# if accelerator.is_main_process:
model_pred, end_index = solver.euler_style_multiphase_pred(noisy_model_input, model_pred, index, multiphase)
with torch.no_grad():
w = distill_cfg
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_teacher_output = teacher_transformer(
noisy_model_input,
encoder_hidden_states,
timesteps,
encoder_attention_mask, # B, L
return_dict=False,
)[0].float()
if not_apply_cfg_solver:
uncond_teacher_output = cond_teacher_output
else:
# Get teacher model prediction on noisy_latents and unconditional embedding
with torch.autocast("cuda", dtype=torch.bfloat16):
uncond_teacher_output = teacher_transformer(
noisy_model_input,
uncond_prompt_embed.unsqueeze(0).expand(bsz, -1, -1),
timesteps,
uncond_prompt_mask.unsqueeze(0).expand(bsz, -1),
return_dict=False,
)[0].float()
teacher_output = uncond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.euler_step(noisy_model_input, teacher_output, index)
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", dtype=torch.bfloat16):
if ema_transformer is not None:
target_pred = ema_transformer(
x_prev.float(),
encoder_hidden_states,
timesteps_prev,
encoder_attention_mask, # B, L
return_dict=False,
)[0]
else:
target_pred = transformer(
x_prev.float(),
encoder_hidden_states,
timesteps_prev,
encoder_attention_mask, # B, L
return_dict=False,
)[0]
target, end_index = solver.euler_style_multiphase_pred(x_prev, target_pred, index, multiphase, True)
huber_c = 0.001
# loss = loss.mean()
loss = (torch.mean(torch.sqrt((model_pred.float() - target.float())**2 + huber_c**2) - huber_c) /
gradient_accumulation_steps)
if pred_decay_weight > 0:
if pred_decay_type == "l1":
pred_decay_loss = (torch.mean(torch.sqrt(model_pred.float()**2)) * pred_decay_weight /
gradient_accumulation_steps)
loss += pred_decay_loss
elif pred_decay_type == "l2":
# essnetially k2?
pred_decay_loss = (torch.mean(model_pred.float()**2) * pred_decay_weight / gradient_accumulation_steps)
loss += pred_decay_loss
else:
assert NotImplementedError("pred_decay_type is not implemented")
# calculate model_pred norm and mean
get_norm(model_pred.detach().float(), model_pred_norm, gradient_accumulation_steps)
loss.backward()
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
total_loss += avg_loss.item()
# update ema
if ema_transformer is not None:
reshard_fsdp(ema_transformer)
for p_averaged, p_model in zip(ema_transformer.parameters(), transformer.parameters()):
with torch.no_grad():
p_averaged.copy_(torch.lerp(p_averaged.detach(), p_model.detach(), 1 - ema_decay))
grad_norm = transformer.clip_grad_norm_(max_grad_norm)
optimizer.step()
lr_scheduler.step()
return total_loss, grad_norm.item(), model_pred_norm
def main(args):
torch.backends.cuda.matmul.allow_tf32 = True
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
initialize_sequence_parallel_state(args.sp_size)
# If passed along, set the training seed now. On GPU...
if args.seed is not None:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed(args.seed + rank)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator = None
# Handle the repository creation
if rank <= 0 and args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
transformer = load_transformer(
args.model_type,
args.dit_model_name_or_path,
args.pretrained_model_name_or_path,
torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16,
)
teacher_transformer = deepcopy(transformer)
if args.use_ema:
ema_transformer = deepcopy(transformer)
else:
ema_transformer = None
if args.use_lora:
assert args.model_type == "mochi", "LoRA is only supported for Mochi model."
transformer.requires_grad_(False)
transformer_lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
main_print(
f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M")
main_print(f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}")
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
args.use_lora,
args.use_cpu_offload,
args.master_weight_type,
)
if args.use_lora:
transformer.config.lora_rank = args.lora_rank
transformer.config.lora_alpha = args.lora_alpha
transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
transformer._no_split_modules = no_split_modules
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)
transformer = FSDP(
transformer,
**fsdp_kwargs,
)
teacher_transformer = FSDP(
teacher_transformer,
**fsdp_kwargs,
)
if args.use_ema:
ema_transformer = FSDP(
ema_transformer,
**fsdp_kwargs,
)
main_print("--> model loaded")
if args.gradient_checkpointing:
apply_fsdp_checkpointing(transformer, no_split_modules, args.selective_checkpointing)
apply_fsdp_checkpointing(teacher_transformer, no_split_modules, args.selective_checkpointing)
if args.use_ema:
apply_fsdp_checkpointing(ema_transformer, no_split_modules, args.selective_checkpointing)
# Set model as trainable.
transformer.train()
teacher_transformer.requires_grad_(False)
if args.use_ema:
ema_transformer.requires_grad_(False)
noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=args.shift)
if args.scheduler_type == "pcm_linear_quadratic":
linear_steps = int(noise_scheduler.config.num_train_timesteps * args.linear_range)
sigmas = linear_quadratic_schedule(
noise_scheduler.config.num_train_timesteps,
args.linear_quadratic_threshold,
linear_steps,
)
sigmas = torch.tensor(sigmas).to(dtype=torch.float32)
else:
sigmas = noise_scheduler.sigmas
solver = EulerSolver(
sigmas.numpy()[::-1],
noise_scheduler.config.num_train_timesteps,
euler_timesteps=args.num_euler_timesteps,
)
solver.to(device)
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
init_steps = 0
if args.resume_from_lora_checkpoint:
transformer, optimizer, init_steps = resume_lora_optimizer(transformer, args.resume_from_lora_checkpoint,
optimizer)
main_print(f"optimizer: {optimizer}")
# todo add lr scheduler
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * world_size,
num_training_steps=args.max_train_steps * world_size,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
last_epoch=init_steps - 1,
)
train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
uncond_prompt_embed = train_dataset.uncond_prompt_embed
uncond_prompt_mask = train_dataset.uncond_prompt_mask
sampler = (LengthGroupedSampler(
args.train_batch_size,
rank=rank,
world_size=world_size,
lengths=train_dataset.lengths,
group_frame=args.group_frame,
group_resolution=args.group_resolution,
) if (args.group_frame or args.group_resolution) else DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=False))
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
collate_fn=latent_collate_function,
pin_memory=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if rank <= 0:
project = args.tracker_project_name or "fastvideo"
wandb.init(project=project, config=args)
# Train!
total_batch_size = (world_size * args.gradient_accumulation_steps / args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
main_print(f" Num Epochs = {args.num_train_epochs}")
main_print(f" Resume training from step {init_steps}")
main_print(f" Instantaneous batch size per device = {args.train_batch_size}")
main_print(f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}")
main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
main_print(f" Total optimization steps = {args.max_train_steps}")
main_print(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
)
# print dtype
main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}")
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
assert NotImplementedError("resume_from_checkpoint is not supported now.")
# TODO
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=init_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=local_rank > 0,
)
loader = sp_parallel_dataloader_wrapper(
train_dataloader,
device,
args.train_batch_size,
args.sp_size,
args.train_sp_batch_size,
)
step_times = deque(maxlen=100)
# todo future
for i in range(init_steps):
next(loader)
# log_validation(args, transformer, device,
# torch.bfloat16, 0, scheduler_type=args.scheduler_type, shift=args.shift, num_euler_timesteps=args.num_euler_timesteps, linear_quadratic_threshold=args.linear_quadratic_threshold,ema=False)
def get_num_phases(multi_phased_distill_schedule, step):
# step-phase,step-phase
multi_phases = multi_phased_distill_schedule.split(",")
phase = multi_phases[-1].split("-")[-1]
for step_phases in multi_phases:
phase_step, phase = step_phases.split("-")
if step <= int(phase_step):
return int(phase)
return phase
for step in range(init_steps + 1, args.max_train_steps + 1):
start_time = time.time()
assert args.multi_phased_distill_schedule is not None
num_phases = get_num_phases(args.multi_phased_distill_schedule, step)
loss, grad_norm, pred_norm = distill_one_step(
transformer,
args.model_type,
teacher_transformer,
ema_transformer,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
solver,
noise_random_generator,
args.gradient_accumulation_steps,
args.sp_size,
args.max_grad_norm,
uncond_prompt_embed,
uncond_prompt_mask,
args.num_euler_timesteps,
num_phases,
args.not_apply_cfg_solver,
args.distill_cfg,
args.ema_decay,
args.pred_decay_weight,
args.pred_decay_type,
args.hunyuan_teacher_disable_cfg,
)
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"step_time": f"{step_time:.2f}s",
"grad_norm": grad_norm,
"phases": num_phases,
})
progress_bar.update(1)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# "pred_fro_norm": pred_norm["fro"], # codespell:ignore
# "pred_largest_singular_value": pred_norm["largest singular value"],
# "pred_absolute_mean": pred_norm["absolute mean"],
# "pred_absolute_max": pred_norm["absolute max"],
# },
# step=step,
# )
if step % args.checkpointing_steps == 0:
if args.use_lora:
# Save LoRA weights
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, step)
else:
# Your existing checkpoint saving code
if args.use_ema:
save_checkpoint(ema_transformer, rank, args.output_dir, step)
else:
save_checkpoint(transformer, rank, args.output_dir, step)
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
log_validation(
args,
transformer,
device,
torch.bfloat16,
step,
scheduler_type=args.scheduler_type,
shift=args.shift,
num_euler_timesteps=args.num_euler_timesteps,
linear_quadratic_threshold=args.linear_quadratic_threshold,
linear_range=args.linear_range,
ema=False,
)
if args.use_ema:
log_validation(
args,
ema_transformer,
device,
torch.bfloat16,
step,
scheduler_type=args.scheduler_type,
shift=args.shift,
num_euler_timesteps=args.num_euler_timesteps,
linear_quadratic_threshold=args.linear_quadratic_threshold,
linear_range=args.linear_range,
ema=True,
)
if args.use_lora:
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps)
else:
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)
if get_sequence_parallel_state():
destroy_sequence_parallel_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default="mochi", help="The type of model to train.")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=10,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
# text encoder & vae & diffusion model
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dit_model_name_or_path", type=str)
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
# diffusion setting
parser.add_argument("--ema_decay", type=float, default=0.95)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.1)
# validation & logs
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--validation_sampling_steps", type=str, default="64")
parser.add_argument("--validation_guidance_scale", type=str, default="4.5")
parser.add_argument("--validation_steps", type=float, default=64)
parser.add_argument("--log_validation", action="store_true")
parser.add_argument("--tracker_project_name", type=str, default=None)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."),
)
parser.add_argument("--shift", type=float, default=1.0)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--resume_from_lora_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
# optimizer & scheduler & Training
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=10,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--selective_checkpointing", type=float, default=1.0)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument(
"--use_cpu_offload",
action="store_true",
help="Whether to use CPU offload for param & gradient & optimizer states.",
)
parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
parser.add_argument(
"--train_sp_batch_size",
type=int,
default=1,
help="Batch size for sequence parallel training",
)
parser.add_argument(
"--use_lora",
action="store_true",
default=False,
help="Whether to use LoRA for finetuning.",
)
parser.add_argument("--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA.")
parser.add_argument("--lora_rank", type=int, default=128, help="LoRA rank parameter. ")
parser.add_argument("--fsdp_sharding_startegy", default="full")
# lr_scheduler
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument("--num_euler_timesteps", type=int, default=100)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of cycles in the learning rate scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument(
"--not_apply_cfg_solver",
action="store_true",
help="Whether to apply the cfg_solver.",
)
parser.add_argument("--distill_cfg", type=float, default=3.0, help="Distillation coefficient.")
# ["euler_linear_quadratic", "pcm", "pcm_linear_qudratic"]
parser.add_argument("--scheduler_type", type=str, default="pcm", help="The scheduler type to use.")
parser.add_argument(
"--linear_quadratic_threshold",
type=float,
default=0.025,
help="Threshold for linear quadratic scheduler.",
)
parser.add_argument(
"--linear_range",
type=float,
default=0.5,
help="Range for linear quadratic scheduler.",
)
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay to apply.")
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA.")
parser.add_argument("--multi_phased_distill_schedule", type=str, default=None)
parser.add_argument("--pred_decay_weight", type=float, default=0.0)
parser.add_argument("--pred_decay_type", default="l1")
parser.add_argument("--hunyuan_teacher_disable_cfg", action="store_true")
parser.add_argument(
"--master_weight_type",
type=str,
default="fp32",
help="Weight type to use - fp32 or bf16.",
)
args = parser.parse_args()
main(args)
import torch.nn as nn
from diffusers.utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DiscriminatorHead(nn.Module):
def __init__(self, input_channel, output_channel=1):
super().__init__()
inner_channel = 1024
self.conv1 = nn.Sequential(
nn.Conv2d(input_channel, inner_channel, 1, 1, 0),
nn.GroupNorm(32, inner_channel),
nn.LeakyReLU(inplace=True), # use LeakyReLu instead of GELU shown in the paper to save memory
)
self.conv2 = nn.Sequential(
nn.Conv2d(inner_channel, inner_channel, 1, 1, 0),
nn.GroupNorm(32, inner_channel),
nn.LeakyReLU(inplace=True), # use LeakyReLu instead of GELU shown in the paper to save memory
)
self.conv_out = nn.Conv2d(inner_channel, output_channel, 1, 1, 0)
def forward(self, x):
b, twh, c = x.shape
t = twh // (30 * 53)
x = x.view(-1, 30 * 53, c)
x = x.permute(0, 2, 1)
x = x.view(b * t, c, 30, 53)
x = self.conv1(x)
x = self.conv2(x) + x
x = self.conv_out(x)
return x
class Discriminator(nn.Module):
def __init__(
self,
stride=8,
num_h_per_head=1,
adapter_channel_dims=[3072],
total_layers=48,
):
super().__init__()
adapter_channel_dims = adapter_channel_dims * (total_layers // stride)
self.stride = stride
self.num_h_per_head = num_h_per_head
self.head_num = len(adapter_channel_dims)
self.heads = nn.ModuleList([
nn.ModuleList([DiscriminatorHead(adapter_channel) for _ in range(self.num_h_per_head)])
for adapter_channel in adapter_channel_dims
])
def forward(self, features):
outputs = []
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
assert len(features) == len(self.heads)
for i in range(0, len(features)):
for h in self.heads[i]:
# out = torch.utils.checkpoint.checkpoint(
# create_custom_forward(h),
# features[i],
# use_reentrant=False
# )
out = h(features[i])
outputs.append(out)
return outputs
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput, logging
from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class PCMFMSchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
class PCMFMScheduler(SchedulerMixin, ConfigMixin):
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
pcm_timesteps: int = 50,
linear_quadratic=False,
linear_quadratic_threshold=0.025,
linear_range=0.5,
):
if linear_quadratic:
linear_steps = int(num_train_timesteps * linear_range)
sigmas = linear_quadratic_schedule(num_train_timesteps, linear_quadratic_threshold, linear_steps)
sigmas = torch.tensor(sigmas).to(dtype=torch.float32)
else:
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.euler_timesteps = (np.arange(1, pcm_timesteps + 1) *
(num_train_timesteps // pcm_timesteps)).round().astype(np.int64) - 1
self.sigmas = sigmas.numpy()[::-1][self.euler_timesteps]
self.sigmas = torch.from_numpy((self.sigmas[::-1].copy()))
self.timesteps = self.sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
inference_indices = np.linspace(0, self.config.pcm_timesteps, num=num_inference_steps, endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = torch.from_numpy(inference_indices).long()
self.sigmas_ = self.sigmas[inference_indices]
timesteps = self.sigmas_ * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas_ = torch.cat([self.sigmas_, torch.zeros(1, device=self.sigmas_.device)])
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[PCMFMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (isinstance(timestep, int) or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)):
raise ValueError(("Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."), )
if self.step_index is None:
self._init_step_index(timestep)
sample = sample.to(torch.float32)
sigma = self.sigmas_[self.step_index]
denoised = sample - model_output * sigma
derivative = (sample - denoised) / sigma
dt = self.sigmas_[self.step_index + 1] - sigma
prev_sample = sample + derivative * dt
prev_sample = prev_sample.to(model_output.dtype)
self._step_index += 1
if not return_dict:
return (prev_sample, )
return PCMFMSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
class EulerSolver:
def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
self.step_ratio = timesteps // euler_timesteps
self.euler_timesteps = (np.arange(1, euler_timesteps + 1) * self.step_ratio).round().astype(np.int64) - 1
self.euler_timesteps_prev = np.asarray([0] + self.euler_timesteps[:-1].tolist())
self.sigmas = sigmas[self.euler_timesteps]
self.sigmas_prev = np.asarray([sigmas[0]] +
sigmas[self.euler_timesteps[:-1]].tolist()) # either use sigma0 or 0
self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
self.euler_timesteps_prev = torch.from_numpy(self.euler_timesteps_prev).long()
self.sigmas = torch.from_numpy(self.sigmas)
self.sigmas_prev = torch.from_numpy(self.sigmas_prev)
def to(self, device):
self.euler_timesteps = self.euler_timesteps.to(device)
self.euler_timesteps_prev = self.euler_timesteps_prev.to(device)
self.sigmas = self.sigmas.to(device)
self.sigmas_prev = self.sigmas_prev.to(device)
return self
def euler_step(self, sample, model_pred, timestep_index):
sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape)
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index, model_pred.shape)
x_prev = sample + (sigma_prev - sigma) * model_pred
return x_prev
def euler_style_multiphase_pred(
self,
sample,
model_pred,
timestep_index,
multiphase,
is_target=False,
):
inference_indices = np.linspace(0, len(self.euler_timesteps), num=multiphase, endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (torch.from_numpy(inference_indices).long().to(self.euler_timesteps.device))
expanded_timestep_index = timestep_index.unsqueeze(1).expand(-1, inference_indices.size(0))
valid_indices_mask = expanded_timestep_index >= inference_indices
last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1)
last_valid_index = inference_indices.size(0) - 1 - last_valid_index
timestep_index_end = inference_indices[last_valid_index]
if is_target:
sigma = extract_into_tensor(self.sigmas_prev, timestep_index, sample.shape)
else:
sigma = extract_into_tensor(self.sigmas, timestep_index, sample.shape)
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index_end, sample.shape)
x_prev = sample + (sigma_prev - sigma) * model_pred
return x_prev, timestep_index_end
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment