Commit 4c9d372d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add training callbacks

parent 727e68c2
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
class EarlyStoppingVerbose(EarlyStopping):
"""
The default EarlyStopping callback's verbose mode is too verbose.
This class outputs a message only when it's getting ready to stop.
"""
def _evalute_stopping_criteria(self, *args):
should_stop, reason = super()._evalute_stopping_criteria(*args)
if(should_stop):
rank_zero_info(f"{reason}\n")
return should_stop, reason
......@@ -2,7 +2,7 @@ import argparse
import logging
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "6"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
......@@ -13,6 +13,7 @@ import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch
......@@ -23,6 +24,9 @@ from openfold.data.data_modules import (
DummyDataLoader,
)
from openfold.model.model import AlphaFold
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.seed import seed_everything
......@@ -88,6 +92,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
def main(args):
if(args.seed is not None):
......@@ -108,7 +115,29 @@ def main(args):
)
data_module.prepare_data()
data_module.setup()
callbacks = []
if(args.checkpoint_best_val):
checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
mc = ModelCheckpoint(
dirpath=checkpoint_dir,
filename="openfold_{epoch}_{step}_{val_loss:.2f}",
monitor="val_loss",
)
callbacks.append(mc)
if(args.early_stopping):
es = EarlyStoppingVerbose(
monitor="val_loss",
min_delta=args.min_delta,
patience=args.patience,
verbose=False,
mode="min",
check_finite=True,
strict=True,
)
callbacks.append(es)
plugins = []
if(args.deepspeed_config_path is not None):
plugins.append(DeepSpeedPlugin(config=args.deepspeed_config_path))
......@@ -119,6 +148,7 @@ def main(args):
)
trainer.fit(model_module, datamodule=data_module)
trainer.save_checkpoint("final.ckpt")
if __name__ == "__main__":
......@@ -135,10 +165,15 @@ if __name__ == "__main__":
"template_mmcif_dir", type=str,
help="Directory containing mmCIF files to search for templates"
)
parser.add_argument(
"output_dir", type=str,
help='''Directory in which to output checkpoints, logs, etc. Ignored
if not on rank 0'''
)
parser.add_argument(
"max_template_date", type=str,
help="""Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target"""
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser.add_argument(
"--distillation_data_dir", type=str, default=None,
......@@ -162,9 +197,9 @@ if __name__ == "__main__":
)
parser.add_argument(
"--train_mapping_path", type=str, default=None,
help="""Optional path to a .json file containing a mapping from
help='''Optional path to a .json file containing a mapping from
consecutive numerical indices to sample names. Used to filter
the training set"""
the training set'''
)
parser.add_argument(
"--distillation_mapping_path", type=str, default=None,
......@@ -187,6 +222,24 @@ if __name__ == "__main__":
"--deepspeed_config_path", type=str, default=None,
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser.add_argument(
"--checkpoint_best_val", type=int, default=True,
help="""Whether to save the model parameters that perform best during
validation"""
)
parser.add_argument(
"--early_stopping", type=bool, default=False,
help="Whether to stop training when validation loss fails to decrease"
)
parser.add_argument(
"--min_delta", type=float, default=0,
help="""The smallest decrease in validation loss that counts as an
improvement for the purposes of early stopping"""
)
parser.add_argument(
"--patience", type=int, default=3,
help="Early stopping patience"
)
parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults(
......@@ -195,4 +248,9 @@ if __name__ == "__main__":
args = parser.parse_args()
if(args.seed is None and
((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment