Unverified Commit bb2c64a0 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add the new SD2 attention params to the VD text unet (#1400)

parent 05a36d5c
...@@ -28,7 +28,9 @@ def get_down_block( ...@@ -28,7 +28,9 @@ def get_down_block(
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
downsample_padding=None, downsample_padding=None,
dual_cross_attention=None, dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat": if down_block_type == "DownBlockFlat":
...@@ -58,6 +60,9 @@ def get_down_block( ...@@ -58,6 +60,9 @@ def get_down_block(
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
raise ValueError(f"{down_block_type} is not supported.") raise ValueError(f"{down_block_type} is not supported.")
...@@ -75,7 +80,9 @@ def get_up_block( ...@@ -75,7 +80,9 @@ def get_up_block(
attn_num_head_channels, attn_num_head_channels,
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
dual_cross_attention=None, dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat": if up_block_type == "UpBlockFlat":
...@@ -105,6 +112,9 @@ def get_up_block( ...@@ -105,6 +112,9 @@ def get_up_block(
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
raise ValueError(f"{up_block_type} is not supported.") raise ValueError(f"{up_block_type} is not supported.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment