Unverified Commit bb2c64a0 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add the new SD2 attention params to the VD text unet (#1400)

parent 05a36d5c
......@@ -28,7 +28,9 @@ def get_down_block(
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat":
......@@ -58,6 +60,9 @@ def get_down_block(
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
raise ValueError(f"{down_block_type} is not supported.")
......@@ -75,7 +80,9 @@ def get_up_block(
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat":
......@@ -105,6 +112,9 @@ def get_up_block(
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
raise ValueError(f"{up_block_type} is not supported.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment