Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e1489ad2
"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "876dadca840f305fafb6ed167d5dc51329fb4083"
Commit
e1489ad2
authored
May 11, 2024
by
comfyanonymous
Browse files
Fix issue with lowvram mode breaking model saving.
parent
4f63ee99
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
9 deletions
+15
-9
comfy/model_management.py
comfy/model_management.py
+4
-4
comfy/model_patcher.py
comfy/model_patcher.py
+9
-3
comfy/sd.py
comfy/sd.py
+1
-1
comfy_extras/nodes_model_merging.py
comfy_extras/nodes_model_merging.py
+1
-1
No files found.
comfy/model_management.py
View file @
e1489ad2
...
@@ -285,7 +285,7 @@ class LoadedModel:
...
@@ -285,7 +285,7 @@ class LoadedModel:
else
:
else
:
return
self
.
model_memory
()
return
self
.
model_memory
()
def
model_load
(
self
,
lowvram_model_memory
=
0
):
def
model_load
(
self
,
lowvram_model_memory
=
0
,
force_patch_weights
=
False
):
patch_model_to
=
self
.
device
patch_model_to
=
self
.
device
self
.
model
.
model_patches_to
(
self
.
device
)
self
.
model
.
model_patches_to
(
self
.
device
)
...
@@ -295,7 +295,7 @@ class LoadedModel:
...
@@ -295,7 +295,7 @@ class LoadedModel:
try
:
try
:
if
lowvram_model_memory
>
0
and
load_weights
:
if
lowvram_model_memory
>
0
and
load_weights
:
self
.
real_model
=
self
.
model
.
patch_model_lowvram
(
device_to
=
patch_model_to
,
lowvram_model_memory
=
lowvram_model_memory
)
self
.
real_model
=
self
.
model
.
patch_model_lowvram
(
device_to
=
patch_model_to
,
lowvram_model_memory
=
lowvram_model_memory
,
force_patch_weights
=
force_patch_weights
)
else
:
else
:
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
,
patch_weights
=
load_weights
)
self
.
real_model
=
self
.
model
.
patch_model
(
device_to
=
patch_model_to
,
patch_weights
=
load_weights
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -379,7 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
...
@@ -379,7 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if
mem_free_torch
>
mem_free_total
*
0.25
:
if
mem_free_torch
>
mem_free_total
*
0.25
:
soft_empty_cache
()
soft_empty_cache
()
def
load_models_gpu
(
models
,
memory_required
=
0
):
def
load_models_gpu
(
models
,
memory_required
=
0
,
force_patch_weights
=
False
):
global
vram_state
global
vram_state
inference_memory
=
minimum_inference_memory
()
inference_memory
=
minimum_inference_memory
()
...
@@ -444,7 +444,7 @@ def load_models_gpu(models, memory_required=0):
...
@@ -444,7 +444,7 @@ def load_models_gpu(models, memory_required=0):
if
vram_set_state
==
VRAMState
.
NO_VRAM
:
if
vram_set_state
==
VRAMState
.
NO_VRAM
:
lowvram_model_memory
=
64
*
1024
*
1024
lowvram_model_memory
=
64
*
1024
*
1024
cur_loaded_model
=
loaded_model
.
model_load
(
lowvram_model_memory
)
cur_loaded_model
=
loaded_model
.
model_load
(
lowvram_model_memory
,
force_patch_weights
=
force_patch_weights
)
current_loaded_models
.
insert
(
0
,
loaded_model
)
current_loaded_models
.
insert
(
0
,
loaded_model
)
return
return
...
...
comfy/model_patcher.py
View file @
e1489ad2
...
@@ -272,7 +272,7 @@ class ModelPatcher:
...
@@ -272,7 +272,7 @@ class ModelPatcher:
return
self
.
model
return
self
.
model
def
patch_model_lowvram
(
self
,
device_to
=
None
,
lowvram_model_memory
=
0
):
def
patch_model_lowvram
(
self
,
device_to
=
None
,
lowvram_model_memory
=
0
,
force_patch_weights
=
False
):
self
.
patch_model
(
device_to
,
patch_weights
=
False
)
self
.
patch_model
(
device_to
,
patch_weights
=
False
)
logging
.
info
(
"loading in lowvram mode {}"
.
format
(
lowvram_model_memory
/
(
1024
*
1024
)))
logging
.
info
(
"loading in lowvram mode {}"
.
format
(
lowvram_model_memory
/
(
1024
*
1024
)))
...
@@ -296,9 +296,15 @@ class ModelPatcher:
...
@@ -296,9 +296,15 @@ class ModelPatcher:
if
lowvram_weight
:
if
lowvram_weight
:
if
weight_key
in
self
.
patches
:
if
weight_key
in
self
.
patches
:
m
.
weight_function
=
LowVramPatch
(
weight_key
,
self
)
if
force_patch_weights
:
self
.
patch_weight_to_device
(
weight_key
)
else
:
m
.
weight_function
=
LowVramPatch
(
weight_key
,
self
)
if
bias_key
in
self
.
patches
:
if
bias_key
in
self
.
patches
:
m
.
bias_function
=
LowVramPatch
(
bias_key
,
self
)
if
force_patch_weights
:
self
.
patch_weight_to_device
(
bias_key
)
else
:
m
.
bias_function
=
LowVramPatch
(
bias_key
,
self
)
m
.
prev_comfy_cast_weights
=
m
.
comfy_cast_weights
m
.
prev_comfy_cast_weights
=
m
.
comfy_cast_weights
m
.
comfy_cast_weights
=
True
m
.
comfy_cast_weights
=
True
...
...
comfy/sd.py
View file @
e1489ad2
...
@@ -562,7 +562,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
...
@@ -562,7 +562,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
load_models
.
append
(
clip
.
load_model
())
load_models
.
append
(
clip
.
load_model
())
clip_sd
=
clip
.
get_sd
()
clip_sd
=
clip
.
get_sd
()
model_management
.
load_models_gpu
(
load_models
)
model_management
.
load_models_gpu
(
load_models
,
force_patch_weights
=
True
)
clip_vision_sd
=
clip_vision
.
get_sd
()
if
clip_vision
is
not
None
else
None
clip_vision_sd
=
clip_vision
.
get_sd
()
if
clip_vision
is
not
None
else
None
sd
=
model
.
model
.
state_dict_for_saving
(
clip_sd
,
vae
.
get_sd
(),
clip_vision_sd
)
sd
=
model
.
model
.
state_dict_for_saving
(
clip_sd
,
vae
.
get_sd
(),
clip_vision_sd
)
for
k
in
extra_keys
:
for
k
in
extra_keys
:
...
...
comfy_extras/nodes_model_merging.py
View file @
e1489ad2
...
@@ -262,7 +262,7 @@ class CLIPSave:
...
@@ -262,7 +262,7 @@ class CLIPSave:
for
x
in
extra_pnginfo
:
for
x
in
extra_pnginfo
:
metadata
[
x
]
=
json
.
dumps
(
extra_pnginfo
[
x
])
metadata
[
x
]
=
json
.
dumps
(
extra_pnginfo
[
x
])
comfy
.
model_management
.
load_models_gpu
([
clip
.
load_model
()])
comfy
.
model_management
.
load_models_gpu
([
clip
.
load_model
()]
,
force_patch_weights
=
True
)
clip_sd
=
clip
.
get_sd
()
clip_sd
=
clip
.
get_sd
()
for
prefix
in
[
"clip_l."
,
"clip_g."
,
""
]:
for
prefix
in
[
"clip_l."
,
"clip_g."
,
""
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment