Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c18a203a
Commit
c18a203a
authored
Mar 20, 2024
by
comfyanonymous
Browse files
Don't unload model weights for non weight patches.
parent
150a3e94
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
28 deletions
+76
-28
comfy/model_management.py
comfy/model_management.py
+36
-8
comfy/model_patcher.py
comfy/model_patcher.py
+40
-20
No files found.
comfy/model_management.py
View file @
c18a203a
...
...
@@ -273,6 +273,7 @@ class LoadedModel:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
device
=
model
.
load_device
self
.
weights_loaded
=
False
def
model_memory
(
self
):
return
self
.
model
.
model_size
()
...
...
@@ -289,11 +290,13 @@ class LoadedModel:
self
.
model
.
model_patches_to
(
self
.
device
)
self
.
model
.
model_patches_to
(
self
.
model
.
model_dtype
())
load_weights
=
not
self
.
weights_loaded
try
:
if
lowvram_model_memory
>
0
:
if
lowvram_model_memory
>
0
and
load_weights
:
self
.
real_model
=
self
.
model
.
patch_model_lowvram
(
device_to
=
patch_model_to
,
lowvram_model_memory
=
lowvram_model_memory
)
else
:
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
)
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
,
patch_weights
=
load_weights
)
except
Exception
as
e
:
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model_unload
()
...
...
@@ -302,11 +305,13 @@ class LoadedModel:
if
is_intel_xpu
()
and
not
args
.
disable_ipex_optimize
:
self
.
real_model
=
torch
.
xpu
.
optimize
(
self
.
real_model
.
eval
(),
inplace
=
True
,
auto_kernel_selection
=
True
,
graph_mode
=
True
)
self
.
weights_loaded
=
True
return
self
.
real_model
def
model_unload
(
self
):
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
def
model_unload
(
self
,
unpatch_weights
=
True
):
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
,
unpatch_weights
=
unpatch_weights
)
self
.
model
.
model_patches_to
(
self
.
model
.
offload_device
)
self
.
weights_loaded
=
self
.
weights_loaded
and
not
unpatch_weights
def
__eq__
(
self
,
other
):
return
self
.
model
is
other
.
model
...
...
@@ -314,15 +319,35 @@ class LoadedModel:
def
minimum_inference_memory
():
return
(
1024
*
1024
*
1024
)
def
unload_model_clones
(
model
):
def
unload_model_clones
(
loaded_model
,
unload_weights_only
=
True
):
model
=
loaded_model
.
model
to_unload
=
[]
for
i
in
range
(
len
(
current_loaded_models
)):
if
model
.
is_clone
(
current_loaded_models
[
i
].
model
):
to_unload
=
[
i
]
+
to_unload
if
len
(
to_unload
)
==
0
:
return
same_weights
=
0
for
i
in
to_unload
:
logging
.
debug
(
"unload clone {}"
.
format
(
i
))
current_loaded_models
.
pop
(
i
).
model_unload
()
if
model
.
clone_has_same_weights
(
current_loaded_models
[
i
].
model
):
same_weights
+=
1
if
same_weights
==
len
(
to_unload
):
unload_weight
=
False
else
:
unload_weight
=
True
if
unload_weights_only
and
unload_weight
==
False
:
return
for
i
in
to_unload
:
logging
.
debug
(
"unload clone {} {}"
.
format
(
i
,
unload_weight
))
current_loaded_models
.
pop
(
i
).
model_unload
(
unpatch_weights
=
unload_weight
)
loaded_model
.
weights_loaded
=
not
unload_weight
def
free_memory
(
memory_required
,
device
,
keep_loaded
=
[]):
unloaded_model
=
False
...
...
@@ -377,13 +402,16 @@ def load_models_gpu(models, memory_required=0):
total_memory_required
=
{}
for
loaded_model
in
models_to_load
:
unload_model_clones
(
loaded_model
.
model
)
unload_model_clones
(
loaded_model
,
unload_weights_only
=
True
)
#unload clones where the weights are different
total_memory_required
[
loaded_model
.
device
]
=
total_memory_required
.
get
(
loaded_model
.
device
,
0
)
+
loaded_model
.
model_memory_required
(
loaded_model
.
device
)
for
device
in
total_memory_required
:
if
device
!=
torch
.
device
(
"cpu"
):
free_memory
(
total_memory_required
[
device
]
*
1.3
+
extra_mem
,
device
,
models_already_loaded
)
for
loaded_model
in
models_to_load
:
unload_model_clones
(
loaded_model
,
unload_weights_only
=
False
)
#unload the rest of the clones where the weights can stay loaded
for
loaded_model
in
models_to_load
:
model
=
loaded_model
.
model
torch_dev
=
model
.
load_device
...
...
comfy/model_patcher.py
View file @
c18a203a
...
...
@@ -2,6 +2,7 @@ import torch
import
copy
import
inspect
import
logging
import
uuid
import
comfy.utils
import
comfy.model_management
...
...
@@ -25,6 +26,7 @@ class ModelPatcher:
self
.
weight_inplace_update
=
weight_inplace_update
self
.
model_lowvram
=
False
self
.
patches_uuid
=
uuid
.
uuid4
()
def
model_size
(
self
):
if
self
.
size
>
0
:
...
...
@@ -39,10 +41,13 @@ class ModelPatcher:
n
.
patches
=
{}
for
k
in
self
.
patches
:
n
.
patches
[
k
]
=
self
.
patches
[
k
][:]
n
.
patches_uuid
=
self
.
patches_uuid
n
.
object_patches
=
self
.
object_patches
.
copy
()
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_keys
=
self
.
model_keys
n
.
backup
=
self
.
backup
n
.
object_patches_backup
=
self
.
object_patches_backup
return
n
def
is_clone
(
self
,
other
):
...
...
@@ -50,6 +55,19 @@ class ModelPatcher:
return
True
return
False
def
clone_has_same_weights
(
self
,
clone
):
if
not
self
.
is_clone
(
clone
):
return
False
if
len
(
self
.
patches
)
==
0
and
len
(
clone
.
patches
)
==
0
:
return
True
if
self
.
patches_uuid
==
clone
.
patches_uuid
:
if
len
(
self
.
patches
)
!=
len
(
clone
.
patches
):
logging
.
warning
(
"WARNING: something went wrong, same patch uuid but different length of patches."
)
else
:
return
True
def
memory_required
(
self
,
input_shape
):
return
self
.
model
.
memory_required
(
input_shape
=
input_shape
)
...
...
@@ -154,6 +172,7 @@ class ModelPatcher:
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
))
self
.
patches
[
k
]
=
current_patches
self
.
patches_uuid
=
uuid
.
uuid4
()
return
list
(
p
)
def
get_key_patches
(
self
,
filter_prefix
=
None
):
...
...
@@ -387,31 +406,32 @@ class ModelPatcher:
return
weight
def
unpatch_model
(
self
,
device_to
=
None
):
if
self
.
model_lowvram
:
for
m
in
self
.
model
.
modules
():
if
hasattr
(
m
,
"prev_comfy_cast_weights"
):
m
.
comfy_cast_weights
=
m
.
prev_comfy_cast_weights
del
m
.
prev_comfy_cast_weights
m
.
weight_function
=
None
m
.
bias_function
=
None
def
unpatch_model
(
self
,
device_to
=
None
,
unpatch_weights
=
True
):
if
unpatch_weights
:
if
self
.
model_lowvram
:
for
m
in
self
.
model
.
modules
():
if
hasattr
(
m
,
"prev_comfy_cast_weights"
):
m
.
comfy_cast_weights
=
m
.
prev_comfy_cast_weights
del
m
.
prev_comfy_cast_weights
m
.
weight_function
=
None
m
.
bias_function
=
None
self
.
model_lowvram
=
False
self
.
model_lowvram
=
False
keys
=
list
(
self
.
backup
.
keys
())
keys
=
list
(
self
.
backup
.
keys
())
if
self
.
weight_inplace_update
:
for
k
in
keys
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
else
:
for
k
in
keys
:
comfy
.
utils
.
set_attr_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
if
self
.
weight_inplace_update
:
for
k
in
keys
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
else
:
for
k
in
keys
:
comfy
.
utils
.
set_attr_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
self
.
backup
=
{}
self
.
backup
.
clear
()
if
device_to
is
not
None
:
self
.
model
.
to
(
device_to
)
self
.
current_device
=
device_to
if
device_to
is
not
None
:
self
.
model
.
to
(
device_to
)
self
.
current_device
=
device_to
keys
=
list
(
self
.
object_patches_backup
.
keys
())
for
k
in
keys
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment