Commit 3dcc01a7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Tweak training script. Install new LR scheduler

parent 72a971b0
{ {
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"eps": 1e-05
}
},
"fp16": { "fp16": {
"enabled": true, "enabled": false,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"amp": { "amp": {
...@@ -15,7 +8,7 @@ ...@@ -15,7 +8,7 @@
"opt_level": "O2" "opt_level": "O2"
}, },
"bfloat16": { "bfloat16": {
"enabled": false "enabled": true
}, },
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
......
...@@ -27,12 +27,13 @@ from openfold.data.data_modules import ( ...@@ -27,12 +27,13 @@ from openfold.data.data_modules import (
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.callbacks import ( from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.argparse import remove_arguments
from openfold.utils.loss import AlphaFoldLoss, lddt_ca, compute_drmsd from openfold.utils.loss import AlphaFoldLoss, lddt_ca, compute_drmsd
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
...@@ -58,7 +59,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -58,7 +59,6 @@ class OpenFoldWrapper(pl.LightningModule):
) )
self.cached_weights = None self.cached_weights = None
self.last_lr_step = 0
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
...@@ -74,8 +74,8 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -74,8 +74,8 @@ class OpenFoldWrapper(pl.LightningModule):
if(train): if(train):
self.log( self.log(
f"train/loss_epoch", f"{phase}/{loss_name}_epoch",
loss_breakdown["loss"], indiv_loss,
on_step=False, on_epoch=True, logger=True, on_step=False, on_epoch=True, logger=True,
) )
...@@ -116,19 +116,13 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -116,19 +116,13 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs): def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model) self.ema.update(self.model)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug (PL issue 11694)
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
# load_state_dict() is an in-place operation # model.state_dict() contains references to model weights rather
# it will change the content in any reference of model.state_dict() # than copies. Therefore, we need to clone them before calling
# therefore we need to explicitly clone the parameters # load_state_dict().
clone_param = lambda t: t.clone().detach() clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
...@@ -181,7 +175,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -181,7 +175,7 @@ class OpenFoldWrapper(pl.LightningModule):
drmsd_ca_score = compute_drmsd( drmsd_ca_score = compute_drmsd(
pred_coords_masked_ca, pred_coords_masked_ca,
gt_coords_masked_ca, gt_coords_masked_ca,
mask=all_atom_mask_ca, mask=all_atom_mask_ca, # still required here to compute n
) )
metrics["drmsd_ca"] = drmsd_ca_score metrics["drmsd_ca"] = drmsd_ca_score
...@@ -207,11 +201,23 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -207,11 +201,23 @@ class OpenFoldWrapper(pl.LightningModule):
eps: float = 1e-5, eps: float = 1e-5,
) -> torch.optim.Adam: ) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured # Ignored as long as a DeepSpeed optimizer is configured
return torch.optim.Adam( optimizer = torch.optim.Adam(
self.model.parameters(), self.model.parameters(),
lr=learning_rate, lr=learning_rate,
eps=eps eps=eps
) )
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"name": "AlphaFoldLRScheduler",
}
}
def on_load_checkpoint(self, checkpoint): def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"]) self.ema.load_state_dict(checkpoint["ema"])
...@@ -255,6 +261,8 @@ def main(args): ...@@ -255,6 +261,8 @@ def main(args):
if(args.checkpoint_every_epoch): if(args.checkpoint_every_epoch):
mc = ModelCheckpoint( mc = ModelCheckpoint(
every_n_epochs=1, every_n_epochs=1,
auto_insert_metric_name=False,
save_top_k=-1,
) )
callbacks.append(mc) callbacks.append(mc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment