Commit 83d969e3 authored by comfyanonymous's avatar comfyanonymous
Browse files

Disable xformers when tracing model.

parent 1900e511
...@@ -313,8 +313,18 @@ except: ...@@ -313,8 +313,18 @@ except:
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
disabled_xformers = False
if BROKEN_XFORMERS: if BROKEN_XFORMERS:
if b * heads > 65535: if b * heads > 65535:
disabled_xformers = True
if not disabled_xformers:
if torch.jit.is_tracing() or torch.jit.is_scripting():
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask) return attention_pytorch(q, k, v, heads, mask)
q, k, v = map( q, k, v = map(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment