Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3ed4a4e4
Commit
3ed4a4e4
authored
Mar 22, 2023
by
comfyanonymous
Browse files
Try again with vae tiled decoding if regular fails because of OOM.
parent
aae9fe0c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
28 deletions
+28
-28
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+1
-6
comfy/ldm/modules/diffusionmodules/model.py
comfy/ldm/modules/diffusionmodules/model.py
+1
-6
comfy/ldm/modules/sub_quadratic_attention.py
comfy/ldm/modules/sub_quadratic_attention.py
+2
-5
comfy/model_management.py
comfy/model_management.py
+5
-0
comfy/sd.py
comfy/sd.py
+19
-11
No files found.
comfy/ldm/modules/attention.py
View file @
3ed4a4e4
...
...
@@ -20,11 +20,6 @@ if model_management.xformers_enabled():
import
os
_ATTN_PRECISION
=
os
.
environ
.
get
(
"ATTN_PRECISION"
,
"fp32"
)
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
def
exists
(
val
):
return
val
is
not
None
...
...
@@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module):
r1
[:,
i
:
end
]
=
einsum
(
'b i j, b j d -> b i d'
,
s2
,
v
)
del
s2
break
except
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
if
first_op_done
==
False
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
ipc_collect
()
...
...
comfy/ldm/modules/diffusionmodules/model.py
View file @
3ed4a4e4
...
...
@@ -13,11 +13,6 @@ if model_management.xformers_enabled():
import
xformers
import
xformers.ops
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
...
...
@@ -221,7 +216,7 @@ class AttnBlock(nn.Module):
r1
[:,
:,
i
:
end
]
=
torch
.
bmm
(
v
,
s2
)
del
s2
break
except
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
steps
*=
2
if
steps
>
128
:
raise
e
...
...
comfy/ldm/modules/sub_quadratic_attention.py
View file @
3ed4a4e4
...
...
@@ -24,10 +24,7 @@ except ImportError:
from
torch
import
Tensor
from
typing
import
List
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
import
model_management
def
dynamic_slice
(
x
:
Tensor
,
...
...
@@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking(
try
:
attn_probs
=
attn_scores
.
softmax
(
dim
=-
1
)
del
attn_scores
except
OOM_EXCEPTION
:
except
model_management
.
OOM_EXCEPTION
:
print
(
"ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead"
)
attn_scores
-=
attn_scores
.
max
(
dim
=-
1
,
keepdim
=
True
).
values
torch
.
exp
(
attn_scores
,
out
=
attn_scores
)
...
...
comfy/model_management.py
View file @
3ed4a4e4
...
...
@@ -31,6 +31,11 @@ try:
except
:
pass
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
if
"--disable-xformers"
in
sys
.
argv
:
XFORMERS_IS_AVAILBLE
=
False
else
:
...
...
comfy/sd.py
View file @
3ed4a4e4
...
...
@@ -383,12 +383,26 @@ class VAE:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
def
decode
(
self
,
samples
):
def
decode_tiled_
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
a
.
to
(
self
.
device
))
+
1.0
)
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
8
))
/
3.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
output
def
decode
(
self
,
samples_in
):
model_management
.
unload_model
()
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
samples
=
samples
.
to
(
self
.
device
)
try
:
samples
=
samples_in
.
to
(
self
.
device
)
pixel_samples
=
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
samples
)
pixel_samples
=
torch
.
clamp
((
pixel_samples
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
self
.
first_stage_model
=
self
.
first_stage_model
.
cpu
()
pixel_samples
=
pixel_samples
.
cpu
().
movedim
(
1
,
-
1
)
return
pixel_samples
...
...
@@ -396,13 +410,7 @@ class VAE:
def
decode_tiled
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
model_management
.
unload_model
()
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
a
.
to
(
self
.
device
))
+
1.0
)
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
8
))
/
3.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
output
=
self
.
decode_tiled_
(
samples
,
tile_x
,
tile_y
,
overlap
)
self
.
first_stage_model
=
self
.
first_stage_model
.
cpu
()
return
output
.
movedim
(
1
,
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment