Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3ed4a4e4
Commit
3ed4a4e4
authored
Mar 22, 2023
by
comfyanonymous
Browse files
Try again with vae tiled decoding if regular fails because of OOM.
parent
aae9fe0c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
28 deletions
+28
-28
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+1
-6
comfy/ldm/modules/diffusionmodules/model.py
comfy/ldm/modules/diffusionmodules/model.py
+1
-6
comfy/ldm/modules/sub_quadratic_attention.py
comfy/ldm/modules/sub_quadratic_attention.py
+2
-5
comfy/model_management.py
comfy/model_management.py
+5
-0
comfy/sd.py
comfy/sd.py
+19
-11
No files found.
comfy/ldm/modules/attention.py
View file @
3ed4a4e4
...
@@ -20,11 +20,6 @@ if model_management.xformers_enabled():
...
@@ -20,11 +20,6 @@ if model_management.xformers_enabled():
import
os
import
os
_ATTN_PRECISION
=
os
.
environ
.
get
(
"ATTN_PRECISION"
,
"fp32"
)
_ATTN_PRECISION
=
os
.
environ
.
get
(
"ATTN_PRECISION"
,
"fp32"
)
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
def
exists
(
val
):
def
exists
(
val
):
return
val
is
not
None
return
val
is
not
None
...
@@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module):
...
@@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module):
r1
[:,
i
:
end
]
=
einsum
(
'b i j, b j d -> b i d'
,
s2
,
v
)
r1
[:,
i
:
end
]
=
einsum
(
'b i j, b j d -> b i d'
,
s2
,
v
)
del
s2
del
s2
break
break
except
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
if
first_op_done
==
False
:
if
first_op_done
==
False
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
ipc_collect
()
torch
.
cuda
.
ipc_collect
()
...
...
comfy/ldm/modules/diffusionmodules/model.py
View file @
3ed4a4e4
...
@@ -13,11 +13,6 @@ if model_management.xformers_enabled():
...
@@ -13,11 +13,6 @@ if model_management.xformers_enabled():
import
xformers
import
xformers
import
xformers.ops
import
xformers.ops
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
This matches the implementation in Denoising Diffusion Probabilistic Models:
...
@@ -221,7 +216,7 @@ class AttnBlock(nn.Module):
...
@@ -221,7 +216,7 @@ class AttnBlock(nn.Module):
r1
[:,
:,
i
:
end
]
=
torch
.
bmm
(
v
,
s2
)
r1
[:,
:,
i
:
end
]
=
torch
.
bmm
(
v
,
s2
)
del
s2
del
s2
break
break
except
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
steps
*=
2
steps
*=
2
if
steps
>
128
:
if
steps
>
128
:
raise
e
raise
e
...
...
comfy/ldm/modules/sub_quadratic_attention.py
View file @
3ed4a4e4
...
@@ -24,10 +24,7 @@ except ImportError:
...
@@ -24,10 +24,7 @@ except ImportError:
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
List
from
typing
import
List
try
:
import
model_management
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
def
dynamic_slice
(
def
dynamic_slice
(
x
:
Tensor
,
x
:
Tensor
,
...
@@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking(
...
@@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking(
try
:
try
:
attn_probs
=
attn_scores
.
softmax
(
dim
=-
1
)
attn_probs
=
attn_scores
.
softmax
(
dim
=-
1
)
del
attn_scores
del
attn_scores
except
OOM_EXCEPTION
:
except
model_management
.
OOM_EXCEPTION
:
print
(
"ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead"
)
print
(
"ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead"
)
attn_scores
-=
attn_scores
.
max
(
dim
=-
1
,
keepdim
=
True
).
values
attn_scores
-=
attn_scores
.
max
(
dim
=-
1
,
keepdim
=
True
).
values
torch
.
exp
(
attn_scores
,
out
=
attn_scores
)
torch
.
exp
(
attn_scores
,
out
=
attn_scores
)
...
...
comfy/model_management.py
View file @
3ed4a4e4
...
@@ -31,6 +31,11 @@ try:
...
@@ -31,6 +31,11 @@ try:
except
:
except
:
pass
pass
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
if
"--disable-xformers"
in
sys
.
argv
:
if
"--disable-xformers"
in
sys
.
argv
:
XFORMERS_IS_AVAILBLE
=
False
XFORMERS_IS_AVAILBLE
=
False
else
:
else
:
...
...
comfy/sd.py
View file @
3ed4a4e4
...
@@ -383,12 +383,26 @@ class VAE:
...
@@ -383,12 +383,26 @@ class VAE:
device
=
model_management
.
get_torch_device
()
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
device
=
device
def
decode
(
self
,
samples
):
def
decode_tiled_
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
a
.
to
(
self
.
device
))
+
1.0
)
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
8
))
/
3.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
output
def
decode
(
self
,
samples_in
):
model_management
.
unload_model
()
model_management
.
unload_model
()
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
samples
=
samples
.
to
(
self
.
device
)
try
:
pixel_samples
=
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
samples
)
samples
=
samples_in
.
to
(
self
.
device
)
pixel_samples
=
torch
.
clamp
((
pixel_samples
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
pixel_samples
=
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
samples
)
pixel_samples
=
torch
.
clamp
((
pixel_samples
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
self
.
first_stage_model
=
self
.
first_stage_model
.
cpu
()
self
.
first_stage_model
=
self
.
first_stage_model
.
cpu
()
pixel_samples
=
pixel_samples
.
cpu
().
movedim
(
1
,
-
1
)
pixel_samples
=
pixel_samples
.
cpu
().
movedim
(
1
,
-
1
)
return
pixel_samples
return
pixel_samples
...
@@ -396,13 +410,7 @@ class VAE:
...
@@ -396,13 +410,7 @@ class VAE:
def
decode_tiled
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
def
decode_tiled
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
model_management
.
unload_model
()
model_management
.
unload_model
()
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
a
.
to
(
self
.
device
))
+
1.0
)
output
=
self
.
decode_tiled_
(
samples
,
tile_x
,
tile_y
,
overlap
)
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
8
))
/
3.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
self
.
first_stage_model
=
self
.
first_stage_model
.
cpu
()
self
.
first_stage_model
=
self
.
first_stage_model
.
cpu
()
return
output
.
movedim
(
1
,
-
1
)
return
output
.
movedim
(
1
,
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment