Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
db8b59ec
Commit
db8b59ec
authored
Mar 13, 2024
by
comfyanonymous
Browse files
Lower memory usage for loras in lowvram mode at the cost of perf.
parent
eda87043
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
101 additions
and
48 deletions
+101
-48
comfy/model_management.py
comfy/model_management.py
+5
-31
comfy/model_patcher.py
comfy/model_patcher.py
+74
-17
comfy/ops.py
comfy/ops.py
+22
-0
No files found.
comfy/model_management.py
View file @
db8b59ec
...
...
@@ -272,7 +272,6 @@ def module_size(module):
class
LoadedModel
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
model_accelerated
=
False
self
.
device
=
model
.
load_device
def
model_memory
(
self
):
...
...
@@ -285,52 +284,27 @@ class LoadedModel:
return
self
.
model_memory
()
def
model_load
(
self
,
lowvram_model_memory
=
0
):
patch_model_to
=
None
if
lowvram_model_memory
==
0
:
patch_model_to
=
self
.
device
patch_model_to
=
self
.
device
self
.
model
.
model_patches_to
(
self
.
device
)
self
.
model
.
model_patches_to
(
self
.
model
.
model_dtype
())
try
:
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
)
#TODO: do something with loras and offloading to CPU
if
lowvram_model_memory
>
0
:
self
.
real_model
=
self
.
model
.
patch_model_lowvram
(
device_to
=
patch_model_to
,
lowvram_model_memory
=
lowvram_model_memory
)
else
:
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
)
except
Exception
as
e
:
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model_unload
()
raise
e
if
lowvram_model_memory
>
0
:
logging
.
info
(
"loading in lowvram mode {}"
.
format
(
lowvram_model_memory
/
(
1024
*
1024
)))
mem_counter
=
0
for
m
in
self
.
real_model
.
modules
():
if
hasattr
(
m
,
"comfy_cast_weights"
):
m
.
prev_comfy_cast_weights
=
m
.
comfy_cast_weights
m
.
comfy_cast_weights
=
True
module_mem
=
module_size
(
m
)
if
mem_counter
+
module_mem
<
lowvram_model_memory
:
m
.
to
(
self
.
device
)
mem_counter
+=
module_mem
elif
hasattr
(
m
,
"weight"
):
#only modules with comfy_cast_weights can be set to lowvram mode
m
.
to
(
self
.
device
)
mem_counter
+=
module_size
(
m
)
logging
.
warning
(
"lowvram: loaded module regularly {}"
.
format
(
m
))
self
.
model_accelerated
=
True
if
is_intel_xpu
()
and
not
args
.
disable_ipex_optimize
:
self
.
real_model
=
torch
.
xpu
.
optimize
(
self
.
real_model
.
eval
(),
inplace
=
True
,
auto_kernel_selection
=
True
,
graph_mode
=
True
)
return
self
.
real_model
def
model_unload
(
self
):
if
self
.
model_accelerated
:
for
m
in
self
.
real_model
.
modules
():
if
hasattr
(
m
,
"prev_comfy_cast_weights"
):
m
.
comfy_cast_weights
=
m
.
prev_comfy_cast_weights
del
m
.
prev_comfy_cast_weights
self
.
model_accelerated
=
False
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model
.
model_patches_to
(
self
.
model
.
offload_device
)
...
...
comfy/model_patcher.py
View file @
db8b59ec
...
...
@@ -24,6 +24,7 @@ class ModelPatcher:
self
.
current_device
=
current_device
self
.
weight_inplace_update
=
weight_inplace_update
self
.
model_lowvram
=
False
def
model_size
(
self
):
if
self
.
size
>
0
:
...
...
@@ -178,6 +179,27 @@ class ModelPatcher:
sd
.
pop
(
k
)
return
sd
def
patch_weight_to_device
(
self
,
key
,
device_to
=
None
):
if
key
not
in
self
.
patches
:
return
weight
=
comfy
.
utils
.
get_attr
(
self
.
model
,
key
)
inplace_update
=
self
.
weight_inplace_update
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
to
(
device
=
self
.
offload_device
,
copy
=
inplace_update
)
if
device_to
is
not
None
:
temp_weight
=
comfy
.
model_management
.
cast_to_device
(
weight
,
device_to
,
torch
.
float32
,
copy
=
True
)
else
:
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
out_weight
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
if
inplace_update
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
key
,
out_weight
)
else
:
comfy
.
utils
.
set_attr_param
(
self
.
model
,
key
,
out_weight
)
def
patch_model
(
self
,
device_to
=
None
,
patch_weights
=
True
):
for
k
in
self
.
object_patches
:
old
=
comfy
.
utils
.
set_attr
(
self
.
model
,
k
,
self
.
object_patches
[
k
])
...
...
@@ -191,23 +213,7 @@ class ModelPatcher:
logging
.
warning
(
"could not patch. key doesn't exist in model: {}"
.
format
(
key
))
continue
weight
=
model_sd
[
key
]
inplace_update
=
self
.
weight_inplace_update
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
to
(
device
=
self
.
offload_device
,
copy
=
inplace_update
)
if
device_to
is
not
None
:
temp_weight
=
comfy
.
model_management
.
cast_to_device
(
weight
,
device_to
,
torch
.
float32
,
copy
=
True
)
else
:
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
out_weight
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
if
inplace_update
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
key
,
out_weight
)
else
:
comfy
.
utils
.
set_attr_param
(
self
.
model
,
key
,
out_weight
)
del
temp_weight
self
.
patch_weight_to_device
(
key
,
device_to
)
if
device_to
is
not
None
:
self
.
model
.
to
(
device_to
)
...
...
@@ -215,6 +221,47 @@ class ModelPatcher:
return
self
.
model
def
patch_model_lowvram
(
self
,
device_to
=
None
,
lowvram_model_memory
=
0
):
self
.
patch_model
(
device_to
,
patch_weights
=
False
)
logging
.
info
(
"loading in lowvram mode {}"
.
format
(
lowvram_model_memory
/
(
1024
*
1024
)))
class
LowVramPatch
:
def
__init__
(
self
,
key
,
model_patcher
):
self
.
key
=
key
self
.
model_patcher
=
model_patcher
def
__call__
(
self
,
weight
):
return
self
.
model_patcher
.
calculate_weight
(
self
.
model_patcher
.
patches
[
self
.
key
],
weight
,
self
.
key
)
mem_counter
=
0
for
n
,
m
in
self
.
model
.
named_modules
():
lowvram_weight
=
False
if
hasattr
(
m
,
"comfy_cast_weights"
):
module_mem
=
comfy
.
model_management
.
module_size
(
m
)
if
mem_counter
+
module_mem
>=
lowvram_model_memory
:
lowvram_weight
=
True
weight_key
=
"{}.weight"
.
format
(
n
)
bias_key
=
"{}.bias"
.
format
(
n
)
if
lowvram_weight
:
if
weight_key
in
self
.
patches
:
m
.
weight_function
=
LowVramPatch
(
weight_key
,
self
)
if
bias_key
in
self
.
patches
:
m
.
bias_function
=
LowVramPatch
(
weight_key
,
self
)
m
.
prev_comfy_cast_weights
=
m
.
comfy_cast_weights
m
.
comfy_cast_weights
=
True
else
:
if
hasattr
(
m
,
"weight"
):
self
.
patch_weight_to_device
(
weight_key
,
device_to
)
self
.
patch_weight_to_device
(
bias_key
,
device_to
)
m
.
to
(
device_to
)
mem_counter
+=
comfy
.
model_management
.
module_size
(
m
)
logging
.
debug
(
"lowvram: loaded module regularly {}"
.
format
(
m
))
self
.
model_lowvram
=
True
return
self
.
model
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
for
p
in
patches
:
alpha
=
p
[
0
]
...
...
@@ -341,6 +388,16 @@ class ModelPatcher:
return
weight
def
unpatch_model
(
self
,
device_to
=
None
):
if
self
.
model_lowvram
:
for
m
in
self
.
model
.
modules
():
if
hasattr
(
m
,
"prev_comfy_cast_weights"
):
m
.
comfy_cast_weights
=
m
.
prev_comfy_cast_weights
del
m
.
prev_comfy_cast_weights
m
.
weight_function
=
None
m
.
bias_function
=
None
self
.
model_lowvram
=
False
keys
=
list
(
self
.
backup
.
keys
())
if
self
.
weight_inplace_update
:
...
...
comfy/ops.py
View file @
db8b59ec
...
...
@@ -24,13 +24,20 @@ def cast_bias_weight(s, input):
non_blocking
=
comfy
.
model_management
.
device_supports_non_blocking
(
input
.
device
)
if
s
.
bias
is
not
None
:
bias
=
s
.
bias
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
if
s
.
bias_function
is
not
None
:
bias
=
s
.
bias_function
(
bias
)
weight
=
s
.
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
if
s
.
weight_function
is
not
None
:
weight
=
s
.
weight_function
(
weight
)
return
weight
,
bias
class
disable_weight_init
:
class
Linear
(
torch
.
nn
.
Linear
):
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
return
None
...
...
@@ -46,6 +53,9 @@ class disable_weight_init:
class
Conv2d
(
torch
.
nn
.
Conv2d
):
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
return
None
...
...
@@ -61,6 +71,9 @@ class disable_weight_init:
class
Conv3d
(
torch
.
nn
.
Conv3d
):
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
return
None
...
...
@@ -76,6 +89,9 @@ class disable_weight_init:
class
GroupNorm
(
torch
.
nn
.
GroupNorm
):
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
return
None
...
...
@@ -92,6 +108,9 @@ class disable_weight_init:
class
LayerNorm
(
torch
.
nn
.
LayerNorm
):
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
return
None
...
...
@@ -111,6 +130,9 @@ class disable_weight_init:
class
ConvTranspose2d
(
torch
.
nn
.
ConvTranspose2d
):
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment