Commit 04777561 authored by comfyanonymous's avatar comfyanonymous
Browse files

Lower the chances of an OOM.

parent 853e96ad
...@@ -76,7 +76,8 @@ def _summarize_chunk( ...@@ -76,7 +76,8 @@ def _summarize_chunk(
) )
max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach() max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score) torch.exp(attn_weights - max_score, out=attn_weights)
exp_weights = attn_weights
exp_values = torch.bmm(exp_weights, value) exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1) max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment