# Copyright (c) Alibaba, Inc. and its affiliates. import os import torch from diffusers import DDIMScheduler, MotionAdapter from diffusers.pipelines import AnimateDiffPipeline from diffusers.utils import export_to_gif from swift import Swift, snapshot_download from swift.aigc.utils import AnimateDiffInferArguments from swift.utils import get_logger, get_main logger = get_logger() def animatediff_infer(args: AnimateDiffInferArguments) -> None: generator = torch.Generator(device='cpu') generator.manual_seed(args.seed) # Load scheduler, tokenizer and models. noise_scheduler = DDIMScheduler( num_train_timesteps=args.num_train_timesteps, beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, steps_offset=args.steps_offset, clip_sample=args.clip_sample, ) if not os.path.exists(args.model_id_or_path): pretrained_model_path = snapshot_download(args.model_id_or_path, revision=args.model_revision) else: pretrained_model_path = args.model_id_or_path motion_adapter = None if args.motion_adapter_id_or_path is not None: if not os.path.exists(args.motion_adapter_id_or_path): args.motion_adapter_id_or_path = snapshot_download( args.motion_adapter_id_or_path, revision=args.motion_adapter_revision) motion_adapter = MotionAdapter.from_pretrained(args.motion_adapter_id_or_path) if args.sft_type == 'full': motion_adapter_dir = args.ckpt_dir if args.ckpt_dir is not None else os.path.join( pretrained_model_path, 'motion_adapter') motion_adapter = MotionAdapter.from_pretrained(motion_adapter_dir) validation_pipeline = AnimateDiffPipeline.from_pretrained( pretrained_model_path, motion_adapter=motion_adapter, ).to('cuda') validation_pipeline.scheduler = noise_scheduler if not args.sft_type == 'full': model = Swift.from_pretrained(validation_pipeline.unet, args.ckpt_dir) if args.merge_lora: ckpt_dir, ckpt_name = os.path.split(args.ckpt_dir) merged_lora_path = os.path.join(ckpt_dir, f'{ckpt_name}-merged') logger.info(f'merged_lora_path: `{merged_lora_path}`') logger.info("Setting args.sft_type: 'full'") logger.info(f'Setting args.ckpt_dir: {merged_lora_path}') args.sft_type = 'full' args.ckpt_dir = merged_lora_path if os.path.exists(args.ckpt_dir) and not args.replace_if_exists: logger.warn(f'The weight directory for the merged LoRA already exists in {args.ckpt_dir}, ' 'skipping the saving process. ' 'you can pass `replace_if_exists=True` to overwrite it.') return Swift.merge_and_unload(model) validation_pipeline.unet = model.model validation_pipeline.save_pretrained(args.ckpt_dir) validation_pipeline.enable_vae_slicing() validation_pipeline.enable_model_cpu_offload() if args.eval_human: idx = 0 while True: prompt = input('<<< ') sample = validation_pipeline( prompt, negative_prompt='bad quality, worse quality', generator=generator, num_frames=args.sample_n_frames, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, ).frames[0] os.makedirs(args.output_path, exist_ok=True) logger.info(f'Output saved to: {f"{args.output_path}/output-{idx}.gif"}') export_to_gif(sample, f'{args.output_path}/output-{idx}.gif') idx += 1 else: with open(args.validation_prompts_path, 'r') as f: validation_data = f.readlines() for idx, prompt in enumerate(validation_data): sample = validation_pipeline( prompt, negative_prompt='bad quality, worse quality', generator=generator, num_frames=args.sample_n_frames, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, ).frames[0] os.makedirs(args.output_path, exist_ok=True) logger.info(f'Output saved to: {f"{args.output_path}/output-{idx}.gif"}') export_to_gif(sample, f'{args.output_path}/output-{idx}.gif') animatediff_infer_main = get_main(AnimateDiffInferArguments, animatediff_infer)