Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3ed4a4e4
"doc/vscode:/vscode.git/clone" did not exist on "fb06f0faa0748293a478d0cb4d7bb0b010a26950"
Commit
3ed4a4e4
authored
Mar 22, 2023
by
comfyanonymous
Browse files
Try again with vae tiled decoding if regular fails because of OOM.
parent
aae9fe0c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
28 deletions
+28
-28
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+1
-6
comfy/ldm/modules/diffusionmodules/model.py
comfy/ldm/modules/diffusionmodules/model.py
+1
-6
comfy/ldm/modules/sub_quadratic_attention.py
comfy/ldm/modules/sub_quadratic_attention.py
+2
-5
comfy/model_management.py
comfy/model_management.py
+5
-0
comfy/sd.py
comfy/sd.py
+19
-11
No files found.
comfy/ldm/modules/attention.py
View file @
3ed4a4e4
...
...
@@ -20,11 +20,6 @@ if model_management.xformers_enabled():
import
os
_ATTN_PRECISION
=
os
.
environ
.
get
(
"ATTN_PRECISION"
,
"fp32"
)
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
def
exists
(
val
):
return
val
is
not
None
...
...
@@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module):
r1
[:,
i
:
end
]
=
einsum
(
'b i j, b j d -> b i d'
,
s2
,
v
)
del
s2
break
except
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
if
first_op_done
==
False
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
ipc_collect
()
...
...
comfy/ldm/modules/diffusionmodules/model.py
View file @
3ed4a4e4
...
...
@@ -13,11 +13,6 @@ if model_management.xformers_enabled():
import
xformers
import
xformers.ops
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
...
...
@@ -221,7 +216,7 @@ class AttnBlock(nn.Module):
r1
[:,
:,
i
:
end
]
=
torch
.
bmm
(
v
,
s2
)
del
s2
break
except
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
steps
*=
2
if
steps
>
128
:
raise
e
...
...
comfy/ldm/modules/sub_quadratic_attention.py
View file @
3ed4a4e4
...
...
@@ -24,10 +24,7 @@ except ImportError:
from
torch
import
Tensor
from
typing
import
List
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
import
model_management
def
dynamic_slice
(
x
:
Tensor
,
...
...
@@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking(
try
:
attn_probs
=
attn_scores
.
softmax
(
dim
=-
1
)
del
attn_scores
except
OOM_EXCEPTION
:
except
model_management
.
OOM_EXCEPTION
:
print
(
"ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead"
)
attn_scores
-=
attn_scores
.
max
(
dim
=-
1
,
keepdim
=
True
).
values
torch
.
exp
(
attn_scores
,
out
=
attn_scores
)
...
...
comfy/model_management.py
View file @
3ed4a4e4
...
...
@@ -31,6 +31,11 @@ try:
except
:
pass
try
:
OOM_EXCEPTION
=
torch
.
cuda
.
OutOfMemoryError
except
:
OOM_EXCEPTION
=
Exception
if
"--disable-xformers"
in
sys
.
argv
:
XFORMERS_IS_AVAILBLE
=
False
else
:
...
...
comfy/sd.py
View file @
3ed4a4e4
...
...
@@ -383,12 +383,26 @@ class VAE:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
def
decode
(
self
,
samples
):
def
decode_tiled_
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
a
.
to
(
self
.
device
))
+
1.0
)
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
8
))
/
3.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
output
def
decode
(
self
,
samples_in
):
model_management
.
unload_model
()
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
samples
=
samples
.
to
(
self
.
device
)
try
:
samples
=
samples_in
.
to
(
self
.
device
)
pixel_samples
=
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
samples
)
pixel_samples
=
torch
.
clamp
((
pixel_samples
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
self
.
first_stage_model
=
self
.
first_stage_model
.
cpu
()
pixel_samples
=
pixel_samples
.
cpu
().
movedim
(
1
,
-
1
)
return
pixel_samples
...
...
@@ -396,13 +410,7 @@ class VAE:
def
decode_tiled
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
model_management
.
unload_model
()
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
a
.
to
(
self
.
device
))
+
1.0
)
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
8
))
/
3.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
output
=
self
.
decode_tiled_
(
samples
,
tile_x
,
tile_y
,
overlap
)
self
.
first_stage_model
=
self
.
first_stage_model
.
cpu
()
return
output
.
movedim
(
1
,
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment