Unverified Commit b2d6bff6 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #263 from l-bick/load_pretrained_jax_weights

Load pretrained jax weights
parents 700fe8fe 0c4a93f7
......@@ -37,6 +37,9 @@ from openfold.utils.validation_metrics import (
gdt_ts,
gdt_ha,
)
from openfold.utils.import_weights import (
import_jax_weights_,
)
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
......@@ -241,6 +244,17 @@ class OpenFoldWrapper(pl.LightningModule):
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_(
self.model, jax_path, version=model_version
)
def main(args):
if(args.seed is not None):
......@@ -269,6 +283,9 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model
if(args.script_modules):
......@@ -476,6 +493,10 @@ if __name__ == "__main__":
"--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state"
)
parser.add_argument(
"--resume_from_jax_params", type=str, default=None,
help="""Path to an .npz JAX parameter file with which to initialize the model"""
)
parser.add_argument(
"--log_performance", type=bool_type, default=False,
help="Measure performance"
......@@ -570,6 +591,9 @@ if __name__ == "__main__":
if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment