Commit acdc6f42 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix loading some malformed checkpoints?

parent 051f472e
...@@ -17,7 +17,10 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): ...@@ -17,7 +17,10 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment