"official/vision/beta/projects/basnet/train.py" did not exist on "bab70e6b59129ed6c64f1dd44514eff9ea942317"
Commit b96ae489 authored by mashun1's avatar mashun1
Browse files

magic-animate

parents
Pipeline #674 canceled with stages
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Adapted from https://github.com/guoyww/AnimateDiff
import os
import imageio
import numpy as np
import torch
import torchvision
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25):
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
def save_images_grid(images: torch.Tensor, path: str):
assert images.shape[2] == 1 # no time dimension
images = images.squeeze(2)
grid = torchvision.utils.make_grid(images)
grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8)
os.makedirs(os.path.dirname(path), exist_ok=True)
Image.fromarray(grid).save(path)
# DDIM Inversion
@torch.no_grad()
def init_prompt(prompt, pipeline):
uncond_input = pipeline.tokenizer(
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
return_tensors="pt"
)
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
text_input = pipeline.tokenizer(
[prompt],
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
context = torch.cat([uncond_embeddings, text_embeddings])
return context
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
timestep, next_timestep = min(
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
return next_sample
def get_noise_pred_single(latents, t, context, unet):
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
return noise_pred
@torch.no_grad()
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
context = init_prompt(prompt, pipeline)
uncond_embeddings, cond_embeddings = context.chunk(2)
all_latent = [latent]
latent = latent.clone().detach()
for i in tqdm(range(num_inv_steps)):
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
latent = next_step(noise_pred, t, latent, ddim_scheduler)
all_latent.append(latent)
return all_latent
@torch.no_grad()
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
return ddim_latents
def video2images(path, step=4, length=16, start=0):
reader = imageio.get_reader(path)
frames = []
for frame in reader:
frames.append(np.array(frame))
frames = frames[start::step][:length]
return frames
def images2video(video, path, fps=8):
imageio.mimsave(path, video, fps=fps)
return
tensor_interpolation = None
def get_tensor_interpolation_method():
return tensor_interpolation
def set_tensor_interpolation_method(is_slerp):
global tensor_interpolation
tensor_interpolation = slerp if is_slerp else linear
def linear(v1, v2, t):
return (1.0 - t) * v1 + t * v2
def slerp(
v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
) -> torch.Tensor:
u0 = v0 / v0.norm()
u1 = v1 / v1.norm()
dot = (u0 * u1).sum()
if dot.abs() > DOT_THRESHOLD:
#logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
return (1.0 - t) * v0 + t * v1
omega = dot.acos()
return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
\ No newline at end of file
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2022 ByteDance and/or its affiliates.
#
# Copyright (2022) PV3D Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import av, gc
import torch
import warnings
import numpy as np
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20
# remove warnings
av.logging.set_level(av.logging.ERROR)
class VideoReader():
"""
Simple wrapper around PyAV that exposes a few useful functions for
dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries.
Acknowledgement: Codes are borrowed from Bruno Korbar
"""
def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False):
"""
Arguments:
video_path (str): path or byte of the video to be loaded
"""
self.container = av.open(video)
self.num_frames = num_frames
self.bi_frame = bi_frame
self.resampler = None
if audio_resample_rate is not None:
self.resampler = av.AudioResampler(rate=audio_resample_rate)
if self.container.streams.video:
# enable multi-threaded video decoding
if decode_lossy:
warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning)
self.container.streams.video[0].thread_type = 'AUTO'
self.video_stream = self.container.streams.video[0]
else:
self.video_stream = None
self.fps = self._get_video_frame_rate()
def seek(self, pts, backward=True, any_frame=False):
stream = self.video_stream
self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream)
def _occasional_gc(self):
# there are a lot of reference cycles in PyAV, so need to manually call
# the garbage collector from time to time
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect()
def _read_video(self, offset):
self._occasional_gc()
pts = self.container.duration * offset
time_ = pts / float(av.time_base)
self.container.seek(int(pts))
video_frames = []
count = 0
for _, frame in enumerate(self._iter_frames()):
if frame.pts * frame.time_base >= time_:
video_frames.append(frame)
if count >= self.num_frames - 1:
break
count += 1
return video_frames
def _iter_frames(self):
for packet in self.container.demux(self.video_stream):
for frame in packet.decode():
yield frame
def _compute_video_stats(self):
if self.video_stream is None or self.container is None:
return 0
num_of_frames = self.container.streams.video[0].frames
if num_of_frames == 0:
num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base)
self.seek(0, backward=False)
count = 0
time_base = 512
for p in self.container.decode(video=0):
count = count + 1
if count == 1:
start_pts = p.pts
elif count == 2:
time_base = p.pts - start_pts
break
return start_pts, time_base, num_of_frames
def _get_video_frame_rate(self):
return float(self.container.streams.video[0].guessed_rate)
def sample(self, debug=False):
if self.container is None:
raise RuntimeError('video stream not found')
sample = dict()
_, _, total_num_frames = self._compute_video_stats()
offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item()
video_frames = self._read_video(offset/total_num_frames)
video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
sample["frames"] = video_frames
sample["frame_idx"] = [offset]
if self.bi_frame:
frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)]
frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)]
frames.sort()
video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]])
Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)]
sample["frames"] = video_frames
sample["real_t"] = torch.tensor(Ts, dtype=torch.float32)
sample["frame_idx"] = [offset+min(frames), offset+max(frames)]
return sample
return sample
def read_frames(self, frame_indices):
self.num_frames = frame_indices[1] - frame_indices[0]
video_frames = self._read_video(frame_indices[0]/self.get_num_frames())
video_frames = np.array([
np.uint8(video_frames[0].to_rgb().to_ndarray()),
np.uint8(video_frames[-1].to_rgb().to_ndarray())
])
return video_frames
def read(self):
video_frames = self._read_video(0)
video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
return video_frames
def get_num_frames(self):
_, _, total_num_frames = self._compute_video_stats()
return total_num_frames
\ No newline at end of file
# 模型唯一标识
modelCode = 486
# 模型名称
modelName = magic-animate_pytorch
# 模型描述
modelDescription = magic-animate可以使图像中人物按照给定动作动起来。
# 应用场景
appScenario = 推理,AIGC,媒体,科研,教育
# 框架类型
frameType = pytorch
python3 -m magicanimate.pipelines.animation --config configs/prompts/animation.yaml
python3 -m magicanimate.pipelines.animation --config configs/prompts/animation.yaml --dist
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment