"...models/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "d2d32f692cd0d5cc628c2d0399dc29dc70039417"
Commit fbec92cb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update training script

parent 70362e4b
......@@ -12,6 +12,7 @@ import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
......@@ -68,10 +69,14 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss = self.loss(outputs, batch)
# Log it
self.log("train/loss", loss, on_step=True, logger=True)
return loss
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
......@@ -81,13 +86,17 @@ class OpenFoldWrapper(pl.LightningModule):
# Calculate validation loss
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = lddt_ca(
lddt_ca_score = lddt_ca(
outputs["final_atom_positions"],
batch["all_atom_positions"],
batch["all_atom_mask"],
eps=self.config.globals.eps,
per_residue=False,
)
self.log("val/lddt_ca", lddt_ca_score, logger=True)
batch["use_clamped_fape"] = 0.
loss = self.loss(outputs, batch)
self.log("val/loss", loss, logger=True)
def validation_epoch_end(self, _):
......@@ -106,9 +115,6 @@ class OpenFoldWrapper(pl.LightningModule):
eps=eps
)
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"])
......@@ -137,7 +143,7 @@ def main(args):
if(args.script_modules):
script_preset_(model_module)
#data_module = DummyDataLoader("batch.pickle")
#data_module = DummyDataLoader("new_batch.pickle")
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
......@@ -148,22 +154,19 @@ def main(args):
data_module.setup()
callbacks = []
if(args.checkpoint_best_val):
checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
if(args.checkpoint_every_epoch):
mc = ModelCheckpoint(
filename="openfold_{epoch}_{step}_{val_loss:.2f}",
monitor="val/loss",
mode="max",
every_n_epochs=1,
)
callbacks.append(mc)
if(args.early_stopping):
es = EarlyStoppingVerbose(
monitor="val/loss",
monitor="val/lddt_ca",
min_delta=args.min_delta,
patience=args.patience,
verbose=False,
mode="min",
mode="max",
check_finite=True,
strict=True,
)
......@@ -189,14 +192,8 @@ def main(args):
loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None):
#if "SLURM_JOB_ID" in os.environ:
# cluster_environment = SLURMEnvironment()
#else:
# cluster_environment = None
strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path,
# cluster_environment=cluster_environment,
)
if(args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path)
......@@ -313,9 +310,8 @@ if __name__ == "__main__":
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser.add_argument(
"--checkpoint_best_val", type=bool_type, default=True,
help="""Whether to save the model parameters that perform best during
validation"""
"--checkpoint_every_epoch", action="store_true", default=False,
help="""Whether to checkpoint at the end of every training epoch"""
)
parser.add_argument(
"--early_stopping", type=bool_type, default=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment