Commit 4c9d372d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add training callbacks

parent 727e68c2
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
class EarlyStoppingVerbose(EarlyStopping):
"""
The default EarlyStopping callback's verbose mode is too verbose.
This class outputs a message only when it's getting ready to stop.
"""
def _evalute_stopping_criteria(self, *args):
should_stop, reason = super()._evalute_stopping_criteria(*args)
if(should_stop):
rank_zero_info(f"{reason}\n")
return should_stop, reason
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import logging import logging
import os import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "6" os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14" #os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069" #os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0" #os.environ["NODE_RANK"]="0"
...@@ -13,6 +13,7 @@ import time ...@@ -13,6 +13,7 @@ import time
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch import torch
...@@ -23,6 +24,9 @@ from openfold.data.data_modules import ( ...@@ -23,6 +24,9 @@ from openfold.data.data_modules import (
DummyDataLoader, DummyDataLoader,
) )
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
...@@ -88,6 +92,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -88,6 +92,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs): def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model) self.ema.update(self.model)
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
...@@ -108,7 +115,29 @@ def main(args): ...@@ -108,7 +115,29 @@ def main(args):
) )
data_module.prepare_data() data_module.prepare_data()
data_module.setup() data_module.setup()
callbacks = []
if(args.checkpoint_best_val):
checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
mc = ModelCheckpoint(
dirpath=checkpoint_dir,
filename="openfold_{epoch}_{step}_{val_loss:.2f}",
monitor="val_loss",
)
callbacks.append(mc)
if(args.early_stopping):
es = EarlyStoppingVerbose(
monitor="val_loss",
min_delta=args.min_delta,
patience=args.patience,
verbose=False,
mode="min",
check_finite=True,
strict=True,
)
callbacks.append(es)
plugins = [] plugins = []
if(args.deepspeed_config_path is not None): if(args.deepspeed_config_path is not None):
plugins.append(DeepSpeedPlugin(config=args.deepspeed_config_path)) plugins.append(DeepSpeedPlugin(config=args.deepspeed_config_path))
...@@ -119,6 +148,7 @@ def main(args): ...@@ -119,6 +148,7 @@ def main(args):
) )
trainer.fit(model_module, datamodule=data_module) trainer.fit(model_module, datamodule=data_module)
trainer.save_checkpoint("final.ckpt")
if __name__ == "__main__": if __name__ == "__main__":
...@@ -135,10 +165,15 @@ if __name__ == "__main__": ...@@ -135,10 +165,15 @@ if __name__ == "__main__":
"template_mmcif_dir", type=str, "template_mmcif_dir", type=str,
help="Directory containing mmCIF files to search for templates" help="Directory containing mmCIF files to search for templates"
) )
parser.add_argument(
"output_dir", type=str,
help='''Directory in which to output checkpoints, logs, etc. Ignored
if not on rank 0'''
)
parser.add_argument( parser.add_argument(
"max_template_date", type=str, "max_template_date", type=str,
help="""Cutoff for all templates. In training mode, templates are also help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target""" filtered by the release date of the target'''
) )
parser.add_argument( parser.add_argument(
"--distillation_data_dir", type=str, default=None, "--distillation_data_dir", type=str, default=None,
...@@ -162,9 +197,9 @@ if __name__ == "__main__": ...@@ -162,9 +197,9 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--train_mapping_path", type=str, default=None, "--train_mapping_path", type=str, default=None,
help="""Optional path to a .json file containing a mapping from help='''Optional path to a .json file containing a mapping from
consecutive numerical indices to sample names. Used to filter consecutive numerical indices to sample names. Used to filter
the training set""" the training set'''
) )
parser.add_argument( parser.add_argument(
"--distillation_mapping_path", type=str, default=None, "--distillation_mapping_path", type=str, default=None,
...@@ -187,6 +222,24 @@ if __name__ == "__main__": ...@@ -187,6 +222,24 @@ if __name__ == "__main__":
"--deepspeed_config_path", type=str, default=None, "--deepspeed_config_path", type=str, default=None,
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled" help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
) )
parser.add_argument(
"--checkpoint_best_val", type=int, default=True,
help="""Whether to save the model parameters that perform best during
validation"""
)
parser.add_argument(
"--early_stopping", type=bool, default=False,
help="Whether to stop training when validation loss fails to decrease"
)
parser.add_argument(
"--min_delta", type=float, default=0,
help="""The smallest decrease in validation loss that counts as an
improvement for the purposes of early stopping"""
)
parser.add_argument(
"--patience", type=int, default=3,
help="Early stopping patience"
)
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults( parser.set_defaults(
...@@ -195,4 +248,9 @@ if __name__ == "__main__": ...@@ -195,4 +248,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if(args.seed is None and
((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment