Commit 1921ac99 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add support for public checkpoints

parent d6b36a80
...@@ -189,11 +189,18 @@ def main(args): ...@@ -189,11 +189,18 @@ def main(args):
args.openfold_checkpoint_path, args.openfold_checkpoint_path,
ckpt_path, ckpt_path,
) )
else:
ckpt_path = args.openfold_checkpoint_path
d = torch.load(ckpt_path) d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"]) model.load_state_dict(d["ema"]["params"])
else:
ckpt_path = args.openfold_checkpoint_path
d = torch.load(ckpt_path)
if("ema" in d):
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
else: else:
raise ValueError( raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must " "At least one of jax_param_path or openfold_checkpoint_path must "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment