Unverified Commit 77bfb562 authored by Isaac's avatar Isaac Committed by GitHub
Browse files

adding required parameters while calling the get_up_block and get_down_block (#3210)



* removed unnecessary parameters from get_up_block and get_down_block functions

* adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 70ef774f
...@@ -42,6 +42,9 @@ def get_down_block( ...@@ -42,6 +42,9 @@ def get_down_block(
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False, upcast_attention=False,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat": if down_block_type == "DownBlockFlat":
...@@ -98,6 +101,9 @@ def get_up_block( ...@@ -98,6 +101,9 @@ def get_up_block(
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False, upcast_attention=False,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat": if up_block_type == "UpBlockFlat":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment