Commit df40d4f3 authored by comfyanonymous's avatar comfyanonymous
Browse files

torch.cuda.OutOfMemoryError is not present on older pytorch versions.

parent 1d9ec62c
...@@ -19,6 +19,11 @@ from typing import Optional, NamedTuple, Protocol, List ...@@ -19,6 +19,11 @@ from typing import Optional, NamedTuple, Protocol, List
from torch import Tensor from torch import Tensor
from typing import List from typing import List
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def dynamic_slice( def dynamic_slice(
x: Tensor, x: Tensor,
starts: List[int], starts: List[int],
...@@ -151,7 +156,7 @@ def _get_attention_scores_no_kv_chunking( ...@@ -151,7 +156,7 @@ def _get_attention_scores_no_kv_chunking(
try: try:
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
except torch.cuda.OutOfMemoryError: except OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True) summed = torch.sum(attn_scores, dim=-1, keepdim=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment