Unverified Commit 3e0d128d authored by Mathis Koroglu's avatar Mathis Koroglu Committed by GitHub
Browse files

Motion Model / Adapter versatility (#8301)

* Motion Model / Adapter versatility

- allow to use a different number of layers per block
- allow to use a different number of transformer per layers per block
- allow a different number of motion attention head per block
- use dropout argument in get_down/up_block in 3d blocks

* Motion Model added arguments renamed & refactoring

* Add test for asymmetric UNetMotionModel
parent a536e775
......@@ -58,7 +58,9 @@ def get_down_block(
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
dropout: float = 0.0,
) -> Union[
"DownBlock3D",
"CrossAttnDownBlock3D",
......@@ -79,6 +81,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
)
elif down_block_type == "CrossAttnDownBlock3D":
if cross_attention_dim is None:
......@@ -100,6 +103,7 @@ def get_down_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
)
if down_block_type == "DownBlockMotion":
return DownBlockMotion(
......@@ -115,6 +119,8 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif down_block_type == "CrossAttnDownBlockMotion":
if cross_attention_dim is None:
......@@ -139,6 +145,8 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif down_block_type == "DownBlockSpatioTemporal":
# added for SDV
......@@ -189,7 +197,8 @@ def get_up_block(
temporal_num_attention_heads: int = 8,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
dropout: float = 0.0,
) -> Union[
"UpBlock3D",
......@@ -212,6 +221,7 @@ def get_up_block(
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
......@@ -234,6 +244,7 @@ def get_up_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
if up_block_type == "UpBlockMotion":
return UpBlockMotion(
......@@ -250,6 +261,8 @@ def get_up_block(
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif up_block_type == "CrossAttnUpBlockMotion":
if cross_attention_dim is None:
......@@ -275,6 +288,8 @@ def get_up_block(
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif up_block_type == "UpBlockSpatioTemporal":
# added for SDV
......@@ -948,14 +963,31 @@ class DownBlockMotion(nn.Module):
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
temporal_num_attention_heads: int = 1,
temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
)
# support for variable number of attention head per temporal layers
if isinstance(temporal_num_attention_heads, int):
temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
elif len(temporal_num_attention_heads) != num_layers:
raise ValueError(
f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
)
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
......@@ -974,15 +1006,16 @@ class DownBlockMotion(nn.Module):
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
num_attention_heads=temporal_num_attention_heads[i],
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
attention_head_dim=out_channels // temporal_num_attention_heads[i],
)
)
......@@ -1065,7 +1098,7 @@ class CrossAttnDownBlockMotion(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -1084,6 +1117,7 @@ class CrossAttnDownBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
......@@ -1093,6 +1127,22 @@ class CrossAttnDownBlockMotion(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
......@@ -1116,7 +1166,7 @@ class CrossAttnDownBlockMotion(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......@@ -1141,6 +1191,7 @@ class CrossAttnDownBlockMotion(nn.Module):
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
......@@ -1257,7 +1308,7 @@ class CrossAttnUpBlockMotion(nn.Module):
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -1275,6 +1326,7 @@ class CrossAttnUpBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
......@@ -1284,6 +1336,22 @@ class CrossAttnUpBlockMotion(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
)
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
......@@ -1309,7 +1377,7 @@ class CrossAttnUpBlockMotion(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......@@ -1333,6 +1401,7 @@ class CrossAttnUpBlockMotion(nn.Module):
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
......@@ -1467,11 +1536,20 @@ class UpBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
......@@ -1495,6 +1573,7 @@ class UpBlockMotion(nn.Module):
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=temporal_norm_num_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
......@@ -1596,7 +1675,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -1605,13 +1684,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: float = False,
use_linear_projection: float = False,
upcast_attention: float = False,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
......@@ -1619,6 +1699,22 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# there is always at least one resnet
resnets = [
ResnetBlock2D(
......@@ -1637,14 +1733,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
attentions = []
motion_modules = []
for _ in range(num_layers):
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......@@ -1682,6 +1778,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
num_attention_heads=temporal_num_attention_heads,
attention_head_dim=in_channels // temporal_num_attention_heads,
in_channels=in_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
......
......@@ -57,7 +57,8 @@ class MotionModules(nn.Module):
self,
in_channels: int,
layers_per_block: int = 2,
num_attention_heads: int = 8,
transformer_layers_per_block: Union[int, Tuple[int]] = 8,
num_attention_heads: Union[int, Tuple[int]] = 8,
attention_bias: bool = False,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
......@@ -67,10 +68,19 @@ class MotionModules(nn.Module):
super().__init__()
self.motion_modules = nn.ModuleList([])
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
elif len(transformer_layers_per_block) != layers_per_block:
raise ValueError(
f"The number of transformer layers per block must match the number of layers per block, "
f"got {layers_per_block} and {len(transformer_layers_per_block)}"
)
for i in range(layers_per_block):
self.motion_modules.append(
TransformerTemporalModel(
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
......@@ -88,9 +98,11 @@ class MotionAdapter(ModelMixin, ConfigMixin):
def __init__(
self,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
motion_layers_per_block: int = 2,
motion_layers_per_block: Union[int, Tuple[int]] = 2,
motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1,
motion_mid_block_layers_per_block: int = 1,
motion_num_attention_heads: int = 8,
motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1,
motion_num_attention_heads: Union[int, Tuple[int]] = 8,
motion_norm_num_groups: int = 32,
motion_max_seq_length: int = 32,
use_motion_mid_block: bool = True,
......@@ -101,11 +113,15 @@ class MotionAdapter(ModelMixin, ConfigMixin):
Args:
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each UNet block.
motion_layers_per_block (`int`, *optional*, defaults to 2):
motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2):
The number of motion layers per UNet block.
motion_transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1):
The number of transformer layers to use in each motion layer in each block.
motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
The number of motion layers in the middle UNet block.
motion_num_attention_heads (`int`, *optional*, defaults to 8):
motion_transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer layers to use in each motion layer in the middle block.
motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8):
The number of heads to use in each attention layer of the motion module.
motion_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use in each group normalization layer of the motion module.
......@@ -119,6 +135,35 @@ class MotionAdapter(ModelMixin, ConfigMixin):
down_blocks = []
up_blocks = []
if isinstance(motion_layers_per_block, int):
motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels)
elif len(motion_layers_per_block) != len(block_out_channels):
raise ValueError(
f"The number of motion layers per block must match the number of blocks, "
f"got {len(block_out_channels)} and {len(motion_layers_per_block)}"
)
if isinstance(motion_transformer_layers_per_block, int):
motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels)
if isinstance(motion_transformer_layers_per_mid_block, int):
motion_transformer_layers_per_mid_block = (
motion_transformer_layers_per_mid_block,
) * motion_mid_block_layers_per_block
elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block:
raise ValueError(
f"The number of layers per mid block ({motion_mid_block_layers_per_block}) "
f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})"
)
if isinstance(motion_num_attention_heads, int):
motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels)
elif len(motion_num_attention_heads) != len(block_out_channels):
raise ValueError(
f"The length of the attention head number tuple in the motion module must match the "
f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}"
)
if conv_in_channels:
# input
self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1)
......@@ -134,9 +179,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
num_attention_heads=motion_num_attention_heads[i],
max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block,
layers_per_block=motion_layers_per_block[i],
transformer_layers_per_block=motion_transformer_layers_per_block[i],
)
)
......@@ -147,15 +193,20 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
layers_per_block=motion_mid_block_layers_per_block,
num_attention_heads=motion_num_attention_heads[-1],
max_seq_length=motion_max_seq_length,
layers_per_block=motion_mid_block_layers_per_block,
transformer_layers_per_block=motion_transformer_layers_per_mid_block,
)
else:
self.mid_block = None
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
reversed_motion_layers_per_block = list(reversed(motion_layers_per_block))
reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block))
reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
for i, channel in enumerate(reversed_block_out_channels):
output_channel = reversed_block_out_channels[i]
up_blocks.append(
......@@ -165,9 +216,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
num_attention_heads=reversed_motion_num_attention_heads[i],
max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block + 1,
layers_per_block=reversed_motion_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i],
)
)
......@@ -208,7 +260,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
"CrossAttnUpBlockMotion",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
......@@ -216,12 +268,18 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
use_linear_projection: bool = False,
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True,
motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
use_motion_mid_block: bool = True,
mid_block_layers: int = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
......@@ -264,6 +322,16 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
if (
isinstance(temporal_transformer_layers_per_block, list)
and reverse_temporal_transformer_layers_per_block is None
):
for layer_number_per_block in temporal_transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError(
"Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet."
)
# input
conv_in_kernel = 3
conv_out_kernel = 3
......@@ -304,6 +372,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if isinstance(reverse_transformer_layers_per_block, int):
reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types)
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
if isinstance(reverse_temporal_transformer_layers_per_block, int):
reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len(
down_block_types
)
if isinstance(motion_num_attention_heads, int):
motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
......@@ -326,13 +408,19 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i],
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
)
self.down_blocks.append(down_block)
# mid
if transformer_layers_per_mid_block is None:
transformer_layers_per_mid_block = (
transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
)
if use_motion_mid_block:
self.mid_block = UNetMidBlockCrossAttnMotion(
in_channels=block_out_channels[-1],
......@@ -345,9 +433,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
num_layers=mid_block_layers,
temporal_num_attention_heads=motion_num_attention_heads[-1],
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[-1],
transformer_layers_per_block=transformer_layers_per_mid_block,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block,
)
else:
......@@ -362,7 +452,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
transformer_layers_per_block=transformer_layers_per_block[-1],
num_layers=mid_block_layers,
transformer_layers_per_block=transformer_layers_per_mid_block,
)
# count how many layers upsample the images
......@@ -373,7 +464,13 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
if reverse_transformer_layers_per_block is None:
reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
if reverse_temporal_transformer_layers_per_block is None:
reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
......@@ -406,9 +503,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
dual_cross_attention=False,
resolution_idx=i,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
transformer_layers_per_block=reverse_transformer_layers_per_block[i],
temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......@@ -440,6 +538,24 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if has_motion_adapter:
motion_adapter.to(device=unet.device)
# check compatibility of number of blocks
if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]):
raise ValueError("Incompatible Motion Adapter, got different number of blocks")
# check layers compatibility for each block
if isinstance(unet.config["layers_per_block"], int):
expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"])
else:
expanded_layers_per_block = list(unet.config["layers_per_block"])
if isinstance(motion_adapter.config["motion_layers_per_block"], int):
expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len(
motion_adapter.config["block_out_channels"]
)
else:
expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"])
if expanded_layers_per_block != expanded_adapter_layers_per_block:
raise ValueError("Incompatible Motion Adapter, got different number of layers per block")
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
config = dict(unet.config)
config["_class_name"] = cls.__name__
......@@ -458,13 +574,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
up_blocks.append("CrossAttnUpBlockMotion")
else:
up_blocks.append("UpBlockMotion")
config["up_block_types"] = up_blocks
if has_motion_adapter:
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"]
config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[
"motion_transformer_layers_per_mid_block"
]
config["temporal_transformer_layers_per_block"] = motion_adapter.config[
"motion_transformer_layers_per_block"
]
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
# For PIA UNets we need to set the number input channels to 9
if motion_adapter.config["conv_in_channels"]:
......
......@@ -306,3 +306,36 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_asymmetric_motion_model(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["layers_per_block"] = (2, 3)
init_dict["transformer_layers_per_block"] = ((1, 2), (3, 4, 5))
init_dict["reverse_transformer_layers_per_block"] = ((7, 6, 7, 4), (4, 2, 2))
init_dict["temporal_transformer_layers_per_block"] = ((2, 5), (2, 3, 5))
init_dict["reverse_temporal_transformer_layers_per_block"] = ((5, 4, 3, 4), (3, 2, 2))
init_dict["num_attention_heads"] = (2, 4)
init_dict["motion_num_attention_heads"] = (4, 4)
init_dict["reverse_motion_num_attention_heads"] = (2, 2)
init_dict["use_motion_mid_block"] = True
init_dict["mid_block_layers"] = 2
init_dict["transformer_layers_per_mid_block"] = (1, 5)
init_dict["temporal_transformer_layers_per_mid_block"] = (2, 4)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment