Commit fbec92cb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update training script

parent 70362e4b
...@@ -12,6 +12,7 @@ import time ...@@ -12,6 +12,7 @@ import time
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
...@@ -68,10 +69,14 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -68,10 +69,14 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
# Log it
self.log("train/loss", loss, on_step=True, logger=True) self.log("train/loss", loss, on_step=True, logger=True)
return loss return loss
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
...@@ -81,13 +86,17 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -81,13 +86,17 @@ class OpenFoldWrapper(pl.LightningModule):
# Calculate validation loss # Calculate validation loss
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = lddt_ca( lddt_ca_score = lddt_ca(
outputs["final_atom_positions"], outputs["final_atom_positions"],
batch["all_atom_positions"], batch["all_atom_positions"],
batch["all_atom_mask"], batch["all_atom_mask"],
eps=self.config.globals.eps, eps=self.config.globals.eps,
per_residue=False, per_residue=False,
) )
self.log("val/lddt_ca", lddt_ca_score, logger=True)
batch["use_clamped_fape"] = 0.
loss = self.loss(outputs, batch)
self.log("val/loss", loss, logger=True) self.log("val/loss", loss, logger=True)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
...@@ -106,9 +115,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -106,9 +115,6 @@ class OpenFoldWrapper(pl.LightningModule):
eps=eps eps=eps
) )
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def on_load_checkpoint(self, checkpoint): def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"]) self.ema.load_state_dict(checkpoint["ema"])
...@@ -137,7 +143,7 @@ def main(args): ...@@ -137,7 +143,7 @@ def main(args):
if(args.script_modules): if(args.script_modules):
script_preset_(model_module) script_preset_(model_module)
#data_module = DummyDataLoader("batch.pickle") #data_module = DummyDataLoader("new_batch.pickle")
data_module = OpenFoldDataModule( data_module = OpenFoldDataModule(
config=config.data, config=config.data,
batch_seed=args.seed, batch_seed=args.seed,
...@@ -148,22 +154,19 @@ def main(args): ...@@ -148,22 +154,19 @@ def main(args):
data_module.setup() data_module.setup()
callbacks = [] callbacks = []
if(args.checkpoint_best_val): if(args.checkpoint_every_epoch):
checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
mc = ModelCheckpoint( mc = ModelCheckpoint(
filename="openfold_{epoch}_{step}_{val_loss:.2f}", every_n_epochs=1,
monitor="val/loss",
mode="max",
) )
callbacks.append(mc) callbacks.append(mc)
if(args.early_stopping): if(args.early_stopping):
es = EarlyStoppingVerbose( es = EarlyStoppingVerbose(
monitor="val/loss", monitor="val/lddt_ca",
min_delta=args.min_delta, min_delta=args.min_delta,
patience=args.patience, patience=args.patience,
verbose=False, verbose=False,
mode="min", mode="max",
check_finite=True, check_finite=True,
strict=True, strict=True,
) )
...@@ -189,14 +192,8 @@ def main(args): ...@@ -189,14 +192,8 @@ def main(args):
loggers.append(wdb_logger) loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None): if(args.deepspeed_config_path is not None):
#if "SLURM_JOB_ID" in os.environ:
# cluster_environment = SLURMEnvironment()
#else:
# cluster_environment = None
strategy = DeepSpeedPlugin( strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path, config=args.deepspeed_config_path,
# cluster_environment=cluster_environment,
) )
if(args.wandb): if(args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save(args.deepspeed_config_path)
...@@ -313,9 +310,8 @@ if __name__ == "__main__": ...@@ -313,9 +310,8 @@ if __name__ == "__main__":
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled" help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
) )
parser.add_argument( parser.add_argument(
"--checkpoint_best_val", type=bool_type, default=True, "--checkpoint_every_epoch", action="store_true", default=False,
help="""Whether to save the model parameters that perform best during help="""Whether to checkpoint at the end of every training epoch"""
validation"""
) )
parser.add_argument( parser.add_argument(
"--early_stopping", type=bool_type, default=False, "--early_stopping", type=bool_type, default=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment