Commit 3e5ea74a authored by comfyanonymous's avatar comfyanonymous
Browse files

Make buggy xformers fall back on pytorch attention.

parent 982338b9
......@@ -278,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
)
return r1
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
except:
pass
def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.unsqueeze(3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment