Unverified Commit 818f7607 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[Pipeline] AnimateDiff SDXL (#6721)



* update conversion script to handle motion adapter sdxl checkpoint

* add animatediff xl

* handle addition_embed_type

* fix output

* update

* add imports

* make fix-copies

* add decode latents

* update docstrings

* add animatediff sdxl to docs

* remove unnecessary lines

* update example

* add test

* revert conv_in conv_out kernel param

* remove unused param addition_embed_type_num_heads

* latest IPAdapter impl

* make fix-copies

* fix return

* add IPAdapterTesterMixin to tests

* fix return

* revert based on suggestion

* add freeinit

* fix test_to_dtype test

* use StableDiffusionMixin instead of different helper methods

* fix progress bar iterations

* apply suggestions from review

* hardcode flip_sin_to_cos and freq_shift

* make fix-copies

* fix ip adapter implementation

* fix last failing test

* make style

* Update docs/source/en/api/pipelines/animatediff.md
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* remove todo

* fix doc-builder errors

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent f29b9348
...@@ -101,6 +101,53 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you ...@@ -101,6 +101,53 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
</Tip> </Tip>
### AnimateDiffSDXLPipeline
AnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available.
```python
import torch
from diffusers.models import MotionAdapter
from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16)
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
)
pipe = AnimateDiffSDXLPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
output = pipe(
prompt="a panda surfing in the ocean, realistic, high quality",
negative_prompt="low quality, worst quality",
num_inference_steps=20,
guidance_scale=8,
width=1024,
height=1024,
num_frames=16,
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
### AnimateDiffVideoToVideoPipeline ### AnimateDiffVideoToVideoPipeline
AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities. AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
...@@ -522,6 +569,12 @@ export_to_gif(frames, "animatelcm-motion-lora.gif") ...@@ -522,6 +569,12 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
- all - all
- __call__ - __call__
## AnimateDiffSDXLPipeline
[[autodoc]] AnimateDiffSDXLPipeline
- all
- __call__
## AnimateDiffVideoToVideoPipeline ## AnimateDiffVideoToVideoPipeline
[[autodoc]] AnimateDiffVideoToVideoPipeline [[autodoc]] AnimateDiffVideoToVideoPipeline
......
...@@ -31,6 +31,7 @@ def get_args(): ...@@ -31,6 +31,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--use_motion_mid_block", action="store_true") parser.add_argument("--use_motion_mid_block", action="store_true")
parser.add_argument("--motion_max_seq_length", type=int, default=32) parser.add_argument("--motion_max_seq_length", type=int, default=32)
parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
parser.add_argument("--save_fp16", action="store_true") parser.add_argument("--save_fp16", action="store_true")
return parser.parse_args() return parser.parse_args()
...@@ -49,11 +50,13 @@ if __name__ == "__main__": ...@@ -49,11 +50,13 @@ if __name__ == "__main__":
conv_state_dict = convert_motion_module(state_dict) conv_state_dict = convert_motion_module(state_dict)
adapter = MotionAdapter( adapter = MotionAdapter(
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length block_out_channels=args.block_out_channels,
use_motion_mid_block=args.use_motion_mid_block,
motion_max_seq_length=args.motion_max_seq_length,
) )
# skip loading position embeddings # skip loading position embeddings
adapter.load_state_dict(conv_state_dict, strict=False) adapter.load_state_dict(conv_state_dict, strict=False)
adapter.save_pretrained(args.output_path) adapter.save_pretrained(args.output_path)
if args.save_fp16: if args.save_fp16:
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16") adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")
...@@ -216,6 +216,7 @@ else: ...@@ -216,6 +216,7 @@ else:
"AmusedInpaintPipeline", "AmusedInpaintPipeline",
"AmusedPipeline", "AmusedPipeline",
"AnimateDiffPipeline", "AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffVideoToVideoPipeline", "AnimateDiffVideoToVideoPipeline",
"AudioLDM2Pipeline", "AudioLDM2Pipeline",
"AudioLDM2ProjectionModel", "AudioLDM2ProjectionModel",
...@@ -595,6 +596,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -595,6 +596,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AmusedInpaintPipeline, AmusedInpaintPipeline,
AmusedPipeline, AmusedPipeline,
AnimateDiffPipeline, AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffVideoToVideoPipeline, AnimateDiffVideoToVideoPipeline,
AudioLDM2Pipeline, AudioLDM2Pipeline,
AudioLDM2ProjectionModel, AudioLDM2ProjectionModel,
......
...@@ -121,6 +121,7 @@ def get_down_block( ...@@ -121,6 +121,7 @@ def get_down_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
return CrossAttnDownBlockMotion( return CrossAttnDownBlockMotion(
num_layers=num_layers, num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -255,6 +256,7 @@ def get_up_block( ...@@ -255,6 +256,7 @@ def get_up_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
return CrossAttnUpBlockMotion( return CrossAttnUpBlockMotion(
num_layers=num_layers, num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
......
...@@ -211,6 +211,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -211,6 +211,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
norm_num_groups: int = 32, norm_num_groups: int = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
use_linear_projection: bool = False, use_linear_projection: bool = False,
num_attention_heads: Union[int, Tuple[int, ...]] = 8, num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32, motion_max_seq_length: int = 32,
...@@ -218,6 +220,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -218,6 +220,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
use_motion_mid_block: int = True, use_motion_mid_block: int = True,
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None, encoder_hid_dim_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
): ):
super().__init__() super().__init__()
...@@ -240,6 +245,21 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -240,6 +245,21 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input # input
conv_in_kernel = 3 conv_in_kernel = 3
conv_out_kernel = 3 conv_out_kernel = 3
...@@ -260,6 +280,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -260,6 +280,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if encoder_hid_dim_type is None: if encoder_hid_dim_type is None:
self.encoder_hid_proj = None self.encoder_hid_proj = None
if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
# class embedding # class embedding
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
...@@ -267,6 +291,15 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -267,6 +291,15 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(num_attention_heads, int): if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types) num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# down # down
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types): for i, down_block_type in enumerate(down_block_types):
...@@ -276,7 +309,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -276,7 +309,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
down_block = get_down_block( down_block = get_down_block(
down_block_type, down_block_type,
num_layers=layers_per_block, num_layers=layers_per_block[i],
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
...@@ -284,13 +317,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -284,13 +317,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i], num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
dual_cross_attention=False, dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads, temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length, temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i],
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -302,13 +336,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -302,13 +336,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1], num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads, temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length, temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[-1],
) )
else: else:
...@@ -318,11 +353,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -318,11 +353,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1], num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
transformer_layers_per_block=transformer_layers_per_block[-1],
) )
# count how many layers upsample the images # count how many layers upsample the images
...@@ -331,6 +367,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -331,6 +367,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
...@@ -349,7 +388,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -349,7 +388,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
up_block = get_up_block( up_block = get_up_block(
up_block_type, up_block_type,
num_layers=layers_per_block + 1, num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
...@@ -358,13 +397,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -358,13 +397,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i], num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False, dual_cross_attention=False,
resolution_idx=i, resolution_idx=i,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads, temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length, temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -835,6 +875,28 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -835,6 +875,28 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0) emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
......
...@@ -114,6 +114,7 @@ else: ...@@ -114,6 +114,7 @@ else:
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [ _import_structure["animatediff"] = [
"AnimateDiffPipeline", "AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffVideoToVideoPipeline", "AnimateDiffVideoToVideoPipeline",
] ]
_import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"]
...@@ -367,7 +368,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -367,7 +368,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_torch_and_transformers_objects import * from ..utils.dummy_torch_and_transformers_objects import *
else: else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline
from .audioldm import AudioLDMPipeline from .audioldm import AudioLDMPipeline
from .audioldm2 import ( from .audioldm2 import (
AudioLDM2Pipeline, AudioLDM2Pipeline,
......
...@@ -22,6 +22,7 @@ except OptionalDependencyNotAvailable: ...@@ -22,6 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"] _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"] _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .pipeline_animatediff import AnimateDiffPipeline from .pipeline_animatediff import AnimateDiffPipeline
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
......
...@@ -92,6 +92,21 @@ class AnimateDiffPipeline(metaclass=DummyObject): ...@@ -92,6 +92,21 @@ class AnimateDiffPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class AnimateDiffSDXLPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject): class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
import diffusers
from diffusers import (
AnimateDiffSDXLPipeline,
AutoencoderKL,
DDIMScheduler,
MotionAdapter,
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
SDXLOptionalComponentsTesterMixin,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
class AnimateDiffPipelineSDXLFastTests(
IPAdapterTesterMixin,
SDFunctionTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = AnimateDiffSDXLPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64, 128),
layers_per_block=2,
time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4, 8),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2, 4),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
norm_num_groups=1,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
motion_adapter = MotionAdapter(
block_out_channels=(32, 64, 128),
motion_layers_per_block=2,
motion_norm_num_groups=2,
motion_num_attention_heads=4,
use_motion_mid_block=False,
)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"motion_adapter": motion_adapter,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_motion_unet_loading(self):
components = self.get_dummy_components()
pipe = AnimateDiffSDXLPipeline(**components)
assert isinstance(pipe.unet, UNetMotionModel)
@unittest.skip("Attention slicing is not enabled in this pipeline")
def test_attention_slicing_forward_pass(self):
pass
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
additional_params_copy_to_batched_inputs=["num_inference_steps"],
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for components in pipe.components.values():
if hasattr(components, "set_default_attn_processor"):
components.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
if name == "prompt":
len_prompt = len(value)
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
batched_inputs[name][-1] = 100 * "very long"
else:
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
output = pipe(**inputs)
output_batch = pipe(**batched_inputs)
assert output_batch[0].shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
# pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
# pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_prompt_embeds(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt)
pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_without_offload = pipe(**inputs).frames[0]
output_without_offload = (
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
)
pipe.enable_xformers_memory_efficient_attention()
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = pipe(**inputs).frames[0]
output_with_offload = (
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
)
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment