Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
import argparse import argparse
import logging import logging
import os import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
import random import random
import sys import sys
import time import time
...@@ -43,8 +37,12 @@ from openfold.utils.validation_metrics import ( ...@@ -43,8 +37,12 @@ from openfold.utils.validation_metrics import (
gdt_ts, gdt_ts,
gdt_ha, gdt_ha,
) )
from openfold.utils.import_weights import (
import_jax_weights_,
)
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
) )
from openfold.utils.logger import PerformanceLoggingCallback from openfold.utils.logger import PerformanceLoggingCallback
...@@ -61,7 +59,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -61,7 +59,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
self.cached_weights = None self.cached_weights = None
self.last_lr_step = 0 self.last_lr_step = -1
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
...@@ -102,7 +100,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -102,7 +100,7 @@ class OpenFoldWrapper(pl.LightningModule):
# Run the model # Run the model
outputs = self(batch) outputs = self(batch)
# Remove the recycling dimension # Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
...@@ -204,12 +202,23 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -204,12 +202,23 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
eps: float = 1e-5, eps: float = 1e-5,
) -> torch.optim.Adam: ) -> torch.optim.Adam:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured # Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
self.model.parameters(), self.model.parameters(),
lr=learning_rate, lr=learning_rate,
eps=eps eps=eps
) )
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler( lr_scheduler = AlphaFoldLRScheduler(
optimizer, optimizer,
) )
...@@ -224,11 +233,28 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -224,11 +233,28 @@ class OpenFoldWrapper(pl.LightningModule):
} }
def on_load_checkpoint(self, checkpoint): def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"]) ema = checkpoint["ema"]
if(not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict() checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_(
self.model, jax_path, version=model_version
)
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
...@@ -237,15 +263,29 @@ def main(args): ...@@ -237,15 +263,29 @@ def main(args):
config = model_config( config = model_config(
args.config_preset, args.config_preset,
train=True, train=True,
low_prec=(args.precision == "16") low_prec=(str(args.precision) == "16")
) )
model_module = OpenFoldWrapper(config) model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only): if(args.resume_from_ckpt and args.resume_model_weights_only):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) if(os.path.isdir(args.resume_from_ckpt)):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model # TorchScript components of the model
if(args.script_modules): if(args.script_modules):
...@@ -397,14 +437,14 @@ if __name__ == "__main__": ...@@ -397,14 +437,14 @@ if __name__ == "__main__":
help="Path to the kalign binary" help="Path to the kalign binary"
) )
parser.add_argument( parser.add_argument(
"--train_mapping_path", type=str, default=None, "--train_filter_path", type=str, default=None,
help='''Optional path to a .json file containing a mapping from help='''Optional path to a text file containing names of training
consecutive numerical indices to sample names. Used to filter examples to include, one per line. Used to filter the training
the training set''' set'''
) )
parser.add_argument( parser.add_argument(
"--distillation_mapping_path", type=str, default=None, "--distillation_filter_path", type=str, default=None,
help="""See --train_mapping_path""" help="""See --train_filter_path"""
) )
parser.add_argument( parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None, "--obsolete_pdbs_file_path", type=str, default=None,
...@@ -453,6 +493,10 @@ if __name__ == "__main__": ...@@ -453,6 +493,10 @@ if __name__ == "__main__":
"--resume_model_weights_only", type=bool_type, default=False, "--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state" help="Whether to load just model weights as opposed to training state"
) )
parser.add_argument(
"--resume_from_jax_params", type=str, default=None,
help="""Path to an .npz JAX parameter file with which to initialize the model"""
)
parser.add_argument( parser.add_argument(
"--log_performance", type=bool_type, default=False, "--log_performance", type=bool_type, default=False,
help="Measure performance" help="Measure performance"
...@@ -512,10 +556,12 @@ if __name__ == "__main__": ...@@ -512,10 +556,12 @@ if __name__ == "__main__":
"--_distillation_structure_index_path", type=str, default=None, "--_distillation_structure_index_path", type=str, default=None,
) )
parser.add_argument( parser.add_argument(
"--_alignment_index_path", type=str, default=None, "--alignment_index_path", type=str, default=None,
help="Training alignment index. See the README for instructions."
) )
parser.add_argument( parser.add_argument(
"--_distillation_alignment_index_path", type=str, default=None, "--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
) )
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
...@@ -542,6 +588,12 @@ if __name__ == "__main__": ...@@ -542,6 +588,12 @@ if __name__ == "__main__":
(args.num_nodes is not None and args.num_nodes > 1))): (args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified") raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch # This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1 args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment