Unverified Commit e5b94b4c authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Move community AnimateDiff ControlNet to core (#8972)



* add animatediff controlnet to core

* make style; remove unused method

* fix copied from comment

* add tests

* changes to make tests work

* add utility function to load videos

* update docs

* update pipeline example

* make style

* update docs with example

* address review comments

* add latest freeinit test from #8969

* LoraLoaderMixin -> StableDiffusionLoraLoaderMixin

* fix docs

* Update src/diffusers/utils/loading_utils.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* fix: variable out of scope

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 69e72b1d
......@@ -25,6 +25,9 @@ The abstract of the paper is the following:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
| [AnimateDiffControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py) | *Controlled Video-to-Video Generation with AnimateDiff using ControlNet* |
| [AnimateDiffSparseControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py) | *Controlled Video-to-Video Generation with AnimateDiff using SparseCtrl* |
| [AnimateDiffSDXLPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py) | *Video-to-Video Generation with AnimateDiff* |
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
## Available checkpoints
......@@ -100,6 +103,83 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
</Tip>
### AnimateDiffControlNetPipeline
AnimateDiff can also be used with ControlNets ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala. With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide depth maps, the ControlNet model generates a video that'll preserve the spatial information from the depth maps. It is a more flexible and accurate way to control the video generation process.
```python
import torch
from diffusers import AnimateDiffControlNetPipeline, AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
from diffusers.utils import export_to_gif, load_video
# Additionally, you will need a preprocess videos before they can be used with the ControlNet
# HF maintains just the right package for it: `pip install controlnet_aux`
from controlnet_aux.processor import ZoeDetector
# Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file
# Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()
controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16)
# We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
"SG161222/Realistic_Vision_V5.1_noVAE",
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])
depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
conditioning_frames = []
with pipe.progress_bar(total=len(video)) as progress_bar:
for frame in video:
conditioning_frames.append(depth_detector(frame))
progress_bar.update()
prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality"
negative_prompt = "bad quality, worst quality"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=len(video),
num_inference_steps=10,
guidance_scale=2.0,
conditioning_frames=conditioning_frames,
generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_gif(video, "animatediff_controlnet.gif", fps=8)
```
Here are some sample outputs:
<table align="center">
<tr>
<th align="center">Source Video</th>
<th align="center">Output Video</th>
</tr>
<tr>
<td align="center">
raccoon playing a guitar
<br />
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" alt="racoon playing a guitar" />
</td>
<td align="center">
a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality
<br/>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-controlnet-output.gif" alt="a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality" />
</td>
</tr>
</table>
### AnimateDiffSparseControlNetPipeline
[SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.
......@@ -762,6 +842,12 @@ pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapt
- all
- __call__
## AnimateDiffControlNetPipeline
[[autodoc]] AnimateDiffControlNetPipeline
- all
- __call__
## AnimateDiffSparseControlNetPipeline
[[autodoc]] AnimateDiffSparseControlNetPipeline
......
......@@ -232,6 +232,7 @@ else:
"AmusedImg2ImgPipeline",
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffControlNetPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
......@@ -652,6 +653,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AmusedImg2ImgPipeline,
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffControlNetPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
......
......@@ -118,6 +118,7 @@ else:
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
"AnimateDiffControlNetPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
......@@ -419,6 +420,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import (
AnimateDiffControlNetPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
......
......@@ -22,6 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
_import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"]
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
......@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_animatediff import AnimateDiffPipeline
from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import (
... AnimateDiffControlNetPipeline,
... AutoencoderKL,
... ControlNetModel,
... MotionAdapter,
... LCMScheduler,
... )
>>> from diffusers.utils import export_to_gif, load_video
>>> # Additionally, you will need a preprocess videos before they can be used with the ControlNet
>>> # HF maintains just the right package for it: `pip install controlnet_aux`
>>> from controlnet_aux.processor import ZoeDetector
>>> # Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file
>>> # Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()
>>> controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16)
>>> # We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
>>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
>>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
>>> pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
... "SG161222/Realistic_Vision_V5.1_noVAE",
... motion_adapter=motion_adapter,
... controlnet=controlnet,
... vae=vae,
... ).to(device="cuda", dtype=torch.float16)
>>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
>>> pipe.load_lora_weights(
... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora"
... )
>>> pipe.set_adapters(["lcm-lora"], [0.8])
>>> depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
>>> video = load_video(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
... )
>>> conditioning_frames = []
>>> with pipe.progress_bar(total=len(video)) as progress_bar:
... for frame in video:
... conditioning_frames.append(depth_detector(frame))
... progress_bar.update()
>>> prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality"
>>> negative_prompt = "bad quality, worst quality"
>>> video = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_frames=len(video),
... num_inference_steps=10,
... guidance_scale=2.0,
... conditioning_frames=conditioning_frames,
... generator=torch.Generator().manual_seed(42),
... ).frames[0]
>>> export_to_gif(video, "animatediff_controlnet.gif", fps=8)
```
"""
class AnimateDiffControlNetPipeline(
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
):
r"""
Pipeline for text-to-video generation with ControlNet guidance.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (`CLIPTokenizer`):
A [`~transformers.CLIPTokenizer`] to tokenize text.
unet ([`UNet2DConditionModel`]):
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
motion_adapter ([`MotionAdapter`]):
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers,
feature_extractor: Optional[CLIPImageProcessor] = None,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
):
super().__init__()
if isinstance(unet, UNet2DConditionModel):
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
motion_adapter=motion_adapter,
controlnet=controlnet,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_video_processor = VideoProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
def decode_latents(self, latents, decode_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
video = []
for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
num_frames,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
video=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if not isinstance(video, list):
raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(video)}")
if len(video) != num_frames:
raise ValueError(f"Excepted image to have length {num_frames} but got {len(video)=}")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(video, list) or not isinstance(video[0], list):
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}")
if len(video[0]) != num_frames:
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}")
if any(len(img) != len(video[0]) for img in video):
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
else:
assert False
# Check `controlnet_conditioning_scale`
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False
if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start]
if not isinstance(control_guidance_end, (tuple, list)):
control_guidance_end = [control_guidance_end]
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
if isinstance(self.controlnet, MultiControlNetModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_video(
self,
video,
width,
height,
batch_size,
num_videos_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(
dtype=torch.float32
)
video = video.permute(0, 2, 1, 3, 4).flatten(0, 1)
video_batch_size = video.shape[0]
if video_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_videos_per_prompt
video = video.repeat_interleave(repeat_by, dim=0)
video = video.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
video = torch.cat([video] * 2)
return video
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
conditioning_frames: Optional[List[PipelineImageInput]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated video.
num_frames (`int`, *optional*, defaults to 16):
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
amounts to 2 seconds of video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
`(batch_size, num_channel, num_frames, height, width)`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
conditioning_frames (`List[PipelineImageInput]`, *optional*):
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple
ControlNets are specified, images must be passed as a list such that each element of the list can be
correctly batched for input to a single ControlNet.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
The ControlNet encoder tries to recognize the content of the input image even if you remove all
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
video=conditioning_frames,
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_start=control_guidance_start,
control_guidance_end=control_guidance_end,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_videos_per_prompt,
self.do_classifier_free_guidance,
)
if isinstance(controlnet, ControlNetModel):
conditioning_frames = self.prepare_video(
video=conditioning_frames,
width=width,
height=height,
batch_size=batch_size * num_videos_per_prompt * num_frames,
num_videos_per_prompt=num_videos_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(controlnet, MultiControlNetModel):
cond_prepared_videos = []
for frame_ in conditioning_frames:
prepared_video = self.prepare_video(
video=frame_,
width=width,
height=height,
batch_size=batch_size * num_videos_per_prompt * num_frames,
num_videos_per_prompt=num_videos_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
cond_prepared_videos.append(prepared_video)
conditioning_frames = cond_prepared_videos
else:
assert False
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
for free_init_iter in range(num_free_init_iters):
if self.free_init_enabled:
latents, timesteps = self._apply_free_init(
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
)
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
control_model_input = torch.transpose(control_model_input, 1, 2)
control_model_input = control_model_input.reshape(
(-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
)
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=conditioning_frames,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 9. Post processing
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents, decode_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AnimateDiffPipelineOutput(frames=video)
......@@ -93,7 +93,7 @@ from .import_utils import (
is_xformers_available,
requires_backends,
)
from .loading_utils import load_image
from .loading_utils import load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
......
......@@ -77,6 +77,21 @@ class AmusedPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class AnimateDiffControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AnimateDiffPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
......@@ -9,10 +9,7 @@ import numpy as np
import PIL.Image
import PIL.ImageOps
from .import_utils import (
BACKENDS_MAPPING,
is_opencv_available,
)
from .import_utils import BACKENDS_MAPPING, is_opencv_available
from .logging import get_logger
......
import os
from typing import Callable, Union
import tempfile
from typing import Callable, List, Optional, Union
import PIL.Image
import PIL.ImageOps
import requests
from .import_utils import BACKENDS_MAPPING, is_opencv_available
def load_image(
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
) -> PIL.Image.Image:
"""
Loads `image` to a PIL Image.
......@@ -15,7 +18,7 @@ def load_image(
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
"RGB".
......@@ -47,3 +50,73 @@ def load_image(
image = image.convert("RGB")
return image
def load_video(
video: str,
convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
) -> List[PIL.Image.Image]:
"""
Loads `video` to a list of PIL Image.
Args:
video (`str`):
A URL or Path to a video to convert to a list of PIL Image format.
convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
A conversion method to apply to the video after loading it. When set to `None` the images will be converted
to "RGB".
Returns:
`List[PIL.Image.Image]`:
The video as a list of PIL images.
"""
is_url = video.startswith("http://") or video.startswith("https://")
is_file = os.path.isfile(video)
was_tempfile_created = False
if not (is_url or is_file):
raise ValueError(
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
)
if is_url:
video_data = requests.get(video, stream=True).raw
video_path = tempfile.NamedTemporaryFile(suffix=os.path.splitext(video)[1], delete=False).name
was_tempfile_created = True
with open(video_path, "wb") as f:
f.write(video_data.read())
video = video_path
pil_images = []
if video.endswith(".gif"):
gif = PIL.Image.open(video)
try:
while True:
pil_images.append(gif.copy())
gif.seek(gif.tell() + 1)
except EOFError:
pass
else:
if is_opencv_available():
import cv2
else:
raise ImportError(BACKENDS_MAPPING["opencv"][1].format("load_video"))
video_capture = cv2.VideoCapture(video)
success, frame = video_capture.read()
while success:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_images.append(PIL.Image.fromarray(frame))
success, frame = video_capture.read()
video_capture.release()
if was_tempfile_created:
os.remove(video_path)
if convert_method is not None:
pil_images = convert_method(pil_images)
return pil_images
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import (
AnimateDiffControlNetPipeline,
AutoencoderKL,
ControlNetModel,
DDIMScheduler,
DPMSolverMultistepScheduler,
LCMScheduler,
MotionAdapter,
StableDiffusionPipeline,
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.utils import logging
from diffusers.utils.testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineFromPipeTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
class AnimateDiffControlNetPipelineFastTests(
IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
):
pipeline_class = AnimateDiffControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"conditioning_frames"})
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
def get_dummy_components(self):
cross_attention_dim = 8
block_out_channels = (8, 8)
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=block_out_channels,
layers_per_block=2,
sample_size=8,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
)
torch.manual_seed(0)
controlnet = ControlNetModel(
block_out_channels=block_out_channels,
layers_per_block=2,
in_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
cross_attention_dim=cross_attention_dim,
conditioning_embedding_out_channels=(8, 8),
norm_num_groups=1,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=block_out_channels,
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
motion_adapter = MotionAdapter(
block_out_channels=block_out_channels,
motion_layers_per_block=2,
motion_norm_num_groups=2,
motion_num_attention_heads=4,
)
components = {
"unet": unet,
"controlnet": controlnet,
"scheduler": scheduler,
"vae": vae,
"motion_adapter": motion_adapter,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"feature_extractor": None,
"image_encoder": None,
}
return components
def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 2):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
video_height = 32
video_width = 32
conditioning_frames = [Image.new("RGB", (video_width, video_height))] * num_frames
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"conditioning_frames": conditioning_frames,
"generator": generator,
"num_inference_steps": 2,
"num_frames": num_frames,
"guidance_scale": 7.5,
"output_type": "pt",
}
return inputs
def test_from_pipe_consistent_config(self):
assert self.original_pipeline_class == StableDiffusionPipeline
original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
original_kwargs = {"requires_safety_checker": False}
# create original_pipeline_class(sd)
pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
# original_pipeline_class(sd) -> pipeline_class
pipe_components = self.get_dummy_components()
pipe_additional_components = {}
for name, component in pipe_components.items():
if name not in pipe_original.components:
pipe_additional_components[name] = component
pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
# pipeline_class -> original_pipeline_class(sd)
original_pipe_additional_components = {}
for name, component in pipe_original.components.items():
if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
original_pipe_additional_components[name] = component
pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
# compare the config
original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
assert original_config_2 == original_config
def test_motion_unet_loading(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
assert isinstance(pipe.unet, UNetMotionModel)
@unittest.skip("Attention slicing is not enabled in this pipeline")
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[
0.6604,
0.4099,
0.4928,
0.5706,
0.5096,
0.5012,
0.6051,
0.5169,
0.5021,
0.4864,
0.4261,
0.5779,
0.5822,
0.4049,
0.5253,
0.6160,
0.4150,
0.5155,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_dict_tuple_outputs_equivalent(self):
expected_slice = None
if torch_device == "cpu":
expected_slice = np.array([0.6051, 0.5169, 0.5021, 0.6160, 0.4150, 0.5155])
return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
additional_params_copy_to_batched_inputs=["num_inference_steps"],
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for components in pipe.components.values():
if hasattr(components, "set_default_attn_processor"):
components.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
if name == "prompt":
len_prompt = len(value)
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
batched_inputs[name][-1] = 100 * "very long"
else:
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
output = pipe(**inputs)
output_batch = pipe(**batched_inputs)
assert output_batch[0].shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
# pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
# pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_prompt_embeds(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
inputs.pop("prompt")
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
pipe(**inputs)
def test_free_init(self):
components = self.get_dummy_components()
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
pipe.enable_free_init(
num_iters=2,
use_fast_sampling=True,
method="butterworth",
order=4,
spatial_stop_frequency=0.25,
temporal_stop_frequency=0.25,
)
inputs_enable_free_init = self.get_dummy_inputs(torch_device)
frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
pipe.disable_free_init()
inputs_disable_free_init = self.get_dummy_inputs(torch_device)
frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
self.assertGreater(
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
)
self.assertLess(
max_diff_disabled,
1e-4,
"Disabling of FreeInit should lead to results similar to the default pipeline results",
)
def test_free_init_with_schedulers(self):
components = self.get_dummy_components()
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
schedulers_to_test = [
DPMSolverMultistepScheduler.from_config(
components["scheduler"].config,
timestep_spacing="linspace",
beta_schedule="linear",
algorithm_type="dpmsolver++",
steps_offset=1,
clip_sample=False,
),
LCMScheduler.from_config(
components["scheduler"].config,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
clip_sample=False,
),
]
components.pop("scheduler")
for scheduler in schedulers_to_test:
components["scheduler"] = scheduler
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
inputs = self.get_dummy_inputs(torch_device)
frames_enable_free_init = pipe(**inputs).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
self.assertGreater(
sum_enabled,
1e1,
"Enabling of FreeInit should lead to results different from the default pipeline results",
)
def test_vae_slicing(self, video_count=2):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * video_count
inputs["conditioning_frames"] = [inputs["conditioning_frames"]] * video_count
output_1 = pipe(**inputs)
# make sure sliced vae decode yields the same result
pipe.enable_vae_slicing()
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * video_count
inputs["conditioning_frames"] = [inputs["conditioning_frames"]] * video_count
output_2 = pipe(**inputs)
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment