"...resnet50_tensorflow.git" did not exist on "d991ac0ac88594517bc3055e67e395441b087b94"
Commit 798c90e1 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix pytorch 2.0 cross attention not working.

parent f9d09c26
......@@ -489,6 +489,8 @@ if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv:
if "--use-pytorch-cross-attention" in sys.argv:
print("Using pytorch cross attention")
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
CrossAttention = CrossAttentionPytorch
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
......@@ -497,6 +499,7 @@ else:
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment