Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
30abc324
"models/vscode:/vscode.git/clone" did not exist on "e779b250e1c253e7db7f379744a76d1f66fe63c8"
Commit
30abc324
authored
Apr 08, 2024
by
comfyanonymous
Browse files
Support properly saving CosXL checkpoints.
parent
d644b6bc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
2 deletions
+14
-2
comfy/sd.py
comfy/sd.py
+4
-1
comfy_extras/nodes_model_merging.py
comfy_extras/nodes_model_merging.py
+10
-1
No files found.
comfy/sd.py
View file @
30abc324
...
@@ -600,7 +600,7 @@ def load_unet(unet_path):
...
@@ -600,7 +600,7 @@ def load_unet(unet_path):
raise
RuntimeError
(
"ERROR: Could not detect model type of: {}"
.
format
(
unet_path
))
raise
RuntimeError
(
"ERROR: Could not detect model type of: {}"
.
format
(
unet_path
))
return
model
return
model
def
save_checkpoint
(
output_path
,
model
,
clip
=
None
,
vae
=
None
,
clip_vision
=
None
,
metadata
=
None
):
def
save_checkpoint
(
output_path
,
model
,
clip
=
None
,
vae
=
None
,
clip_vision
=
None
,
metadata
=
None
,
extra_keys
=
{}
):
clip_sd
=
None
clip_sd
=
None
load_models
=
[
model
]
load_models
=
[
model
]
if
clip
is
not
None
:
if
clip
is
not
None
:
...
@@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
...
@@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
model_management
.
load_models_gpu
(
load_models
)
model_management
.
load_models_gpu
(
load_models
)
clip_vision_sd
=
clip_vision
.
get_sd
()
if
clip_vision
is
not
None
else
None
clip_vision_sd
=
clip_vision
.
get_sd
()
if
clip_vision
is
not
None
else
None
sd
=
model
.
model
.
state_dict_for_saving
(
clip_sd
,
vae
.
get_sd
(),
clip_vision_sd
)
sd
=
model
.
model
.
state_dict_for_saving
(
clip_sd
,
vae
.
get_sd
(),
clip_vision_sd
)
for
k
in
extra_keys
:
sd
[
k
]
=
extra_keys
[
k
]
comfy
.
utils
.
save_torch_file
(
sd
,
output_path
,
metadata
=
metadata
)
comfy
.
utils
.
save_torch_file
(
sd
,
output_path
,
metadata
=
metadata
)
comfy_extras/nodes_model_merging.py
View file @
30abc324
...
@@ -2,7 +2,9 @@ import comfy.sd
...
@@ -2,7 +2,9 @@ import comfy.sd
import
comfy.utils
import
comfy.utils
import
comfy.model_base
import
comfy.model_base
import
comfy.model_management
import
comfy.model_management
import
comfy.model_sampling
import
torch
import
folder_paths
import
folder_paths
import
json
import
json
import
os
import
os
...
@@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
...
@@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
# "v2-inpainting"
extra_keys
=
{}
model_sampling
=
model
.
get_model_object
(
"model_sampling"
)
if
isinstance
(
model_sampling
,
comfy
.
model_sampling
.
ModelSamplingContinuousEDM
):
if
isinstance
(
model_sampling
,
comfy
.
model_sampling
.
V_PREDICTION
):
extra_keys
[
"edm_vpred.sigma_max"
]
=
torch
.
tensor
(
model_sampling
.
sigma_max
).
float
()
extra_keys
[
"edm_vpred.sigma_min"
]
=
torch
.
tensor
(
model_sampling
.
sigma_min
).
float
()
if
model
.
model
.
model_type
==
comfy
.
model_base
.
ModelType
.
EPS
:
if
model
.
model
.
model_type
==
comfy
.
model_base
.
ModelType
.
EPS
:
metadata
[
"modelspec.predict_key"
]
=
"epsilon"
metadata
[
"modelspec.predict_key"
]
=
"epsilon"
elif
model
.
model
.
model_type
==
comfy
.
model_base
.
ModelType
.
V_PREDICTION
:
elif
model
.
model
.
model_type
==
comfy
.
model_base
.
ModelType
.
V_PREDICTION
:
...
@@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
...
@@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
output_checkpoint
=
f
"
{
filename
}
_
{
counter
:
05
}
_.safetensors"
output_checkpoint
=
f
"
{
filename
}
_
{
counter
:
05
}
_.safetensors"
output_checkpoint
=
os
.
path
.
join
(
full_output_folder
,
output_checkpoint
)
output_checkpoint
=
os
.
path
.
join
(
full_output_folder
,
output_checkpoint
)
comfy
.
sd
.
save_checkpoint
(
output_checkpoint
,
model
,
clip
,
vae
,
clip_vision
,
metadata
=
metadata
)
comfy
.
sd
.
save_checkpoint
(
output_checkpoint
,
model
,
clip
,
vae
,
clip_vision
,
metadata
=
metadata
,
extra_keys
=
extra_keys
)
class
CheckpointSave
:
class
CheckpointSave
:
def
__init__
(
self
):
def
__init__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment