Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
db8b59ec
Commit
db8b59ec
authored
Mar 13, 2024
by
comfyanonymous
Browse files
Lower memory usage for loras in lowvram mode at the cost of perf.
parent
eda87043
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
101 additions
and
48 deletions
+101
-48
comfy/model_management.py
comfy/model_management.py
+5
-31
comfy/model_patcher.py
comfy/model_patcher.py
+74
-17
comfy/ops.py
comfy/ops.py
+22
-0
No files found.
comfy/model_management.py
View file @
db8b59ec
...
@@ -272,7 +272,6 @@ def module_size(module):
...
@@ -272,7 +272,6 @@ def module_size(module):
class
LoadedModel
:
class
LoadedModel
:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
model
=
model
self
.
model_accelerated
=
False
self
.
device
=
model
.
load_device
self
.
device
=
model
.
load_device
def
model_memory
(
self
):
def
model_memory
(
self
):
...
@@ -285,52 +284,27 @@ class LoadedModel:
...
@@ -285,52 +284,27 @@ class LoadedModel:
return
self
.
model_memory
()
return
self
.
model_memory
()
def
model_load
(
self
,
lowvram_model_memory
=
0
):
def
model_load
(
self
,
lowvram_model_memory
=
0
):
patch_model_to
=
None
if
lowvram_model_memory
==
0
:
patch_model_to
=
self
.
device
patch_model_to
=
self
.
device
self
.
model
.
model_patches_to
(
self
.
device
)
self
.
model
.
model_patches_to
(
self
.
device
)
self
.
model
.
model_patches_to
(
self
.
model
.
model_dtype
())
self
.
model
.
model_patches_to
(
self
.
model
.
model_dtype
())
try
:
try
:
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
)
#TODO: do something with loras and offloading to CPU
if
lowvram_model_memory
>
0
:
self
.
real_model
=
self
.
model
.
patch_model_lowvram
(
device_to
=
patch_model_to
,
lowvram_model_memory
=
lowvram_model_memory
)
else
:
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model_unload
()
self
.
model_unload
()
raise
e
raise
e
if
lowvram_model_memory
>
0
:
logging
.
info
(
"loading in lowvram mode {}"
.
format
(
lowvram_model_memory
/
(
1024
*
1024
)))
mem_counter
=
0
for
m
in
self
.
real_model
.
modules
():
if
hasattr
(
m
,
"comfy_cast_weights"
):
m
.
prev_comfy_cast_weights
=
m
.
comfy_cast_weights
m
.
comfy_cast_weights
=
True
module_mem
=
module_size
(
m
)
if
mem_counter
+
module_mem
<
lowvram_model_memory
:
m
.
to
(
self
.
device
)
mem_counter
+=
module_mem
elif
hasattr
(
m
,
"weight"
):
#only modules with comfy_cast_weights can be set to lowvram mode
m
.
to
(
self
.
device
)
mem_counter
+=
module_size
(
m
)
logging
.
warning
(
"lowvram: loaded module regularly {}"
.
format
(
m
))
self
.
model_accelerated
=
True
if
is_intel_xpu
()
and
not
args
.
disable_ipex_optimize
:
if
is_intel_xpu
()
and
not
args
.
disable_ipex_optimize
:
self
.
real_model
=
torch
.
xpu
.
optimize
(
self
.
real_model
.
eval
(),
inplace
=
True
,
auto_kernel_selection
=
True
,
graph_mode
=
True
)
self
.
real_model
=
torch
.
xpu
.
optimize
(
self
.
real_model
.
eval
(),
inplace
=
True
,
auto_kernel_selection
=
True
,
graph_mode
=
True
)
return
self
.
real_model
return
self
.
real_model
def
model_unload
(
self
):
def
model_unload
(
self
):
if
self
.
model_accelerated
:
for
m
in
self
.
real_model
.
modules
():
if
hasattr
(
m
,
"prev_comfy_cast_weights"
):
m
.
comfy_cast_weights
=
m
.
prev_comfy_cast_weights
del
m
.
prev_comfy_cast_weights
self
.
model_accelerated
=
False
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model
.
model_patches_to
(
self
.
model
.
offload_device
)
self
.
model
.
model_patches_to
(
self
.
model
.
offload_device
)
...
...
comfy/model_patcher.py
View file @
db8b59ec
...
@@ -24,6 +24,7 @@ class ModelPatcher:
...
@@ -24,6 +24,7 @@ class ModelPatcher:
self
.
current_device
=
current_device
self
.
current_device
=
current_device
self
.
weight_inplace_update
=
weight_inplace_update
self
.
weight_inplace_update
=
weight_inplace_update
self
.
model_lowvram
=
False
def
model_size
(
self
):
def
model_size
(
self
):
if
self
.
size
>
0
:
if
self
.
size
>
0
:
...
@@ -178,20 +179,11 @@ class ModelPatcher:
...
@@ -178,20 +179,11 @@ class ModelPatcher:
sd
.
pop
(
k
)
sd
.
pop
(
k
)
return
sd
return
sd
def
patch_model
(
self
,
device_to
=
None
,
patch_weights
=
True
):
def
patch_weight_to_device
(
self
,
key
,
device_to
=
None
):
for
k
in
self
.
object_patches
:
if
key
not
in
self
.
patches
:
old
=
comfy
.
utils
.
set_attr
(
self
.
model
,
k
,
self
.
object_patches
[
k
])
return
if
k
not
in
self
.
object_patches_backup
:
self
.
object_patches_backup
[
k
]
=
old
if
patch_weights
:
model_sd
=
self
.
model_state_dict
()
for
key
in
self
.
patches
:
if
key
not
in
model_sd
:
logging
.
warning
(
"could not patch. key doesn't exist in model: {}"
.
format
(
key
))
continue
weight
=
model_sd
[
key
]
weight
=
comfy
.
utils
.
get_attr
(
self
.
model
,
key
)
inplace_update
=
self
.
weight_inplace_update
inplace_update
=
self
.
weight_inplace_update
...
@@ -207,7 +199,21 @@ class ModelPatcher:
...
@@ -207,7 +199,21 @@ class ModelPatcher:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
key
,
out_weight
)
comfy
.
utils
.
copy_to_param
(
self
.
model
,
key
,
out_weight
)
else
:
else
:
comfy
.
utils
.
set_attr_param
(
self
.
model
,
key
,
out_weight
)
comfy
.
utils
.
set_attr_param
(
self
.
model
,
key
,
out_weight
)
del
temp_weight
def
patch_model
(
self
,
device_to
=
None
,
patch_weights
=
True
):
for
k
in
self
.
object_patches
:
old
=
comfy
.
utils
.
set_attr
(
self
.
model
,
k
,
self
.
object_patches
[
k
])
if
k
not
in
self
.
object_patches_backup
:
self
.
object_patches_backup
[
k
]
=
old
if
patch_weights
:
model_sd
=
self
.
model_state_dict
()
for
key
in
self
.
patches
:
if
key
not
in
model_sd
:
logging
.
warning
(
"could not patch. key doesn't exist in model: {}"
.
format
(
key
))
continue
self
.
patch_weight_to_device
(
key
,
device_to
)
if
device_to
is
not
None
:
if
device_to
is
not
None
:
self
.
model
.
to
(
device_to
)
self
.
model
.
to
(
device_to
)
...
@@ -215,6 +221,47 @@ class ModelPatcher:
...
@@ -215,6 +221,47 @@ class ModelPatcher:
return
self
.
model
return
self
.
model
def
patch_model_lowvram
(
self
,
device_to
=
None
,
lowvram_model_memory
=
0
):
self
.
patch_model
(
device_to
,
patch_weights
=
False
)
logging
.
info
(
"loading in lowvram mode {}"
.
format
(
lowvram_model_memory
/
(
1024
*
1024
)))
class
LowVramPatch
:
def
__init__
(
self
,
key
,
model_patcher
):
self
.
key
=
key
self
.
model_patcher
=
model_patcher
def
__call__
(
self
,
weight
):
return
self
.
model_patcher
.
calculate_weight
(
self
.
model_patcher
.
patches
[
self
.
key
],
weight
,
self
.
key
)
mem_counter
=
0
for
n
,
m
in
self
.
model
.
named_modules
():
lowvram_weight
=
False
if
hasattr
(
m
,
"comfy_cast_weights"
):
module_mem
=
comfy
.
model_management
.
module_size
(
m
)
if
mem_counter
+
module_mem
>=
lowvram_model_memory
:
lowvram_weight
=
True
weight_key
=
"{}.weight"
.
format
(
n
)
bias_key
=
"{}.bias"
.
format
(
n
)
if
lowvram_weight
:
if
weight_key
in
self
.
patches
:
m
.
weight_function
=
LowVramPatch
(
weight_key
,
self
)
if
bias_key
in
self
.
patches
:
m
.
bias_function
=
LowVramPatch
(
weight_key
,
self
)
m
.
prev_comfy_cast_weights
=
m
.
comfy_cast_weights
m
.
comfy_cast_weights
=
True
else
:
if
hasattr
(
m
,
"weight"
):
self
.
patch_weight_to_device
(
weight_key
,
device_to
)
self
.
patch_weight_to_device
(
bias_key
,
device_to
)
m
.
to
(
device_to
)
mem_counter
+=
comfy
.
model_management
.
module_size
(
m
)
logging
.
debug
(
"lowvram: loaded module regularly {}"
.
format
(
m
))
self
.
model_lowvram
=
True
return
self
.
model
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
for
p
in
patches
:
for
p
in
patches
:
alpha
=
p
[
0
]
alpha
=
p
[
0
]
...
@@ -341,6 +388,16 @@ class ModelPatcher:
...
@@ -341,6 +388,16 @@ class ModelPatcher:
return
weight
return
weight
def
unpatch_model
(
self
,
device_to
=
None
):
def
unpatch_model
(
self
,
device_to
=
None
):
if
self
.
model_lowvram
:
for
m
in
self
.
model
.
modules
():
if
hasattr
(
m
,
"prev_comfy_cast_weights"
):
m
.
comfy_cast_weights
=
m
.
prev_comfy_cast_weights
del
m
.
prev_comfy_cast_weights
m
.
weight_function
=
None
m
.
bias_function
=
None
self
.
model_lowvram
=
False
keys
=
list
(
self
.
backup
.
keys
())
keys
=
list
(
self
.
backup
.
keys
())
if
self
.
weight_inplace_update
:
if
self
.
weight_inplace_update
:
...
...
comfy/ops.py
View file @
db8b59ec
...
@@ -24,13 +24,20 @@ def cast_bias_weight(s, input):
...
@@ -24,13 +24,20 @@ def cast_bias_weight(s, input):
non_blocking
=
comfy
.
model_management
.
device_supports_non_blocking
(
input
.
device
)
non_blocking
=
comfy
.
model_management
.
device_supports_non_blocking
(
input
.
device
)
if
s
.
bias
is
not
None
:
if
s
.
bias
is
not
None
:
bias
=
s
.
bias
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
bias
=
s
.
bias
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
if
s
.
bias_function
is
not
None
:
bias
=
s
.
bias_function
(
bias
)
weight
=
s
.
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
weight
=
s
.
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
if
s
.
weight_function
is
not
None
:
weight
=
s
.
weight_function
(
weight
)
return
weight
,
bias
return
weight
,
bias
class
disable_weight_init
:
class
disable_weight_init
:
class
Linear
(
torch
.
nn
.
Linear
):
class
Linear
(
torch
.
nn
.
Linear
):
comfy_cast_weights
=
False
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
return
None
return
None
...
@@ -46,6 +53,9 @@ class disable_weight_init:
...
@@ -46,6 +53,9 @@ class disable_weight_init:
class
Conv2d
(
torch
.
nn
.
Conv2d
):
class
Conv2d
(
torch
.
nn
.
Conv2d
):
comfy_cast_weights
=
False
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
return
None
return
None
...
@@ -61,6 +71,9 @@ class disable_weight_init:
...
@@ -61,6 +71,9 @@ class disable_weight_init:
class
Conv3d
(
torch
.
nn
.
Conv3d
):
class
Conv3d
(
torch
.
nn
.
Conv3d
):
comfy_cast_weights
=
False
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
return
None
return
None
...
@@ -76,6 +89,9 @@ class disable_weight_init:
...
@@ -76,6 +89,9 @@ class disable_weight_init:
class
GroupNorm
(
torch
.
nn
.
GroupNorm
):
class
GroupNorm
(
torch
.
nn
.
GroupNorm
):
comfy_cast_weights
=
False
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
return
None
return
None
...
@@ -92,6 +108,9 @@ class disable_weight_init:
...
@@ -92,6 +108,9 @@ class disable_weight_init:
class
LayerNorm
(
torch
.
nn
.
LayerNorm
):
class
LayerNorm
(
torch
.
nn
.
LayerNorm
):
comfy_cast_weights
=
False
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
return
None
return
None
...
@@ -111,6 +130,9 @@ class disable_weight_init:
...
@@ -111,6 +130,9 @@ class disable_weight_init:
class
ConvTranspose2d
(
torch
.
nn
.
ConvTranspose2d
):
class
ConvTranspose2d
(
torch
.
nn
.
ConvTranspose2d
):
comfy_cast_weights
=
False
comfy_cast_weights
=
False
weight_function
=
None
bias_function
=
None
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment