Commit bb064c97 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add a separate optimized_attention_masked function.

parent 7e09e889
......@@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None):
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled():
print("Using xformers cross attention")
......@@ -309,6 +308,9 @@ else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
......@@ -334,7 +336,10 @@ class CrossAttention(nn.Module):
else:
v = self.to_v(context)
out = optimized_attention(q, k, v, self.heads, mask)
if mask is None:
out = optimized_attention(q, k, v, self.heads)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask)
return self.to_out(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment