Commit f30d77b7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Prep for training

parent 0dc316c3
......@@ -96,7 +96,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(mapping_path is None):
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
......@@ -159,6 +161,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
_alignment_index = None
if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
......@@ -546,7 +549,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
_alignment_index=self._alignment_index,
)
if(self.training_mode):
......@@ -560,6 +562,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=self._alignment_index,
)
distillation_dataset = None
......
......@@ -425,12 +425,13 @@ class DataPipeline:
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(_alignment_index["db"], "rb")
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).encode("utf-8")
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
......@@ -448,8 +449,8 @@ class DataPipeline:
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
msa_data[name] = data
fp.close()
else:
......@@ -481,18 +482,20 @@ class DataPipeline:
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(_alignment_index["db"], 'rb')
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).encode("utf-8")
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[f] = hits
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
......@@ -568,7 +571,7 @@ class DataPipeline:
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {
**sequence_features,
......@@ -607,7 +610,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
......@@ -641,7 +644,7 @@ class DataPipeline:
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**pdb_feats, **template_features, **msa_features}
......
......@@ -329,26 +329,15 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
return pred_lddt_ca * 100
def lddt_loss(
logits: torch.Tensor,
def lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
per_residue: bool = True,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
dmat_true = torch.sqrt(
eps
+ torch.sum(
......@@ -389,8 +378,63 @@ def lddt_loss(
)
score = score * 0.25
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
dims = (-1,) if per_residue else (-2, -1)
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
def lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
)
def lddt_loss(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
score = lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps
)
score = score.detach()
......@@ -1526,4 +1570,8 @@ class AlphaFoldLoss(nn.Module):
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
seq_len = torch.mean(batch["seq_length"].float())
crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
return cum_loss
......@@ -138,22 +138,22 @@ def main(args):
cache = None
dirs = []
#if(cache is not None and args.filter):
# dirs = set(os.listdir(args.output_dir))
# def prot_is_done(f):
# prot_id = os.path.splitext(f)[0]
# if(prot_id in cache):
# chain_ids = cache[prot_id]["chain_ids"]
# for c in chain_ids:
# full_name = prot_id + "_" + c
# if(not full_name in dirs):
# return False
# else:
# return False
# return True
# files = [f for f in files if not prot_is_done(f)]
if(cache is not None and args.filter):
dirs = set(os.listdir(args.output_dir))
def prot_is_done(f):
prot_id = os.path.splitext(f)[0]
if(prot_id in cache):
chain_ids = cache[prot_id]["chain_ids"]
for c in chain_ids:
full_name = prot_id + "_" + c
if(not full_name in dirs):
return False
else:
return False
return True
files = [f for f in files if not prot_is_done(f)]
def split_up_arglist(arglist):
# Split up the survivors
......
......@@ -13,6 +13,7 @@ import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
......@@ -29,7 +30,7 @@ from openfold.utils.callbacks import (
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.argparse import remove_arguments
from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.seed import seed_everything
from openfold.utils.tensor_utils import tensor_tree_map
from scripts.zero_to_fp32 import (
......@@ -67,22 +68,27 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss = self.loss(outputs, batch)
self.log("loss", loss)
self.log("train/loss", loss, logger=True)
return {"loss": loss}
return loss
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
self.cached_weights = self.model.state_dict()
self.model.load_state_dict(self.ema.state_dict()["params"])
# Calculate validation loss
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch)
self.log("val_loss", loss)
return {"val_loss": loss}
loss = lddt_ca(
outputs["final_atom_positions"],
batch["all_atom_positions"],
batch["all_atom_mask"],
eps=self.config.globals.eps,
per_residue=False,
)
self.log("val/loss", loss, logger=True)
def validation_epoch_end(self, _):
# Restore the model weights to normal
......@@ -91,7 +97,7 @@ class OpenFoldWrapper(pl.LightningModule):
def configure_optimizers(self,
learning_rate: float = 1e-3,
eps: float = 1e-8
eps: float = 1e-5
) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured
return torch.optim.Adam(
......@@ -103,6 +109,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"])
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
......@@ -114,7 +123,7 @@ def main(args):
config = model_config(
"model_1",
train=True,
low_prec=(args.precision == 16)
low_prec=(args.precision == "16")
)
model_module = OpenFoldWrapper(config)
......@@ -144,13 +153,13 @@ def main(args):
mc = ModelCheckpoint(
dirpath=checkpoint_dir,
filename="openfold_{epoch}_{step}_{val_loss:.2f}",
monitor="val_loss",
monitor="val/loss",
)
callbacks.append(mc)
if(args.early_stopping):
es = EarlyStoppingVerbose(
monitor="val_loss",
monitor="val/loss",
min_delta=args.min_delta,
patience=args.patience,
verbose=False,
......@@ -159,7 +168,7 @@ def main(args):
strict=True,
)
callbacks.append(es)
if(args.log_performance):
global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback(
......@@ -168,24 +177,40 @@ def main(args):
)
callbacks.append(perf)
loggers = []
if(args.wandb):
wdb_logger = WandbLogger(
name=args.experiment_name,
save_dir=args.output_dir,
id=args.wandb_id,
project=args.wandb_project,
**{"entity": args.wandb_entity}
)
loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None):
if "SLURM_JOB_ID" in os.environ:
cluster_environment = SLURMEnvironment()
else:
cluster_environment = None
#if "SLURM_JOB_ID" in os.environ:
# cluster_environment = SLURMEnvironment()
#else:
# cluster_environment = None
strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path,
cluster_environment=cluster_environment,
# cluster_environment=cluster_environment,
)
elif (args.gpus is not None and args.gpus) > 1 or args.num_nodes > 1:
if(args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path)
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False)
else:
strategy = None
trainer = pl.Trainer.from_argparse_args(
args,
default_root_dir=args.output_dir,
strategy=strategy,
callbacks=callbacks,
logger=loggers,
)
if(args.resume_model_weights_only):
......@@ -200,7 +225,7 @@ def main(args):
)
trainer.save_checkpoint(
os.path.join(trainer.logger.log_dir, "checkpoints", "final.ckpt")
os.path.join(args.output_dir, "checkpoints", "final.ckpt")
)
......@@ -315,6 +340,21 @@ if __name__ == "__main__":
"--log_performance", type=bool_type, default=False,
help="Measure performance"
)
parser.add_argument(
"--wandb", action="store_true", default=False,
)
parser.add_argument(
"--experiment_name", type=str, default=None,
)
parser.add_argument(
"--wandb_id", type=str, default=None,
)
parser.add_argument(
"--wandb_project", type=str, default=None,
)
parser.add_argument(
"--wandb_entity", type=str, default=None,
)
parser.add_argument(
"--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model"
......@@ -328,6 +368,12 @@ if __name__ == "__main__":
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
)
parser.add_argument(
"--obsolete_pdbs_file_path", type=str,
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
......@@ -341,7 +387,8 @@ if __name__ == "__main__":
[
"--accelerator",
"--resume_from_checkpoint",
"--reload_dataloaders_every_epoch"
"--reload_dataloaders_every_epoch",
"--reload_dataloaders_every_n_epochs",
]
)
......@@ -353,6 +400,6 @@ if __name__ == "__main__":
raise ValueError("For distributed training, --seed must be specified")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_epoch = True
args.reload_dataloaders_every_n_epochs = 1
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment