"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c413353e8e4b7f7652c877f4ade69f7e6926a430"
Unverified Commit bb2c64a0 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add the new SD2 attention params to the VD text unet (#1400)

parent 05a36d5c
...@@ -28,7 +28,9 @@ def get_down_block( ...@@ -28,7 +28,9 @@ def get_down_block(
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
downsample_padding=None, downsample_padding=None,
dual_cross_attention=None, dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat": if down_block_type == "DownBlockFlat":
...@@ -58,6 +60,9 @@ def get_down_block( ...@@ -58,6 +60,9 @@ def get_down_block(
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
raise ValueError(f"{down_block_type} is not supported.") raise ValueError(f"{down_block_type} is not supported.")
...@@ -75,7 +80,9 @@ def get_up_block( ...@@ -75,7 +80,9 @@ def get_up_block(
attn_num_head_channels, attn_num_head_channels,
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
dual_cross_attention=None, dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat": if up_block_type == "UpBlockFlat":
...@@ -105,6 +112,9 @@ def get_up_block( ...@@ -105,6 +112,9 @@ def get_up_block(
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
raise ValueError(f"{up_block_type} is not supported.") raise ValueError(f"{up_block_type} is not supported.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment