Commit 4358096c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix pLDDT bug

parent 236c6865
...@@ -550,7 +550,7 @@ config = mlc.ConfigDict( ...@@ -550,7 +550,7 @@ config = mlc.ConfigDict(
"eps": 1e-4, "eps": 1e-4,
"weight": 1.0, "weight": 1.0,
}, },
"lddt": { "plddt_loss": {
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"cutoff": 15.0, "cutoff": 15.0,
......
...@@ -1562,7 +1562,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1562,7 +1562,7 @@ class AlphaFoldLoss(nn.Module):
"plddt_loss": lambda: lddt_loss( "plddt_loss": lambda: lddt_loss(
logits=out["lddt_logits"], logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"], all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt}, **{**batch, **self.config.plddt_loss},
), ),
"masked_msa": lambda: masked_msa_loss( "masked_msa": lambda: masked_msa_loss(
logits=out["masked_msa_logits"], logits=out["masked_msa_logits"],
......
...@@ -8,7 +8,6 @@ import os ...@@ -8,7 +8,6 @@ import os
#os.environ["NODE_RANK"]="0" #os.environ["NODE_RANK"]="0"
import random import random
import sys
import time import time
import numpy as np import numpy as np
...@@ -27,22 +26,14 @@ from openfold.data.data_modules import ( ...@@ -27,22 +26,14 @@ from openfold.data.data_modules import (
) )
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.callbacks import ( from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.argparse import remove_arguments
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
drmsd,
gdt_ts,
gdt_ha,
)
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint get_fp32_state_dict_from_zero_checkpoint
) )
...@@ -66,36 +57,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -66,36 +57,6 @@ class OpenFoldWrapper(pl.LightningModule):
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
def _log(self, loss_breakdown, batch, outputs, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"{phase}/{loss_name}",
indiv_loss,
on_step=train, on_epoch=(not train), logger=True,
)
if(train):
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
on_step=False, on_epoch=True, logger=True,
)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for k,v in other_metrics.items():
self.log(
f"{phase}/{k}",
v,
on_step=False, on_epoch=True, logger=True
)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device): if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device) self.ema.to(batch["aatype"].device)
...@@ -107,121 +68,54 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -107,121 +68,54 @@ class OpenFoldWrapper(pl.LightningModule):
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss # Compute loss
loss, loss_breakdown = self.loss( loss = self.loss(outputs, batch)
outputs, batch, _return_breakdown=True
)
# Log it self.log("train/loss", loss, on_step=True, logger=True)
self._log(loss_breakdown, batch, outputs)
return loss return loss
def on_before_zero_grad(self, *args, **kwargs): def training_step_end(self, outputs):
self.ema.update(self.model) # Temporary measure to address DeepSpeed scheduler bug
if(self.trainer.global_step != self.last_lr_step):
self.lr_schedulers().step()
self.last_lr_step = self.trainer.global_step
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather self.cached_weights = self.model.state_dict()
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model # Calculate validation loss
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = lddt_ca(
# Compute loss and other metrics outputs["final_atom_positions"],
batch["use_clamped_fape"] = 0. batch["all_atom_positions"],
_, loss_breakdown = self.loss( batch["all_atom_mask"],
outputs, batch, _return_breakdown=True eps=self.config.globals.eps,
per_residue=False,
) )
self.log("val/loss", loss, logger=True)
self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
self.cached_weights = None self.cached_weights = None
def _compute_validation_metrics(self,
batch,
outputs,
superimposition_metrics=False
):
metrics = {}
gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca(
pred_coords,
gt_coords,
all_atom_mask,
eps=self.config.globals.eps,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
return metrics
def configure_optimizers(self, def configure_optimizers(self,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
eps: float = 1e-5, eps: float = 1e-5,
) -> torch.optim.Adam: ) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured # Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam( return torch.optim.Adam(
self.model.parameters(), self.model.parameters(),
lr=learning_rate, lr=learning_rate,
eps=eps eps=eps
) )
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
)
return { def on_before_zero_grad(self, *args, **kwargs):
"optimizer": optimizer, self.ema.update(self.model)
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"name": "AlphaFoldLRScheduler",
}
}
def on_load_checkpoint(self, checkpoint): def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"]) self.ema.load_state_dict(checkpoint["ema"])
...@@ -235,7 +129,7 @@ def main(args): ...@@ -235,7 +129,7 @@ def main(args):
seed_everything(args.seed) seed_everything(args.seed)
config = model_config( config = model_config(
args.config_preset, "model_1",
train=True, train=True,
low_prec=(args.precision == "16") low_prec=(args.precision == "16")
) )
...@@ -246,7 +140,7 @@ def main(args): ...@@ -246,7 +140,7 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
# TorchScript components of the model # TorchScript components of the model
if(args.script_modules): if(args.script_modules):
script_preset_(model_module) script_preset_(model_module)
...@@ -265,18 +159,16 @@ def main(args): ...@@ -265,18 +159,16 @@ def main(args):
if(args.checkpoint_every_epoch): if(args.checkpoint_every_epoch):
mc = ModelCheckpoint( mc = ModelCheckpoint(
every_n_epochs=1, every_n_epochs=1,
auto_insert_metric_name=False,
save_top_k=-1,
) )
callbacks.append(mc) callbacks.append(mc)
if(args.early_stopping): if(args.early_stopping):
es = EarlyStoppingVerbose( es = EarlyStoppingVerbose(
monitor="val/lddt_ca", monitor="val/loss",
min_delta=args.min_delta, min_delta=args.min_delta,
patience=args.patience, patience=args.patience,
verbose=False, verbose=False,
mode="max", mode="min",
check_finite=True, check_finite=True,
strict=True, strict=True,
) )
...@@ -306,8 +198,14 @@ def main(args): ...@@ -306,8 +198,14 @@ def main(args):
loggers.append(wdb_logger) loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None): if(args.deepspeed_config_path is not None):
#if "SLURM_JOB_ID" in os.environ:
# cluster_environment = SLURMEnvironment()
#else:
# cluster_environment = None
strategy = DeepSpeedPlugin( strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path, config=args.deepspeed_config_path,
# cluster_environment=cluster_environment,
) )
if(args.wandb): if(args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save(args.deepspeed_config_path)
...@@ -316,12 +214,7 @@ def main(args): ...@@ -316,12 +214,7 @@ def main(args):
strategy = DDPPlugin(find_unused_parameters=False) strategy = DDPPlugin(find_unused_parameters=False)
else: else:
strategy = None strategy = None
if(args.wandb):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
default_root_dir=args.output_dir, default_root_dir=args.output_dir,
...@@ -459,65 +352,37 @@ if __name__ == "__main__": ...@@ -459,65 +352,37 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--wandb", action="store_true", default=False, "--wandb", action="store_true", default=False,
help="Whether to log metrics to Weights & Biases"
) )
parser.add_argument( parser.add_argument(
"--experiment_name", type=str, default=None, "--experiment_name", type=str, default=None,
help="Name of the current experiment. Used for wandb logging"
) )
parser.add_argument( parser.add_argument(
"--wandb_id", type=str, default=None, "--wandb_id", type=str, default=None,
help="ID of a previous run to be resumed"
) )
parser.add_argument( parser.add_argument(
"--wandb_project", type=str, default=None, "--wandb_project", type=str, default=None,
help="Name of the wandb project to which this run will belong"
) )
parser.add_argument( parser.add_argument(
"--wandb_entity", type=str, default=None, "--wandb_entity", type=str, default=None,
help="wandb username or team name to which runs are attributed"
) )
parser.add_argument( parser.add_argument(
"--script_modules", type=bool_type, default=False, "--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model" help="Whether to TorchScript eligible components of them model"
) )
parser.add_argument( parser.add_argument(
"--train_chain_data_cache_path", type=str, default=None, "--train_prot_data_cache_path", type=str, default=None,
) )
parser.add_argument( parser.add_argument(
"--distillation_chain_data_cache_path", type=str, default=None, "--distillation_prot_data_cache_path", type=str, default=None,
) )
parser.add_argument( parser.add_argument(
"--train_epoch_len", type=int, default=10000, "--train_epoch_len", type=int, default=10000,
help=(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser.add_argument(
"--log_lr", action="store_true", default=False,
help="Whether to log the actual learning rate"
)
parser.add_argument(
"--config_preset", type=str, default="initial_training",
help=(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
) )
parser.add_argument( parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None, "--_alignment_index_path", type=str, default=None,
) )
parser.add_argument( parser.add_argument(
"--alignment_index_path", type=str, default=None, "--log_lr", action="store_true", default=False,
help="Training alignment index. See the README for instructions."
)
parser.add_argument(
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
) )
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment