Unverified Commit a934e5bc authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Versatile Diffusion] add upcast_attention (#1605)

add upcast_attention arg
parent a643c630
......@@ -31,6 +31,7 @@ def get_down_block(
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat":
......@@ -83,6 +84,7 @@ def get_up_block(
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment