Commit 3711b31d authored by comfyanonymous's avatar comfyanonymous
Browse files

Support Stable Cascade in checkpoint format.

parent d91f45ef
...@@ -319,6 +319,10 @@ class Stable_Cascade_C(supported_models_base.BASE): ...@@ -319,6 +319,10 @@ class Stable_Cascade_C(supported_models_base.BASE):
"shift": 2.0, "shift": 2.0,
} }
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoder."]
clip_vision_prefix = "clip_l_vision."
def process_unet_state_dict(self, state_dict): def process_unet_state_dict(self, state_dict):
key_list = list(state_dict.keys()) key_list = list(state_dict.keys())
for y in ["weight", "bias"]: for y in ["weight", "bias"]:
...@@ -355,6 +359,8 @@ class Stable_Cascade_B(Stable_Cascade_C): ...@@ -355,6 +359,8 @@ class Stable_Cascade_B(Stable_Cascade_C):
"shift": 1.0, "shift": 1.0,
} }
clip_vision_prefix = None
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_B(self, device=device) out = model_base.StableCascade_B(self, device=device)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment