Commit 125b03ee authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix some OOM issues with split attention.

parent 41b07ff8
......@@ -229,7 +229,7 @@ def attention_split(q, k, v, heads, mask=None):
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3 if element_size == 2 else 2.5
modifier = 3
mem_required = tensor_size * modifier
steps = 1
......@@ -257,10 +257,10 @@ def attention_split(q, k, v, heads, mask=None):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment