Commit 1daccf36 authored by comfyanonymous's avatar comfyanonymous
Browse files

Run softmax in place if it OOMs.

parent 0d8ad938
......@@ -146,8 +146,17 @@ def _get_attention_scores_no_kv_chunking(
alpha=scale,
beta=0,
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except torch.cuda.OutOfMemoryError:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
attn_scores /= summed
attn_probs = attn_scores
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment