"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "47293041477869e9ffa2335adf3f552575a7caaf"
Commit 334ba48c authored by comfyanonymous's avatar comfyanonymous
Browse files

More generic unet prefix detection code.

parent 14764aa2
...@@ -261,13 +261,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal ...@@ -261,13 +261,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return model_config return model_config
def unet_prefix_from_state_dict(state_dict): def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models candidates = ["model.diffusion_model.", #ldm/sgm models
unet_key_prefix = "model.model." "model.model.", #audio models
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow ]
unet_key_prefix = "model." counts = {k: 0 for k in candidates}
for k in state_dict:
for c in candidates:
if k.startswith(c):
counts[c] += 1
break
top = max(counts, key=counts.get)
if counts[top] > 5:
return top
else: else:
unet_key_prefix = "model.diffusion_model." return "model." #aura flow and others
return unet_key_prefix
def convert_config(unet_config): def convert_config(unet_config):
new_config = unet_config.copy() new_config = unet_config.copy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment