Commit 1a4edd19 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix overflow issue with inplace softmax.

parent 509c7dfc
......@@ -158,6 +158,7 @@ def _get_attention_scores_no_kv_chunking(
del attn_scores
except OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
attn_scores /= summed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment