Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c18a203a
Commit
c18a203a
authored
Mar 20, 2024
by
comfyanonymous
Browse files
Don't unload model weights for non weight patches.
parent
150a3e94
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
28 deletions
+76
-28
comfy/model_management.py
comfy/model_management.py
+36
-8
comfy/model_patcher.py
comfy/model_patcher.py
+40
-20
No files found.
comfy/model_management.py
View file @
c18a203a
...
@@ -273,6 +273,7 @@ class LoadedModel:
...
@@ -273,6 +273,7 @@ class LoadedModel:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
model
=
model
self
.
device
=
model
.
load_device
self
.
device
=
model
.
load_device
self
.
weights_loaded
=
False
def
model_memory
(
self
):
def
model_memory
(
self
):
return
self
.
model
.
model_size
()
return
self
.
model
.
model_size
()
...
@@ -289,11 +290,13 @@ class LoadedModel:
...
@@ -289,11 +290,13 @@ class LoadedModel:
self
.
model
.
model_patches_to
(
self
.
device
)
self
.
model
.
model_patches_to
(
self
.
device
)
self
.
model
.
model_patches_to
(
self
.
model
.
model_dtype
())
self
.
model
.
model_patches_to
(
self
.
model
.
model_dtype
())
load_weights
=
not
self
.
weights_loaded
try
:
try
:
if
lowvram_model_memory
>
0
:
if
lowvram_model_memory
>
0
and
load_weights
:
self
.
real_model
=
self
.
model
.
patch_model_lowvram
(
device_to
=
patch_model_to
,
lowvram_model_memory
=
lowvram_model_memory
)
self
.
real_model
=
self
.
model
.
patch_model_lowvram
(
device_to
=
patch_model_to
,
lowvram_model_memory
=
lowvram_model_memory
)
else
:
else
:
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
)
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
,
patch_weights
=
load_weights
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model_unload
()
self
.
model_unload
()
...
@@ -302,11 +305,13 @@ class LoadedModel:
...
@@ -302,11 +305,13 @@ class LoadedModel:
if
is_intel_xpu
()
and
not
args
.
disable_ipex_optimize
:
if
is_intel_xpu
()
and
not
args
.
disable_ipex_optimize
:
self
.
real_model
=
torch
.
xpu
.
optimize
(
self
.
real_model
.
eval
(),
inplace
=
True
,
auto_kernel_selection
=
True
,
graph_mode
=
True
)
self
.
real_model
=
torch
.
xpu
.
optimize
(
self
.
real_model
.
eval
(),
inplace
=
True
,
auto_kernel_selection
=
True
,
graph_mode
=
True
)
self
.
weights_loaded
=
True
return
self
.
real_model
return
self
.
real_model
def
model_unload
(
self
):
def
model_unload
(
self
,
unpatch_weights
=
True
):
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
)
self
.
model
.
unpatch_model
(
self
.
model
.
offload_device
,
unpatch_weights
=
unpatch_weights
)
self
.
model
.
model_patches_to
(
self
.
model
.
offload_device
)
self
.
model
.
model_patches_to
(
self
.
model
.
offload_device
)
self
.
weights_loaded
=
self
.
weights_loaded
and
not
unpatch_weights
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
self
.
model
is
other
.
model
return
self
.
model
is
other
.
model
...
@@ -314,15 +319,35 @@ class LoadedModel:
...
@@ -314,15 +319,35 @@ class LoadedModel:
def
minimum_inference_memory
():
def
minimum_inference_memory
():
return
(
1024
*
1024
*
1024
)
return
(
1024
*
1024
*
1024
)
def
unload_model_clones
(
model
):
def
unload_model_clones
(
loaded_model
,
unload_weights_only
=
True
):
model
=
loaded_model
.
model
to_unload
=
[]
to_unload
=
[]
for
i
in
range
(
len
(
current_loaded_models
)):
for
i
in
range
(
len
(
current_loaded_models
)):
if
model
.
is_clone
(
current_loaded_models
[
i
].
model
):
if
model
.
is_clone
(
current_loaded_models
[
i
].
model
):
to_unload
=
[
i
]
+
to_unload
to_unload
=
[
i
]
+
to_unload
if
len
(
to_unload
)
==
0
:
return
same_weights
=
0
for
i
in
to_unload
:
for
i
in
to_unload
:
logging
.
debug
(
"unload clone {}"
.
format
(
i
))
if
model
.
clone_has_same_weights
(
current_loaded_models
[
i
].
model
):
current_loaded_models
.
pop
(
i
).
model_unload
()
same_weights
+=
1
if
same_weights
==
len
(
to_unload
):
unload_weight
=
False
else
:
unload_weight
=
True
if
unload_weights_only
and
unload_weight
==
False
:
return
for
i
in
to_unload
:
logging
.
debug
(
"unload clone {} {}"
.
format
(
i
,
unload_weight
))
current_loaded_models
.
pop
(
i
).
model_unload
(
unpatch_weights
=
unload_weight
)
loaded_model
.
weights_loaded
=
not
unload_weight
def
free_memory
(
memory_required
,
device
,
keep_loaded
=
[]):
def
free_memory
(
memory_required
,
device
,
keep_loaded
=
[]):
unloaded_model
=
False
unloaded_model
=
False
...
@@ -377,13 +402,16 @@ def load_models_gpu(models, memory_required=0):
...
@@ -377,13 +402,16 @@ def load_models_gpu(models, memory_required=0):
total_memory_required
=
{}
total_memory_required
=
{}
for
loaded_model
in
models_to_load
:
for
loaded_model
in
models_to_load
:
unload_model_clones
(
loaded_model
.
model
)
unload_model_clones
(
loaded_model
,
unload_weights_only
=
True
)
#unload clones where the weights are different
total_memory_required
[
loaded_model
.
device
]
=
total_memory_required
.
get
(
loaded_model
.
device
,
0
)
+
loaded_model
.
model_memory_required
(
loaded_model
.
device
)
total_memory_required
[
loaded_model
.
device
]
=
total_memory_required
.
get
(
loaded_model
.
device
,
0
)
+
loaded_model
.
model_memory_required
(
loaded_model
.
device
)
for
device
in
total_memory_required
:
for
device
in
total_memory_required
:
if
device
!=
torch
.
device
(
"cpu"
):
if
device
!=
torch
.
device
(
"cpu"
):
free_memory
(
total_memory_required
[
device
]
*
1.3
+
extra_mem
,
device
,
models_already_loaded
)
free_memory
(
total_memory_required
[
device
]
*
1.3
+
extra_mem
,
device
,
models_already_loaded
)
for
loaded_model
in
models_to_load
:
unload_model_clones
(
loaded_model
,
unload_weights_only
=
False
)
#unload the rest of the clones where the weights can stay loaded
for
loaded_model
in
models_to_load
:
for
loaded_model
in
models_to_load
:
model
=
loaded_model
.
model
model
=
loaded_model
.
model
torch_dev
=
model
.
load_device
torch_dev
=
model
.
load_device
...
...
comfy/model_patcher.py
View file @
c18a203a
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
copy
import
copy
import
inspect
import
inspect
import
logging
import
logging
import
uuid
import
comfy.utils
import
comfy.utils
import
comfy.model_management
import
comfy.model_management
...
@@ -25,6 +26,7 @@ class ModelPatcher:
...
@@ -25,6 +26,7 @@ class ModelPatcher:
self
.
weight_inplace_update
=
weight_inplace_update
self
.
weight_inplace_update
=
weight_inplace_update
self
.
model_lowvram
=
False
self
.
model_lowvram
=
False
self
.
patches_uuid
=
uuid
.
uuid4
()
def
model_size
(
self
):
def
model_size
(
self
):
if
self
.
size
>
0
:
if
self
.
size
>
0
:
...
@@ -39,10 +41,13 @@ class ModelPatcher:
...
@@ -39,10 +41,13 @@ class ModelPatcher:
n
.
patches
=
{}
n
.
patches
=
{}
for
k
in
self
.
patches
:
for
k
in
self
.
patches
:
n
.
patches
[
k
]
=
self
.
patches
[
k
][:]
n
.
patches
[
k
]
=
self
.
patches
[
k
][:]
n
.
patches_uuid
=
self
.
patches_uuid
n
.
object_patches
=
self
.
object_patches
.
copy
()
n
.
object_patches
=
self
.
object_patches
.
copy
()
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_keys
=
self
.
model_keys
n
.
model_keys
=
self
.
model_keys
n
.
backup
=
self
.
backup
n
.
object_patches_backup
=
self
.
object_patches_backup
return
n
return
n
def
is_clone
(
self
,
other
):
def
is_clone
(
self
,
other
):
...
@@ -50,6 +55,19 @@ class ModelPatcher:
...
@@ -50,6 +55,19 @@ class ModelPatcher:
return
True
return
True
return
False
return
False
def
clone_has_same_weights
(
self
,
clone
):
if
not
self
.
is_clone
(
clone
):
return
False
if
len
(
self
.
patches
)
==
0
and
len
(
clone
.
patches
)
==
0
:
return
True
if
self
.
patches_uuid
==
clone
.
patches_uuid
:
if
len
(
self
.
patches
)
!=
len
(
clone
.
patches
):
logging
.
warning
(
"WARNING: something went wrong, same patch uuid but different length of patches."
)
else
:
return
True
def
memory_required
(
self
,
input_shape
):
def
memory_required
(
self
,
input_shape
):
return
self
.
model
.
memory_required
(
input_shape
=
input_shape
)
return
self
.
model
.
memory_required
(
input_shape
=
input_shape
)
...
@@ -154,6 +172,7 @@ class ModelPatcher:
...
@@ -154,6 +172,7 @@ class ModelPatcher:
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
))
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
))
self
.
patches
[
k
]
=
current_patches
self
.
patches
[
k
]
=
current_patches
self
.
patches_uuid
=
uuid
.
uuid4
()
return
list
(
p
)
return
list
(
p
)
def
get_key_patches
(
self
,
filter_prefix
=
None
):
def
get_key_patches
(
self
,
filter_prefix
=
None
):
...
@@ -387,31 +406,32 @@ class ModelPatcher:
...
@@ -387,31 +406,32 @@ class ModelPatcher:
return
weight
return
weight
def
unpatch_model
(
self
,
device_to
=
None
):
def
unpatch_model
(
self
,
device_to
=
None
,
unpatch_weights
=
True
):
if
self
.
model_lowvram
:
if
unpatch_weights
:
for
m
in
self
.
model
.
modules
():
if
self
.
model_lowvram
:
if
hasattr
(
m
,
"prev_comfy_cast_weights"
):
for
m
in
self
.
model
.
modules
():
m
.
comfy_cast_weights
=
m
.
prev_comfy_cast_weights
if
hasattr
(
m
,
"prev_comfy_cast_weights"
):
del
m
.
prev_comfy_cast_weights
m
.
comfy_cast_weights
=
m
.
prev_comfy_cast_weights
m
.
weight_function
=
None
del
m
.
prev_comfy_cast_weights
m
.
bias_function
=
None
m
.
weight_function
=
None
m
.
bias_function
=
None
self
.
model_lowvram
=
False
self
.
model_lowvram
=
False
keys
=
list
(
self
.
backup
.
keys
())
keys
=
list
(
self
.
backup
.
keys
())
if
self
.
weight_inplace_update
:
if
self
.
weight_inplace_update
:
for
k
in
keys
:
for
k
in
keys
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
comfy
.
utils
.
copy_to_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
else
:
else
:
for
k
in
keys
:
for
k
in
keys
:
comfy
.
utils
.
set_attr_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
comfy
.
utils
.
set_attr_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
self
.
backup
=
{}
self
.
backup
.
clear
()
if
device_to
is
not
None
:
if
device_to
is
not
None
:
self
.
model
.
to
(
device_to
)
self
.
model
.
to
(
device_to
)
self
.
current_device
=
device_to
self
.
current_device
=
device_to
keys
=
list
(
self
.
object_patches_backup
.
keys
())
keys
=
list
(
self
.
object_patches_backup
.
keys
())
for
k
in
keys
:
for
k
in
keys
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment