Commit f30d77b7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Prep for training

parent 0dc316c3
...@@ -96,7 +96,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -96,7 +96,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(mapping_path is None): if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir)) self._chain_ids = list(os.listdir(alignment_dir))
else: else:
with open(mapping_path, "r") as f: with open(mapping_path, "r") as f:
...@@ -159,6 +161,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -159,6 +161,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
_alignment_index = None _alignment_index = None
if(self._alignment_index is not None): if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name] _alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
...@@ -546,7 +549,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -546,7 +549,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.template_release_dates_cache_path, self.template_release_dates_cache_path,
obsolete_pdbs_file_path= obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path, self.obsolete_pdbs_file_path,
_alignment_index=self._alignment_index,
) )
if(self.training_mode): if(self.training_mode):
...@@ -560,6 +562,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -560,6 +562,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True, _output_raw=True,
_alignment_index=self._alignment_index,
) )
distillation_dataset = None distillation_dataset = None
......
...@@ -425,12 +425,13 @@ class DataPipeline: ...@@ -425,12 +425,13 @@ class DataPipeline:
_alignment_index: Optional[Any] = None, _alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
if(_alignment_index is not None): if(_alignment_index is not None):
fp = open(_alignment_index["db"], "rb") fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
def read_msa(start, size): def read_msa(start, size):
fp.seek(start) fp.seek(start)
msa = fp.read(size).encode("utf-8") msa = fp.read(size).decode("utf-8")
return msa return msa
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in _alignment_index["files"]:
...@@ -448,8 +449,8 @@ class DataPipeline: ...@@ -448,8 +449,8 @@ class DataPipeline:
data = {"msa": msa, "deletion_matrix": deletion_matrix} data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
msa_data[f] = data msa_data[name] = data
fp.close() fp.close()
else: else:
...@@ -481,18 +482,20 @@ class DataPipeline: ...@@ -481,18 +482,20 @@ class DataPipeline:
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
if(_alignment_index is not None): if(_alignment_index is not None):
fp = open(_alignment_index["db"], 'rb') fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
def read_template(start, size): def read_template(start, size):
fp.seek(start) fp.seek(start)
return fp.read(size).encode("utf-8") return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1] ext = os.path.splitext(name)[-1]
if(ext == ".hhr"): if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size)) hits = parsers.parse_hhr(read_template(start, size))
all_hits[f] = hits all_hits[name] = hits
fp.close()
else: else:
for f in os.listdir(alignment_dir): for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f) path = os.path.join(alignment_dir, f)
...@@ -568,7 +571,7 @@ class DataPipeline: ...@@ -568,7 +571,7 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence) msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return { return {
**sequence_features, **sequence_features,
...@@ -607,7 +610,7 @@ class DataPipeline: ...@@ -607,7 +610,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"]) query_release_date=to_date(mmcif.header["release_date"])
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence) msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**mmcif_feats, **template_features, **msa_features} return {**mmcif_feats, **template_features, **msa_features}
...@@ -641,7 +644,7 @@ class DataPipeline: ...@@ -641,7 +644,7 @@ class DataPipeline:
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence) msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features}
......
...@@ -329,26 +329,15 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor: ...@@ -329,26 +329,15 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
return pred_lddt_ca * 100 return pred_lddt_ca * 100
def lddt_loss( def lddt(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor, all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor, all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor, all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0, cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10, eps: float = 1e-10,
**kwargs, per_residue: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
n = all_atom_mask.shape[-2] n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
dmat_true = torch.sqrt( dmat_true = torch.sqrt(
eps eps
+ torch.sum( + torch.sum(
...@@ -389,8 +378,63 @@ def lddt_loss( ...@@ -389,8 +378,63 @@ def lddt_loss(
) )
score = score * 0.25 score = score * 0.25
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=-1)) dims = (-1,) if per_residue else (-2, -1)
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1)) norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
def lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
)
def lddt_loss(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
score = lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps
)
score = score.detach() score = score.detach()
...@@ -1526,4 +1570,8 @@ class AlphaFoldLoss(nn.Module): ...@@ -1526,4 +1570,8 @@ class AlphaFoldLoss(nn.Module):
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss cum_loss = cum_loss + weight * loss
seq_len = torch.mean(batch["seq_length"].float())
crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
return cum_loss return cum_loss
...@@ -138,22 +138,22 @@ def main(args): ...@@ -138,22 +138,22 @@ def main(args):
cache = None cache = None
dirs = [] dirs = []
#if(cache is not None and args.filter): if(cache is not None and args.filter):
# dirs = set(os.listdir(args.output_dir)) dirs = set(os.listdir(args.output_dir))
# def prot_is_done(f): def prot_is_done(f):
# prot_id = os.path.splitext(f)[0] prot_id = os.path.splitext(f)[0]
# if(prot_id in cache): if(prot_id in cache):
# chain_ids = cache[prot_id]["chain_ids"] chain_ids = cache[prot_id]["chain_ids"]
# for c in chain_ids: for c in chain_ids:
# full_name = prot_id + "_" + c full_name = prot_id + "_" + c
# if(not full_name in dirs): if(not full_name in dirs):
# return False return False
# else: else:
# return False return False
# return True return True
# files = [f for f in files if not prot_is_done(f)] files = [f for f in files if not prot_is_done(f)]
def split_up_arglist(arglist): def split_up_arglist(arglist):
# Split up the survivors # Split up the survivors
......
...@@ -13,6 +13,7 @@ import time ...@@ -13,6 +13,7 @@ import time
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch import torch
...@@ -29,7 +30,7 @@ from openfold.utils.callbacks import ( ...@@ -29,7 +30,7 @@ from openfold.utils.callbacks import (
) )
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.argparse import remove_arguments from openfold.utils.argparse import remove_arguments
from openfold.utils.loss import AlphaFoldLoss from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
...@@ -67,22 +68,27 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -67,22 +68,27 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("loss", loss) self.log("train/loss", loss, logger=True)
return {"loss": loss} return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
self.cached_weights = self.model.state_dict() self.cached_weights = self.model.state_dict()
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
# Calculate validation loss # Calculate validation loss
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch) loss = lddt_ca(
self.log("val_loss", loss) outputs["final_atom_positions"],
return {"val_loss": loss} batch["all_atom_positions"],
batch["all_atom_mask"],
eps=self.config.globals.eps,
per_residue=False,
)
self.log("val/loss", loss, logger=True)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
...@@ -91,7 +97,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -91,7 +97,7 @@ class OpenFoldWrapper(pl.LightningModule):
def configure_optimizers(self, def configure_optimizers(self,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
eps: float = 1e-8 eps: float = 1e-5
) -> torch.optim.Adam: ) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured # Ignored as long as a DeepSpeed optimizer is configured
return torch.optim.Adam( return torch.optim.Adam(
...@@ -103,6 +109,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -103,6 +109,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs): def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model) self.ema.update(self.model)
def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"])
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict() checkpoint["ema"] = self.ema.state_dict()
...@@ -114,7 +123,7 @@ def main(args): ...@@ -114,7 +123,7 @@ def main(args):
config = model_config( config = model_config(
"model_1", "model_1",
train=True, train=True,
low_prec=(args.precision == 16) low_prec=(args.precision == "16")
) )
model_module = OpenFoldWrapper(config) model_module = OpenFoldWrapper(config)
...@@ -144,13 +153,13 @@ def main(args): ...@@ -144,13 +153,13 @@ def main(args):
mc = ModelCheckpoint( mc = ModelCheckpoint(
dirpath=checkpoint_dir, dirpath=checkpoint_dir,
filename="openfold_{epoch}_{step}_{val_loss:.2f}", filename="openfold_{epoch}_{step}_{val_loss:.2f}",
monitor="val_loss", monitor="val/loss",
) )
callbacks.append(mc) callbacks.append(mc)
if(args.early_stopping): if(args.early_stopping):
es = EarlyStoppingVerbose( es = EarlyStoppingVerbose(
monitor="val_loss", monitor="val/loss",
min_delta=args.min_delta, min_delta=args.min_delta,
patience=args.patience, patience=args.patience,
verbose=False, verbose=False,
...@@ -159,7 +168,7 @@ def main(args): ...@@ -159,7 +168,7 @@ def main(args):
strict=True, strict=True,
) )
callbacks.append(es) callbacks.append(es)
if(args.log_performance): if(args.log_performance):
global_batch_size = args.num_nodes * args.gpus global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback( perf = PerformanceLoggingCallback(
...@@ -168,24 +177,40 @@ def main(args): ...@@ -168,24 +177,40 @@ def main(args):
) )
callbacks.append(perf) callbacks.append(perf)
loggers = []
if(args.wandb):
wdb_logger = WandbLogger(
name=args.experiment_name,
save_dir=args.output_dir,
id=args.wandb_id,
project=args.wandb_project,
**{"entity": args.wandb_entity}
)
loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None): if(args.deepspeed_config_path is not None):
if "SLURM_JOB_ID" in os.environ: #if "SLURM_JOB_ID" in os.environ:
cluster_environment = SLURMEnvironment() # cluster_environment = SLURMEnvironment()
else: #else:
cluster_environment = None # cluster_environment = None
strategy = DeepSpeedPlugin( strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path, config=args.deepspeed_config_path,
cluster_environment=cluster_environment, # cluster_environment=cluster_environment,
) )
elif (args.gpus is not None and args.gpus) > 1 or args.num_nodes > 1: if(args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path)
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False) strategy = DDPPlugin(find_unused_parameters=False)
else: else:
strategy = None strategy = None
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
default_root_dir=args.output_dir,
strategy=strategy, strategy=strategy,
callbacks=callbacks, callbacks=callbacks,
logger=loggers,
) )
if(args.resume_model_weights_only): if(args.resume_model_weights_only):
...@@ -200,7 +225,7 @@ def main(args): ...@@ -200,7 +225,7 @@ def main(args):
) )
trainer.save_checkpoint( trainer.save_checkpoint(
os.path.join(trainer.logger.log_dir, "checkpoints", "final.ckpt") os.path.join(args.output_dir, "checkpoints", "final.ckpt")
) )
...@@ -315,6 +340,21 @@ if __name__ == "__main__": ...@@ -315,6 +340,21 @@ if __name__ == "__main__":
"--log_performance", type=bool_type, default=False, "--log_performance", type=bool_type, default=False,
help="Measure performance" help="Measure performance"
) )
parser.add_argument(
"--wandb", action="store_true", default=False,
)
parser.add_argument(
"--experiment_name", type=str, default=None,
)
parser.add_argument(
"--wandb_id", type=str, default=None,
)
parser.add_argument(
"--wandb_project", type=str, default=None,
)
parser.add_argument(
"--wandb_entity", type=str, default=None,
)
parser.add_argument( parser.add_argument(
"--script_modules", type=bool_type, default=False, "--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model" help="Whether to TorchScript eligible components of them model"
...@@ -328,6 +368,12 @@ if __name__ == "__main__": ...@@ -328,6 +368,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--train_epoch_len", type=int, default=10000, "--train_epoch_len", type=int, default=10000,
) )
parser.add_argument(
"--obsolete_pdbs_file_path", type=str,
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass # Disable the initial validation pass
...@@ -341,7 +387,8 @@ if __name__ == "__main__": ...@@ -341,7 +387,8 @@ if __name__ == "__main__":
[ [
"--accelerator", "--accelerator",
"--resume_from_checkpoint", "--resume_from_checkpoint",
"--reload_dataloaders_every_epoch" "--reload_dataloaders_every_epoch",
"--reload_dataloaders_every_n_epochs",
] ]
) )
...@@ -353,6 +400,6 @@ if __name__ == "__main__": ...@@ -353,6 +400,6 @@ if __name__ == "__main__":
raise ValueError("For distributed training, --seed must be specified") raise ValueError("For distributed training, --seed must be specified")
# This re-applies the training-time filters at the beginning of every epoch # This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_epoch = True args.reload_dataloaders_every_n_epochs = 1
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment