Commit 0cf1541c authored by Christina Floristean's avatar Christina Floristean
Browse files

Refactoring multimer data pipeline and permutation alignment.

parent 377f854c
...@@ -19,6 +19,8 @@ dependencies: ...@@ -19,6 +19,8 @@ dependencies:
- deepspeed==0.5.10 - deepspeed==0.5.10
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0 - ml-collections==0.1.0
- jax==0.3.25
- pandas==2.0.2
- numpy==1.21.2 - numpy==1.21.2
- PyYAML==5.4.1 - PyYAML==5.4.1
- requests==2.26.0 - requests==2.26.0
......
...@@ -749,6 +749,9 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -749,6 +749,9 @@ multimer_config_update = mlc.ConfigDict({
"sym_id", "sym_id",
] ]
}, },
"supervised": {
"clamp_prob": 1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model: # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508 # c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048 # c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
...@@ -765,7 +768,8 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -765,7 +768,8 @@ multimer_config_update = mlc.ConfigDict({
"max_extra_msa": 2048, "max_extra_msa": 2048,
"crop_size": 640, "crop_size": 640,
"spatial_crop_prob": 0.5, "spatial_crop_prob": 0.5,
"interface_threshold": 10. "interface_threshold": 10.,
"clamp_prob": 1.,
}, },
}, },
"model": { "model": {
......
This diff is collapsed.
...@@ -93,24 +93,11 @@ def np_example_to_features( ...@@ -93,24 +93,11 @@ def np_example_to_features(
with torch.no_grad(): with torch.no_grad():
if is_multimer: if is_multimer:
if mode == 'train': features = input_pipeline_multimer.process_tensors_from_config(
features,gt_features = input_pipeline_multimer.process_tensors_from_config( tensor_dict,
tensor_dict, cfg.common,
cfg.common, cfg[mode],
cfg[mode], )
is_training=True
)
return {k: v for k, v in features.items()}, gt_features
else:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=False
)
return {k: v for k, v in features.items()}
else: else:
features = input_pipeline.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
tensor_dict, tensor_dict,
......
...@@ -21,16 +21,17 @@ from openfold.data import ( ...@@ -21,16 +21,17 @@ from openfold.data import (
data_transforms_multimer, data_transforms_multimer,
) )
def grountruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks, def groundtruth_transforms_fns():
data_transforms.make_atom14_positions, transforms = [data_transforms.make_atom14_masks,
data_transforms.atom37_to_frames, data_transforms.make_atom14_positions,
data_transforms.atom37_to_torsion_angles(""), data_transforms.atom37_to_frames,
data_transforms.make_pseudo_beta(""), data_transforms.atom37_to_torsion_angles(""),
data_transforms.get_backbone_frames, data_transforms.make_pseudo_beta(""),
data_transforms.get_chi_angles, data_transforms.get_backbone_frames,
] data_transforms.get_chi_angles]
return transforms return transforms
def nonensembled_transform_fns(): def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled.""" """Input pipeline data transformers that are not ensembled."""
...@@ -112,20 +113,24 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -112,20 +113,24 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return transforms return transforms
def prepare_ground_truth_features(tensors): def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training""" """Prepare ground truth features that are only needed for loss calculation during training"""
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id'] gt_features = ['all_atom_mask', 'all_atom_positions', 'asym_id', 'sym_id', 'entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES} gt_tensors = {k: v for k, v in tensors.items() if k in gt_features}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long) gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors) gt_tensors = compose(groundtruth_transforms_fns())(gt_tensors)
return gt_tensors return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
if is_training: process_gt_feats = mode_cfg.supervised
gt_tensors= prepare_ground_truth_features(tensors) gt_tensors = {}
if process_gt_feats:
gt_tensors = prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max) ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long) tensors['aatype'] = tensors['aatype'].to(torch.long)
...@@ -152,10 +157,10 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False) ...@@ -152,10 +157,10 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False)
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
) )
if is_training: if process_gt_feats:
return tensors,gt_tensors tensors['gt_features'] = gt_tensors
else:
return tensors return tensors
@data_transforms.curry1 @data_transforms.curry1
def compose(x, fs): def compose(x, fs):
......
This diff is collapsed.
...@@ -13,7 +13,7 @@ from tqdm import tqdm ...@@ -13,7 +13,7 @@ from tqdm import tqdm
from openfold.data.mmcif_parsing import parse from openfold.data.mmcif_parsing import parse
def parse_file(f, args): def parse_file(f, args, chain_cluster_size_dict=None):
with open(os.path.join(args.mmcif_dir, f), "r") as fp: with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read() mmcif_string = fp.read()
file_id = os.path.splitext(f)[0] file_id = os.path.splitext(f)[0]
...@@ -28,6 +28,18 @@ def parse_file(f, args): ...@@ -28,6 +28,18 @@ def parse_file(f, args):
local_data["release_date"] = mmcif.header["release_date"] local_data["release_date"] = mmcif.header["release_date"]
chain_ids, seqs = list(zip(*mmcif.chain_to_seqres.items())) chain_ids, seqs = list(zip(*mmcif.chain_to_seqres.items()))
if chain_cluster_size_dict is not None:
cluster_sizes = []
for chain_id in chain_ids:
full_name = "_".join([file_id, chain_id])
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
cluster_sizes.append(cluster_size)
local_data["cluster_sizes"] = cluster_sizes
local_data["chain_ids"] = chain_ids local_data["chain_ids"] = chain_ids
local_data["seqs"] = seqs local_data["seqs"] = seqs
local_data["no_chains"] = len(chain_ids) local_data["no_chains"] = len(chain_ids)
...@@ -38,8 +50,21 @@ def parse_file(f, args): ...@@ -38,8 +50,21 @@ def parse_file(f, args):
def main(args): def main(args):
chain_cluster_size_dict = None
if args.cluster_file is not None:
chain_cluster_size_dict = {}
with open(args.cluster_file, "r") as fp:
clusters = [l.strip() for l in fp.readlines()]
for cluster in clusters:
chain_ids = cluster.split()
cluster_len = len(chain_ids)
for chain_id in chain_ids:
chain_id = chain_id.upper()
chain_cluster_size_dict[chain_id] = cluster_len
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f] files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args) fn = partial(parse_file, args=args, chain_cluster_size_dict=chain_cluster_size_dict)
data = {} data = {}
with Pool(processes=args.no_workers) as p: with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar: with tqdm(total=len(files)) as pbar:
...@@ -63,6 +88,15 @@ if __name__ == "__main__": ...@@ -63,6 +88,15 @@ if __name__ == "__main__":
"--no_workers", type=int, default=4, "--no_workers", type=int, default=4,
help="Number of workers to use for parsing" help="Number of workers to use for parsing"
) )
parser.add_argument(
"--cluster_file", type=str, default=None,
help=(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser.add_argument( parser.add_argument(
"--chunksize", type=int, default=10, "--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time" help="How many files should be distributed to each worker at a time"
......
import argparse import argparse
import logging import logging
import os import os
import random
import sys import sys
import time
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import ( from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
OpenFoldDataModule,OpenFoldMultimerDataModule,
DummyDataLoader,
)
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants from openfold.np import residue_constants
...@@ -53,7 +46,12 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -53,7 +46,12 @@ class OpenFoldWrapper(pl.LightningModule):
super(OpenFoldWrapper, self).__init__() super(OpenFoldWrapper, self).__init__()
self.config = config self.config = config
self.model = AlphaFold(config) self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss)
if self.config.globals.is_multimer:
self.loss = AlphaFoldMultimerLoss(config.loss)
else:
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage( self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay model=self.model, decay=config.ema.decay
) )
...@@ -256,72 +254,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -256,72 +254,6 @@ class OpenFoldWrapper(pl.LightningModule):
) )
class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
features,gt_features = batch
# Log it
if(self.ema.device != features["aatype"].device):
self.ema.to(features["aatype"].device)
# Run the model
outputs = self(features)
# Remove the recycling dimension
features = tensor_tree_map(lambda t: t[..., -1], features)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, (features,gt_features), _return_breakdown=True
)
# Log it
self._log(loss_breakdown, features, outputs)
return loss
def validation_step(self, batch, batch_idx):
features,gt_features = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(features)
# Compute loss and other metrics
features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, (features,gt_features), _return_breakdown=True
)
self._log(loss_breakdown, features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed)
...@@ -331,10 +263,8 @@ def main(args): ...@@ -331,10 +263,8 @@ def main(args):
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=(str(args.precision) == "16")
) )
if "multimer" in args.config_preset: model_module = OpenFoldWrapper(config)
model_module = OpenFoldMultimerWrapper(config)
else:
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt): if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)): if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt) last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
...@@ -359,7 +289,6 @@ def main(args): ...@@ -359,7 +289,6 @@ def main(args):
if(args.script_modules): if(args.script_modules):
script_preset_(model_module) script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle")
if "multimer" in args.config_preset: if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule( data_module = OpenFoldMultimerDataModule(
config=config.data, config=config.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment