Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
5e1fced6
Commit
5e1fced6
authored
Jul 11, 2024
by
comfyanonymous
Browse files
Cleaner support for loading different diffusion model types.
parent
ffe0bb0a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
6 deletions
+7
-6
comfy/model_detection.py
comfy/model_detection.py
+5
-0
comfy/sd.py
comfy/sd.py
+2
-6
No files found.
comfy/model_detection.py
View file @
5e1fced6
...
...
@@ -105,6 +105,9 @@ def detect_unet_config(state_dict, key_prefix):
unet_config
[
"audio_model"
]
=
"dit1.0"
return
unet_config
if
'{}input_blocks.0.0.weight'
.
format
(
key_prefix
)
not
in
state_dict_keys
:
return
None
unet_config
=
{
"use_checkpoint"
:
False
,
"image_size"
:
32
,
...
...
@@ -239,6 +242,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
def
model_config_from_unet
(
state_dict
,
unet_key_prefix
,
use_base_if_no_match
=
False
):
unet_config
=
detect_unet_config
(
state_dict
,
unet_key_prefix
)
if
unet_config
is
None
:
return
None
model_config
=
model_config_from_unet_config
(
unet_config
,
state_dict
)
if
model_config
is
None
and
use_base_if_no_match
:
return
comfy
.
supported_models_base
.
BASE
(
unet_config
)
...
...
comfy/sd.py
View file @
5e1fced6
...
...
@@ -546,21 +546,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def
load_unet_state_dict
(
sd
):
#load unet in diffusers or regular format
#Allow loading unets from checkpoint files
checkpoint
=
False
diffusion_model_prefix
=
model_detection
.
unet_prefix_from_state_dict
(
sd
)
temp_sd
=
comfy
.
utils
.
state_dict_prefix_replace
(
sd
,
{
diffusion_model_prefix
:
""
},
filter_keys
=
True
)
if
len
(
temp_sd
)
>
0
:
sd
=
temp_sd
checkpoint
=
True
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
load_device
=
model_management
.
get_torch_device
()
if
checkpoint
or
"input_blocks.0.0.weight"
in
sd
or
'clf.1.weight'
in
sd
:
#ldm or stable cascade
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
)
if
model_config
is
None
:
return
None
if
model_config
is
not
None
:
new_sd
=
sd
elif
'transformer_blocks.0.attn.add_q_proj.weight'
in
sd
:
#MMDIT SD3
new_sd
=
model_detection
.
convert_diffusers_mmdit
(
sd
,
""
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment