"git@developer.sourcefind.cn:OpenDAS/Uni-Core.git" did not exist on "f24a5f708a86906514fbb775b0ff1e878524d2d6"
Commit 3dcc01a7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Tweak training script. Install new LR scheduler

parent 72a971b0
{
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"eps": 1e-05
}
},
"fp16": {
"enabled": true,
"enabled": false,
"min_loss_scale": 1
},
"amp": {
......@@ -15,7 +8,7 @@
"opt_level": "O2"
},
"bfloat16": {
"enabled": false
"enabled": true
},
"zero_optimization": {
"stage": 2,
......
......@@ -27,12 +27,13 @@ from openfold.data.data_modules import (
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.argparse import remove_arguments
from openfold.utils.loss import AlphaFoldLoss, lddt_ca, compute_drmsd
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
......@@ -58,7 +59,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
self.cached_weights = None
self.last_lr_step = 0
def forward(self, batch):
return self.model(batch)
......@@ -72,12 +72,12 @@ class OpenFoldWrapper(pl.LightningModule):
on_step=train, on_epoch=(not train), logger=True,
)
if(train):
self.log(
f"train/loss_epoch",
loss_breakdown["loss"],
on_step=False, on_epoch=True, logger=True,
)
if(train):
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
on_step=False, on_epoch=True, logger=True,
)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
......@@ -116,19 +116,13 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug (PL issue 11694)
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# load_state_dict() is an in-place operation
# it will change the content in any reference of model.state_dict()
# therefore we need to explicitly clone the parameters
clone_param = lambda t: t.clone().detach()
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
......@@ -175,15 +169,15 @@ class OpenFoldWrapper(pl.LightningModule):
eps=self.config.globals.eps,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = compute_drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
......@@ -207,11 +201,23 @@ class OpenFoldWrapper(pl.LightningModule):
eps: float = 1e-5,
) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured
return torch.optim.Adam(
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
eps=eps
)
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"name": "AlphaFoldLRScheduler",
}
}
def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"])
......@@ -236,7 +242,7 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...")
# TorchScript components of the model
if(args.script_modules):
script_preset_(model_module)
......@@ -255,6 +261,8 @@ def main(args):
if(args.checkpoint_every_epoch):
mc = ModelCheckpoint(
every_n_epochs=1,
auto_insert_metric_name=False,
save_top_k=-1,
)
callbacks.append(mc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment