Unverified Commit 3e0d128d authored by Mathis Koroglu's avatar Mathis Koroglu Committed by GitHub
Browse files

Motion Model / Adapter versatility (#8301)

* Motion Model / Adapter versatility

- allow to use a different number of layers per block
- allow to use a different number of transformer per layers per block
- allow a different number of motion attention head per block
- use dropout argument in get_down/up_block in 3d blocks

* Motion Model added arguments renamed & refactoring

* Add test for asymmetric UNetMotionModel
parent a536e775
...@@ -58,7 +58,9 @@ def get_down_block( ...@@ -58,7 +58,9 @@ def get_down_block(
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8, temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32, temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
dropout: float = 0.0,
) -> Union[ ) -> Union[
"DownBlock3D", "DownBlock3D",
"CrossAttnDownBlock3D", "CrossAttnDownBlock3D",
...@@ -79,6 +81,7 @@ def get_down_block( ...@@ -79,6 +81,7 @@ def get_down_block(
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
) )
elif down_block_type == "CrossAttnDownBlock3D": elif down_block_type == "CrossAttnDownBlock3D":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -100,6 +103,7 @@ def get_down_block( ...@@ -100,6 +103,7 @@ def get_down_block(
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
) )
if down_block_type == "DownBlockMotion": if down_block_type == "DownBlockMotion":
return DownBlockMotion( return DownBlockMotion(
...@@ -115,6 +119,8 @@ def get_down_block( ...@@ -115,6 +119,8 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads, temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length, temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
) )
elif down_block_type == "CrossAttnDownBlockMotion": elif down_block_type == "CrossAttnDownBlockMotion":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -139,6 +145,8 @@ def get_down_block( ...@@ -139,6 +145,8 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads, temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length, temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
) )
elif down_block_type == "DownBlockSpatioTemporal": elif down_block_type == "DownBlockSpatioTemporal":
# added for SDV # added for SDV
...@@ -189,7 +197,8 @@ def get_up_block( ...@@ -189,7 +197,8 @@ def get_up_block(
temporal_num_attention_heads: int = 8, temporal_num_attention_heads: int = 8,
temporal_cross_attention_dim: Optional[int] = None, temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32, temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
dropout: float = 0.0, dropout: float = 0.0,
) -> Union[ ) -> Union[
"UpBlock3D", "UpBlock3D",
...@@ -212,6 +221,7 @@ def get_up_block( ...@@ -212,6 +221,7 @@ def get_up_block(
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx, resolution_idx=resolution_idx,
dropout=dropout,
) )
elif up_block_type == "CrossAttnUpBlock3D": elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -234,6 +244,7 @@ def get_up_block( ...@@ -234,6 +244,7 @@ def get_up_block(
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx, resolution_idx=resolution_idx,
dropout=dropout,
) )
if up_block_type == "UpBlockMotion": if up_block_type == "UpBlockMotion":
return UpBlockMotion( return UpBlockMotion(
...@@ -250,6 +261,8 @@ def get_up_block( ...@@ -250,6 +261,8 @@ def get_up_block(
resolution_idx=resolution_idx, resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads, temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length, temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
) )
elif up_block_type == "CrossAttnUpBlockMotion": elif up_block_type == "CrossAttnUpBlockMotion":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -275,6 +288,8 @@ def get_up_block( ...@@ -275,6 +288,8 @@ def get_up_block(
resolution_idx=resolution_idx, resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads, temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length, temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
) )
elif up_block_type == "UpBlockSpatioTemporal": elif up_block_type == "UpBlockSpatioTemporal":
# added for SDV # added for SDV
...@@ -948,14 +963,31 @@ class DownBlockMotion(nn.Module): ...@@ -948,14 +963,31 @@ class DownBlockMotion(nn.Module):
output_scale_factor: float = 1.0, output_scale_factor: float = 1.0,
add_downsample: bool = True, add_downsample: bool = True,
downsample_padding: int = 1, downsample_padding: int = 1,
temporal_num_attention_heads: int = 1, temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
temporal_cross_attention_dim: Optional[int] = None, temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32, temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
motion_modules = [] motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
)
# support for variable number of attention head per temporal layers
if isinstance(temporal_num_attention_heads, int):
temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
elif len(temporal_num_attention_heads) != num_layers:
raise ValueError(
f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
)
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
...@@ -974,15 +1006,16 @@ class DownBlockMotion(nn.Module): ...@@ -974,15 +1006,16 @@ class DownBlockMotion(nn.Module):
) )
motion_modules.append( motion_modules.append(
TransformerTemporalModel( TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads, num_attention_heads=temporal_num_attention_heads[i],
in_channels=out_channels, in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim, cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False, attention_bias=False,
activation_fn="geglu", activation_fn="geglu",
positional_embeddings="sinusoidal", positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length, num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads, attention_head_dim=out_channels // temporal_num_attention_heads[i],
) )
) )
...@@ -1065,7 +1098,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1065,7 +1098,7 @@ class CrossAttnDownBlockMotion(nn.Module):
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -1084,6 +1117,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1084,6 +1117,7 @@ class CrossAttnDownBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None, temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8, temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32, temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1093,6 +1127,22 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1093,6 +1127,22 @@ class CrossAttnDownBlockMotion(nn.Module):
self.has_cross_attention = True self.has_cross_attention = True
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
...@@ -1116,7 +1166,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1116,7 +1166,7 @@ class CrossAttnDownBlockMotion(nn.Module):
num_attention_heads, num_attention_heads,
out_channels // num_attention_heads, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -1141,6 +1191,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1141,6 +1191,7 @@ class CrossAttnDownBlockMotion(nn.Module):
TransformerTemporalModel( TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads, num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim, cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False, attention_bias=False,
...@@ -1257,7 +1308,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1257,7 +1308,7 @@ class CrossAttnUpBlockMotion(nn.Module):
resolution_idx: Optional[int] = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -1275,6 +1326,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1275,6 +1326,7 @@ class CrossAttnUpBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None, temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8, temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32, temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1284,6 +1336,22 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1284,6 +1336,22 @@ class CrossAttnUpBlockMotion(nn.Module):
self.has_cross_attention = True self.has_cross_attention = True
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
)
for i in range(num_layers): for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
...@@ -1309,7 +1377,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1309,7 +1377,7 @@ class CrossAttnUpBlockMotion(nn.Module):
num_attention_heads, num_attention_heads,
out_channels // num_attention_heads, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -1333,6 +1401,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1333,6 +1401,7 @@ class CrossAttnUpBlockMotion(nn.Module):
TransformerTemporalModel( TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads, num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim, cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False, attention_bias=False,
...@@ -1467,11 +1536,20 @@ class UpBlockMotion(nn.Module): ...@@ -1467,11 +1536,20 @@ class UpBlockMotion(nn.Module):
temporal_cross_attention_dim: Optional[int] = None, temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8, temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32, temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
motion_modules = [] motion_modules = []
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
for i in range(num_layers): for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
...@@ -1495,6 +1573,7 @@ class UpBlockMotion(nn.Module): ...@@ -1495,6 +1573,7 @@ class UpBlockMotion(nn.Module):
TransformerTemporalModel( TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads, num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=temporal_norm_num_groups, norm_num_groups=temporal_norm_num_groups,
cross_attention_dim=temporal_cross_attention_dim, cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False, attention_bias=False,
...@@ -1596,7 +1675,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1596,7 +1675,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -1605,13 +1684,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1605,13 +1684,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
num_attention_heads: int = 1, num_attention_heads: int = 1,
output_scale_factor: float = 1.0, output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
dual_cross_attention: float = False, dual_cross_attention: bool = False,
use_linear_projection: float = False, use_linear_projection: bool = False,
upcast_attention: float = False, upcast_attention: bool = False,
attention_type: str = "default", attention_type: str = "default",
temporal_num_attention_heads: int = 1, temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None, temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32, temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
): ):
super().__init__() super().__init__()
...@@ -1619,6 +1699,22 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1619,6 +1699,22 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlock2D( ResnetBlock2D(
...@@ -1637,14 +1733,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1637,14 +1733,14 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
attentions = [] attentions = []
motion_modules = [] motion_modules = []
for _ in range(num_layers): for i in range(num_layers):
if not dual_cross_attention: if not dual_cross_attention:
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
num_attention_heads, num_attention_heads,
in_channels // num_attention_heads, in_channels // num_attention_heads,
in_channels=in_channels, in_channels=in_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -1682,6 +1778,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1682,6 +1778,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
num_attention_heads=temporal_num_attention_heads, num_attention_heads=temporal_num_attention_heads,
attention_head_dim=in_channels // temporal_num_attention_heads, attention_head_dim=in_channels // temporal_num_attention_heads,
in_channels=in_channels, in_channels=in_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim, cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False, attention_bias=False,
......
...@@ -57,7 +57,8 @@ class MotionModules(nn.Module): ...@@ -57,7 +57,8 @@ class MotionModules(nn.Module):
self, self,
in_channels: int, in_channels: int,
layers_per_block: int = 2, layers_per_block: int = 2,
num_attention_heads: int = 8, transformer_layers_per_block: Union[int, Tuple[int]] = 8,
num_attention_heads: Union[int, Tuple[int]] = 8,
attention_bias: bool = False, attention_bias: bool = False,
cross_attention_dim: Optional[int] = None, cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu", activation_fn: str = "geglu",
...@@ -67,10 +68,19 @@ class MotionModules(nn.Module): ...@@ -67,10 +68,19 @@ class MotionModules(nn.Module):
super().__init__() super().__init__()
self.motion_modules = nn.ModuleList([]) self.motion_modules = nn.ModuleList([])
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
elif len(transformer_layers_per_block) != layers_per_block:
raise ValueError(
f"The number of transformer layers per block must match the number of layers per block, "
f"got {layers_per_block} and {len(transformer_layers_per_block)}"
)
for i in range(layers_per_block): for i in range(layers_per_block):
self.motion_modules.append( self.motion_modules.append(
TransformerTemporalModel( TransformerTemporalModel(
in_channels=in_channels, in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn, activation_fn=activation_fn,
...@@ -88,9 +98,11 @@ class MotionAdapter(ModelMixin, ConfigMixin): ...@@ -88,9 +98,11 @@ class MotionAdapter(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
motion_layers_per_block: int = 2, motion_layers_per_block: Union[int, Tuple[int]] = 2,
motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1,
motion_mid_block_layers_per_block: int = 1, motion_mid_block_layers_per_block: int = 1,
motion_num_attention_heads: int = 8, motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1,
motion_num_attention_heads: Union[int, Tuple[int]] = 8,
motion_norm_num_groups: int = 32, motion_norm_num_groups: int = 32,
motion_max_seq_length: int = 32, motion_max_seq_length: int = 32,
use_motion_mid_block: bool = True, use_motion_mid_block: bool = True,
...@@ -101,11 +113,15 @@ class MotionAdapter(ModelMixin, ConfigMixin): ...@@ -101,11 +113,15 @@ class MotionAdapter(ModelMixin, ConfigMixin):
Args: Args:
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each UNet block. The tuple of output channels for each UNet block.
motion_layers_per_block (`int`, *optional*, defaults to 2): motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2):
The number of motion layers per UNet block. The number of motion layers per UNet block.
motion_transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1):
The number of transformer layers to use in each motion layer in each block.
motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
The number of motion layers in the middle UNet block. The number of motion layers in the middle UNet block.
motion_num_attention_heads (`int`, *optional*, defaults to 8): motion_transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer layers to use in each motion layer in the middle block.
motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8):
The number of heads to use in each attention layer of the motion module. The number of heads to use in each attention layer of the motion module.
motion_norm_num_groups (`int`, *optional*, defaults to 32): motion_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use in each group normalization layer of the motion module. The number of groups to use in each group normalization layer of the motion module.
...@@ -119,6 +135,35 @@ class MotionAdapter(ModelMixin, ConfigMixin): ...@@ -119,6 +135,35 @@ class MotionAdapter(ModelMixin, ConfigMixin):
down_blocks = [] down_blocks = []
up_blocks = [] up_blocks = []
if isinstance(motion_layers_per_block, int):
motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels)
elif len(motion_layers_per_block) != len(block_out_channels):
raise ValueError(
f"The number of motion layers per block must match the number of blocks, "
f"got {len(block_out_channels)} and {len(motion_layers_per_block)}"
)
if isinstance(motion_transformer_layers_per_block, int):
motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels)
if isinstance(motion_transformer_layers_per_mid_block, int):
motion_transformer_layers_per_mid_block = (
motion_transformer_layers_per_mid_block,
) * motion_mid_block_layers_per_block
elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block:
raise ValueError(
f"The number of layers per mid block ({motion_mid_block_layers_per_block}) "
f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})"
)
if isinstance(motion_num_attention_heads, int):
motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels)
elif len(motion_num_attention_heads) != len(block_out_channels):
raise ValueError(
f"The length of the attention head number tuple in the motion module must match the "
f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}"
)
if conv_in_channels: if conv_in_channels:
# input # input
self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1) self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1)
...@@ -134,9 +179,10 @@ class MotionAdapter(ModelMixin, ConfigMixin): ...@@ -134,9 +179,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None, cross_attention_dim=None,
activation_fn="geglu", activation_fn="geglu",
attention_bias=False, attention_bias=False,
num_attention_heads=motion_num_attention_heads, num_attention_heads=motion_num_attention_heads[i],
max_seq_length=motion_max_seq_length, max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block, layers_per_block=motion_layers_per_block[i],
transformer_layers_per_block=motion_transformer_layers_per_block[i],
) )
) )
...@@ -147,15 +193,20 @@ class MotionAdapter(ModelMixin, ConfigMixin): ...@@ -147,15 +193,20 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None, cross_attention_dim=None,
activation_fn="geglu", activation_fn="geglu",
attention_bias=False, attention_bias=False,
num_attention_heads=motion_num_attention_heads, num_attention_heads=motion_num_attention_heads[-1],
layers_per_block=motion_mid_block_layers_per_block,
max_seq_length=motion_max_seq_length, max_seq_length=motion_max_seq_length,
layers_per_block=motion_mid_block_layers_per_block,
transformer_layers_per_block=motion_transformer_layers_per_mid_block,
) )
else: else:
self.mid_block = None self.mid_block = None
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
reversed_motion_layers_per_block = list(reversed(motion_layers_per_block))
reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block))
reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
for i, channel in enumerate(reversed_block_out_channels): for i, channel in enumerate(reversed_block_out_channels):
output_channel = reversed_block_out_channels[i] output_channel = reversed_block_out_channels[i]
up_blocks.append( up_blocks.append(
...@@ -165,9 +216,10 @@ class MotionAdapter(ModelMixin, ConfigMixin): ...@@ -165,9 +216,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=None, cross_attention_dim=None,
activation_fn="geglu", activation_fn="geglu",
attention_bias=False, attention_bias=False,
num_attention_heads=motion_num_attention_heads, num_attention_heads=reversed_motion_num_attention_heads[i],
max_seq_length=motion_max_seq_length, max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block + 1, layers_per_block=reversed_motion_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i],
) )
) )
...@@ -208,7 +260,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -208,7 +260,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
"CrossAttnUpBlockMotion", "CrossAttnUpBlockMotion",
), ),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
act_fn: str = "silu", act_fn: str = "silu",
...@@ -216,12 +268,18 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -216,12 +268,18 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
use_linear_projection: bool = False, use_linear_projection: bool = False,
num_attention_heads: Union[int, Tuple[int, ...]] = 8, num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32, motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8, motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
use_motion_mid_block: int = True, reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
use_motion_mid_block: bool = True,
mid_block_layers: int = 1,
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None, encoder_hid_dim_type: Optional[str] = None,
addition_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None,
...@@ -264,6 +322,16 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -264,6 +322,16 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(layer_number_per_block, list): if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
if (
isinstance(temporal_transformer_layers_per_block, list)
and reverse_temporal_transformer_layers_per_block is None
):
for layer_number_per_block in temporal_transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError(
"Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet."
)
# input # input
conv_in_kernel = 3 conv_in_kernel = 3
conv_out_kernel = 3 conv_out_kernel = 3
...@@ -304,6 +372,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -304,6 +372,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(transformer_layers_per_block, int): if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if isinstance(reverse_transformer_layers_per_block, int):
reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types)
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
if isinstance(reverse_temporal_transformer_layers_per_block, int):
reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len(
down_block_types
)
if isinstance(motion_num_attention_heads, int):
motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
# down # down
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types): for i, down_block_type in enumerate(down_block_types):
...@@ -326,13 +408,19 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -326,13 +408,19 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
dual_cross_attention=False, dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads, temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length, temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i],
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
# mid # mid
if transformer_layers_per_mid_block is None:
transformer_layers_per_mid_block = (
transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
)
if use_motion_mid_block: if use_motion_mid_block:
self.mid_block = UNetMidBlockCrossAttnMotion( self.mid_block = UNetMidBlockCrossAttnMotion(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
...@@ -345,9 +433,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -345,9 +433,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads, num_layers=mid_block_layers,
temporal_num_attention_heads=motion_num_attention_heads[-1],
temporal_max_seq_length=motion_max_seq_length, temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[-1], transformer_layers_per_block=transformer_layers_per_mid_block,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block,
) )
else: else:
...@@ -362,7 +452,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -362,7 +452,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
transformer_layers_per_block=transformer_layers_per_block[-1], num_layers=mid_block_layers,
transformer_layers_per_block=transformer_layers_per_mid_block,
) )
# count how many layers upsample the images # count how many layers upsample the images
...@@ -373,7 +464,13 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -373,7 +464,13 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block)) reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
if reverse_transformer_layers_per_block is None:
reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
if reverse_temporal_transformer_layers_per_block is None:
reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
...@@ -406,9 +503,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -406,9 +503,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
dual_cross_attention=False, dual_cross_attention=False,
resolution_idx=i, resolution_idx=i,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads, temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length, temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reversed_transformer_layers_per_block[i], transformer_layers_per_block=reverse_transformer_layers_per_block[i],
temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -440,6 +538,24 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -440,6 +538,24 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if has_motion_adapter: if has_motion_adapter:
motion_adapter.to(device=unet.device) motion_adapter.to(device=unet.device)
# check compatibility of number of blocks
if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]):
raise ValueError("Incompatible Motion Adapter, got different number of blocks")
# check layers compatibility for each block
if isinstance(unet.config["layers_per_block"], int):
expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"])
else:
expanded_layers_per_block = list(unet.config["layers_per_block"])
if isinstance(motion_adapter.config["motion_layers_per_block"], int):
expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len(
motion_adapter.config["block_out_channels"]
)
else:
expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"])
if expanded_layers_per_block != expanded_adapter_layers_per_block:
raise ValueError("Incompatible Motion Adapter, got different number of layers per block")
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
config = dict(unet.config) config = dict(unet.config)
config["_class_name"] = cls.__name__ config["_class_name"] = cls.__name__
...@@ -458,13 +574,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -458,13 +574,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
up_blocks.append("CrossAttnUpBlockMotion") up_blocks.append("CrossAttnUpBlockMotion")
else: else:
up_blocks.append("UpBlockMotion") up_blocks.append("UpBlockMotion")
config["up_block_types"] = up_blocks config["up_block_types"] = up_blocks
if has_motion_adapter: if has_motion_adapter:
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"]
config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[
"motion_transformer_layers_per_mid_block"
]
config["temporal_transformer_layers_per_block"] = motion_adapter.config[
"motion_transformer_layers_per_block"
]
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
# For PIA UNets we need to set the number input channels to 9 # For PIA UNets we need to set the number input channels to 9
if motion_adapter.config["conv_in_channels"]: if motion_adapter.config["conv_in_channels"]:
......
...@@ -306,3 +306,36 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase) ...@@ -306,3 +306,36 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_asymmetric_motion_model(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["layers_per_block"] = (2, 3)
init_dict["transformer_layers_per_block"] = ((1, 2), (3, 4, 5))
init_dict["reverse_transformer_layers_per_block"] = ((7, 6, 7, 4), (4, 2, 2))
init_dict["temporal_transformer_layers_per_block"] = ((2, 5), (2, 3, 5))
init_dict["reverse_temporal_transformer_layers_per_block"] = ((5, 4, 3, 4), (3, 2, 2))
init_dict["num_attention_heads"] = (2, 4)
init_dict["motion_num_attention_heads"] = (4, 4)
init_dict["reverse_motion_num_attention_heads"] = (2, 2)
init_dict["use_motion_mid_block"] = True
init_dict["mid_block_layers"] = 2
init_dict["transformer_layers_per_mid_block"] = (1, 5)
init_dict["temporal_transformer_layers_per_mid_block"] = (2, 4)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment