"references/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3e4e353d5d0c9f48c169477f3d03c9fad8f36df7"
Unverified Commit a934e5bc authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Versatile Diffusion] add upcast_attention (#1605)

add upcast_attention arg
parent a643c630
...@@ -31,6 +31,7 @@ def get_down_block( ...@@ -31,6 +31,7 @@ def get_down_block(
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat": if down_block_type == "DownBlockFlat":
...@@ -83,6 +84,7 @@ def get_up_block( ...@@ -83,6 +84,7 @@ def get_up_block(
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat": if up_block_type == "UpBlockFlat":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment