Unverified Commit ae84e405 authored by Stephen's avatar Stephen Committed by GitHub
Browse files

Pass use_linear_projection parameter to mid block in UNetMotionModel (#7035)

* pass linear projection parameter to mid block

* add cond_proj_dim to motion UNet

* run style and quality checks
parent 3a661133
......@@ -217,6 +217,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
use_motion_mid_block: int = True,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
):
super().__init__()
......@@ -252,9 +253,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim
)
if encoder_hid_dim_type is None:
......@@ -306,6 +305,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
)
......@@ -321,6 +321,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
)
# count how many layers upsample the images
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment