Commit 509c7dfc authored by comfyanonymous's avatar comfyanonymous
Browse files

Use real softmax in split op to fix issue with some images.

parent 7e1e193f
......@@ -215,32 +215,24 @@ class AttnBlock(nn.Module):
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
first_op_done = False
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
first_op_done = True
torch.exp(s1, out=s1)
summed = torch.sum(s1, dim=2, keepdim=True)
s1 /= summed
s2 = s1.permute(0,2,1)
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except OOM_EXCEPTION as e:
if first_op_done == False:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
else:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
h_ = r1.reshape(b,c,h,w)
del r1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment