Commit 0cf1541c authored by Christina Floristean's avatar Christina Floristean
Browse files

Refactoring multimer data pipeline and permutation alignment.

parent 377f854c
......@@ -19,6 +19,8 @@ dependencies:
- deepspeed==0.5.10
- dm-tree==0.1.6
- ml-collections==0.1.0
- jax==0.3.25
- pandas==2.0.2
- numpy==1.21.2
- PyYAML==5.4.1
- requests==2.26.0
......
......@@ -749,6 +749,9 @@ multimer_config_update = mlc.ConfigDict({
"sym_id",
]
},
"supervised": {
"clamp_prob": 1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
......@@ -765,7 +768,8 @@ multimer_config_update = mlc.ConfigDict({
"max_extra_msa": 2048,
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.
"interface_threshold": 10.,
"clamp_prob": 1.,
},
},
"model": {
......
This diff is collapsed.
......@@ -93,24 +93,11 @@ def np_example_to_features(
with torch.no_grad():
if is_multimer:
if mode == 'train':
features,gt_features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=True
)
return {k: v for k, v in features.items()}, gt_features
else:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=False
)
return {k: v for k, v in features.items()}
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
else:
features = input_pipeline.process_tensors_from_config(
tensor_dict,
......
......@@ -21,16 +21,17 @@ from openfold.data import (
data_transforms_multimer,
)
def grountruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
return transforms
def groundtruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles]
return transforms
def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled."""
......@@ -112,20 +113,24 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return transforms
def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training"""
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_features = ['all_atom_mask', 'all_atom_positions', 'asym_id', 'sym_id', 'entity_id']
gt_tensors = {k: v for k, v in tensors.items() if k in gt_features}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
gt_tensors = compose(groundtruth_transforms_fns())(gt_tensors)
return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
if is_training:
gt_tensors= prepare_ground_truth_features(tensors)
process_gt_feats = mode_cfg.supervised
gt_tensors = {}
if process_gt_feats:
gt_tensors = prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long)
......@@ -152,10 +157,10 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False)
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
if is_training:
return tensors,gt_tensors
else:
return tensors
if process_gt_feats:
tensors['gt_features'] = gt_tensors
return tensors
@data_transforms.curry1
def compose(x, fs):
......
This diff is collapsed.
......@@ -13,7 +13,7 @@ from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
def parse_file(f, args):
def parse_file(f, args, chain_cluster_size_dict=None):
with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
......@@ -28,6 +28,18 @@ def parse_file(f, args):
local_data["release_date"] = mmcif.header["release_date"]
chain_ids, seqs = list(zip(*mmcif.chain_to_seqres.items()))
if chain_cluster_size_dict is not None:
cluster_sizes = []
for chain_id in chain_ids:
full_name = "_".join([file_id, chain_id])
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
cluster_sizes.append(cluster_size)
local_data["cluster_sizes"] = cluster_sizes
local_data["chain_ids"] = chain_ids
local_data["seqs"] = seqs
local_data["no_chains"] = len(chain_ids)
......@@ -38,8 +50,21 @@ def parse_file(f, args):
def main(args):
chain_cluster_size_dict = None
if args.cluster_file is not None:
chain_cluster_size_dict = {}
with open(args.cluster_file, "r") as fp:
clusters = [l.strip() for l in fp.readlines()]
for cluster in clusters:
chain_ids = cluster.split()
cluster_len = len(chain_ids)
for chain_id in chain_ids:
chain_id = chain_id.upper()
chain_cluster_size_dict[chain_id] = cluster_len
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args)
fn = partial(parse_file, args=args, chain_cluster_size_dict=chain_cluster_size_dict)
data = {}
with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar:
......@@ -63,6 +88,15 @@ if __name__ == "__main__":
"--no_workers", type=int, default=4,
help="Number of workers to use for parsing"
)
parser.add_argument(
"--cluster_file", type=str, default=None,
help=(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser.add_argument(
"--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time"
......
import argparse
import logging
import os
import random
import sys
import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
from openfold.config import model_config
from openfold.data.data_modules import (
OpenFoldDataModule,OpenFoldMultimerDataModule,
DummyDataLoader,
)
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
......@@ -53,7 +46,12 @@ class OpenFoldWrapper(pl.LightningModule):
super(OpenFoldWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss)
if self.config.globals.is_multimer:
self.loss = AlphaFoldMultimerLoss(config.loss)
else:
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
......@@ -256,72 +254,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
features,gt_features = batch
# Log it
if(self.ema.device != features["aatype"].device):
self.ema.to(features["aatype"].device)
# Run the model
outputs = self(features)
# Remove the recycling dimension
features = tensor_tree_map(lambda t: t[..., -1], features)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, (features,gt_features), _return_breakdown=True
)
# Log it
self._log(loss_breakdown, features, outputs)
return loss
def validation_step(self, batch, batch_idx):
features,gt_features = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(features)
# Compute loss and other metrics
features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, (features,gt_features), _return_breakdown=True
)
self._log(loss_breakdown, features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def main(args):
if(args.seed is not None):
seed_everything(args.seed)
......@@ -331,10 +263,8 @@ def main(args):
train=True,
low_prec=(str(args.precision) == "16")
)
if "multimer" in args.config_preset:
model_module = OpenFoldMultimerWrapper(config)
else:
model_module = OpenFoldWrapper(config)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
......@@ -359,7 +289,6 @@ def main(args):
if(args.script_modules):
script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle")
if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment