# Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import argparse import datetime import inspect import os import random import numpy as np from PIL import Image from omegaconf import OmegaConf from collections import OrderedDict from facechain.utils import snapshot_download import torch import torch.distributed as dist from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler from tqdm import tqdm from transformers import CLIPTextModel, CLIPTokenizer from facechain_animate.magicanimate.models.unet_controlnet import UNet3DConditionModel from facechain_animate.magicanimate.models.controlnet import ControlNetModel from facechain_animate.magicanimate.models.appearance_encoder import AppearanceEncoderModel from facechain_animate.magicanimate.models.mutual_self_attention import ReferenceAttentionControl from facechain_animate.magicanimate.pipelines.pipeline_animation import AnimationPipeline from facechain_animate.magicanimate.utils.util import save_videos_grid from facechain_animate.magicanimate.utils.dist_tools import distributed_init from accelerate.utils import set_seed from facechain_animate.magicanimate.utils.videoreader import VideoReader from einops import rearrange from pathlib import Path def main(args): *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) config = OmegaConf.load(args.config) # Initialize distributed training device = torch.device(f"cuda:{args.rank}") dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} if config.savename is None: time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") savedir = f"samples/{Path(args.config).stem}-{time_str}" else: savedir = f"samples/{config.savename}" if args.dist: dist.broadcast_object_list([savedir], 0) dist.barrier() if args.rank == 0: os.makedirs(savedir, exist_ok=True) inference_config = OmegaConf.load(config.inference_config) ### >>> create animation pipeline >>> ### sd15_model_dir = snapshot_download('AI-ModelScope/stable-diffusion-v1-5') sdvae_model_dir = snapshot_download('zhuzhukeji/sd-vae-ft-mse') magicanimate_model_dir = snapshot_download('AI-ModelScope/MagicAnimate') tokenizer = CLIPTokenizer.from_pretrained(sd15_model_dir, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(sd15_model_dir, subfolder="text_encoder") unet = UNet3DConditionModel.from_pretrained_2d(sd15_model_dir, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) vae = AutoencoderKL.from_pretrained(sdvae_model_dir, subfolder="vae") appearance_encoder = AppearanceEncoderModel.from_pretrained(magicanimate_model_dir, subfolder="appearance_encoder").cuda() reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) ### Load controlnet controlnet = ControlNetModel.from_pretrained(magicanimate_model_dir, subfolder="densepose_controlnet") unet.enable_xformers_memory_efficient_attention() appearance_encoder.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() vae.to(torch.float16) unet.to(torch.float16) text_encoder.to(torch.float16) appearance_encoder.to(torch.float16) controlnet.to(torch.float16) pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), # NOTE: UniPCMultistepScheduler ) # 1. unet ckpt # 1.1 motion module motion_module_state_dict = torch.load(os.path.join(magicanimate_model_dir, 'temporal_attention/temporal_attention.ckpt'), map_location="cpu") if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict try: # extra steps for self-trained models state_dict = OrderedDict() for key in motion_module_state_dict.keys(): if key.startswith("module."): _key = key.split("module.")[-1] state_dict[_key] = motion_module_state_dict[key] else: state_dict[key] = motion_module_state_dict[key] motion_module_state_dict = state_dict del state_dict missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) assert len(unexpected) == 0 except: _tmp_ = OrderedDict() for key in motion_module_state_dict.keys(): if "motion_modules" in key: if key.startswith("unet."): _key = key.split('unet.')[-1] _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] missing, unexpected = unet.load_state_dict(_tmp_, strict=False) assert len(unexpected) == 0 del _tmp_ del motion_module_state_dict pipeline.to(device) ### <<< create validation pipeline <<< ### random_seeds = config.get("seed", [-1]) random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) text_video_dir = args.videos_dir source_img_dir = args.images_dir test_videos = [os.path.join(text_video_dir, each) for each in os.listdir(text_video_dir)] source_images = [os.path.join(source_img_dir, each) for each in os.listdir(source_img_dir)] random_seeds = random_seeds * len(source_images) if len(random_seeds) == 1 else random_seeds # input test videos (either source video/ conditions) num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps) # read size, step from yaml file sizes = [config.size] * len(test_videos) steps = [config.S] * len(test_videos) config.random_seed = [] prompt = n_prompt = "" for idx, (source_image, test_video, random_seed, size, step) in tqdm( enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), total=len(test_videos), disable=(args.rank!=0) ): samples_per_video = [] samples_per_clip = [] # manually set random seed for reproduction if random_seed != -1: torch.manual_seed(random_seed) set_seed(random_seed) else: torch.seed() config.random_seed.append(torch.initial_seed()) if test_video.endswith('.mp4'): control = VideoReader(test_video).read() if control[0].shape[0] != size: control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] if config.max_length is not None: control = control[config.offset: (config.offset+config.max_length)] control = np.array(control) if source_image.endswith(".mp4"): source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size))) else: source_image = np.array(Image.open(source_image).resize((size, size))) H, W, C = source_image.shape print(f"current seed: {torch.initial_seed()}") init_latents = None # print(f"sampling {prompt} ...") original_length = control.shape[0] if control.shape[0] % config.L > 0: control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge') generator = torch.Generator(device=torch.device("cuda:0")) generator.manual_seed(torch.initial_seed()) sample = pipeline( prompt, negative_prompt = n_prompt, num_inference_steps = config.steps, guidance_scale = config.guidance_scale, width = W, height = H, video_length = len(control), controlnet_condition = control, init_latents = init_latents, generator = generator, num_actual_inference_steps = num_actual_inference_steps, appearance_encoder = appearance_encoder, reference_control_writer = reference_control_writer, reference_control_reader = reference_control_reader, source_image = source_image, **dist_kwargs, ).videos if args.rank == 0: source_images = np.array([source_image] * original_length) source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 samples_per_video.append(source_images) control = control / 255.0 control = rearrange(control, "t h w c -> 1 c t h w") control = torch.from_numpy(control) samples_per_video.append(control[:, :, :original_length]) samples_per_video.append(sample[:, :, :original_length]) samples_per_video = torch.cat(samples_per_video) video_name = os.path.basename(test_video)[:-4] source_name = os.path.basename(source_image).split(".")[0] save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4") save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4") if config.save_individual_videos: save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4") save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4") if args.dist: dist.barrier() if args.rank == 0: OmegaConf.save(config, f"{savedir}/config.yaml") def distributed_main(device_id, args): args.rank = device_id args.device_id = device_id if torch.cuda.is_available(): torch.cuda.set_device(args.device_id) torch.cuda.init() distributed_init(args) main(args) def run(args): if args.dist: args.world_size = max(1, torch.cuda.device_count()) assert args.world_size <= torch.cuda.device_count() if args.world_size > 0 and torch.cuda.device_count() > 1: port = random.randint(10000, 20000) args.init_method = f"tcp://localhost:{port}" torch.multiprocessing.spawn( fn=distributed_main, args=(args,), nprocs=args.world_size, ) else: main(args) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--dist", action="store_true", required=False) parser.add_argument("--rank", type=int, default=0, required=False) parser.add_argument("--world_size", type=int, default=1, required=False) parser.add_argument("--videos_dir", type=str, default='facechain_animate/resources/MagicAnimate/driving/densepose/', required=False) parser.add_argument("--images_dir", type=str, default='facechain_animate/resources/MagicAnimate/source_image/', required=False) args = parser.parse_args() run(args)