Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4ef1479d
"test/vscode:/vscode.git/clone" did not exist on "7c45b8b4bb05abdbb089b7287d0fca890e911840"
Commit
4ef1479d
authored
Jun 22, 2024
by
comfyanonymous
Browse files
Multi dimension tiled scale function and tiled VAE audio encoding fallback.
parent
887a6341
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
42 deletions
+52
-42
comfy/sd.py
comfy/sd.py
+11
-20
comfy/utils.py
comfy/utils.py
+41
-22
No files found.
comfy/sd.py
View file @
4ef1479d
...
...
@@ -298,25 +298,9 @@ class VAE:
/
3.0
)
return
output
def
decode_tiled_1d
(
self
,
samples
,
tile_x
=
128
,
overlap
=
64
):
output
=
torch
.
zeros
((
samples
.
shape
[
0
],
self
.
output_channels
)
+
tuple
(
map
(
lambda
a
:
a
*
self
.
upscale_ratio
,
samples
.
shape
[
2
:])),
device
=
self
.
output_device
)
output_mult
=
torch
.
zeros
((
samples
.
shape
[
0
],
self
.
output_channels
)
+
tuple
(
map
(
lambda
a
:
a
*
self
.
upscale_ratio
,
samples
.
shape
[
2
:])),
device
=
self
.
output_device
)
for
j
in
range
(
samples
.
shape
[
0
]):
for
i
in
range
(
0
,
samples
.
shape
[
-
1
],
tile_x
-
overlap
):
f
=
i
t
=
i
+
tile_x
o
=
output
[
j
:
j
+
1
,:,
f
*
self
.
upscale_ratio
:
t
*
self
.
upscale_ratio
]
m
=
torch
.
ones_like
(
o
)
l
=
m
.
shape
[
-
1
]
for
x
in
range
(
overlap
):
c
=
((
x
+
1
)
/
overlap
)
m
[:,:,
x
:
x
+
1
]
*=
c
m
[:,:,
l
-
x
-
1
:
l
-
x
]
*=
c
o
+=
self
.
first_stage_model
.
decode
(
samples
[
j
:
j
+
1
,:,
f
:
t
].
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
().
to
(
self
.
output_device
)
*
m
output_mult
[
j
:
j
+
1
,:,
f
*
self
.
upscale_ratio
:
t
*
self
.
upscale_ratio
]
+=
m
return
output
/
output_mult
def
decode_tiled_1d
(
self
,
samples
,
tile_x
=
128
,
overlap
=
32
):
decode_fn
=
lambda
a
:
self
.
first_stage_model
.
decode
(
a
.
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
()
return
comfy
.
utils
.
tiled_scale_multidim
(
samples
,
decode_fn
,
tile
=
(
tile_x
,),
overlap
=
overlap
,
upscale_amount
=
self
.
upscale_ratio
,
out_channels
=
self
.
output_channels
,
output_device
=
self
.
output_device
)
def
encode_tiled_
(
self
,
pixel_samples
,
tile_x
=
512
,
tile_y
=
512
,
overlap
=
64
):
steps
=
pixel_samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
,
tile_y
,
overlap
)
...
...
@@ -331,6 +315,10 @@ class VAE:
samples
/=
3.0
return
samples
def
encode_tiled_1d
(
self
,
samples
,
tile_x
=
128
*
2048
,
overlap
=
32
*
2048
):
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
((
self
.
process_input
(
a
)).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
()
return
comfy
.
utils
.
tiled_scale_multidim
(
samples
,
encode_fn
,
tile
=
(
tile_x
,),
overlap
=
overlap
,
upscale_amount
=
(
1
/
self
.
downscale_ratio
),
out_channels
=
self
.
latent_channels
,
output_device
=
self
.
output_device
)
def
decode
(
self
,
samples_in
):
try
:
memory_used
=
self
.
memory_used_decode
(
samples_in
.
shape
,
self
.
vae_dtype
)
...
...
@@ -374,6 +362,9 @@ class VAE:
except
model_management
.
OOM_EXCEPTION
as
e
:
logging
.
warning
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
if
len
(
pixel_samples
.
shape
)
==
3
:
samples
=
self
.
encode_tiled_1d
(
pixel_samples
)
else
:
samples
=
self
.
encode_tiled_
(
pixel_samples
)
return
samples
...
...
comfy/utils.py
View file @
4ef1479d
...
...
@@ -6,6 +6,7 @@ import safetensors.torch
import
numpy
as
np
from
PIL
import
Image
import
logging
import
itertools
def
load_torch_file
(
ckpt
,
safe_load
=
False
,
device
=
None
):
if
device
is
None
:
...
...
@@ -506,34 +507,52 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return
math
.
ceil
((
height
/
(
tile_y
-
overlap
)))
*
math
.
ceil
((
width
/
(
tile_x
-
overlap
)))
@
torch
.
inference_mode
()
def
tiled_scale
(
samples
,
function
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
8
,
upscale_amount
=
4
,
out_channels
=
3
,
output_device
=
"cpu"
,
pbar
=
None
):
output
=
torch
.
empty
((
samples
.
shape
[
0
],
out_channels
,
round
(
samples
.
shape
[
2
]
*
upscale_amount
),
round
(
samples
.
shape
[
3
]
*
upscale_amount
)),
device
=
output_device
)
def
tiled_scale_multidim
(
samples
,
function
,
tile
=
(
64
,
64
),
overlap
=
8
,
upscale_amount
=
4
,
out_channels
=
3
,
output_device
=
"cpu"
,
pbar
=
None
):
dims
=
len
(
tile
)
output
=
torch
.
empty
([
samples
.
shape
[
0
],
out_channels
]
+
list
(
map
(
lambda
a
:
round
(
a
*
upscale_amount
),
samples
.
shape
[
2
:])),
device
=
output_device
)
for
b
in
range
(
samples
.
shape
[
0
]):
s
=
samples
[
b
:
b
+
1
]
out
=
torch
.
zeros
((
s
.
shape
[
0
],
out_channels
,
round
(
s
.
shape
[
2
]
*
upscale_amount
),
round
(
s
.
shape
[
3
]
*
upscale_amount
)),
device
=
output_device
)
out_div
=
torch
.
zeros
((
s
.
shape
[
0
],
out_channels
,
round
(
s
.
shape
[
2
]
*
upscale_amount
),
round
(
s
.
shape
[
3
]
*
upscale_amount
)),
device
=
output_device
)
for
y
in
range
(
0
,
s
.
shape
[
2
],
tile_y
-
overlap
):
for
x
in
range
(
0
,
s
.
shape
[
3
],
tile_x
-
overlap
):
x
=
max
(
0
,
min
(
s
.
shape
[
-
1
]
-
overlap
,
x
))
y
=
max
(
0
,
min
(
s
.
shape
[
-
2
]
-
overlap
,
y
))
s_in
=
s
[:,:,
y
:
y
+
tile_y
,
x
:
x
+
tile_x
]
out
=
torch
.
zeros
([
s
.
shape
[
0
],
out_channels
]
+
list
(
map
(
lambda
a
:
round
(
a
*
upscale_amount
),
s
.
shape
[
2
:])),
device
=
output_device
)
out_div
=
torch
.
zeros
([
s
.
shape
[
0
],
out_channels
]
+
list
(
map
(
lambda
a
:
round
(
a
*
upscale_amount
),
s
.
shape
[
2
:])),
device
=
output_device
)
for
it
in
itertools
.
product
(
*
map
(
lambda
a
:
range
(
0
,
a
[
0
],
a
[
1
]
-
overlap
),
zip
(
s
.
shape
[
2
:],
tile
))):
s_in
=
s
upscaled
=
[]
for
d
in
range
(
dims
):
pos
=
max
(
0
,
min
(
s
.
shape
[
d
+
2
]
-
overlap
,
it
[
d
]))
l
=
min
(
tile
[
d
],
s
.
shape
[
d
+
2
]
-
pos
)
s_in
=
s_in
.
narrow
(
d
+
2
,
pos
,
l
)
upscaled
.
append
(
round
(
pos
*
upscale_amount
))
ps
=
function
(
s_in
).
to
(
output_device
)
mask
=
torch
.
ones_like
(
ps
)
feather
=
round
(
overlap
*
upscale_amount
)
for
t
in
range
(
feather
):
mask
[:,:,
t
:
1
+
t
,:]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
mask
[:,:,
mask
.
shape
[
2
]
-
1
-
t
:
mask
.
shape
[
2
]
-
t
,:]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
mask
[:,:,:,
t
:
1
+
t
]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
mask
[:,:,:,
mask
.
shape
[
3
]
-
1
-
t
:
mask
.
shape
[
3
]
-
t
]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
out
[:,:,
round
(
y
*
upscale_amount
):
round
((
y
+
tile_y
)
*
upscale_amount
),
round
(
x
*
upscale_amount
):
round
((
x
+
tile_x
)
*
upscale_amount
)]
+=
ps
*
mask
out_div
[:,:,
round
(
y
*
upscale_amount
):
round
((
y
+
tile_y
)
*
upscale_amount
),
round
(
x
*
upscale_amount
):
round
((
x
+
tile_x
)
*
upscale_amount
)]
+=
mask
for
d
in
range
(
2
,
dims
+
2
):
m
=
mask
.
narrow
(
d
,
t
,
1
)
m
*=
((
1.0
/
feather
)
*
(
t
+
1
))
m
=
mask
.
narrow
(
d
,
mask
.
shape
[
d
]
-
1
-
t
,
1
)
m
*=
((
1.0
/
feather
)
*
(
t
+
1
))
o
=
out
o_d
=
out_div
for
d
in
range
(
dims
):
o
=
o
.
narrow
(
d
+
2
,
upscaled
[
d
],
mask
.
shape
[
d
+
2
])
o_d
=
o_d
.
narrow
(
d
+
2
,
upscaled
[
d
],
mask
.
shape
[
d
+
2
])
o
+=
ps
*
mask
o_d
+=
mask
if
pbar
is
not
None
:
pbar
.
update
(
1
)
output
[
b
:
b
+
1
]
=
out
/
out_div
return
output
def
tiled_scale
(
samples
,
function
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
8
,
upscale_amount
=
4
,
out_channels
=
3
,
output_device
=
"cpu"
,
pbar
=
None
):
return
tiled_scale_multidim
(
samples
,
function
,
(
tile_y
,
tile_x
),
overlap
,
upscale_amount
,
out_channels
,
output_device
,
pbar
)
PROGRESS_BAR_ENABLED
=
True
def
set_progress_bar_enabled
(
enabled
):
global
PROGRESS_BAR_ENABLED
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment