Unverified Commit c4a3b09a authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[UNet2DConditionModel] add cross_attention_dim as an argument (#155)

add cross_attention_dim as an argument
parent 616c3a42
...@@ -28,6 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -28,6 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
act_fn="silu", act_fn="silu",
norm_num_groups=32, norm_num_groups=32,
norm_eps=1e-5, norm_eps=1e-5,
cross_attention_dim=1280,
attention_head_dim=8, attention_head_dim=8,
): ):
super().__init__() super().__init__()
...@@ -64,6 +65,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -64,6 +65,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block, add_downsample=not is_final_block,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
) )
...@@ -77,6 +79,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -77,6 +79,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
) )
...@@ -101,6 +104,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -101,6 +104,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block, add_upsample=not is_final_block,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
......
...@@ -31,6 +31,7 @@ def get_down_block( ...@@ -31,6 +31,7 @@ def get_down_block(
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, attn_num_head_channels,
cross_attention_dim=None,
downsample_padding=None, downsample_padding=None,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
...@@ -58,6 +59,8 @@ def get_down_block( ...@@ -58,6 +59,8 @@ def get_down_block(
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif down_block_type == "CrossAttnDownBlock2D": elif down_block_type == "CrossAttnDownBlock2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
return CrossAttnDownBlock2D( return CrossAttnDownBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
...@@ -67,6 +70,7 @@ def get_down_block( ...@@ -67,6 +70,7 @@ def get_down_block(
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif down_block_type == "SkipDownBlock2D": elif down_block_type == "SkipDownBlock2D":
...@@ -115,6 +119,7 @@ def get_up_block( ...@@ -115,6 +119,7 @@ def get_up_block(
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, attn_num_head_channels,
cross_attention_dim=None,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D": if up_block_type == "UpBlock2D":
...@@ -129,6 +134,8 @@ def get_up_block( ...@@ -129,6 +134,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
) )
elif up_block_type == "CrossAttnUpBlock2D": elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
return CrossAttnUpBlock2D( return CrossAttnUpBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
...@@ -138,6 +145,7 @@ def get_up_block( ...@@ -138,6 +145,7 @@ def get_up_block(
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif up_block_type == "AttnUpBlock2D": elif up_block_type == "AttnUpBlock2D":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment