Commit 0e56f303 authored by mashun's avatar mashun
Browse files

pyramid-flow

parents
Pipeline #2007 canceled with stages
#!/bin/bash
# This script is used for Pyramid-Flow Video Generation Training (Using Temporal Pyramid and autoregressive training)
# It enables the autoregressive video generative training with temporal pyramid
# make sure to set, NUM_FRAMES % VIDEO_SYNC_GROUP == 0; GPUS % VIDEO_SYNC_GROUP == 0
export HIP_VISIBLE_DEVICES=4,5,6,7
GPUS=4 # The gpu number
SHARD_STRATEGY=zero2 # zero2 or zero3
VIDEO_SYNC_GROUP=4 # values in [4, 8, 16] The number of process that accepts the same input video, used for temporal pyramid AR training.
MODEL_NAME=pyramid_flux # The model name, `pyramid_flux` or `pyramid_mmdit`
MODEL_PATH=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/ # The downloaded ckpt dir. IMPORTANT: It should match with model_name, flux or mmdit (sd3)
VARIANT=diffusion_transformer_384p # The DiT Variant
OUTPUT_DIR=./temp_dit # The checkpoint saving dir
BATCH_SIZE=4 # It should satisfy batch_size % 4 == 0
GRAD_ACCU_STEPS=1
RESOLUTION="384p" # 384p or 768p
NUM_FRAMES=8 # e.g., 16 for 5s, 32 for 10s
ANNO_FILE=annotation/customs/video_text.jsonl # The video annotation file path
# For the 768p version, make sure to add the args: --gradient_checkpointing
torchrun --nproc_per_node $GPUS \
train/train_pyramid_flow.py \
--num_workers 8 \
--task t2v \
--use_fsdp \
--fsdp_shard_strategy $SHARD_STRATEGY \
--use_temporal_causal \
--use_temporal_pyramid \
--interp_condition_pos \
--sync_video_input \
--video_sync_group $VIDEO_SYNC_GROUP \
--load_text_encoder \
--model_name $MODEL_NAME \
--model_path $MODEL_PATH \
--model_dtype bf16 \
--model_variant $VARIANT \
--schedule_shift 1.0 \
--gradient_accumulation_steps $GRAD_ACCU_STEPS \
--output_dir $OUTPUT_DIR \
--batch_size $BATCH_SIZE \
--max_frames $NUM_FRAMES \
--resolution $RESOLUTION \
--anno_file $ANNO_FILE \
--frame_per_unit 1 \
--lr_scheduler constant_with_warmup \
--opt adamw \
--opt_beta1 0.9 \
--opt_beta2 0.95 \
--seed 42 \
--weight_decay 1e-4 \
--clip_grad 1.0 \
--lr 5e-5 \
--warmup_steps 1000 \
--epochs 2 \
--iters_per_epoch 1000 \
--report_to tensorboard \
--print_freq 40 \
--save_ckpt_freq 1
\ No newline at end of file
#!/bin/bash
# This script is used for Pyramid-Flow Image and Video Generation Training (without using Temporal Pyramid and autoregressive training)
# Since the design of spatial pyramid and temporal pyramid are decoupled, we can only use the spatial pyramid flow
# to train with full-sequence diffusion, which is also more effective than the normal flow matching training strategy
GPUS=4 # The gpu number
TASK=t2i # t2i or t2v
SHARD_STRATEGY=zero2 # zero2 or zero3
MODEL_NAME=pyramid_flux # The model name, `pyramid_flux` or `pyramid_mmdit`
MODEL_PATH=/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux/ # The downloaded ckpt dir. IMPORTANT: It should match with model_name, flux or mmdit (sd3)
VARIANT=diffusion_transformer_image # The DiT Variant, diffusion_transformer_image or diffusion_transformer_384p
OUTPUT_DIR=./temp_dit_no_ar # The checkpoint saving dir
NUM_FRAMES=8 # e.g., 8 for 2s, 16 for 5s, 32 for 10s
BATCH_SIZE=4 # It should satisfy batch_size % 4 == 0
RESOLUTION="384p" # 384p or 768p
ANNO_FILE=annotation/customs/image_text.jsonl # The annotation file path
torchrun --nproc_per_node $GPUS \
train/train_pyramid_flow.py \
--num_workers 8 \
--task $TASK \
--load_vae \
--use_fsdp \
--fsdp_shard_strategy $SHARD_STRATEGY \
--use_flash_attn \
--load_text_encoder \
--model_name $MODEL_NAME \
--model_path $MODEL_PATH \
--model_dtype bf16 \
--model_variant $VARIANT \
--schedule_shift 1.0 \
--gradient_accumulation_steps 1 \
--output_dir $OUTPUT_DIR \
--batch_size $BATCH_SIZE \
--max_frames $NUM_FRAMES \
--resolution $RESOLUTION \
--anno_file $ANNO_FILE \
--frame_per_unit 1 \
--lr_scheduler constant_with_warmup \
--opt adamw \
--opt_beta1 0.9 \
--opt_beta2 0.95 \
--seed 42 \
--weight_decay 1e-4 \
--clip_grad 1.0 \
--lr 1e-4 \
--warmup_steps 100 \
--epochs 2 \
--iters_per_epoch 200 \
--report_to tensorboard \
--print_freq 40 \
--save_ckpt_freq 1
\ No newline at end of file
import os
import torch
import sys
sys.path.append(os.path.abspath('.'))
import argparse
import datetime
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from collections import OrderedDict
from einops import rearrange
import json
import jsonlines
from tqdm import tqdm
from torch.utils.data import DataLoader, DistributedSampler
from trainer_misc import init_distributed_mode
from pyramid_dit import (
SD3TextEncoderWithMask,
FluxTextEncoderWithMask,
)
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process script', add_help=False)
parser.add_argument('--batch_size', default=4, type=int)
parser.add_argument('--anno_file', type=str, default='', help="The video annotation file")
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16 or df16")
parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The Model Architecture Name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_path', default='', type=str, help='The pre-trained weight path')
return parser.parse_args()
class VideoTextDataset(Dataset):
def __init__(self, anno_file):
super().__init__()
self.annotation = []
with jsonlines.open(anno_file, 'r') as reader:
for item in tqdm(reader):
self.annotation.append(item) # The item is a dict that has key_name: text, text_fea
def __getitem__(self, index):
try:
anno = self.annotation[index]
text = anno['text']
text_fea_path = anno['text_fea'] # The text feature saving path
text_fea_save_dir = os.path.split(text_fea_path)[0]
if not os.path.exists(text_fea_save_dir):
os.makedirs(text_fea_save_dir, exist_ok=True)
return text, text_fea_path
except Exception as e:
print(f'Error with {e}')
return None, None
def __len__(self):
return len(self.annotation)
def build_data_loader(args):
def collate_fn(batch):
text_list = []
output_path_list = []
for text, text_fea_path in batch:
if text is not None:
text_list.append(text)
output_path_list.append(text_fea_path)
return {'text': text_list, 'output': output_path_list}
dataset = VideoTextDataset(args.anno_file)
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False)
loader = DataLoader(
dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
sampler=sampler, shuffle=False, collate_fn=collate_fn, drop_last=False
)
return loader
def build_model(args):
model_dtype = args.model_dtype
model_name = args.model_name
model_path = args.model_path
if model_dtype == 'bf16':
torch_dtype = torch.bfloat16
elif model_dtype == 'fp16':
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
if model_name == "pyramid_flux":
text_encoder = FluxTextEncoderWithMask(model_path, torch_dtype=torch_dtype)
elif model_name == "pyramid_mmdit":
text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
else:
raise NotImplementedError
return text_encoder
def main():
args = get_args()
init_distributed_mode(args)
# fix the seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
device = torch.device('cuda')
rank = args.rank
model = build_model(args)
model.to(device)
if args.model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif args.model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
data_loader = build_data_loader(args)
torch.distributed.barrier()
task_queue = []
for sample in tqdm(data_loader):
texts = sample['text']
outputs = sample['output']
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
prompt_embeds, prompt_attention_masks, pooled_prompt_embeds = model(texts, device)
for output_path, prompt_embed, prompt_attention_mask, pooled_prompt_embed in zip(outputs, prompt_embeds, prompt_attention_masks, pooled_prompt_embeds):
output_dict = {
'prompt_embed': prompt_embed.unsqueeze(0).cpu().clone(),
'prompt_attention_mask': prompt_attention_mask.unsqueeze(0).cpu().clone(),
'pooled_prompt_embed': pooled_prompt_embed.unsqueeze(0).cpu().clone(),
}
torch.save(output_dict, output_path)
torch.distributed.barrier()
if __name__ == '__main__':
main()
\ No newline at end of file
import os
import sys
sys.path.append(os.path.abspath('.'))
import argparse
import datetime
import numpy as np
import time
import torch
import io
import json
import jsonlines
import cv2
import math
import random
from pathlib import Path
from tqdm import tqdm
from concurrent import futures
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from collections import OrderedDict
from torchvision import transforms as pth_transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from trainer_misc import init_distributed_mode
from video_vae import CausalVideoVAELossWrapper
def get_transform(width, height, new_width=None, new_height=None, resize=False,):
transform_list = []
if resize:
# rescale according to the largest ratio
scale = max(new_width / width, new_height / height)
resized_width = round(width * scale)
resized_height = round(height * scale)
transform_list.append(pth_transforms.Resize((resized_height, resized_width), InterpolationMode.BICUBIC, antialias=True))
transform_list.append(pth_transforms.CenterCrop((new_height, new_width)))
transform_list.extend([
pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
transform_list = pth_transforms.Compose(transform_list)
return transform_list
def load_video_and_transform(video_path, frame_indexs, frame_number, new_width=None, new_height=None, resize=False):
video_capture = None
frame_indexs_set = set(frame_indexs)
try:
video_capture = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
frames = []
frame_index = 0
while True:
flag, frame = video_capture.read()
if not flag:
break
if frame_index > frame_indexs[-1]:
break
if frame_index in frame_indexs_set:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = torch.from_numpy(frame)
frame = frame.permute(2, 0, 1)
frames.append(frame)
frame_index += 1
video_capture.release()
if len(frames) == 0:
print(f"Empty video {video_path}")
return None
frames = frames[:frame_number]
duration = ((len(frames) - 1) // 8) * 8 + 1 # make sure the frames match: f * 8 + 1
frames = frames[:duration]
frames = torch.stack(frames).float() / 255
video_transform = get_transform(frames.shape[-1], frames.shape[-2], new_width, new_height, resize=resize)
frames = video_transform(frames).permute(1, 0, 2, 3)
return frames
except Exception as e:
print(f"Loading video: {video_path} exception {e}")
if video_capture is not None:
video_capture.release()
return None
class VideoDataset(Dataset):
def __init__(self, anno_file, width, height, num_frames):
super().__init__()
self.annotation = []
self.width = width
self.height = height
self.num_frames = num_frames
with jsonlines.open(anno_file, 'r') as reader:
for item in tqdm(reader):
self.annotation.append(item)
tot_len = len(self.annotation)
print(f"Totally {len(self.annotation)} videos")
def process_one_video(self, video_item):
videos_per_task = []
video_path = video_item['video']
output_latent_path = video_item['latent']
# The sampled frame indexs of a video, if not specified, load frames: [0, num_frames)
frame_indexs = video_item['frames'] if 'frames' in video_item else list(range(self.num_frames))
try:
video_frames_tensors = load_video_and_transform(
video_path,
frame_indexs,
frame_number=self.num_frames, # The num_frames to encode
new_width=self.width,
new_height=self.height,
resize=True
)
if video_frames_tensors is None:
return videos_per_task
video_frames_tensors = video_frames_tensors.unsqueeze(0)
videos_per_task.append({'video': video_path, 'input': video_frames_tensors, 'output': output_latent_path})
except Exception as e:
print(f"Load video tensor ERROR: {e}")
return videos_per_task
def __getitem__(self, index):
try:
video_item = self.annotation[index]
videos_per_task = self.process_one_video(video_item)
except Exception as e:
print(f'Error with {e}')
videos_per_task = []
return videos_per_task
def __len__(self):
return len(self.annotation)
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Training script', add_help=False)
parser.add_argument('--batch_size', default=4, type=int)
parser.add_argument('--model_path', default='', type=str, help='The pre-trained weight path')
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16 or df16")
parser.add_argument('--anno_file', type=str, default='', help="The video annotation file")
parser.add_argument('--width', type=int, default=640, help="The video width")
parser.add_argument('--height', type=int, default=384, help="The video height")
parser.add_argument('--num_frames', type=int, default=121, help="The frame number to encode")
parser.add_argument('--save_memory', action='store_true', help="Open the VAE tiling")
return parser.parse_args()
def build_model(args):
model_path = args.model_path
model_dtype = args.model_dtype
model = CausalVideoVAELossWrapper(model_path, model_dtype=model_dtype, interpolate=False, add_discriminator=False)
model = model.eval()
return model
def build_data_loader(args):
def collate_fn(batch):
return_batch = {'input' : [], 'output': []}
for videos_ in batch:
for video_input in videos_:
return_batch['input'].append(video_input['input'])
return_batch['output'].append(video_input['output'])
return return_batch
dataset = VideoDataset(args.anno_file, args.width, args.height, args.num_frames)
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False)
loader = DataLoader(
dataset, batch_size=args.batch_size, num_workers=6, pin_memory=True,
sampler=sampler, shuffle=False, collate_fn=collate_fn, drop_last=False, prefetch_factor=2,
)
return loader
def save_tensor(tensor, output_path):
try:
torch.save(tensor.clone(), output_path)
except Exception as e:
pass
def main():
args = get_args()
init_distributed_mode(args)
device = torch.device('cuda')
rank = args.rank
model = build_model(args)
model.to(device)
if args.model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif args.model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
data_loader = build_data_loader(args)
torch.distributed.barrier()
window_size = 16
temporal_chunk = True
task_queue = []
if args.save_memory:
# Open the tiling, to reduce gpu memory cost
model.vae.enable_tiling()
with futures.ThreadPoolExecutor(max_workers=16) as executor:
for sample in tqdm(data_loader):
input_video_list = sample['input']
output_path_list = sample['output']
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
for video_input, output_path in zip(input_video_list, output_path_list):
video_latent = model.encode_latent(video_input.to(device), sample=True, window_size=window_size, temporal_chunk=temporal_chunk, tile_sample_min_size=256)
video_latent = video_latent.to(torch_dtype).cpu()
task_queue.append(executor.submit(save_tensor, video_latent, output_path))
for future in futures.as_completed(task_queue):
res = future.result()
torch.distributed.barrier()
if __name__ == "__main__":
main()
\ No newline at end of file
import os
import sys
sys.path.append(os.path.abspath('.'))
import argparse
import datetime
import numpy as np
import time
import torch
import logging
import json
import math
import random
import diffusers
import transformers
from pathlib import Path
from packaging import version
from copy import deepcopy
from dataset import (
ImageTextDataset,
LengthGroupedVideoTextDataset,
create_image_text_dataloaders,
create_length_grouped_video_text_dataloader
)
from pyramid_dit import (
PyramidDiTForVideoGeneration,
JointTransformerBlock,
FluxSingleTransformerBlock,
FluxTransformerBlock,
)
from trainer_misc import (
init_distributed_mode,
setup_for_distributed,
create_optimizer,
train_one_epoch_with_fsdp,
constant_scheduler,
cosine_scheduler,
)
from trainer_misc import (
is_sequence_parallel_initialized,
init_sequence_parallel_group,
get_sequence_parallel_proc_num,
init_sync_input_group,
get_sync_input_group,
)
from collections import OrderedDict
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
FullStateDictConfig,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
ShardingStrategy,
BackwardPrefetch,
MixedPrecision,
CPUOffload,
StateDictType,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
from transformers.models.t5.modeling_t5 import T5Block
import accelerate
from accelerate import Accelerator
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
from accelerate import FullyShardedDataParallelPlugin
from diffusers.utils import is_wandb_available
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from diffusers.optimization import get_scheduler
logger = get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser('Pyramid-Flow Multi-process Training script', add_help=False)
parser.add_argument('--task', default='t2v', type=str, choices=["t2v", "t2i"], help="Training image generation or video generation")
parser.add_argument('--batch_size', default=4, type=int, help="The per device batch size")
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--print_freq', default=20, type=int)
parser.add_argument('--iters_per_epoch', default=2000, type=int)
parser.add_argument('--save_ckpt_freq', default=20, type=int)
# Model parameters
parser.add_argument('--ema_update', action='store_true')
parser.add_argument('--ema_decay', default=0.9999, type=float, metavar='MODEL', help='ema decay rate')
parser.add_argument('--load_ema_model', default='', type=str, help='The ema model checkpoint loading')
parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The Model Architecture Name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_path', default='', type=str, help='The pre-trained dit weight path')
parser.add_argument('--model_variant', default='diffusion_transformer_384p', type=str, help='The dit model variant', choices=['diffusion_transformer_768p', 'diffusion_transformer_384p', 'diffusion_transformer_image'])
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16 or fp16", choices=['bf16', 'fp16'])
parser.add_argument('--load_model_ema_to_cpu', action='store_true')
# FSDP condig
parser.add_argument('--use_fsdp', action='store_true')
parser.add_argument('--fsdp_shard_strategy', default='zero2', type=str, choices=['zero2', 'zero3'])
# The training manner config
parser.add_argument('--use_flash_attn', action='store_true')
parser.add_argument('--use_temporal_causal', action='store_true', default=True)
parser.add_argument('--interp_condition_pos', action='store_true', default=True)
parser.add_argument('--sync_video_input', action='store_true', help="whether to sync the video input")
parser.add_argument('--load_text_encoder', action='store_true', help="whether to load the text encoder during training")
parser.add_argument('--load_vae', action='store_true', help="whether to load the video vae during training")
# Sequence Parallel config
parser.add_argument('--use_sequence_parallel', action='store_true')
parser.add_argument('--sp_group_size', default=1, type=int, help="The group size of sequence parallel")
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of process used for video training, default=-1 means using all process. This args indicated using how many processes for video training")
# Model input config
parser.add_argument('--max_frames', default=16, type=int, help='number of max video frames')
parser.add_argument('--frame_per_unit', default=1, type=int, help="The number of frames per training unit")
parser.add_argument('--schedule_shift', default=1.0, type=float, help="The flow matching schedule shift")
parser.add_argument('--corrupt_ratio', default=1/3, type=float, help="The corruption ratio for the clean history in AR training")
# Dataset Cconfig
parser.add_argument('--anno_file', default='', type=str, help="The annotation jsonl file")
parser.add_argument('--resolution', default='384p', type=str, help="The input resolution", choices=['384p', '768p'])
# Training set config
parser.add_argument('--dit_pretrained_weight', default='', type=str, help='The pretrained dit checkpoint')
parser.add_argument('--vae_pretrained_weight', default='', type=str,)
parser.add_argument('--not_add_normalize', action='store_true')
parser.add_argument('--use_temporal_pyramid', action='store_true', help="Whether to use the AR temporal pyramid training for video generation")
parser.add_argument('--gradient_checkpointing', action='store_true')
parser.add_argument('--gradient_checkpointing_ratio', type=float, default=0.75, help="The ratio of transformer blocks used for gradient_checkpointing")
parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--video_sync_group', default=8, type=int, help="The number of process that accepts the same input video, used for temporal pyramid AR training. \
This contributes to stable AR training. We recommend to set this value to 4, 8 or 16. If you have enough GPUs, set it equals to max_frames (16 for 5s, 32 for 10s), \
make sure to satisfy `max_frames % video_sync_group == 0`")
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_beta1', default=0.9, type=float, metavar='BETA1',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--opt_beta2', default=0.999, type=float, metavar='BETA2',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='weight decay (default: 1e-4)')
parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
help='learning rate (default: 5e-5)')
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument(
"--lr_scheduler", type=str, default="constant_with_warmup",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
help='epochs to warmup LR, if scheduler supports')
# Dataset parameters
parser.add_argument('--output_dir', type=str, default='',
help='path where to save, empty for no saving')
parser.add_argument('--logging_dir', type=str, default='log', help='path where to tensorboard log')
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
# Distributed Training parameters
parser.add_argument('--device', default='cuda', type=str,
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.set_defaults(auto_resume=True)
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--global_step', default=0, type=int, metavar='N', help='The global optimization step')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training', type=str)
return parser.parse_args()
def build_model_runner(args):
model_dtype = args.model_dtype
model_path = args.model_path
model_name = args.model_name
model_variant = args.model_variant
print(f"Load the {model_name} model checkpoint from path: {model_path}, using dtype {model_dtype}")
sample_ratios = [1, 2, 1] # The sample_ratios of each stage
assert args.batch_size % int(sum(sample_ratios)) == 0, "The batchsize should be diivided by sum(sample_ratios)"
runner = PyramidDiTForVideoGeneration(
model_path,
model_dtype,
model_name=model_name,
use_gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_ratio=args.gradient_checkpointing_ratio,
return_log=True,
model_variant=model_variant,
timestep_shift=args.schedule_shift,
stages=[1, 2, 4], # using 3 stages
stage_range=[0, 1/3, 2/3, 1],
sample_ratios=sample_ratios, # The sample proportion in a training batch
use_mixed_training=True,
use_flash_attn=args.use_flash_attn,
load_text_encoder=args.load_text_encoder,
load_vae=args.load_vae,
max_temporal_length=args.max_frames,
frame_per_unit=args.frame_per_unit,
use_temporal_causal=args.use_temporal_causal,
corrupt_ratio=args.corrupt_ratio,
interp_condition_pos=args.interp_condition_pos,
video_sync_group=args.video_sync_group,
)
if args.dit_pretrained_weight:
dit_pretrained_weight = args.dit_pretrained_weight
print(f"Loading the pre-trained DiT checkpoint from {dit_pretrained_weight}")
runner.load_checkpoint(dit_pretrained_weight)
if args.vae_pretrained_weight:
vae_pretrained_weight = args.vae_pretrained_weight
print(f"Loading the pre-trained VAE checkpoint from {vae_pretrained_weight}")
runner.load_vae_checkpoint(vae_pretrained_weight)
return runner
def auto_resume(args, accelerator):
if len(args.resume) > 0:
path = args.resume
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint does not exist. Starting a new training run."
)
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
return initial_global_step
def build_fsdp_plugin(args):
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP if args.fsdp_shard_strategy == 'zero2' else ShardingStrategy.FULL_SHARD,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
auto_wrap_policy=ModuleWrapPolicy([FluxSingleTransformerBlock, FluxTransformerBlock, JointTransformerBlock, T5Block, CLIPEncoderLayer]),
cpu_offload=CPUOffload(offload_params=False),
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
)
return fsdp_plugin
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
# Initialize the Environment variables throught MPI run
init_distributed_mode(args, init_pytorch_ddp=False) # set `init_pytorch_ddp` to False, since the accelerate will do later
if args.use_fsdp:
fsdp_plugin = build_fsdp_plugin(args)
else:
fsdp_plugin = None
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.model_dtype,
log_with=args.report_to,
project_config=accelerator_project_config,
fsdp_plugin=fsdp_plugin,
)
# To block the print on non main process
setup_for_distributed(accelerator.is_main_process)
# If uses the sequence parallel
if args.use_sequence_parallel:
assert args.sp_group_size > 1, "Sequence Parallel needs group size > 1"
init_sequence_parallel_group(args)
print(f"Using sequence parallel, the parallel size is {args.sp_group_size}")
if args.sp_proc_num == -1:
args.sp_proc_num = accelerator.num_processes # if not specified, all processes are used for video training
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed, device_specific=True)
device = accelerator.device
# building model
runner = build_model_runner(args)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
if runner.vae:
logger.info(f"Rank {args.rank}: Casting VAE to {weight_dtype}", main_process_only=False)
runner.vae.to(dtype=weight_dtype)
if runner.text_encoder:
logger.info(f"Rank {args.rank}: Casting TextEncoder to {weight_dtype}", main_process_only=False)
runner.text_encoder.to(dtype=weight_dtype)
# building dataloader
global_rank = accelerator.process_index
anno_file = args.anno_file
if args.task == 't2i':
# For image generation training
if args.resolution == '384p':
image_ratios = [1/1, 3/5, 5/3]
image_sizes = [(512, 512), (384, 640), (640, 384)]
else:
assert args.resolution == '768p'
image_ratios = [1/1, 3/5, 5/3]
image_sizes = [(1024, 1024), (768, 1280), (1280, 768)]
image_text_dataset = ImageTextDataset(
anno_file,
add_normalize=not args.not_add_normalize,
ratios=image_ratios,
sizes=image_sizes,
)
train_dataloader = create_image_text_dataloaders(
image_text_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
multi_aspect_ratio=True,
epoch=args.seed,
sizes=image_sizes,
use_distributed=True,
world_size=accelerator.num_processes,
rank=global_rank,
)
else:
assert args.task == 't2v'
# For video generation training
video_text_dataset = LengthGroupedVideoTextDataset(
anno_file,
max_frames=args.max_frames,
resolution=args.resolution,
load_vae_latent=not args.load_vae,
load_text_fea=not args.load_text_encoder,
)
if args.sync_video_input:
assert args.sp_proc_num % args.video_sync_group == 0, "The video_sync_group should be divided by world size"
assert args.max_frames % args.video_sync_group == 0, "The video_sync_group should be divided by num_frames"
train_dataloader = create_length_grouped_video_text_dataloader(
video_text_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
max_frames=args.max_frames,
world_size=args.sp_proc_num // args.video_sync_group,
rank=global_rank // args.video_sync_group,
epoch=args.seed,
use_distributed=True,
)
else:
train_dataloader = create_length_grouped_video_text_dataloader(
video_text_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
max_frames=args.max_frames,
world_size=args.sp_proc_num,
rank=global_rank,
epoch=args.seed,
use_distributed=True,
)
accelerator.wait_for_everyone()
logger.info("Building dataset finished")
# building ema model
model_ema = deepcopy(runner.dit) if args.ema_update else None
if model_ema:
model_ema.eval()
# set the ema model not update by gradient
if model_ema:
model_ema.to(dtype=weight_dtype)
for param in model_ema.parameters():
param.requires_grad = False
# report model details
n_learnable_parameters = sum(p.numel() for p in runner.dit.parameters() if p.requires_grad)
n_fix_parameters = sum(p.numel() for p in runner.dit.parameters() if not p.requires_grad)
logger.info(f'total number of learnable params: {n_learnable_parameters / 1e6} M')
logger.info(f'total number of fixed params in : {n_fix_parameters / 1e6} M')
# `accelerate` 0.16.0 will have better support for customized saving
# Register Hook to load and save model_ema
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if model_ema:
model_ema_state = model_ema.state_dict()
torch.save(model_ema_state, os.path.join(output_dir, 'pytorch_model_ema.bin'))
def load_model_hook(models, input_dir):
if model_ema:
model_ema_path = os.path.join(input_dir, 'pytorch_model_ema.bin')
if os.path.exists(model_ema_path):
model_ema_state = torch.load(model_ema_path, map_location='cpu')
load_res = model_ema.load_state_dict(model_ema_state)
print(f"Loading ema weights {load_res}")
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Create the Optimizer
optimizer = create_optimizer(args, runner.dit)
logger.info(f"optimizer: {optimizer}")
# Create the LR scheduler
num_training_steps_per_epoch = args.iters_per_epoch
args.max_train_steps = args.epochs * num_training_steps_per_epoch
warmup_iters = args.warmup_epochs * num_training_steps_per_epoch
if args.warmup_steps > 0:
warmup_iters = args.warmup_steps
logger.info(f"LRScheduler: {args.lr_scheduler}, Warmup steps: {warmup_iters * args.gradient_accumulation_steps}")
if args.lr_scheduler == 'cosine':
lr_schedule_values = cosine_scheduler(
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
)
elif args.lr_scheduler == 'constant_with_warmup':
lr_schedule_values = constant_scheduler(
args.lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
)
else:
raise NotImplementedError(f"Not Implemented for scheduler {args.lr_scheduler}")
# Wrap the model, optmizer, and scheduler with accelerate
logger.info(f'before accelerator.prepare')
if fsdp_plugin is not None:
logger.info(f'show fsdp configs:')
print('accelerator.state.fsdp_plugin.use_orig_params', accelerator.state.fsdp_plugin.use_orig_params)
print('accelerator.state.fsdp_plugin.sync_module_states', accelerator.state.fsdp_plugin.sync_module_states)
print('accelerator.state.fsdp_plugin.forward_prefetch', accelerator.state.fsdp_plugin.forward_prefetch)
print('accelerator.state.fsdp_plugin.mixed_precision_policy', accelerator.state.fsdp_plugin.mixed_precision_policy)
print('accelerator.state.fsdp_plugin.backward_prefetch', accelerator.state.fsdp_plugin.backward_prefetch)
# Only wrapping the trained dit and huge text encoder
runner.dit, optimizer = accelerator.prepare(runner.dit, optimizer)
# Load the VAE and EMAmodel to GPU
if runner.vae:
runner.vae.to(device)
if runner.text_encoder:
runner.text_encoder.to(device)
logger.info(f'after accelerator.prepare')
logger.info(f'{runner.dit}')
if model_ema and (not args.load_model_ema_to_cpu):
model_ema.to(device)
if accelerator.is_main_process:
accelerator.init_trackers(os.path.basename(args.output_dir), config=vars(args))
# Report the training info
total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info("LR = %.8f" % args.lr)
logger.info("Min LR = %.8f" % args.min_lr)
logger.info("Weigth Decay = %.8f" % args.weight_decay)
logger.info("Batch size = %d" % total_batch_size)
logger.info("Number of training steps = %d" % (num_training_steps_per_epoch * args.epochs))
logger.info("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
# Auto resume the checkpoint
initial_global_step = auto_resume(args, accelerator)
first_epoch = initial_global_step // num_training_steps_per_epoch
# Start Train!
start_time = time.time()
accelerator.wait_for_everyone()
for epoch in range(first_epoch, args.epochs):
train_stats = train_one_epoch_with_fsdp(
runner,
model_ema,
accelerator,
args.model_dtype,
train_dataloader,
optimizer,
lr_schedule_values,
device,
epoch,
args.clip_grad,
start_steps=epoch * num_training_steps_per_epoch,
args=args,
print_freq=args.print_freq,
iters_per_epoch=num_training_steps_per_epoch,
ema_decay=args.ema_decay,
use_temporal_pyramid=args.use_temporal_pyramid,
)
if args.output_dir:
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
if accelerator.sync_gradients:
global_step = num_training_steps_per_epoch * (epoch + 1)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path, safe_serialization=False)
logger.info(f"Saved state to {save_path}")
accelerator.wait_for_everyone()
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch, 'n_parameters': n_learnable_parameters}
if args.output_dir and accelerator.is_main_process:
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == '__main__':
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
opts = get_args()
if opts.output_dir:
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
main(opts)
import sys
import os
sys.path.append(os.path.abspath('.'))
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import random
from pathlib import Path
from collections import OrderedDict
from dataset import (
ImageDataset,
VideoDataset,
create_mixed_dataloaders,
)
from trainer_misc import (
NativeScalerWithGradNormCount,
create_optimizer,
train_one_epoch,
auto_load_model,
save_model,
init_distributed_mode,
cosine_scheduler,
)
from video_vae import CausalVideoVAELossWrapper
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import utils
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Training script for Video VAE', add_help=False)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--print_freq', default=20, type=int)
parser.add_argument('--iters_per_epoch', default=2000, type=int)
parser.add_argument('--save_ckpt_freq', default=20, type=int)
# Model parameters
parser.add_argument('--ema_update', action='store_true')
parser.add_argument('--ema_decay', default=0.99, type=float, metavar='MODEL', help='ema decay for quantizer')
parser.add_argument('--model_path', default='', type=str, help='The vae weight path')
parser.add_argument('--model_dtype', default='bf16', help="The Model Dtype: bf16 or df16")
# Using the context parallel to distribute multiple video clips to different devices
parser.add_argument('--use_context_parallel', action='store_true')
parser.add_argument('--context_size', default=2, type=int, help="The context length size")
parser.add_argument('--resolution', default=256, type=int, help="The input resolution for VAE training")
parser.add_argument('--max_frames', default=24, type=int, help='number of max video frames')
parser.add_argument('--use_image_video_mixed_training', action='store_true', help="Whether to use the mixed image and video training")
# The loss weights
parser.add_argument('--lpips_ckpt', default="/home/jinyang06/models/vae/video_vae_baseline/vgg_lpips.pth", type=str, help="The LPIPS checkpoint path")
parser.add_argument('--disc_start', default=0, type=int, help="The start iteration for adding GAN Loss")
parser.add_argument('--logvar_init', default=0.0, type=float, help="The log var init" )
parser.add_argument('--kl_weight', default=1e-6, type=float, help="The KL loss weight")
parser.add_argument('--pixelloss_weight', default=1.0, type=float, help="The pixel reconstruction loss weight")
parser.add_argument('--perceptual_weight', default=1.0, type=float, help="The perception loss weight")
parser.add_argument('--disc_weight', default=0.1, type=float, help="The GAN loss weight")
parser.add_argument('--pretrained_vae_weight', default='', type=str, help='The pretrained vae ckpt path')
parser.add_argument('--not_add_normalize', action='store_true')
parser.add_argument('--add_discriminator', action='store_true')
parser.add_argument('--freeze_encoder', action='store_true')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='weight decay (default: 1e-4)')
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
weight decay. We use a cosine schedule for WD.
(Set the same value with args.weight_decay to keep weight decay no change)""")
parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
help='learning rate (default: 5e-5)')
parser.add_argument('--lr_disc', type=float, default=1e-5, metavar='LR',
help='learning rate (default: 1e-5) of the discriminator')
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
help='epochs to warmup LR, if scheduler supports')
# Dataset parameters
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--image_anno', default='', type=str, help="The image data annotation file path")
parser.add_argument('--video_anno', default='', type=str, help="The video data annotation file path")
parser.add_argument('--image_mix_ratio', default=0.1, type=float, help="The image data proportion in the training batch")
# Distributed Training parameters
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
parser.set_defaults(auto_resume=True)
parser.add_argument('--dist_eval', action='store_true', default=True,
help='Enabling distributed evaluation')
parser.add_argument('--disable_eval', action='store_true', default=False)
parser.add_argument('--eval', action='store_true', default=False, help="Perform evaluation only")
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--global_step', default=0, type=int, metavar='N', help='The global optimization step')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser.parse_args()
def build_model(args):
model_dtype = args.model_dtype
model_path = args.model_path
print(f"Load the base VideoVAE checkpoint from path: {model_path}, using dtype {model_dtype}")
model = CausalVideoVAELossWrapper(
model_path,
model_dtype='fp32', # For training, we used mixed training
disc_start=args.disc_start,
logvar_init=args.logvar_init,
kl_weight=args.kl_weight,
pixelloss_weight=args.pixelloss_weight,
perceptual_weight=args.perceptual_weight,
disc_weight=args.disc_weight,
interpolate=False,
add_discriminator=args.add_discriminator,
freeze_encoder=args.freeze_encoder,
load_loss_module=True,
lpips_ckpt=args.lpips_ckpt,
)
if args.pretrained_vae_weight:
pretrained_vae_weight = args.pretrained_vae_weight
print(f"Loading the vae checkpoint from {pretrained_vae_weight}")
model.load_checkpoint(pretrained_vae_weight)
return model
def main(args):
init_distributed_mode(args)
# If enabled, distribute multiple video clips to different devices
if args.use_context_parallel:
utils.initialize_context_parallel(args.context_size)
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
model = build_model(args)
world_size = utils.get_world_size()
global_rank = utils.get_rank()
num_training_steps_per_epoch = args.iters_per_epoch
log_writer = None
# building dataset and dataloaders
image_gpus = max(1, int(world_size * args.image_mix_ratio))
if args.use_image_video_mixed_training:
video_gpus = world_size - image_gpus
else:
# only use video data
video_gpus = world_size
image_gpus = 0
if global_rank < video_gpus:
training_dataset = VideoDataset(args.video_anno, resolution=args.resolution,
max_frames=args.max_frames, add_normalize=not args.not_add_normalize)
else:
training_dataset = ImageDataset(args.image_anno, resolution=args.resolution,
max_frames=args.max_frames // 4, add_normalize=not args.not_add_normalize)
data_loader_train = create_mixed_dataloaders(
training_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
epoch=args.seed,
world_size=world_size,
rank=global_rank,
image_mix_ratio=args.image_mix_ratio,
)
torch.distributed.barrier()
model.to(device)
model_without_ddp = model
n_learnable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
n_fix_parameters = sum(p.numel() for p in model.parameters() if not p.requires_grad)
for name, p in model.named_parameters():
if not p.requires_grad:
print(name)
print(f'total number of learnable params: {n_learnable_parameters / 1e6} M')
print(f'total number of fixed params in : {n_fix_parameters / 1e6} M')
total_batch_size = args.batch_size * utils.get_world_size()
print("LR = %.8f" % args.lr)
print("Min LR = %.8f" % args.min_lr)
print("Weigth Decay = %.8f" % args.weight_decay)
print("Batch size = %d" % total_batch_size)
print("Number of training steps = %d" % (num_training_steps_per_epoch * args.epochs))
print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
optimizer = create_optimizer(args, model_without_ddp.vae)
optimizer_disc = create_optimizer(args, model_without_ddp.loss.discriminator) if args.add_discriminator else None
loss_scaler = NativeScalerWithGradNormCount(enabled=True if args.model_dtype == "fp16" else False)
loss_scaler_disc = NativeScalerWithGradNormCount(enabled=True if args.model_dtype == "fp16" else False) if args.add_discriminator else None
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
model_without_ddp = model.module
print("Use step level LR & WD scheduler!")
lr_schedule_values = cosine_scheduler(
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
)
lr_schedule_values_disc = cosine_scheduler(
args.lr_disc, args.min_lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
) if args.add_discriminator else None
auto_load_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, optimizer_disc=optimizer_disc,
)
print(f"Start training for {args.epochs} epochs, the global iterations is {args.global_step}")
start_time = time.time()
torch.distributed.barrier()
for epoch in range(args.start_epoch, args.epochs):
train_stats = train_one_epoch(
model,
args.model_dtype,
data_loader_train,
optimizer,
optimizer_disc,
device,
epoch,
loss_scaler,
loss_scaler_disc,
args.clip_grad,
log_writer=log_writer,
start_steps=epoch * num_training_steps_per_epoch,
lr_schedule_values=lr_schedule_values,
lr_schedule_values_disc=lr_schedule_values_disc,
args=args,
print_freq=args.print_freq,
iters_per_epoch=num_training_steps_per_epoch,
)
if args.output_dir:
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, save_ckpt_freq=args.save_ckpt_freq, optimizer_disc=optimizer_disc
)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch, 'n_parameters': n_learnable_parameters}
if args.output_dir and utils.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
opts = get_args()
if opts.output_dir:
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
main(opts)
from .utils import (
create_optimizer,
get_rank,
get_world_size,
is_main_process,
is_dist_avail_and_initialized,
init_distributed_mode,
setup_for_distributed,
cosine_scheduler,
constant_scheduler,
NativeScalerWithGradNormCount,
auto_load_model,
save_model,
)
from .sp_utils import (
is_sequence_parallel_initialized,
init_sequence_parallel_group,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sequence_parallel_group_rank,
get_sequence_parallel_proc_num,
init_sync_input_group,
get_sync_input_group,
)
from .communicate import all_to_all
from .fsdp_trainer import train_one_epoch_with_fsdp
from .vae_ddp_trainer import train_one_epoch
\ No newline at end of file
import torch
import torch.nn as nn
import math
import torch.distributed as dist
def _all_to_all(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
concat_output: bool,
):
if world_size == 1:
return input_
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
if concat_output:
return torch.cat(output_list, dim=gather_dim).contiguous()
else:
# For multi-gpus inference, the latent on each gpu are same, only remain the first one
return output_list[0]
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, process_group, world_size, scatter_dim, gather_dim, concat_output):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.world_size = world_size
ctx.concat_output = concat_output
output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim, concat_output)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = _all_to_all(
grad_output,
ctx.world_size,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
ctx.concat_output,
)
return (
grad_output,
None,
None,
None,
None,
)
def all_to_all(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
world_size: int = 1,
scatter_dim: int = 2,
gather_dim: int = 1,
concat_output: bool = True,
):
return _AllToAll.apply(input_, process_group, world_size, scatter_dim, gather_dim, concat_output)
\ No newline at end of file
import math
import sys
from typing import Iterable
import torch
import torch.nn as nn
import accelerate
from .utils import MetricLogger, SmoothedValue
def update_ema_for_dit(model, model_ema, accelerator, decay):
"""Apply exponential moving average update.
The weights are updated in-place as follow:
w_ema = w_ema * decay + (1 - decay) * w
Args:
model: active model that is being optimized
model_ema: running average model
decay: exponential decay parameter
"""
with torch.no_grad():
msd = accelerator.get_state_dict(model)
for k, ema_v in model_ema.state_dict().items():
if k in msd:
model_v = msd[k].detach().to(ema_v.device, dtype=ema_v.dtype)
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
def get_decay(optimization_step: int, ema_decay: float) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - 1)
if step <= 0:
return 0.0
cur_decay_value = (1 + step) / (10 + step)
cur_decay_value = min(cur_decay_value, ema_decay)
cur_decay_value = max(cur_decay_value, 0.0)
return cur_decay_value
def train_one_epoch_with_fsdp(
runner,
model_ema: torch.nn.Module,
accelerator: accelerate.Accelerator,
model_dtype: str,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
lr_schedule_values,
device: torch.device,
epoch: int,
clip_grad: float = 1.0,
start_steps=None,
args=None,
print_freq=20,
iters_per_epoch=2000,
ema_decay=0.9999,
use_temporal_pyramid=True,
):
runner.dit.train()
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('min_lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
train_loss = 0.0
print("Start training epoch {}, {} iters per inner epoch. Training dtype {}".format(epoch, iters_per_epoch, model_dtype))
for step in metric_logger.log_every(range(iters_per_epoch), print_freq, header):
if step >= iters_per_epoch:
break
if lr_schedule_values is not None:
for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedule_values[start_steps] * param_group.get("lr_scale", 1.0)
for _ in range(args.gradient_accumulation_steps):
with accelerator.accumulate(runner.dit):
# To fetch the data sample and Move the input to device
samples = next(data_loader)
video = samples['video'].to(accelerator.device)
text = samples['text']
identifier = samples['identifier']
# Perform the forward using the accerlate
loss, log_loss = runner(video, text, identifier,
use_temporal_pyramid=use_temporal_pyramid, accelerator=accelerator)
# Check if the loss is nan
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value), force=True)
sys.exit(1)
avg_loss = accelerator.gather(loss.repeat(args.batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
accelerator.backward(loss)
# clip the gradient
if accelerator.sync_gradients:
params_to_clip = runner.dit.parameters()
grad_norm = accelerator.clip_grad_norm_(params_to_clip, clip_grad)
# To deal with the abnormal data point
if train_loss >= 2.0:
print(f"The ERROR data sample, finding extreme high loss {train_loss}, skip updating the parameters", force=True)
# zero out the gradient, do not update
optimizer.zero_grad()
train_loss = 0.001 # fix the loss for logging
else:
optimizer.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
# Update every 100 steps
if model_ema is not None and start_steps % 100 == 0:
# cur_ema_decay = get_decay(start_steps, ema_decay)
cur_ema_decay = ema_decay
update_ema_for_dit(runner.dit, model_ema, accelerator, decay=cur_ema_decay)
start_steps += 1
# Report to tensorboard
accelerator.log({"train_loss": train_loss}, step=start_steps)
metric_logger.update(loss=train_loss)
train_loss = 0.0
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group["weight_decay"] > 0:
weight_decay_value = group["weight_decay"]
metric_logger.update(weight_decay=weight_decay_value)
metric_logger.update(grad_norm=grad_norm)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
\ No newline at end of file
import os
import torch
import torch.distributed as dist
from .utils import is_dist_avail_and_initialized, get_rank
SEQ_PARALLEL_GROUP = None
SEQ_PARALLEL_SIZE = None
SEQ_PARALLEL_PROC_NUM = None # using how many process for sequence parallel
SYNC_INPUT_GROUP = None
SYNC_INPUT_SIZE = None
def is_sequence_parallel_initialized():
if SEQ_PARALLEL_GROUP is None:
return False
else:
return True
def init_sequence_parallel_group(args):
global SEQ_PARALLEL_GROUP
global SEQ_PARALLEL_SIZE
global SEQ_PARALLEL_PROC_NUM
assert SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
SEQ_PARALLEL_SIZE = args.sp_group_size
print(f"Setting the Sequence Parallel Size {SEQ_PARALLEL_SIZE}")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if args.sp_proc_num == -1:
SEQ_PARALLEL_PROC_NUM = world_size
else:
SEQ_PARALLEL_PROC_NUM = args.sp_proc_num
assert SEQ_PARALLEL_PROC_NUM % SEQ_PARALLEL_SIZE == 0, "The process needs to be evenly divided"
for i in range(0, SEQ_PARALLEL_PROC_NUM, SEQ_PARALLEL_SIZE):
ranks = list(range(i, i + SEQ_PARALLEL_SIZE))
group = torch.distributed.new_group(ranks)
if rank in ranks:
SEQ_PARALLEL_GROUP = group
break
def init_sync_input_group(args):
global SYNC_INPUT_GROUP
global SYNC_INPUT_SIZE
assert SYNC_INPUT_GROUP is None, "parallel group is already initialized"
assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
SYNC_INPUT_SIZE = args.max_frames
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
for i in range(0, world_size, SYNC_INPUT_SIZE):
ranks = list(range(i, i + SYNC_INPUT_SIZE))
group = torch.distributed.new_group(ranks)
if rank in ranks:
SYNC_INPUT_GROUP = group
break
def get_sequence_parallel_group():
assert SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
return SEQ_PARALLEL_GROUP
def get_sync_input_group():
return SYNC_INPUT_GROUP
def get_sequence_parallel_world_size():
assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
return SEQ_PARALLEL_SIZE
def get_sequence_parallel_rank():
assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
rank = get_rank()
cp_rank = rank % SEQ_PARALLEL_SIZE
return cp_rank
def get_sequence_parallel_group_rank():
assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
rank = get_rank()
cp_group_rank = rank // SEQ_PARALLEL_SIZE
return cp_group_rank
def get_sequence_parallel_proc_num():
return SEQ_PARALLEL_PROC_NUM
import io
import os
import math
import time
import json
import glob
from collections import defaultdict, deque, OrderedDict
import datetime
import numpy as np
from pathlib import Path
import argparse
import torch
from torch import optim as optim
import torch.distributed as dist
try:
from torch._six import inf
except ImportError:
from torch import inf
from tensorboardX import SummaryWriter
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def init_distributed_mode(args, init_pytorch_ddp=True):
if int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')) > 0:
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
os.environ["LOCAL_RANK"] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
os.environ["RANK"] = os.environ['OMPI_COMM_WORLD_RANK']
os.environ["WORLD_SIZE"] = os.environ['OMPI_COMM_WORLD_SIZE']
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
args.dist_backend = 'nccl'
args.dist_url = "env://"
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
if init_pytorch_ddp:
# Init DDP Group, for script without using accelerate framework
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank, timeout=datetime.timedelta(days=365))
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
start_warmup_value=0, warmup_steps=-1):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_steps > 0:
warmup_iters = warmup_steps
print("Set warmup steps = %d" % warmup_iters)
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = np.array(
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def constant_scheduler(base_value, epochs, niter_per_ep, warmup_epochs=0,
start_warmup_value=1e-6, warmup_steps=-1):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_steps > 0:
warmup_iters = warmup_steps
print("Set warmup steps = %d" % warmup_iters)
if warmup_iters > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = epochs * niter_per_ep - warmup_iters
schedule = np.array([base_value] * iters)
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def get_parameter_groups(model, weight_decay=1e-5, base_lr=1e-4, skip_list=(), get_num_layer=None, get_layer_scale=None, **kwargs):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(kwargs.get('filter_name', [])) > 0:
flag = False
for filter_n in kwargs.get('filter_name', []):
if filter_n in name:
print(f"filter {name} because of the pattern {filter_n}")
flag = True
if flag:
continue
default_scale=1.
if param.ndim <= 1 or name.endswith(".bias") or name in skip_list: # param.ndim <= 1 len(param.shape) == 1
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if get_num_layer is not None:
layer_id = get_num_layer(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if get_layer_scale is not None:
scale = get_layer_scale(layer_id)
else:
scale = default_scale
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr": base_lr,
"lr_scale": scale,
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr": base_lr,
"lr_scale": scale,
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values())
def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None, **kwargs):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
skip = {}
if skip_list is not None:
skip = skip_list
elif hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
print(f"Skip weight decay name marked in model: {skip}")
parameters = get_parameter_groups(model, weight_decay, args.lr, skip, get_num_layer, get_layer_scale, **kwargs)
weight_decay = 0.
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if hasattr(args, 'opt_beta1') and args.opt_beta1 is not None:
opt_args['betas'] = (args.opt_beta1, args.opt_beta2)
print('Optimizer config:', opt_args)
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
elif opt_lower == 'adadelta':
optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, optimizer_disc=None):
output_dir = Path(args.output_dir)
if args.auto_resume and len(args.resume) == 0:
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint.pth'))
if len(all_checkpoints) > 0:
args.resume = os.path.join(output_dir, 'checkpoint.pth')
else:
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
print("Auto resume checkpoint: %s" % args.resume)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model']) # strict: bool=True, , strict=False
print("Resume checkpoint %s" % args.resume)
if ('optimizer' in checkpoint) and ('epoch' in checkpoint) and (optimizer is not None):
optimizer.load_state_dict(checkpoint['optimizer'])
print(f"Resume checkpoint at epoch {checkpoint['epoch']}, the global optmization step is {checkpoint['step']}")
args.start_epoch = checkpoint['epoch'] + 1
args.global_step = checkpoint['step'] + 1
if model_ema is not None:
if 'model_ema' in checkpoint:
ema_load_res = model_ema.load_state_dict(checkpoint["model_ema"])
print(f"EMA Model Resume results: {ema_load_res}")
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
if ('optimizer_disc' in checkpoint) and (optimizer_disc is not None):
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, optimizer_disc=None, save_ckpt_freq=1):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
checkpoint_paths = [output_dir / 'checkpoint.pth']
if epoch == 'best':
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name),]
elif (epoch + 1) % save_ckpt_freq == 0:
checkpoint_paths.append(output_dir / ('checkpoint-%s.pth' % epoch_name))
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'epoch': epoch,
'step' : args.global_step,
'args': args,
}
if optimizer is not None:
to_save['optimizer'] = optimizer.state_dict()
if loss_scaler is not None:
to_save['scaler'] = loss_scaler.state_dict()
if model_ema is not None:
to_save['model_ema'] = model_ema.state_dict()
if optimizer_disc is not None:
to_save['optimizer_disc'] = optimizer_disc.state_dict()
save_on_master(to_save, checkpoint_path)
def get_grad_norm_(parameters, norm_type: float = 2.0, layer_names=None) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
layer_norm = torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters])
total_norm = torch.norm(layer_norm, norm_type)
if layer_names is not None:
if torch.isnan(total_norm) or torch.isinf(total_norm) or total_norm > 1.0:
value_top, name_top = torch.topk(layer_norm, k=5)
print(f"Top norm value: {value_top}")
print(f"Top norm name: {[layer_names[i][7:] for i in name_top.tolist()]}")
return total_norm
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self, enabled=True):
print(f"Set the loss scaled to {enabled}")
self._scaler = torch.cuda.amp.GradScaler(enabled=enabled)
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, layer_names=None):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters, layer_names=layer_names)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
\ No newline at end of file
import math
import sys
from typing import Iterable
import torch
import torch.nn as nn
from .utils import (
MetricLogger,
SmoothedValue,
)
def train_one_epoch(
model: torch.nn.Module,
model_dtype: str,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
optimizer_disc: torch.optim.Optimizer,
device: torch.device,
epoch: int,
loss_scaler,
loss_scaler_disc,
clip_grad: float = 0,
log_writer=None,
lr_scheduler=None,
start_steps=None,
lr_schedule_values=None,
lr_schedule_values_disc=None,
args=None,
print_freq=20,
iters_per_epoch=2000,
):
# The trainer for causal video vae
model.train()
metric_logger = MetricLogger(delimiter=" ")
if optimizer is not None:
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('min_lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
if optimizer_disc is not None:
metric_logger.add_meter('disc_lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('disc_min_lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
if model_dtype == 'bf16':
_dtype = torch.bfloat16
else:
_dtype = torch.float16
print("Start training epoch {}, {} iters per inner epoch.".format(epoch, iters_per_epoch))
for step in metric_logger.log_every(range(iters_per_epoch), print_freq, header):
if step >= iters_per_epoch:
break
it = start_steps + step # global training iteration
if lr_schedule_values is not None:
for i, param_group in enumerate(optimizer.param_groups):
if lr_schedule_values is not None:
param_group["lr"] = lr_schedule_values[it] * param_group.get("lr_scale", 1.0)
if optimizer_disc is not None:
for i, param_group in enumerate(optimizer_disc.param_groups):
if lr_schedule_values_disc is not None:
param_group["lr"] = lr_schedule_values_disc[it] * param_group.get("lr_scale", 1.0)
samples = next(data_loader)
samples['video'] = samples['video'].to(device, non_blocking=True)
with torch.cuda.amp.autocast(enabled=True, dtype=_dtype):
rec_loss, gan_loss, log_loss = model(samples['video'], args.global_step, identifier=samples['identifier'])
###################################################################################################
# The update of rec_loss
if rec_loss is not None:
loss_value = rec_loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value), force=True)
sys.exit(1)
optimizer.zero_grad()
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
grad_norm = loss_scaler(rec_loss, optimizer, clip_grad=clip_grad,
parameters=model.module.vae.parameters(), create_graph=is_second_order)
if "scale" in loss_scaler.state_dict():
loss_scale_value = loss_scaler.state_dict()["scale"]
else:
loss_scale_value = 1
metric_logger.update(vae_loss=loss_value)
metric_logger.update(loss_scale=loss_scale_value)
###################################################################################################
# The updaet of gan_loss
if gan_loss is not None:
gan_loss_value = gan_loss.item()
if not math.isfinite(gan_loss_value):
print("The gan discriminator Loss is {}, stopping training".format(gan_loss_value), force=True)
sys.exit(1)
optimizer_disc.zero_grad()
is_second_order = hasattr(optimizer_disc, 'is_second_order') and optimizer_disc.is_second_order
disc_grad_norm = loss_scaler_disc(gan_loss, optimizer_disc, clip_grad=clip_grad,
parameters=model.module.loss.discriminator.parameters(), create_graph=is_second_order)
if "scale" in loss_scaler_disc.state_dict():
disc_loss_scale_value = loss_scaler_disc.state_dict()["scale"]
else:
disc_loss_scale_value = 1
metric_logger.update(disc_loss=gan_loss_value)
metric_logger.update(disc_loss_scale=disc_loss_scale_value)
metric_logger.update(disc_grad_norm=disc_grad_norm)
min_lr = 10.
max_lr = 0.
for group in optimizer_disc.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(disc_lr=max_lr)
metric_logger.update(disc_min_lr=min_lr)
torch.cuda.synchronize()
new_log_loss = {k.split('/')[-1]:v for k, v in log_loss.items() if k not in ['total_loss']}
metric_logger.update(**new_log_loss)
if rec_loss is not None:
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group["weight_decay"] > 0:
weight_decay_value = group["weight_decay"]
metric_logger.update(weight_decay=weight_decay_value)
metric_logger.update(grad_norm=grad_norm)
if log_writer is not None:
log_writer.update(**new_log_loss, head="train/loss")
log_writer.update(lr=max_lr, head="opt")
log_writer.update(min_lr=min_lr, head="opt")
log_writer.update(weight_decay=weight_decay_value, head="opt")
log_writer.update(grad_norm=grad_norm, head="opt")
log_writer.set_step()
if lr_scheduler is not None:
lr_scheduler.step_update(start_steps + step)
args.global_step = args.global_step + 1
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
import os
import torch
import PIL.Image
import numpy as np
from torch import nn
import torch.distributed as dist
import timm.models.hub as timm_hub
"""Modified from https://github.com/CompVis/taming-transformers.git"""
import hashlib
import requests
from tqdm import tqdm
try:
import piq
except:
pass
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_SIZE = None
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def is_context_parallel_initialized():
if _CONTEXT_PARALLEL_GROUP is None:
return False
else:
return True
def set_context_parallel_group(size, group):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_SIZE = size
def initialize_context_parallel(context_parallel_size):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
_CONTEXT_PARALLEL_SIZE = context_parallel_size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
for i in range(0, world_size, context_parallel_size):
ranks = range(i, i + context_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
break
def get_context_parallel_group():
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_world_size():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
return _CONTEXT_PARALLEL_SIZE
def get_context_parallel_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
rank = get_rank()
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
return cp_rank
def get_context_parallel_group_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
rank = get_rank()
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
return cp_group_rank
def download_cached_file(url, check_hash=True, progress=False):
"""
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
"""
def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
return cached_file
if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)
if is_dist_avail_and_initialized():
dist.barrier()
return get_cached_file_path()
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
l.weight.data = l.weight.data.to(torch.float16)
if l.bias is not None:
l.bias.data = l.bias.data.to(torch.float16)
model.apply(_convert_weights_to_fp16)
def convert_weights_to_bf16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_bf16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
l.weight.data = l.weight.data.to(torch.bfloat16)
if l.bias is not None:
l.bias.data = l.bias.data.to(torch.bfloat16)
model.apply(_convert_weights_to_bf16)
def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'):
import json
import jsonlines
print("Dump result")
# Make the temp dir for saving results
if not os.path.exists(result_dir):
if is_main_process():
os.makedirs(result_dir)
if is_dist_avail_and_initialized():
torch.distributed.barrier()
result_file = os.path.join(
result_dir, "%s_rank%d.json" % (filename, get_rank())
)
final_result_file = os.path.join(result_dir, f"{filename}.{save_format}")
json.dump(result, open(result_file, "w"))
if is_dist_avail_and_initialized():
torch.distributed.barrier()
if is_main_process():
# print("rank %d starts merging results." % get_rank())
# combine results from all processes
result = []
for rank in range(get_world_size()):
result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
res = json.load(open(result_file, "r"))
result += res
# print("Remove duplicate")
if remove_duplicate:
result_new = []
id_set = set()
for res in result:
if res[remove_duplicate] not in id_set:
id_set.add(res[remove_duplicate])
result_new.append(res)
result = result_new
if save_format == 'json':
json.dump(result, open(final_result_file, "w"))
else:
assert save_format == 'jsonl', "Only support json adn jsonl format"
with jsonlines.open(final_result_file, "w") as writer:
writer.write_all(result)
# print("result file saved to %s" % final_result_file)
return final_result_file
# resizing utils
# TODO: clean up later
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
h, w = input.shape[-2:]
factors = (h / size[0], w / size[1])
# First, we have to determine sigma
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
sigmas = (
max((factors[0] - 1.0) / 2.0, 0.001),
max((factors[1] - 1.0) / 2.0, 0.001),
)
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
# Make sure it is odd
if (ks[0] % 2) == 0:
ks = ks[0] + 1, ks[1]
if (ks[1] % 2) == 0:
ks = ks[0], ks[1] + 1
input = _gaussian_blur2d(input, ks, sigmas)
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
return output
def _compute_padding(kernel_size):
"""Compute padding tuple."""
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
if len(kernel_size) < 2:
raise AssertionError(kernel_size)
computed = [k - 1 for k in kernel_size]
# for even kernels we need to do asymmetric padding :(
out_padding = 2 * len(kernel_size) * [0]
for i in range(len(kernel_size)):
computed_tmp = computed[-(i + 1)]
pad_front = computed_tmp // 2
pad_rear = computed_tmp - pad_front
out_padding[2 * i + 0] = pad_front
out_padding[2 * i + 1] = pad_rear
return out_padding
def _filter2d(input, kernel):
# prepare kernel
b, c, h, w = input.shape
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
height, width = tmp_kernel.shape[-2:]
padding_shape: list[int] = _compute_padding([height, width])
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
# kernel and input tensor reshape to align element-wise or batch-wise params
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
# convolve the tensor with the kernel.
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
out = output.view(b, c, h, w)
return out
def _gaussian(window_size: int, sigma):
if isinstance(sigma, float):
sigma = torch.tensor([[sigma]])
batch_size = sigma.shape[0]
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
if window_size % 2 == 0:
x = x + 0.5
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
return gauss / gauss.sum(-1, keepdim=True)
def _gaussian_blur2d(input, kernel_size, sigma):
if isinstance(sigma, tuple):
sigma = torch.tensor([sigma], dtype=input.dtype)
else:
sigma = sigma.to(dtype=input.dtype)
ky, kx = int(kernel_size[0]), int(kernel_size[1])
bs = sigma.shape[0]
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
out_x = _filter2d(input, kernel_x[..., None, :])
out = _filter2d(out_x, kernel_y[..., None])
return out
URL_MAP = {
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}
CKPT_MAP = {
"vgg_lpips": "vgg.pth"
}
MD5_MAP = {
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
print(md5_hash(path))
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class KeyNotFoundError(Exception):
def __init__(self, cause, keys=None, visited=None):
self.cause = cause
self.keys = keys
self.visited = visited
messages = list()
if keys is not None:
messages.append("Key not found: {}".format(keys))
if visited is not None:
messages.append("Visited: {}".format(visited))
messages.append("Cause:\n{}".format(cause))
message = "\n".join(messages)
super().__init__(message)
def retrieve(
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
"""Given a nested list or dict return the desired value at key expanding
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
is done in-place.
Parameters
----------
list_or_dict : list or dict
Possibly nested list or dictionary.
key : str
key/to/value, path like string describing all keys necessary to
consider to get to the desired value. List indices can also be
passed here.
splitval : str
String that defines the delimiter between keys of the
different depth levels in `key`.
default : obj
Value returned if :attr:`key` is not found.
expand : bool
Whether to expand callable nodes on the path or not.
Returns
-------
The desired value or if :attr:`default` is not ``None`` and the
:attr:`key` is not found returns ``default``.
Raises
------
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
``None``.
"""
keys = key.split(splitval)
success = True
try:
visited = []
parent = None
last_key = None
for key in keys:
if callable(list_or_dict):
if not expand:
raise KeyNotFoundError(
ValueError(
"Trying to get past callable node with expand=False."
),
keys=keys,
visited=visited,
)
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
last_key = key
parent = list_or_dict
try:
if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
except (KeyError, IndexError, ValueError) as e:
raise KeyNotFoundError(e, keys=keys, visited=visited)
visited += [key]
# final expansion of retrieved value
if expand and callable(list_or_dict):
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
except KeyNotFoundError as e:
if default is None:
raise e
else:
list_or_dict = default
success = False
if not pass_success:
return list_or_dict
else:
return list_or_dict, success
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import numpy as np\n",
"import PIL\n",
"from PIL import Image\n",
"from IPython.display import HTML\n",
"from pyramid_dit import PyramidDiTForVideoGeneration\n",
"from IPython.display import Image as ipython_image\n",
"from diffusers.utils import load_image, export_to_video, export_to_gif"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"variant='diffusion_transformer_384p' # For low resolution\n",
"# variant='diffusion_transformer_768p' # For high resolution, the pyramid_flux does not support high res now, use pyramid_mmdit instead\n",
"\n",
"model_name = \"pyramid_flux\" # or \"pyramid_mmdit\"\n",
"\n",
"if model_name == \"pyramid_flux\":\n",
" assert variant != \"diffusion_transformer_768p\", \"The pyramid_flux does not support high resolution now, we will release it after finishing training\"\n",
"\n",
"model_path = \"/home/modelzoo/Pyramid-Flow/pyramid_flow_model/pyramid-flow-miniflux\" # The downloaded checkpoint dir\n",
"model_dtype = 'bf16'\n",
"\n",
"device_id = 0\n",
"torch.cuda.set_device(device_id)\n",
"\n",
"model = PyramidDiTForVideoGeneration(\n",
" model_path,\n",
" model_dtype,\n",
" model_name=model_name,\n",
" model_variant=variant,\n",
")\n",
"\n",
"model.vae.to(\"cuda\")\n",
"model.dit.to(\"cuda\")\n",
"model.text_encoder.to(\"cuda\")\n",
"\n",
"model.vae.enable_tiling()\n",
"\n",
"if model_dtype == \"bf16\":\n",
" torch_dtype = torch.bfloat16 \n",
"elif model_dtype == \"fp16\":\n",
" torch_dtype = torch.float16\n",
"else:\n",
" torch_dtype = torch.float32\n",
"\n",
"\n",
"def resize_crop_image(img: PIL.Image.Image, tgt_width, tgt_height):\n",
" ori_width, ori_height = img.width, img.height\n",
" scale = max(tgt_width / ori_width, tgt_height / ori_height)\n",
" resized_width = round(ori_width * scale)\n",
" resized_height = round(ori_height * scale)\n",
" img = img.resize((resized_width, resized_height), resample=PIL.Image.LANCZOS)\n",
"\n",
" left = (resized_width - tgt_width) / 2\n",
" top = (resized_height - tgt_height) / 2\n",
" right = (resized_width + tgt_width) / 2\n",
" bottom = (resized_height + tgt_height) / 2\n",
"\n",
" # Crop the center of the image\n",
" img = img.crop((left, top, right, bottom))\n",
" \n",
" return img\n",
"\n",
"\n",
"def show_video(ori_path, rec_path, width=\"100%\"):\n",
" html = ''\n",
" if ori_path is not None:\n",
" html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
" <source src=\"{ori_path}\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\"\n",
" \n",
" html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
" <source src=\"{rec_path}\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\"\n",
" return HTML(html)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Text-to-Video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = 'a woman is walking'\n",
"prompt = \"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors\"\n",
"\n",
"# used for 384p model variant\n",
"width = 640\n",
"height = 384\n",
"\n",
"# used for 768p model variant\n",
"# width = 1280\n",
"# height = 768\n",
"\n",
"temp = 16 # temp in [1, 31] <=> frame in [1, 241] <=> duration in [0, 10s]\n",
"# For the 384p version, only supports maximum 5s generation (temp = 16)\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
" frames = model.generate(\n",
" prompt=prompt,\n",
" num_inference_steps=[20, 20, 20],\n",
" video_num_inference_steps=[10, 10, 10],\n",
" height=height,\n",
" width=width,\n",
" temp=temp,\n",
" guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant\n",
" video_guidance_scale=5.0, # The guidance for the other video latent\n",
" output_type=\"pil\",\n",
" save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
" )\n",
"\n",
"export_to_video(frames, \"./text_to_video_sample.mp4\", fps=24)\n",
"show_video(None, \"./text_to_video_sample.mp4\", \"70%\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Image-to-Video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_path = 'assets/the_great_wall.jpg'\n",
"image = Image.open(image_path).convert(\"RGB\")\n",
"\n",
"# used for 384p model variant\n",
"width = 640\n",
"height = 384\n",
"\n",
"# used for 768p model variant\n",
"# width = 1280\n",
"# height = 768\n",
"\n",
"temp = 16\n",
"image = image.resize((width, height))\n",
"image = resize_crop_image(image, width, height)\n",
"\n",
"display(image)\n",
"\n",
"prompt = \"FPV flying over the Great Wall\"\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
" frames = model.generate_i2v(\n",
" prompt=prompt,\n",
" input_image=image,\n",
" num_inference_steps=[10, 10, 10],\n",
" temp=temp,\n",
" guidance_scale=7.0,\n",
" video_guidance_scale=4.0,\n",
" output_type=\"pil\",\n",
" save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
" )\n",
"\n",
"export_to_video(frames, \"./image_to_video_sample.mp4\", fps=24)\n",
"show_video(None, \"./image_to_video_sample.mp4\", \"70%\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
from .modeling_loss import LPIPSWithDiscriminator
from .modeling_causal_vae import CausalVideoVAE
from .causal_video_vae_wrapper import CausalVideoVAELossWrapper
\ No newline at end of file
import torch
import os
import torch.nn as nn
from collections import OrderedDict
from .modeling_causal_vae import CausalVideoVAE
from .modeling_loss import LPIPSWithDiscriminator
from einops import rearrange
from PIL import Image
from IPython import embed
from utils import (
is_context_parallel_initialized,
get_context_parallel_group,
get_context_parallel_world_size,
get_context_parallel_rank,
get_context_parallel_group_rank,
)
from .context_parallel_ops import (
conv_scatter_to_context_parallel_region,
conv_gather_from_context_parallel_region,
)
class CausalVideoVAELossWrapper(nn.Module):
"""
The causal video vae training and inference running wrapper
"""
def __init__(self, model_path, model_dtype='fp32', disc_start=0, logvar_init=0.0, kl_weight=1.0,
pixelloss_weight=1.0, perceptual_weight=1.0, disc_weight=0.5, interpolate=True,
add_discriminator=True, freeze_encoder=False, load_loss_module=False, lpips_ckpt=None, **kwargs,
):
super().__init__()
if model_dtype == 'bf16':
torch_dtype = torch.bfloat16
elif model_dtype == 'fp16':
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
self.vae = CausalVideoVAE.from_pretrained(model_path, torch_dtype=torch_dtype, interpolate=False)
self.vae_scale_factor = self.vae.config.scaling_factor
if freeze_encoder:
print("Freeze the parameters of vae encoder")
for parameter in self.vae.encoder.parameters():
parameter.requires_grad = False
for parameter in self.vae.quant_conv.parameters():
parameter.requires_grad = False
self.add_discriminator = add_discriminator
self.freeze_encoder = freeze_encoder
# Used for training
if load_loss_module:
self.loss = LPIPSWithDiscriminator(disc_start, logvar_init=logvar_init, kl_weight=kl_weight,
pixelloss_weight=pixelloss_weight, perceptual_weight=perceptual_weight, disc_weight=disc_weight,
add_discriminator=add_discriminator, using_3d_discriminator=False, disc_num_layers=4, lpips_ckpt=lpips_ckpt)
else:
self.loss = None
self.disc_start = disc_start
# print(self.loss)
# exit()
def load_checkpoint(self, checkpoint_path, **kwargs):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'model' in checkpoint:
checkpoint = checkpoint['model']
vae_checkpoint = OrderedDict()
disc_checkpoint = OrderedDict()
for key in checkpoint.keys():
if key.startswith('vae.'):
new_key = key.split('.')
new_key = '.'.join(new_key[1:])
vae_checkpoint[new_key] = checkpoint[key]
if key.startswith('loss.discriminator'):
new_key = key.split('.')
new_key = '.'.join(new_key[2:])
disc_checkpoint[new_key] = checkpoint[key]
vae_ckpt_load_result = self.vae.load_state_dict(vae_checkpoint, strict=False)
print(f"Load vae checkpoint from {checkpoint_path}, load result: {vae_ckpt_load_result}")
disc_ckpt_load_result = self.loss.discriminator.load_state_dict(disc_checkpoint, strict=False)
print(f"Load disc checkpoint from {checkpoint_path}, load result: {disc_ckpt_load_result}")
def forward(self, x, step, identifier=['video']):
xdim = x.ndim
if xdim == 4:
x = x.unsqueeze(2) # (B, C, H, W) -> (B, C, 1, H , W)
if 'video' in identifier:
# The input is video
assert 'image' not in identifier
else:
# The input is image
assert 'video' not in identifier
# We arrange multiple images to a 5D Tensor for compatibility with video input
# So we needs to reformulate images into 1-frame video tensor
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = x.unsqueeze(2) # [(b t) c 1 h w]
if is_context_parallel_initialized():
assert self.training, "Only supports during training now"
cp_world_size = get_context_parallel_world_size()
global_src_rank = get_context_parallel_group_rank() * cp_world_size
# sync the input and split
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
batch_x = conv_scatter_to_context_parallel_region(x, dim=2, kernel_size=1)
else:
batch_x = x
posterior, reconstruct = self.vae(batch_x, freeze_encoder=self.freeze_encoder,
is_init_image=True, temporal_chunk=False,)
# The reconstruct loss
reconstruct_loss, rec_log = self.loss(
batch_x, reconstruct, posterior,
optimizer_idx=0, global_step=step, last_layer=self.vae.get_last_layer(),
)
if step < self.disc_start:
return reconstruct_loss, None, rec_log
# The loss to train the discriminator
gan_loss, gan_log = self.loss(batch_x, reconstruct, posterior, optimizer_idx=1,
global_step=step, last_layer=self.vae.get_last_layer(),
)
loss_log = {**rec_log, **gan_log}
return reconstruct_loss, gan_loss, loss_log
def encode(self, x, sample=False, is_init_image=True,
temporal_chunk=False, window_size=16, tile_sample_min_size=256,):
# x: (B, C, T, H, W) or (B, C, H, W)
B = x.shape[0]
xdim = x.ndim
if xdim == 4:
# The input is an image
x = x.unsqueeze(2)
if sample:
x = self.vae.encode(
x, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
window_size=window_size, tile_sample_min_size=tile_sample_min_size,
).latent_dist.sample()
else:
x = self.vae.encode(
x, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
window_size=window_size, tile_sample_min_size=tile_sample_min_size,
).latent_dist.mode()
return x
def decode(self, x, is_init_image=True, temporal_chunk=False,
window_size=2, tile_sample_min_size=256,):
# x: (B, C, T, H, W) or (B, C, H, W)
B = x.shape[0]
xdim = x.ndim
if xdim == 4:
# The input is an image
x = x.unsqueeze(2)
x = self.vae.decode(
x, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
window_size=window_size, tile_sample_min_size=tile_sample_min_size,
).sample
return x
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def reconstruct(
self, x, sample=False, return_latent=False, is_init_image=True,
temporal_chunk=False, window_size=16, tile_sample_min_size=256, **kwargs
):
assert x.shape[0] == 1
xdim = x.ndim
encode_window_size = window_size
decode_window_size = window_size // self.vae.downsample_scale
# Encode
x = self.encode(
x, sample, is_init_image, temporal_chunk, encode_window_size, tile_sample_min_size,
)
encode_latent = x
# Decode
x = self.decode(
x, is_init_image, temporal_chunk, decode_window_size, tile_sample_min_size
)
output_image = x.float()
output_image = (output_image / 2 + 0.5).clamp(0, 1)
# Convert to PIL images
output_image = rearrange(output_image, "B C T H W -> (B T) C H W")
output_image = output_image.cpu().permute(0, 2, 3, 1).numpy()
output_images = self.numpy_to_pil(output_image)
if return_latent:
return output_images, encode_latent
return output_images
# encode vae latent
def encode_latent(self, x, sample=False, is_init_image=True,
temporal_chunk=False, window_size=16, tile_sample_min_size=256,):
# Encode
latent = self.encode(
x, sample, is_init_image, temporal_chunk, window_size, tile_sample_min_size,
)
return latent
# decode vae latent
def decode_latent(self, latent, is_init_image=True,
temporal_chunk=False, window_size=2, tile_sample_min_size=256,):
x = self.decode(
latent, is_init_image, temporal_chunk, window_size, tile_sample_min_size
)
output_image = x.float()
output_image = (output_image / 2 + 0.5).clamp(0, 1)
# Convert to PIL images
output_image = rearrange(output_image, "B C T H W -> (B T) C H W")
output_image = output_image.cpu().permute(0, 2, 3, 1).numpy()
output_images = self.numpy_to_pil(output_image)
return output_images
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
\ No newline at end of file
# from cogvideoX
import torch
import torch.nn as nn
import math
from utils import (
get_context_parallel_group,
get_context_parallel_rank,
get_context_parallel_world_size,
get_context_parallel_group_rank,
)
def _conv_split(input_, dim=2, kernel_size=1):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
cp_rank = get_context_parallel_rank()
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
if cp_rank == 0:
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
else:
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(dim, 0)[
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
].transpose(dim, 0)
output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _conv_gather(input_, dim=2, kernel_size=1):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
else:
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
torch.empty_like(input_) for _ in range(cp_world_size - 1)
]
if cp_rank == 0:
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _cp_pass_from_previous_rank(input_, dim, kernel_size):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
cp_group_rank = get_context_parallel_group_rank()
cp_world_size = get_context_parallel_world_size()
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
input_ = input_.transpose(0, dim)
# pass from last rank
send_rank = global_rank + 1
recv_rank = global_rank - 1
if send_rank % cp_world_size == 0:
send_rank -= cp_world_size
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
input_ = torch.cat([torch.zeros_like(input_[:1])] * (kernel_size - 1) + [input_], dim=0)
else:
req_recv.wait()
input_ = torch.cat([recv_buffer, input_], dim=0)
input_ = input_.transpose(0, dim).contiguous()
return input_
def _drop_from_previous_rank(input_, dim, kernel_size):
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
return input_
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_split(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_gather(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
class _CPConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _cp_pass_from_previous_rank(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
def cp_pass_from_previous_rank(input_, dim, kernel_size):
return _CPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from diffusers.utils import logging
from diffusers.models.attention_processor import Attention
from .modeling_resnet import (
Downsample2D, ResnetBlock2D, CausalResnetBlock3D, Upsample2D,
TemporalDownsample2x, TemporalUpsample2x,
CausalDownsample2x, CausalTemporalDownsample2x,
CausalUpsample2x, CausalTemporalUpsample2x,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_input_layer(
in_channels: int,
out_channels: int,
norm_num_groups: int,
layer_type: str,
norm_type: str = 'group',
affine: bool = True,
):
if layer_type == 'conv':
input_layer = nn.Conv3d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
)
elif layer_type == 'pixel_shuffle':
input_layer = nn.Sequential(
nn.PixelUnshuffle(2),
nn.Conv2d(in_channels * 4, out_channels, kernel_size=1),
)
else:
raise NotImplementedError(f"Not support input layer {layer_type}")
return input_layer
def get_output_layer(
in_channels: int,
out_channels: int,
norm_num_groups: int,
layer_type: str,
norm_type: str = 'group',
affine: bool = True,
):
if layer_type == 'norm_act_conv':
output_layer = nn.Sequential(
nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6, affine=affine),
nn.SiLU(),
nn.Conv3d(in_channels, out_channels, 3, stride=1, padding=1),
)
elif layer_type == 'pixel_shuffle':
output_layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
nn.PixelShuffle(2),
)
else:
raise NotImplementedError(f"Not support output layer {layer_type}")
return output_layer
def get_down_block(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int = None,
temb_channels: int = None,
add_spatial_downsample: bool = None,
add_temporal_downsample: bool = None,
resnet_eps: float = 1e-6,
resnet_act_fn: str = 'silu',
resnet_groups: Optional[int] = None,
downsample_padding: Optional[int] = None,
resnet_time_scale_shift: str = "default",
attention_head_dim: Optional[int] = None,
dropout: float = 0.0,
norm_affline: bool = True,
norm_layer: str = 'layer',
):
if down_block_type == "DownEncoderBlock2D":
return DownEncoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
add_spatial_downsample=add_spatial_downsample,
add_temporal_downsample=add_temporal_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "DownEncoderBlockCausal3D":
return DownEncoderBlockCausal3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
add_spatial_downsample=add_spatial_downsample,
add_temporal_downsample=add_temporal_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
prev_output_channel: int = None,
temb_channels: int = None,
add_spatial_upsample: bool = None,
add_temporal_upsample: bool = None,
resnet_eps: float = 1e-6,
resnet_act_fn: str = 'silu',
resolution_idx: Optional[int] = None,
resnet_groups: Optional[int] = None,
resnet_time_scale_shift: str = "default",
attention_head_dim: Optional[int] = None,
dropout: float = 0.0,
interpolate: bool = True,
norm_affline: bool = True,
norm_layer: str = 'layer',
) -> nn.Module:
if up_block_type == "UpDecoderBlock2D":
return UpDecoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_spatial_upsample=add_spatial_upsample,
add_temporal_upsample=add_temporal_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
interpolate=interpolate,
)
elif up_block_type == "UpDecoderBlockCausal3D":
return UpDecoderBlockCausal3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_spatial_upsample=add_spatial_upsample,
add_temporal_upsample=add_temporal_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
interpolate=interpolate,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlock2D(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
Args:
in_channels (`int`): The number of input channels.
temb_channels (`int`): The number of temporal embedding channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
model on tasks with long-range temporal dependencies.
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
Whether to use pre-normalization for the resnet blocks.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_head_dim (`int`, *optional*, defaults to 1):
Dimension of a single attention head. The number of attention heads is determined based on this value and
the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
# Spatial attention
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
t = hidden_states.shape[2]
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = rearrange(hidden_states, 'b c t h w -> b t c h w')
hidden_states = rearrange(hidden_states, 'b t c h w -> (b t) c h w')
hidden_states = attn(hidden_states, temb=temb)
hidden_states = rearrange(hidden_states, '(b t) c h w -> b t c h w', t=t)
hidden_states = rearrange(hidden_states, 'b t c h w -> b c t h w')
hidden_states = resnet(hidden_states, temb)
return hidden_states
class CausalUNetMidBlock2D(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
Args:
in_channels (`int`): The number of input channels.
temb_channels (`int`): The number of temporal embedding channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
model on tasks with long-range temporal dependencies.
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
Whether to use pre-normalization for the resnet blocks.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_head_dim (`int`, *optional*, defaults to 1):
Dimension of a single attention head. The number of attention heads is determined based on this value and
the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
# there is always at least one resnet
resnets = [
CausalResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
# Spatial attention
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
CausalResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
t = hidden_states.shape[2]
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = rearrange(hidden_states, 'b c t h w -> b t c h w')
hidden_states = rearrange(hidden_states, 'b t c h w -> (b t) c h w')
hidden_states = attn(hidden_states, temb=temb)
hidden_states = rearrange(hidden_states, '(b t) c h w -> b t c h w', t=t)
hidden_states = rearrange(hidden_states, 'b t c h w -> b c t h w')
hidden_states = resnet(hidden_states, temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return hidden_states
class DownEncoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_spatial_downsample: bool = True,
add_temporal_downsample: bool = False,
downsample_padding: int = 1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
CausalResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_spatial_downsample:
self.downsamplers = nn.ModuleList(
[
CausalDownsample2x(
out_channels, use_conv=True, out_channels=out_channels,
)
]
)
else:
self.downsamplers = None
if add_temporal_downsample:
self.temporal_downsamplers = nn.ModuleList(
[
CausalTemporalDownsample2x(
out_channels, use_conv=True, out_channels=out_channels,
)
]
)
else:
self.temporal_downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if self.temporal_downsamplers is not None:
for temporal_downsampler in self.temporal_downsamplers:
hidden_states = temporal_downsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return hidden_states
class DownEncoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_spatial_downsample: bool = True,
add_temporal_downsample: bool = False,
downsample_padding: int = 1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_spatial_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
if add_temporal_downsample:
self.temporal_downsamplers = nn.ModuleList(
[
TemporalDownsample2x(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding,
)
]
)
else:
self.temporal_downsamplers = None
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if self.temporal_downsamplers is not None:
for temporal_downsampler in self.temporal_downsamplers:
hidden_states = temporal_downsampler(hidden_states)
return hidden_states
class UpDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_spatial_upsample: bool = True,
add_temporal_upsample: bool = False,
temb_channels: Optional[int] = None,
interpolate: bool = True,
):
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_spatial_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
else:
self.upsamplers = None
if add_temporal_upsample:
self.temporal_upsamplers = nn.ModuleList([TemporalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
else:
self.temporal_upsamplers = None
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, is_image: bool = False,
) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
if self.temporal_upsamplers is not None:
for temporal_upsampler in self.temporal_upsamplers:
hidden_states = temporal_upsampler(hidden_states, is_image=is_image)
return hidden_states
class UpDecoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_spatial_upsample: bool = True,
add_temporal_upsample: bool = False,
temb_channels: Optional[int] = None,
interpolate: bool = True,
):
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
CausalResnetBlock3D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_spatial_upsample:
self.upsamplers = nn.ModuleList([CausalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
else:
self.upsamplers = None
if add_temporal_upsample:
self.temporal_upsamplers = nn.ModuleList([CausalTemporalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
else:
self.temporal_upsamplers = None
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
is_init_image=True, temporal_chunk=False,
) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if self.temporal_upsamplers is not None:
for temporal_upsampler in self.temporal_upsamplers:
hidden_states = temporal_upsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return hidden_states
from typing import Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from collections import deque
from einops import rearrange
from timm.models.layers import trunc_normal_
from torch import Tensor
from utils import (
is_context_parallel_initialized,
get_context_parallel_group,
get_context_parallel_world_size,
get_context_parallel_rank,
get_context_parallel_group_rank,
)
from .context_parallel_ops import (
conv_scatter_to_context_parallel_region,
conv_gather_from_context_parallel_region,
cp_pass_from_previous_rank,
)
def divisible_by(num, den):
return (num % den) == 0
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def is_odd(n):
return not divisible_by(n, 2)
class CausalGroupNorm(nn.GroupNorm):
def forward(self, x: Tensor) -> Tensor:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = super().forward(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
pad_mode: str ='constant',
**kwargs
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.time_kernel_size = time_kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop('dilation', 1)
self.pad_mode = pad_mode
if isinstance(stride, int):
stride = (stride, 1, 1)
time_pad = dilation * (time_kernel_size - 1)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.temporal_stride = stride[0]
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)
self.cache_front_feat = deque()
def _clear_context_parallel_cache(self):
del self.cache_front_feat
self.cache_front_feat = deque()
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def context_parallel_forward(self, x):
x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size)
x = F.pad(x, self.time_uncausal_padding, mode='constant')
cp_rank = get_context_parallel_rank()
if cp_rank != 0:
if self.temporal_stride == 2 and self.time_kernel_size == 3:
x = x[:,:,1:]
x = self.conv(x)
return x
def forward(self, x, is_init_image=True, temporal_chunk=False):
# temporal_chunk: whether to use the temporal chunk
if is_context_parallel_initialized():
return self.context_parallel_forward(x)
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
if not temporal_chunk:
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
else:
assert not self.training, "The feature cache should not be used in training"
if is_init_image:
# Encode the first chunk
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
self._clear_context_parallel_cache()
self.cache_front_feat.append(x[:, :, -2:].clone().detach())
else:
x = F.pad(x, self.time_uncausal_padding, mode=pad_mode)
video_front_context = self.cache_front_feat.pop()
self._clear_context_parallel_cache()
if self.temporal_stride == 1 and self.time_kernel_size == 3:
x = torch.cat([video_front_context, x], dim=2)
elif self.temporal_stride == 2 and self.time_kernel_size == 3:
x = torch.cat([video_front_context[:,:,-1:], x], dim=2)
self.cache_front_feat.append(x[:, :, -2:].clone().detach())
x = self.conv(x)
return x
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from .modeling_enc_dec import (
DecoderOutput, DiagonalGaussianDistribution,
CausalVaeDecoder, CausalVaeEncoder,
)
from .modeling_causal_conv import CausalConv3d
from utils import (
is_context_parallel_initialized,
get_context_parallel_group,
get_context_parallel_world_size,
get_context_parallel_rank,
get_context_parallel_group_rank,
)
from .context_parallel_ops import (
conv_scatter_to_context_parallel_region,
conv_gather_from_context_parallel_region,
)
class CausalVideoVAE(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
# encoder related parameters
encoder_in_channels: int = 3,
encoder_out_channels: int = 4,
encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 2),
encoder_down_block_types: Tuple[str, ...] = (
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D",
),
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
encoder_spatial_down_sample: Tuple[bool, ...] = (True, True, True, False),
encoder_temporal_down_sample: Tuple[bool, ...] = (True, True, True, False),
encoder_block_dropout: Tuple[int, ...] = (0.0, 0.0, 0.0, 0.0),
encoder_act_fn: str = "silu",
encoder_norm_num_groups: int = 32,
encoder_double_z: bool = True,
encoder_type: str = 'causal_vae_conv',
# decoder related
decoder_in_channels: int = 4,
decoder_out_channels: int = 3,
decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3),
decoder_up_block_types: Tuple[str, ...] = (
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D",
),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
decoder_spatial_up_sample: Tuple[bool, ...] = (True, True, True, False),
decoder_temporal_up_sample: Tuple[bool, ...] = (True, True, True, False),
decoder_block_dropout: Tuple[int, ...] = (0.0, 0.0, 0.0, 0.0),
decoder_act_fn: str = "silu",
decoder_norm_num_groups: int = 32,
decoder_type: str = 'causal_vae_conv',
sample_size: int = 256,
scaling_factor: float = 0.18215,
add_post_quant_conv: bool = True,
interpolate: bool = False,
downsample_scale: int = 8,
):
super().__init__()
print(f"The latent dimmension channes is {encoder_out_channels}")
# pass init params to Encoder
self.encoder = CausalVaeEncoder(
in_channels=encoder_in_channels,
out_channels=encoder_out_channels,
down_block_types=encoder_down_block_types,
spatial_down_sample=encoder_spatial_down_sample,
temporal_down_sample=encoder_temporal_down_sample,
block_out_channels=encoder_block_out_channels,
layers_per_block=encoder_layers_per_block,
act_fn=encoder_act_fn,
norm_num_groups=encoder_norm_num_groups,
double_z=True,
block_dropout=encoder_block_dropout,
)
# pass init params to Decoder
self.decoder = CausalVaeDecoder(
in_channels=decoder_in_channels,
out_channels=decoder_out_channels,
up_block_types=decoder_up_block_types,
spatial_up_sample=decoder_spatial_up_sample,
temporal_up_sample=decoder_temporal_up_sample,
block_out_channels=decoder_block_out_channels,
layers_per_block=decoder_layers_per_block,
norm_num_groups=decoder_norm_num_groups,
act_fn=decoder_act_fn,
interpolate=interpolate,
block_dropout=decoder_block_dropout,
)
self.quant_conv = CausalConv3d(2 * encoder_out_channels, 2 * encoder_out_channels, kernel_size=1, stride=1)
self.post_quant_conv = CausalConv3d(encoder_out_channels, encoder_out_channels, kernel_size=1, stride=1)
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / downsample_scale)
self.encode_tile_overlap_factor = 1 / 8
self.decode_tile_overlap_factor = 1 / 8
self.downsample_scale = downsample_scale
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def encode(
self, x: torch.FloatTensor, return_dict: bool = True,
is_init_image=True, temporal_chunk=False, window_size=16, tile_sample_min_size=256,
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
self.tile_sample_min_size = tile_sample_min_size
self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict, is_init_image=is_init_image,
temporal_chunk=temporal_chunk, window_size=window_size)
if temporal_chunk:
moments = self.chunk_encode(x, window_size=window_size)
else:
h = self.encoder(x, is_init_image=is_init_image, temporal_chunk=False)
moments = self.quant_conv(h, is_init_image=is_init_image, temporal_chunk=False)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@torch.no_grad()
def chunk_encode(self, x: torch.FloatTensor, window_size=16):
# Only used during inference
# Encode a long video clips through sliding window
num_frames = x.shape[2]
assert (num_frames - 1) % self.downsample_scale == 0
init_window_size = window_size + 1
frame_list = [x[:,:,:init_window_size]]
# To chunk the long video
full_chunk_size = (num_frames - init_window_size) // window_size
fid = init_window_size
for idx in range(full_chunk_size):
frame_list.append(x[:, :, fid:fid+window_size])
fid += window_size
if fid < num_frames:
frame_list.append(x[:, :, fid:])
latent_list = []
for idx, frames in enumerate(frame_list):
if idx == 0:
h = self.encoder(frames, is_init_image=True, temporal_chunk=True)
moments = self.quant_conv(h, is_init_image=True, temporal_chunk=True)
else:
h = self.encoder(frames, is_init_image=False, temporal_chunk=True)
moments = self.quant_conv(h, is_init_image=False, temporal_chunk=True)
latent_list.append(moments)
latent = torch.cat(latent_list, dim=2)
return latent
def get_last_layer(self):
return self.decoder.conv_out.conv.weight
@torch.no_grad()
def chunk_decode(self, z: torch.FloatTensor, window_size=2):
num_frames = z.shape[2]
init_window_size = window_size + 1
frame_list = [z[:,:,:init_window_size]]
# To chunk the long video
full_chunk_size = (num_frames - init_window_size) // window_size
fid = init_window_size
for idx in range(full_chunk_size):
frame_list.append(z[:, :, fid:fid+window_size])
fid += window_size
if fid < num_frames:
frame_list.append(z[:, :, fid:])
dec_list = []
for idx, frames in enumerate(frame_list):
if idx == 0:
z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True)
dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True)
else:
z_h = self.post_quant_conv(frames, is_init_image=False, temporal_chunk=True)
dec = self.decoder(z_h, is_init_image=False, temporal_chunk=True)
dec_list.append(dec)
dec = torch.cat(dec_list, dim=2)
return dec
def decode(self, z: torch.FloatTensor, is_init_image=True, temporal_chunk=False,
return_dict: bool = True, window_size: int = 2, tile_sample_min_size: int = 256,) -> Union[DecoderOutput, torch.FloatTensor]:
self.tile_sample_min_size = tile_sample_min_size
self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, is_init_image=is_init_image,
temporal_chunk=temporal_chunk, window_size=window_size, return_dict=return_dict)
if temporal_chunk:
dec = self.chunk_decode(z, window_size=window_size)
else:
z = self.post_quant_conv(z, is_init_image=is_init_image, temporal_chunk=False)
dec = self.decoder(z, is_init_image=is_init_image, temporal_chunk=False)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True,
is_init_image=True, temporal_chunk=False, window_size=16,) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.encode_tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.encode_tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[3], overlap_size):
row = []
for j in range(0, x.shape[4], overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
if temporal_chunk:
tile = self.chunk_encode(tile, window_size=window_size)
else:
tile = self.encoder(tile, is_init_image=True, temporal_chunk=False)
tile = self.quant_conv(tile, is_init_image=True, temporal_chunk=False)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
moments = torch.cat(result_rows, dim=3)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, is_init_image=True,
temporal_chunk=False, window_size=2, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.decode_tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.decode_tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[3], overlap_size):
row = []
for j in range(0, z.shape[4], overlap_size):
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
if temporal_chunk:
decoded = self.chunk_decode(tile, window_size=window_size)
else:
tile = self.post_quant_conv(tile, is_init_image=True, temporal_chunk=False)
decoded = self.decoder(tile, is_init_image=True, temporal_chunk=False)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = True,
generator: Optional[torch.Generator] = None,
freeze_encoder: bool = False,
is_init_image=True,
temporal_chunk=False,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
if is_context_parallel_initialized():
assert self.training, "Only supports during training now"
if freeze_encoder:
with torch.no_grad():
h = self.encoder(x, is_init_image=True, temporal_chunk=False)
moments = self.quant_conv(h, is_init_image=True, temporal_chunk=False)
posterior = DiagonalGaussianDistribution(moments)
global_posterior = posterior
else:
h = self.encoder(x, is_init_image=True, temporal_chunk=False)
moments = self.quant_conv(h, is_init_image=True, temporal_chunk=False)
posterior = DiagonalGaussianDistribution(moments)
global_moments = conv_gather_from_context_parallel_region(moments, dim=2, kernel_size=1)
global_posterior = DiagonalGaussianDistribution(global_moments)
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
if get_context_parallel_rank() == 0:
dec = self.decode(z, is_init_image=True).sample
else:
# Do not drop the first upsampled frame
dec = self.decode(z, is_init_image=False).sample
return global_posterior, dec
else:
# The normal training
if freeze_encoder:
with torch.no_grad():
posterior = self.encode(x, is_init_image=is_init_image,
temporal_chunk=temporal_chunk).latent_dist
else:
posterior = self.encode(x, is_init_image=is_init_image,
temporal_chunk=temporal_chunk).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, is_init_image=is_init_image, temporal_chunk=temporal_chunk).sample
return posterior, dec
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment