Commit 2b08407d authored by Lucas Bickmann's avatar Lucas Bickmann
Browse files

Added support for Jax-parameter loading to train_openfold.py

parent 2ef7893a
...@@ -37,6 +37,9 @@ from openfold.utils.validation_metrics import ( ...@@ -37,6 +37,9 @@ from openfold.utils.validation_metrics import (
gdt_ts, gdt_ts,
gdt_ha, gdt_ha,
) )
from openfold.utils.import_weights import (
import_jax_weights_,
)
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint, get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint get_global_step_from_zero_checkpoint
...@@ -241,6 +244,17 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -241,6 +244,17 @@ class OpenFoldWrapper(pl.LightningModule):
def resume_last_lr_step(self, lr_step): def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step self.last_lr_step = lr_step
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_(
self.model, jax_path, version=model_version
)
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
...@@ -269,6 +283,9 @@ def main(args): ...@@ -269,6 +283,9 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.jax_param_path):
model_module.load_from_jax(args.jax_param_path)
logging.info(f"Successfully loaded JAX parameters at {args.jax_param_path}...")
# TorchScript components of the model # TorchScript components of the model
if(args.script_modules): if(args.script_modules):
...@@ -531,6 +548,12 @@ if __name__ == "__main__": ...@@ -531,6 +548,12 @@ if __name__ == "__main__":
'used.' 'used.'
) )
) )
parser.add_argument(
"--jax_param_path", type=str, default=None,
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser.add_argument( parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None, "--_distillation_structure_index_path", type=str, default=None,
) )
...@@ -570,6 +593,9 @@ if __name__ == "__main__": ...@@ -570,6 +593,9 @@ if __name__ == "__main__":
if(str(args.precision) == "16" and args.deepspeed_config_path is not None): if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible") raise ValueError("DeepSpeed and FP16 training are not compatible")
if(str(args.jax_param_path) is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch # This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1 args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment