Commit 1921ac99 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add support for public checkpoints

parent d6b36a80
......@@ -189,11 +189,18 @@ def main(args):
args.openfold_checkpoint_path,
ckpt_path,
)
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
ckpt_path = args.openfold_checkpoint_path
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
d = torch.load(ckpt_path)
if("ema" in d):
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
else:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment