Unverified Commit b2d6bff6 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #263 from l-bick/load_pretrained_jax_weights

Load pretrained jax weights
parents 700fe8fe 0c4a93f7
...@@ -37,6 +37,9 @@ from openfold.utils.validation_metrics import ( ...@@ -37,6 +37,9 @@ from openfold.utils.validation_metrics import (
gdt_ts, gdt_ts,
gdt_ha, gdt_ha,
) )
from openfold.utils.import_weights import (
import_jax_weights_,
)
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint, get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint get_global_step_from_zero_checkpoint
...@@ -241,6 +244,17 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -241,6 +244,17 @@ class OpenFoldWrapper(pl.LightningModule):
def resume_last_lr_step(self, lr_step): def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step self.last_lr_step = lr_step
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_(
self.model, jax_path, version=model_version
)
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
...@@ -269,6 +283,9 @@ def main(args): ...@@ -269,6 +283,9 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model # TorchScript components of the model
if(args.script_modules): if(args.script_modules):
...@@ -476,6 +493,10 @@ if __name__ == "__main__": ...@@ -476,6 +493,10 @@ if __name__ == "__main__":
"--resume_model_weights_only", type=bool_type, default=False, "--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state" help="Whether to load just model weights as opposed to training state"
) )
parser.add_argument(
"--resume_from_jax_params", type=str, default=None,
help="""Path to an .npz JAX parameter file with which to initialize the model"""
)
parser.add_argument( parser.add_argument(
"--log_performance", type=bool_type, default=False, "--log_performance", type=bool_type, default=False,
help="Measure performance" help="Measure performance"
...@@ -570,6 +591,9 @@ if __name__ == "__main__": ...@@ -570,6 +591,9 @@ if __name__ == "__main__":
if(str(args.precision) == "16" and args.deepspeed_config_path is not None): if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible") raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch # This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1 args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment