Commit 77605806 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 初始提交

parent 2f260963
Pipeline #1520 failed with stages
This diff is collapsed.
This diff is collapsed.
import os
import torch
import argparse
import torchvision
from pipeline_videogen import VideoGenPipeline
from diffusers.schedulers import DDIMScheduler
from diffusers.models import AutoencoderKL
from diffusers.models import AutoencoderKLTemporalDecoder
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from models import get_models
import imageio
from PIL import Image
import numpy as np
from datasets import video_transforms
from torchvision import transforms
from einops import rearrange, repeat
from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
from copy import deepcopy
def prepare_image(path, vae, transform_video, device, dtype=torch.float16):
with open(path, 'rb') as f:
image = Image.open(f).convert('RGB')
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
image, ori_h, ori_w, crops_coords_top, crops_coords_left = transform_video(image)
image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
image = image.unsqueeze(2)
return image
def main(args):
if args.seed:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 # torch.float16
unet = get_models(args).to(device, dtype=dtype)
if args.enable_vae_temporal_decoder:
if args.use_dct:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
else:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
else:
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device) # huge
# set eval mode
unet.eval()
vae.eval()
text_encoder.eval()
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
unet=unet).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
# videogen_pipeline.enable_vae_slicing()
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
transform_video = video_transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.SDXLCenterCrop((args.image_size[0], args.image_size[1])), # center crop using shor edge, then resize
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
for i, (image, prompt) in enumerate(args.image_prompts):
if args.use_dct:
base_content = prepare_image("./animated_images/" + image, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
else:
base_content = prepare_image("./animated_images/" + image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
if args.use_dct:
# filter params
print("Using DCT!")
base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
# define filter
freq_filter = dct_low_pass_filter(dct_coefficients=base_content,
percentage=0.23)
noise = torch.randn(1, 4, 15, 40, 64).to(device)
# add noise to base_content
diffuse_timesteps = torch.full((1,),int(975))
diffuse_timesteps = diffuse_timesteps.long()
# 3d content
base_content_noise = scheduler.add_noise(
original_samples=base_content_repeat.to(device),
noise=noise,
timesteps=diffuse_timesteps.to(device))
# 3d content
latents = exchanged_mixed_dct_freq(noise=noise,
base_content=base_content_noise,
LPF_3d=freq_filter).to(dtype=torch.float16)
base_content = base_content.to(dtype=torch.float16)
videos = videogen_pipeline(prompt,
latents=latents if args.use_dct else None,
base_content=base_content,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % i + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
args = parser.parse_args()
main(OmegaConf.load(args.config))
export CUDA_VISIBLE_DEVICES=0
python pipelines/animation.py --config configs/sample.yaml
This diff is collapsed.
This diff is collapsed.
import os
import torch
import argparse
import torchvision
from pipeline_videogen import VideoGenPipeline
from pipelines.pipeline_inversion import VideoGenInversionPipeline
from diffusers.schedulers import DDIMScheduler
from diffusers.models import AutoencoderKL
from diffusers.models import AutoencoderKLTemporalDecoder
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from utils import find_model
from models import get_models
import imageio
import decord
import numpy as np
from copy import deepcopy
from PIL import Image
from datasets import video_transforms
from torchvision import transforms
from models.unet import UNet3DConditionModel
from einops import repeat
from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
def prepare_image(path, vae, transform_video, device, dtype=torch.float16):
with open(path, 'rb') as f:
image = Image.open(f).convert('RGB')
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
image, ori_h, ori_w, crops_coords_top, crops_coords_left = transform_video(image)
image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
image = image.unsqueeze(2)
return image
def separation_content_motion(video_clip):
"""
Separate content and motion in a given video.
Args:
video_clip: A given video clip, shape [B, C, F, H, W]
Return:
base_frame: Base frame, shape [B, C, 1, H, W]
motions: Motions based on base frame, shape [B, C, F-1, H, W]
"""
# Selecting the first frame from each video in the batch as the base frame
base_frame = video_clip[:, :, :1, :, :]
# Calculating the motion (difference between each frame and the base frame)
motions = video_clip[:, :, 1:, :, :] - base_frame
return base_frame, motions
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
def main(args):
# torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 # torch.float16
# unet = get_models(args).to(device, dtype=torch.float16)
# state_dict = find_model(args.ckpt)
# unet.load_state_dict(state_dict)
unet = UNet3DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet").to(device, dtype=torch.float16)
if args.enable_vae_temporal_decoder:
if args.use_dct:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
else:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
else:
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
# set eval mode
unet.eval()
vae.eval()
text_encoder.eval()
scheduler_inversion = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,)
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
# beta_end=0.017,
beta_schedule=args.beta_schedule,)
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler_inversion,
unet=unet).to(device)
videogen_pipeline_inversion = VideoGenInversionPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
unet=unet).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
# videogen_pipeline.enable_vae_slicing()
transform_video = video_transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.SDXLCenterCrop((args.image_size[0], args.image_size[1])), # center crop using shor edge, then resize
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
# video_path = './video_editing/A_man_walking_on_the_beach.mp4'
# video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4'
video_path = './video_editing/test_03.mp4'
video_reader = DecordInit()
video = video_reader(video_path)
frame_indice = np.linspace(0, 15, 16, dtype=int)
video = torch.from_numpy(video.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
video = video / 255.0
video = video * 2.0 - 1.0
latents = vae.encode(video.to(dtype=torch.float16, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor).unsqueeze(0).permute(0, 2, 1, 3, 4)
base_content, motion_latents = separation_content_motion(latents)
# image_path = "./video_editing/a_man_walking_in_the_park.png"
image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png"
if args.use_dct:
edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
else:
edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
# prompt_inversion = 'a man walking on the beach'
prompt_inversion = 'a corgi walking in the park at sunrise, oil painting style'
latents = videogen_pipeline_inversion(prompt_inversion,
latents=motion_latents,
base_content=base_content,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=1.0,
# guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
output_type="latent").video
# prompt = 'a man walking in the park'
# prompt = 'a corgi walking in the park at sunrise, oil painting style'
prompt = 'A girl is playing the guitar in her room'
if args.use_dct:
# filter params
print("Using DCT!")
edit_content_repeat = repeat(edit_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
# define filter
freq_filter = dct_low_pass_filter(dct_coefficients=edit_content,
percentage=0.23)
noise = latents.to(dtype=torch.float64)
# add noise to base_content
diffuse_timesteps = torch.full((1,),int(985))
diffuse_timesteps = diffuse_timesteps.long()
# 3d content
edit_content_noise = scheduler.add_noise(
original_samples=edit_content_repeat.to(device),
noise=noise,
timesteps=diffuse_timesteps.to(device))
# 3d content
latents = exchanged_mixed_dct_freq(noise=noise,
base_content=edit_content_noise,
LPF_3d=freq_filter).to(dtype=torch.float16)
latents = latents.to(dtype=torch.float16)
edit_content = edit_content.to(dtype=torch.float16)
videos = videogen_pipeline(prompt,
latents=latents,
base_content=edit_content,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=1.0,
# guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0
print('save path {}'.format(args.save_img_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
args = parser.parse_args()
main(OmegaConf.load(args.config))
\ No newline at end of file
export CUDA_VISIBLE_DEVICES=0
python pipelines/video_editing.py --config configs/sample.yaml
\ No newline at end of file
torch
torchvision
torchaudio
timm
diffusers
accelerate
tensorboard
einops
transformers
av
scikit-image
decord
pandas
imageio-ffmpeg
sentencepiece
beautifulsoup4
ftfy
omegaconf
torch_dct
imageio-ffmpeg
gradio==3.40.0
\ No newline at end of file
import os
import math
import torch
import logging
import subprocess
import numpy as np
import torch.distributed as dist
# from torch._six import inf
from torch import inf
from PIL import Image
from typing import Union, Iterable
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict
import torch_dct
from diffusers.utils import is_bs4_available, is_ftfy_available
import html
import re
import urllib.parse as ul
if is_bs4_available():
from bs4 import BeautifulSoup
if is_ftfy_available():
import ftfy
import torch.fft as fft
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
#################################################################################
# Testing Utils #
#################################################################################
def find_model(model_name):
"""
Finds a pre-trained model
"""
assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}'
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
if "ema" in checkpoint: # supports checkpoints from train.py
print('Using ema ckpt!')
checkpoint = checkpoint["ema"]
else:
checkpoint = checkpoint["model"]
print("Using model ckpt!")
return checkpoint
def save_video_grid(video, nrow=None):
b, t, h, w, c = video.shape
if nrow is None:
nrow = math.ceil(math.sqrt(b))
ncol = math.ceil(b / nrow)
padding = 1
video_grid = torch.zeros((t, (padding + h) * nrow + padding,
(padding + w) * ncol + padding, c), dtype=torch.uint8)
# print(video_grid.shape)
for i in range(b):
r = i // ncol
c = i % ncol
start_r = (padding + h) * r
start_c = (padding + w) * c
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
return video_grid
def save_videos_grid_tav(videos: torch.Tensor, path: str, rescale=False, nrow=None, fps=8):
from einops import rearrange
import imageio
import torchvision
b, _, _, _, _ = videos.shape
if nrow is None:
nrow = math.ceil(math.sqrt(b))
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=nrow)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
# os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
#################################################################################
# MMCV Utils #
#################################################################################
def collect_env():
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import collect_env as collect_base_env
from mmcv.utils import get_git_hash
"""Collect the information of the running environments."""
env_info = collect_base_env()
env_info['MMClassification'] = get_git_hash()[:7]
for name, val in env_info.items():
print(f'{name}: {val}')
print(torch.cuda.get_arch_list())
print(torch.version.cuda)
#################################################################################
# DCT Functions #
#################################################################################
def dct_low_pass_filter(dct_coefficients, percentage=0.3): # 2d [b c f h w]
"""
Applies a low pass filter to the given DCT coefficients.
:param dct_coefficients: 2D tensor of DCT coefficients
:param percentage: percentage of coefficients to keep (between 0 and 1)
:return: 2D tensor of DCT coefficients after applying the low pass filter
"""
# Determine the cutoff indices for both dimensions
cutoff_x = int(dct_coefficients.shape[-2] * percentage)
cutoff_y = int(dct_coefficients.shape[-1] * percentage)
# Create a mask with the same shape as the DCT coefficients
mask = torch.zeros_like(dct_coefficients)
# Set the top-left corner of the mask to 1 (the low-frequency area)
mask[:, :, :, :cutoff_x, :cutoff_y] = 1
return mask
def normalize(tensor):
"""将Tensor归一化到[0, 1]范围内。"""
min_val = tensor.min()
max_val = tensor.max()
normalized = (tensor - min_val) / (max_val - min_val)
return normalized
def denormalize(tensor, max_val_target, min_val_target):
"""将Tensor从[0, 1]范围反归一化到目标的[min_val_target, max_val_target]范围。"""
denormalized = tensor * (max_val_target - min_val_target) + min_val_target
return denormalized
def exchanged_mixed_dct_freq(noise, base_content, LPF_3d, normalized=False):
# noise dct
noise_freq = torch_dct.dct_3d(noise, 'ortho')
# frequency
HPF_3d = 1 - LPF_3d
noise_freq_high = noise_freq * HPF_3d
# base frame dct
base_content_freq = torch_dct.dct_3d(base_content, 'ortho')
# base content low frequency
base_content_freq_low = base_content_freq * LPF_3d
# mixed frequency
mixed_freq = base_content_freq_low + noise_freq_high
# idct
mixed_freq = torch_dct.idct_3d(mixed_freq, 'ortho')
return mixed_freq
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment