"...git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "9119e2ac391390e92fa99a88ff559c43b1ff612b"
Commit 5e1fced6 authored by comfyanonymous's avatar comfyanonymous
Browse files

Cleaner support for loading different diffusion model types.

parent ffe0bb0a
...@@ -105,6 +105,9 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -105,6 +105,9 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["audio_model"] = "dit1.0" unet_config["audio_model"] = "dit1.0"
return unet_config return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
unet_config = { unet_config = {
"use_checkpoint": False, "use_checkpoint": False,
"image_size": 32, "image_size": 32,
...@@ -239,6 +242,8 @@ def model_config_from_unet_config(unet_config, state_dict=None): ...@@ -239,6 +242,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix) unet_config = detect_unet_config(state_dict, unet_key_prefix)
if unet_config is None:
return None
model_config = model_config_from_unet_config(unet_config, state_dict) model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config) return comfy.supported_models_base.BASE(unet_config)
......
...@@ -546,21 +546,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o ...@@ -546,21 +546,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): #load unet in diffusers or regular format def load_unet_state_dict(sd): #load unet in diffusers or regular format
#Allow loading unets from checkpoint files #Allow loading unets from checkpoint files
checkpoint = False
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0: if len(temp_sd) > 0:
sd = temp_sd sd = temp_sd
checkpoint = True
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade if model_config is not None:
model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None
new_sd = sd new_sd = sd
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
new_sd = model_detection.convert_diffusers_mmdit(sd, "") new_sd = model_detection.convert_diffusers_mmdit(sd, "")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment