Commit a7216c73 authored by Jennifer's avatar Jennifer
Browse files

changes reloading weights to account for possible state_dict headers.

parent 711ba0ab
...@@ -295,8 +295,13 @@ def main(args): ...@@ -295,8 +295,13 @@ def main(args):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else: else:
sd = torch.load(args.resume_from_ckpt) sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()} if 'module' in sd:
import_openfold_weights_(model=model_module, state_dict=sd) module_sd = {k[len("module."):]:v for k,v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=module_sd)
elif 'state_dict' in sd:
import_openfold_weights_(model=model_module, state_dict=sd['state_dict'])
else:
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params): if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment