Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4ef1479d
Commit
4ef1479d
authored
Jun 22, 2024
by
comfyanonymous
Browse files
Multi dimension tiled scale function and tiled VAE audio encoding fallback.
parent
887a6341
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
42 deletions
+52
-42
comfy/sd.py
comfy/sd.py
+11
-20
comfy/utils.py
comfy/utils.py
+41
-22
No files found.
comfy/sd.py
View file @
4ef1479d
...
...
@@ -298,25 +298,9 @@ class VAE:
/
3.0
)
return
output
def
decode_tiled_1d
(
self
,
samples
,
tile_x
=
128
,
overlap
=
64
):
output
=
torch
.
zeros
((
samples
.
shape
[
0
],
self
.
output_channels
)
+
tuple
(
map
(
lambda
a
:
a
*
self
.
upscale_ratio
,
samples
.
shape
[
2
:])),
device
=
self
.
output_device
)
output_mult
=
torch
.
zeros
((
samples
.
shape
[
0
],
self
.
output_channels
)
+
tuple
(
map
(
lambda
a
:
a
*
self
.
upscale_ratio
,
samples
.
shape
[
2
:])),
device
=
self
.
output_device
)
for
j
in
range
(
samples
.
shape
[
0
]):
for
i
in
range
(
0
,
samples
.
shape
[
-
1
],
tile_x
-
overlap
):
f
=
i
t
=
i
+
tile_x
o
=
output
[
j
:
j
+
1
,:,
f
*
self
.
upscale_ratio
:
t
*
self
.
upscale_ratio
]
m
=
torch
.
ones_like
(
o
)
l
=
m
.
shape
[
-
1
]
for
x
in
range
(
overlap
):
c
=
((
x
+
1
)
/
overlap
)
m
[:,:,
x
:
x
+
1
]
*=
c
m
[:,:,
l
-
x
-
1
:
l
-
x
]
*=
c
o
+=
self
.
first_stage_model
.
decode
(
samples
[
j
:
j
+
1
,:,
f
:
t
].
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
().
to
(
self
.
output_device
)
*
m
output_mult
[
j
:
j
+
1
,:,
f
*
self
.
upscale_ratio
:
t
*
self
.
upscale_ratio
]
+=
m
return
output
/
output_mult
def
decode_tiled_1d
(
self
,
samples
,
tile_x
=
128
,
overlap
=
32
):
decode_fn
=
lambda
a
:
self
.
first_stage_model
.
decode
(
a
.
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
()
return
comfy
.
utils
.
tiled_scale_multidim
(
samples
,
decode_fn
,
tile
=
(
tile_x
,),
overlap
=
overlap
,
upscale_amount
=
self
.
upscale_ratio
,
out_channels
=
self
.
output_channels
,
output_device
=
self
.
output_device
)
def
encode_tiled_
(
self
,
pixel_samples
,
tile_x
=
512
,
tile_y
=
512
,
overlap
=
64
):
steps
=
pixel_samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
,
tile_y
,
overlap
)
...
...
@@ -331,6 +315,10 @@ class VAE:
samples
/=
3.0
return
samples
def
encode_tiled_1d
(
self
,
samples
,
tile_x
=
128
*
2048
,
overlap
=
32
*
2048
):
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
((
self
.
process_input
(
a
)).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
()
return
comfy
.
utils
.
tiled_scale_multidim
(
samples
,
encode_fn
,
tile
=
(
tile_x
,),
overlap
=
overlap
,
upscale_amount
=
(
1
/
self
.
downscale_ratio
),
out_channels
=
self
.
latent_channels
,
output_device
=
self
.
output_device
)
def
decode
(
self
,
samples_in
):
try
:
memory_used
=
self
.
memory_used_decode
(
samples_in
.
shape
,
self
.
vae_dtype
)
...
...
@@ -374,6 +362,9 @@ class VAE:
except
model_management
.
OOM_EXCEPTION
as
e
:
logging
.
warning
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
if
len
(
pixel_samples
.
shape
)
==
3
:
samples
=
self
.
encode_tiled_1d
(
pixel_samples
)
else
:
samples
=
self
.
encode_tiled_
(
pixel_samples
)
return
samples
...
...
comfy/utils.py
View file @
4ef1479d
...
...
@@ -6,6 +6,7 @@ import safetensors.torch
import
numpy
as
np
from
PIL
import
Image
import
logging
import
itertools
def
load_torch_file
(
ckpt
,
safe_load
=
False
,
device
=
None
):
if
device
is
None
:
...
...
@@ -506,34 +507,52 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return
math
.
ceil
((
height
/
(
tile_y
-
overlap
)))
*
math
.
ceil
((
width
/
(
tile_x
-
overlap
)))
@
torch
.
inference_mode
()
def
tiled_scale
(
samples
,
function
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
8
,
upscale_amount
=
4
,
out_channels
=
3
,
output_device
=
"cpu"
,
pbar
=
None
):
output
=
torch
.
empty
((
samples
.
shape
[
0
],
out_channels
,
round
(
samples
.
shape
[
2
]
*
upscale_amount
),
round
(
samples
.
shape
[
3
]
*
upscale_amount
)),
device
=
output_device
)
def
tiled_scale_multidim
(
samples
,
function
,
tile
=
(
64
,
64
),
overlap
=
8
,
upscale_amount
=
4
,
out_channels
=
3
,
output_device
=
"cpu"
,
pbar
=
None
):
dims
=
len
(
tile
)
output
=
torch
.
empty
([
samples
.
shape
[
0
],
out_channels
]
+
list
(
map
(
lambda
a
:
round
(
a
*
upscale_amount
),
samples
.
shape
[
2
:])),
device
=
output_device
)
for
b
in
range
(
samples
.
shape
[
0
]):
s
=
samples
[
b
:
b
+
1
]
out
=
torch
.
zeros
((
s
.
shape
[
0
],
out_channels
,
round
(
s
.
shape
[
2
]
*
upscale_amount
),
round
(
s
.
shape
[
3
]
*
upscale_amount
)),
device
=
output_device
)
out_div
=
torch
.
zeros
((
s
.
shape
[
0
],
out_channels
,
round
(
s
.
shape
[
2
]
*
upscale_amount
),
round
(
s
.
shape
[
3
]
*
upscale_amount
)),
device
=
output_device
)
for
y
in
range
(
0
,
s
.
shape
[
2
],
tile_y
-
overlap
):
for
x
in
range
(
0
,
s
.
shape
[
3
],
tile_x
-
overlap
):
x
=
max
(
0
,
min
(
s
.
shape
[
-
1
]
-
overlap
,
x
))
y
=
max
(
0
,
min
(
s
.
shape
[
-
2
]
-
overlap
,
y
))
s_in
=
s
[:,:,
y
:
y
+
tile_y
,
x
:
x
+
tile_x
]
out
=
torch
.
zeros
([
s
.
shape
[
0
],
out_channels
]
+
list
(
map
(
lambda
a
:
round
(
a
*
upscale_amount
),
s
.
shape
[
2
:])),
device
=
output_device
)
out_div
=
torch
.
zeros
([
s
.
shape
[
0
],
out_channels
]
+
list
(
map
(
lambda
a
:
round
(
a
*
upscale_amount
),
s
.
shape
[
2
:])),
device
=
output_device
)
for
it
in
itertools
.
product
(
*
map
(
lambda
a
:
range
(
0
,
a
[
0
],
a
[
1
]
-
overlap
),
zip
(
s
.
shape
[
2
:],
tile
))):
s_in
=
s
upscaled
=
[]
for
d
in
range
(
dims
):
pos
=
max
(
0
,
min
(
s
.
shape
[
d
+
2
]
-
overlap
,
it
[
d
]))
l
=
min
(
tile
[
d
],
s
.
shape
[
d
+
2
]
-
pos
)
s_in
=
s_in
.
narrow
(
d
+
2
,
pos
,
l
)
upscaled
.
append
(
round
(
pos
*
upscale_amount
))
ps
=
function
(
s_in
).
to
(
output_device
)
mask
=
torch
.
ones_like
(
ps
)
feather
=
round
(
overlap
*
upscale_amount
)
for
t
in
range
(
feather
):
mask
[:,:,
t
:
1
+
t
,:]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
mask
[:,:,
mask
.
shape
[
2
]
-
1
-
t
:
mask
.
shape
[
2
]
-
t
,:]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
mask
[:,:,:,
t
:
1
+
t
]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
mask
[:,:,:,
mask
.
shape
[
3
]
-
1
-
t
:
mask
.
shape
[
3
]
-
t
]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
out
[:,:,
round
(
y
*
upscale_amount
):
round
((
y
+
tile_y
)
*
upscale_amount
),
round
(
x
*
upscale_amount
):
round
((
x
+
tile_x
)
*
upscale_amount
)]
+=
ps
*
mask
out_div
[:,:,
round
(
y
*
upscale_amount
):
round
((
y
+
tile_y
)
*
upscale_amount
),
round
(
x
*
upscale_amount
):
round
((
x
+
tile_x
)
*
upscale_amount
)]
+=
mask
for
d
in
range
(
2
,
dims
+
2
):
m
=
mask
.
narrow
(
d
,
t
,
1
)
m
*=
((
1.0
/
feather
)
*
(
t
+
1
))
m
=
mask
.
narrow
(
d
,
mask
.
shape
[
d
]
-
1
-
t
,
1
)
m
*=
((
1.0
/
feather
)
*
(
t
+
1
))
o
=
out
o_d
=
out_div
for
d
in
range
(
dims
):
o
=
o
.
narrow
(
d
+
2
,
upscaled
[
d
],
mask
.
shape
[
d
+
2
])
o_d
=
o_d
.
narrow
(
d
+
2
,
upscaled
[
d
],
mask
.
shape
[
d
+
2
])
o
+=
ps
*
mask
o_d
+=
mask
if
pbar
is
not
None
:
pbar
.
update
(
1
)
output
[
b
:
b
+
1
]
=
out
/
out_div
return
output
def
tiled_scale
(
samples
,
function
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
8
,
upscale_amount
=
4
,
out_channels
=
3
,
output_device
=
"cpu"
,
pbar
=
None
):
return
tiled_scale_multidim
(
samples
,
function
,
(
tile_y
,
tile_x
),
overlap
,
upscale_amount
,
out_channels
,
output_device
,
pbar
)
PROGRESS_BAR_ENABLED
=
True
def
set_progress_bar_enabled
(
enabled
):
global
PROGRESS_BAR_ENABLED
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment