Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
import argparse
import logging
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
import random
import sys
import time
......@@ -43,8 +37,12 @@ from openfold.utils.validation_metrics import (
gdt_ts,
gdt_ha,
)
from openfold.utils.import_weights import (
import_jax_weights_,
)
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from openfold.utils.logger import PerformanceLoggingCallback
......@@ -61,7 +59,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
self.cached_weights = None
self.last_lr_step = 0
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
......@@ -102,7 +100,7 @@ class OpenFoldWrapper(pl.LightningModule):
# Run the model
outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
......@@ -204,12 +202,23 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate: float = 1e-3,
eps: float = 1e-5,
) -> torch.optim.Adam:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
eps=eps
)
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
)
......@@ -224,11 +233,28 @@ class OpenFoldWrapper(pl.LightningModule):
}
def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"])
ema = checkpoint["ema"]
if(not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_(
self.model, jax_path, version=model_version
)
def main(args):
if(args.seed is not None):
......@@ -237,15 +263,29 @@ def main(args):
config = model_config(
args.config_preset,
train=True,
low_prec=(args.precision == "16")
low_prec=(str(args.precision) == "16")
)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
if(os.path.isdir(args.resume_from_ckpt)):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model
if(args.script_modules):
......@@ -397,14 +437,14 @@ if __name__ == "__main__":
help="Path to the kalign binary"
)
parser.add_argument(
"--train_mapping_path", type=str, default=None,
help='''Optional path to a .json file containing a mapping from
consecutive numerical indices to sample names. Used to filter
the training set'''
"--train_filter_path", type=str, default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set'''
)
parser.add_argument(
"--distillation_mapping_path", type=str, default=None,
help="""See --train_mapping_path"""
"--distillation_filter_path", type=str, default=None,
help="""See --train_filter_path"""
)
parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None,
......@@ -453,6 +493,10 @@ if __name__ == "__main__":
"--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state"
)
parser.add_argument(
"--resume_from_jax_params", type=str, default=None,
help="""Path to an .npz JAX parameter file with which to initialize the model"""
)
parser.add_argument(
"--log_performance", type=bool_type, default=False,
help="Measure performance"
......@@ -512,10 +556,12 @@ if __name__ == "__main__":
"--_distillation_structure_index_path", type=str, default=None,
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
"--alignment_index_path", type=str, default=None,
help="Training alignment index. See the README for instructions."
)
parser.add_argument(
"--_distillation_alignment_index_path", type=str, default=None,
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
)
parser = pl.Trainer.add_argparse_args(parser)
......@@ -542,6 +588,12 @@ if __name__ == "__main__":
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment