"...resnet50_tensorflow.git" did not exist on "675f4de323cc66a049d3b1c70c9c8afb75215148"
Commit 9f4214e5 authored by comfyanonymous's avatar comfyanonymous
Browse files

Preparing to add another function to load checkpoints.

parent 3cd7d84b
......@@ -26,12 +26,7 @@ def load_torch_file(ckpt):
sd = pl_sd
return sd
def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
print(f"Loading model from {ckpt}")
sd = load_torch_file(ckpt)
model = instantiate_from_config(config.model)
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False)
k = list(sd.keys())
......@@ -654,5 +649,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
model = instantiate_from_config(config.model)
sd = load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
return (ModelPatcher(model), clip, vae)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment