Unverified Commit fbe29c62 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] create modeling blocks specific to AnimateDiff (#8979)



* animatediff specific transformer model

* make style

* make fix-copies

* move blocks to unet motion model

* make style

* remove dummy object

* fix incorrectly passed param causing test failures

* rename model and output class

* fix sparsectrl imports

* remove todo comments

* remove temporal double self attn param from controlnet sparsectrl

* add deprecated versions of blocks

* apply suggestions from review

* update

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 7071b746
......@@ -32,10 +32,7 @@ from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unets.unet_2d_condition import UNet2DConditionModel
from .unets.unet_3d_blocks import (
CrossAttnDownBlockMotion,
DownBlockMotion,
)
from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -317,7 +314,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin):
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
temporal_double_self_attention=False,
)
elif down_block_type == "DownBlockMotion":
down_block = DownBlockMotion(
......@@ -334,7 +330,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block,
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_double_self_attention=False,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
)
else:
......
......@@ -27,17 +27,58 @@ from ..resnet import (
TemporalConvLayer,
Upsample2D,
)
from ..transformers.dual_transformer_2d import DualTransformer2DModel
from ..transformers.transformer_2d import Transformer2DModel
from ..transformers.transformer_temporal import (
TransformerSpatioTemporalModel,
TransformerTemporalModel,
)
from .unet_motion_model import (
CrossAttnDownBlockMotion,
CrossAttnUpBlockMotion,
DownBlockMotion,
UNetMidBlockCrossAttnMotion,
UpBlockMotion,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DownBlockMotion(DownBlockMotion):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead."
deprecate("DownBlockMotion", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead."
deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class UpBlockMotion(UpBlockMotion):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead."
deprecate("UpBlockMotion", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead."
deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead."
deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
def get_down_block(
down_block_type: str,
num_layers: int,
......@@ -64,8 +105,6 @@ def get_down_block(
) -> Union[
"DownBlock3D",
"CrossAttnDownBlock3D",
"DownBlockMotion",
"CrossAttnDownBlockMotion",
"DownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
]:
......@@ -105,49 +144,6 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
)
if down_block_type == "DownBlockMotion":
return DownBlockMotion(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif down_block_type == "CrossAttnDownBlockMotion":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
return CrossAttnDownBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif down_block_type == "DownBlockSpatioTemporal":
# added for SDV
return DownBlockSpatioTemporal(
......@@ -203,8 +199,6 @@ def get_up_block(
) -> Union[
"UpBlock3D",
"CrossAttnUpBlock3D",
"UpBlockMotion",
"CrossAttnUpBlockMotion",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
]:
......@@ -246,51 +240,6 @@ def get_up_block(
resolution_idx=resolution_idx,
dropout=dropout,
)
if up_block_type == "UpBlockMotion":
return UpBlockMotion(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif up_block_type == "CrossAttnUpBlockMotion":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
return CrossAttnUpBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif up_block_type == "UpBlockSpatioTemporal":
# added for SDV
return UpBlockSpatioTemporal(
......@@ -947,924 +896,6 @@ class UpBlock3D(nn.Module):
return hidden_states
class DownBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_double_self_attention: bool = True,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
)
# support for variable number of attention head per temporal layers
if isinstance(temporal_num_attention_heads, int):
temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
elif len(temporal_num_attention_heads) != num_layers:
raise ValueError(
f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
)
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads[i],
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads[i],
double_self_attention=temporal_double_self_attention,
)
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
num_frames: int = 1,
*args,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class CrossAttnDownBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
add_downsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_double_self_attention: bool = True,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
double_self_attention=temporal_double_self_attention,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals: Optional[torch.Tensor] = None,
):
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
output_states = ()
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class CrossAttnUpBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
)
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
blocks = zip(self.resnets, self.attentions, self.motion_modules)
for resnet, attn, motion_module in blocks:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class UpBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
upsample_size=None,
num_frames: int = 1,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class UNetMidBlockCrossAttnMotion(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
motion_modules = []
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
attention_head_dim=in_channels // temporal_num_attention_heads,
in_channels=in_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
activation_fn="geglu",
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
hidden_states = resnet(hidden_states, temb)
return hidden_states
class MidBlockTemporalDecoder(nn.Module):
def __init__(
self,
......
......@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
......@@ -20,7 +22,9 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import logging
from ...utils import BaseOutput, deprecate, is_torch_version, logging
from ...utils.torch_utils import apply_freeu
from ..attention import BasicTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
......@@ -35,24 +39,1094 @@ from ..attention_processor import (
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ..transformers.dual_transformer_2d import DualTransformer2DModel
from ..transformers.transformer_2d import Transformer2DModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_blocks import (
CrossAttnDownBlockMotion,
CrossAttnUpBlockMotion,
DownBlockMotion,
UNetMidBlockCrossAttnMotion,
UpBlockMotion,
get_down_block,
get_up_block,
)
from .unet_3d_condition import UNet3DConditionOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNetMotionOutput(BaseOutput):
"""
The output of [`UNetMotionOutput`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.Tensor
class AnimateDiffTransformer3D(nn.Module):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
for _ in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""
The [`AnimateDiffTransformer3D`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
num_frames (`int`, *optional*, defaults to 1):
The number of frames to be processed per batch. This is used to reshape the hidden states.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Returns:
torch.Tensor:
The output tensor.
"""
# 1. Input
batch_frames, channel, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
residual = hidden_states
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
hidden_states = self.proj_in(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
.permute(0, 3, 4, 1, 2)
.contiguous()
)
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
output = hidden_states + residual
return output
class DownBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
)
# support for variable number of attention head per temporal layers
if isinstance(temporal_num_attention_heads, int):
temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
elif len(temporal_num_attention_heads) != num_layers:
raise ValueError(
f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
)
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
AnimateDiffTransformer3D(
num_attention_heads=temporal_num_attention_heads[i],
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads[i],
)
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
num_frames: int = 1,
*args,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class CrossAttnDownBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
add_downsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
motion_modules.append(
AnimateDiffTransformer3D(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals: Optional[torch.Tensor] = None,
):
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
output_states = ()
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class CrossAttnUpBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
)
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
motion_modules.append(
AnimateDiffTransformer3D(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
blocks = zip(self.resnets, self.attentions, self.motion_modules)
for resnet, attn, motion_module in blocks:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class UpBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
AnimateDiffTransformer3D(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
upsample_size=None,
num_frames: int = 1,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class UNetMidBlockCrossAttnMotion(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
motion_modules = []
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
AnimateDiffTransformer3D(
num_attention_heads=temporal_num_attention_heads,
attention_head_dim=in_channels // temporal_num_attention_heads,
in_channels=in_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
activation_fn="geglu",
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class MotionModules(nn.Module):
def __init__(
self,
......@@ -79,7 +1153,7 @@ class MotionModules(nn.Module):
for i in range(layers_per_block):
self.motion_modules.append(
TransformerTemporalModel(
AnimateDiffTransformer3D(
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
norm_num_groups=norm_num_groups,
......@@ -394,26 +1468,45 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i],
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
)
if down_block_type == "CrossAttnDownBlockMotion":
down_block = CrossAttnDownBlockMotion(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
num_attention_heads=num_attention_heads[i],
cross_attention_dim=cross_attention_dim[i],
downsample_padding=downsample_padding,
add_downsample=not is_final_block,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
)
elif down_block_type == "DownBlockMotion":
down_block = DownBlockMotion(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
num_layers=layers_per_block[i],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_downsample=not is_final_block,
downsample_padding=downsample_padding,
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
)
else:
raise ValueError(
"Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
)
self.down_blocks.append(down_block)
# mid
......@@ -488,27 +1581,47 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
resolution_idx=i,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reverse_transformer_layers_per_block[i],
temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
)
if up_block_type == "CrossAttnUpBlockMotion":
up_block = CrossAttnUpBlockMotion(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
resolution_idx=i,
num_layers=reversed_layers_per_block[i] + 1,
transformer_layers_per_block=reverse_transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
num_attention_heads=reversed_num_attention_heads[i],
cross_attention_dim=reversed_cross_attention_dim[i],
add_upsample=add_upsample,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
)
elif up_block_type == "UpBlockMotion":
up_block = UpBlockMotion(
in_channels=input_channel,
prev_output_channel=prev_output_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
resolution_idx=i,
num_layers=reversed_layers_per_block[i] + 1,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_upsample=add_upsample,
temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
)
else:
raise ValueError(
"Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`"
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......@@ -958,7 +2071,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]:
r"""
The [`UNetMotionModel`] forward method.
......@@ -984,12 +2097,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain
tuple.
Returns:
[`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned,
[`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned,
otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
......@@ -1173,4 +2286,4 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
if not return_dict:
return (sample,)
return UNet3DConditionOutput(sample=sample)
return UNetMotionOutput(sample=sample)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment