Commit 773cdabf authored by comfyanonymous's avatar comfyanonymous
Browse files

Same thing but for the other places where it's used.

parent df40d4f3
......@@ -20,6 +20,11 @@ except:
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def exists(val):
return val is not None
......@@ -316,7 +321,7 @@ class CrossAttentionDoggettx(nn.Module):
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except torch.cuda.OutOfMemoryError as e:
except OOM_EXCEPTION as e:
if first_op_done == False:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
......
......@@ -16,6 +16,10 @@ except:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def get_timestep_embedding(timesteps, embedding_dim):
"""
......@@ -229,7 +233,7 @@ class AttnBlock(nn.Module):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except torch.cuda.OutOfMemoryError as e:
except OOM_EXCEPTION as e:
if first_op_done == False:
steps *= 2
if steps > 128:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment