Commit a7216c73 authored by Jennifer's avatar Jennifer
Browse files

changes reloading weights to account for possible state_dict headers.

parent 711ba0ab
......@@ -295,7 +295,12 @@ def main(args):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()}
if 'module' in sd:
module_sd = {k[len("module."):]:v for k,v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=module_sd)
elif 'state_dict' in sd:
import_openfold_weights_(model=model_module, state_dict=sd['state_dict'])
else:
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment