Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
853e96ad
Commit
853e96ad
authored
Feb 08, 2023
by
comfyanonymous
Browse files
Increase it/s by batching together some stuff sent to unet.
parent
c92633ea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
128 additions
and
42 deletions
+128
-42
comfy/model_management.py
comfy/model_management.py
+29
-6
comfy/samplers.py
comfy/samplers.py
+99
-36
No files found.
comfy/model_management.py
View file @
853e96ad
...
@@ -7,6 +7,7 @@ NORMAL_VRAM = 3
...
@@ -7,6 +7,7 @@ NORMAL_VRAM = 3
accelerate_enabled
=
False
accelerate_enabled
=
False
vram_state
=
NORMAL_VRAM
vram_state
=
NORMAL_VRAM
total_vram
=
0
total_vram_available_mb
=
-
1
total_vram_available_mb
=
-
1
import
sys
import
sys
...
@@ -17,6 +18,12 @@ if "--lowvram" in sys.argv:
...
@@ -17,6 +18,12 @@ if "--lowvram" in sys.argv:
if
"--novram"
in
sys
.
argv
:
if
"--novram"
in
sys
.
argv
:
set_vram_to
=
NO_VRAM
set_vram_to
=
NO_VRAM
try
:
import
torch
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
except
:
pass
if
set_vram_to
!=
NORMAL_VRAM
:
if
set_vram_to
!=
NORMAL_VRAM
:
try
:
try
:
import
accelerate
import
accelerate
...
@@ -26,12 +33,8 @@ if set_vram_to != NORMAL_VRAM:
...
@@ -26,12 +33,8 @@ if set_vram_to != NORMAL_VRAM:
import
traceback
import
traceback
print
(
traceback
.
format_exc
())
print
(
traceback
.
format_exc
())
print
(
"ERROR: COULD NOT ENABLE LOW VRAM MODE."
)
print
(
"ERROR: COULD NOT ENABLE LOW VRAM MODE."
)
try
:
import
torch
total_vram_available_mb
=
(
total_vram
-
1024
)
//
2
total_vram_available_mb
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
except
:
pass
total_vram_available_mb
=
(
total_vram_available_mb
-
1024
)
//
2
total_vram_available_mb
=
int
(
max
(
256
,
total_vram_available_mb
))
total_vram_available_mb
=
int
(
max
(
256
,
total_vram_available_mb
))
...
@@ -81,6 +84,26 @@ def load_model_gpu(model):
...
@@ -81,6 +84,26 @@ def load_model_gpu(model):
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"256MiB"
,
"cpu"
:
"16GiB"
})
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"256MiB"
,
"cpu"
:
"16GiB"
})
elif
vram_state
==
LOW_VRAM
:
elif
vram_state
==
LOW_VRAM
:
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"{}MiB"
.
format
(
total_vram_available_mb
),
"cpu"
:
"16GiB"
})
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"{}MiB"
.
format
(
total_vram_available_mb
),
"cpu"
:
"16GiB"
})
print
(
device_map
,
"{}MiB"
.
format
(
total_vram_available_mb
))
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
"cuda"
)
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
"cuda"
)
model_accelerated
=
True
model_accelerated
=
True
return
current_loaded_model
return
current_loaded_model
def
get_free_memory
():
dev
=
torch
.
cuda
.
current_device
()
stats
=
torch
.
cuda
.
memory_stats
(
dev
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
dev
)
mem_free_torch
=
mem_reserved
-
mem_active
return
mem_free_cuda
+
mem_free_torch
def
maximum_batch_area
():
global
vram_state
if
vram_state
==
NO_VRAM
:
return
0
memory_free
=
get_free_memory
()
/
(
1024
*
1024
)
area
=
((
memory_free
-
1024
)
*
0.9
)
/
(
0.6
)
return
int
(
max
(
area
,
0
))
comfy/samplers.py
View file @
853e96ad
...
@@ -2,6 +2,7 @@ import k_diffusion.sampling
...
@@ -2,6 +2,7 @@ import k_diffusion.sampling
import
k_diffusion.external
import
k_diffusion.external
import
torch
import
torch
import
contextlib
import
contextlib
import
model_management
class
CFGDenoiser
(
torch
.
nn
.
Module
):
class
CFGDenoiser
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
...
@@ -24,55 +25,117 @@ class CFGDenoiserComplex(torch.nn.Module):
...
@@ -24,55 +25,117 @@ class CFGDenoiserComplex(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
def
calc_cond
(
cond
,
x_in
,
sigma
):
def
get_area_and_mult
(
cond
,
x_in
,
sigma
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
min_sigma
=
0.0
max_sigma
=
999.0
if
'area'
in
cond
[
1
]:
area
=
cond
[
1
][
'area'
]
if
'strength'
in
cond
[
1
]:
strength
=
cond
[
1
][
'strength'
]
if
'min_sigma'
in
cond
[
1
]:
min_sigma
=
cond
[
1
][
'min_sigma'
]
if
'max_sigma'
in
cond
[
1
]:
max_sigma
=
cond
[
1
][
'max_sigma'
]
if
sigma
<
min_sigma
or
sigma
>
max_sigma
:
return
None
input_x
=
x_in
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
mult
=
torch
.
ones_like
(
input_x
)
*
strength
rr
=
8
if
area
[
2
]
!=
0
:
for
t
in
range
(
rr
):
mult
[:,:,
area
[
2
]
+
t
:
area
[
2
]
+
1
+
t
,:]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
if
(
area
[
0
]
+
area
[
2
])
<
x_in
.
shape
[
2
]:
for
t
in
range
(
rr
):
mult
[:,:,
area
[
0
]
+
area
[
2
]
-
1
-
t
:
area
[
0
]
+
area
[
2
]
-
t
,:]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
if
area
[
3
]
!=
0
:
for
t
in
range
(
rr
):
mult
[:,:,:,
area
[
3
]
+
t
:
area
[
3
]
+
1
+
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
if
(
area
[
1
]
+
area
[
3
])
<
x_in
.
shape
[
3
]:
for
t
in
range
(
rr
):
mult
[:,:,:,
area
[
1
]
+
area
[
3
]
-
1
-
t
:
area
[
1
]
+
area
[
3
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
return
(
input_x
,
mult
,
cond
[
0
],
area
)
def
calc_cond_uncond_batch
(
cond
,
uncond
,
x_in
,
sigma
,
max_total_area
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_uncond
=
torch
.
zeros_like
(
x_in
)
out_uncond_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
sigma_cmp
=
sigma
[
0
]
sigma_cmp
=
sigma
[
0
]
COND
=
0
UNCOND
=
1
to_run
=
[]
for
x
in
cond
:
for
x
in
cond
:
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
p
=
get_area_and_mult
(
x
,
x_in
,
sigma_cmp
)
strength
=
1.0
if
p
is
None
:
min_sigma
=
0.0
max_sigma
=
999.0
if
'area'
in
x
[
1
]:
area
=
x
[
1
][
'area'
]
if
'strength'
in
x
[
1
]:
strength
=
x
[
1
][
'strength'
]
if
'min_sigma'
in
x
[
1
]:
min_sigma
=
x
[
1
][
'min_sigma'
]
if
'max_sigma'
in
x
[
1
]:
max_sigma
=
x
[
1
][
'max_sigma'
]
if
sigma_cmp
<
min_sigma
or
sigma_cmp
>
max_sigma
:
continue
continue
input_x
=
x_in
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
mult
=
torch
.
ones_like
(
input_x
)
*
strength
to_run
+=
[(
p
,
COND
)]
for
x
in
uncond
:
rr
=
8
p
=
get_area_and_mult
(
x
,
x_in
,
sigma_cmp
)
if
area
[
2
]
!=
0
:
if
p
is
None
:
for
t
in
range
(
rr
):
continue
mult
[:,:,
area
[
2
]
+
t
:
area
[
2
]
+
1
+
t
,:]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
if
(
area
[
0
]
+
area
[
2
])
<
x_in
.
shape
[
2
]:
to_run
+=
[(
p
,
UNCOND
)]
for
t
in
range
(
rr
):
mult
[:,:,
area
[
0
]
+
area
[
2
]
-
1
-
t
:
area
[
0
]
+
area
[
2
]
-
t
,:]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
while
len
(
to_run
)
>
0
:
if
area
[
3
]
!=
0
:
first
=
to_run
[
0
]
for
t
in
range
(
rr
):
first_shape
=
first
[
0
][
0
].
shape
mult
[:,:,:,
area
[
3
]
+
t
:
area
[
3
]
+
1
+
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
to_batch
=
[]
if
(
area
[
1
]
+
area
[
3
])
<
x_in
.
shape
[
3
]:
for
x
in
range
(
len
(
to_run
)):
for
t
in
range
(
rr
):
if
to_run
[
x
][
0
][
0
].
shape
==
first_shape
:
mult
[:,:,:,
area
[
1
]
+
area
[
3
]
-
1
-
t
:
area
[
1
]
+
area
[
3
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
if
to_run
[
x
][
0
][
2
].
shape
==
first
[
0
][
2
].
shape
:
to_batch
+=
[
x
]
out_cond
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
+=
self
.
inner_model
(
input_x
,
sigma
,
cond
=
x
[
0
])
*
mult
if
(
len
(
to_batch
)
*
first_shape
[
0
]
*
first_shape
[
2
]
*
first_shape
[
3
]
>=
max_total_area
):
out_count
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
+=
mult
break
to_batch
.
reverse
()
input_x
=
[]
mult
=
[]
c
=
[]
cond_or_uncond
=
[]
area
=
[]
for
x
in
to_batch
:
o
=
to_run
.
pop
(
x
)
p
=
o
[
0
]
input_x
+=
[
p
[
0
]]
mult
+=
[
p
[
1
]]
c
+=
[
p
[
2
]]
area
+=
[
p
[
3
]]
cond_or_uncond
+=
[
o
[
1
]]
batch_chunks
=
len
(
cond_or_uncond
)
input_x
=
torch
.
cat
(
input_x
)
c
=
torch
.
cat
(
c
)
sigma_
=
torch
.
cat
([
sigma
]
*
batch_chunks
)
output
=
self
.
inner_model
(
input_x
,
sigma_
,
cond
=
c
).
chunk
(
batch_chunks
)
del
input_x
del
input_x
for
o
in
range
(
batch_chunks
):
if
cond_or_uncond
[
o
]
==
COND
:
out_cond
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
output
[
o
]
*
mult
[
o
]
out_count
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
mult
[
o
]
else
:
out_uncond
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
output
[
o
]
*
mult
[
o
]
out_uncond_count
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
mult
[
o
]
del
mult
del
mult
out_cond
/=
out_count
out_cond
/=
out_count
del
out_count
del
out_count
return
out_cond
out_uncond
/=
out_uncond_count
del
out_uncond_count
return
out_cond
,
out_uncond
cond
=
calc_cond
(
cond
,
x
,
sigma
)
uncond
=
calc_cond
(
uncond
,
x
,
sigma
)
max_total_area
=
model_management
.
maximum_batch_area
()
cond
,
uncond
=
calc_cond_uncond_batch
(
cond
,
uncond
,
x
,
sigma
,
max_total_area
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
def
simple_scheduler
(
model
,
steps
):
def
simple_scheduler
(
model
,
steps
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment