"vscode:/vscode.git/clone" did not exist on "994d0b7e56fb120e3318afe8bd8cfd9a493bd168"
Commit a7216c73 authored by Jennifer's avatar Jennifer
Browse files

changes reloading weights to account for possible state_dict headers.

parent 711ba0ab
...@@ -295,7 +295,12 @@ def main(args): ...@@ -295,7 +295,12 @@ def main(args):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else: else:
sd = torch.load(args.resume_from_ckpt) sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()} if 'module' in sd:
module_sd = {k[len("module."):]:v for k,v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=module_sd)
elif 'state_dict' in sd:
import_openfold_weights_(model=model_module, state_dict=sd['state_dict'])
else:
import_openfold_weights_(model=model_module, state_dict=sd) import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params): if(args.resume_from_jax_params):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment