Commit d48c052c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add training parsers

parent eeda001c
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import unittest import unittest
import ml_collections as mlc import ml_collections as mlc
from openfold.features.data_transforms import make_atom14_masks from openfold.features import data_transforms
from openfold.utils.affine_utils import T, affine_vector_to_4x4 from openfold.utils.affine_utils import T, affine_vector_to_4x4
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.loss import ( from openfold.utils.loss import (
...@@ -216,7 +216,7 @@ class TestLoss(unittest.TestCase): ...@@ -216,7 +216,7 @@ class TestLoss(unittest.TestCase):
batch = { batch = {
"atom14_atom_exists": torch.randint(0, 2, (n, 14)), "atom14_atom_exists": torch.randint(0, 2, (n, 14)),
"residue_index": torch.arange(n), "residue_index": torch.arange(n),
"aatype": torch.randint(0, 21, (n,)), "aatype": torch.randint(0, 20, (n,)),
"residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(), "residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
} }
...@@ -250,7 +250,7 @@ class TestLoss(unittest.TestCase): ...@@ -250,7 +250,7 @@ class TestLoss(unittest.TestCase):
batch = { batch = {
"atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)), "atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)),
"residue_index": np.arange(n_res), "residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)), "aatype": np.random.randint(0, 20, (n_res,)),
"residx_atom14_to_atom37": "residx_atom14_to_atom37":
np.random.randint(0, 37, (n_res, 14)).astype(np.int64), np.random.randint(0, 37, (n_res, 14)).astype(np.int64),
} }
...@@ -302,16 +302,20 @@ class TestLoss(unittest.TestCase): ...@@ -302,16 +302,20 @@ class TestLoss(unittest.TestCase):
batch = { batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32), "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"aatype": np.random.randint(0, 21, (n_res,)), "aatype": np.random.randint(0, 20, (n_res,)),
"atom14_gt_positions": np.random.rand(n_res, 14, 3), "atom14_gt_positions": np.random.rand(n_res, 14, 3),
"atom14_gt_exists": "atom14_gt_exists":
np.random.randint(0, 2, (n_res, 14)).astype(np.float32), np.random.randint(0, 2, (n_res, 14)).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
} }
def _build_extra_feats_np(): def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b.update(feats.build_ambiguity_feats(b)) b = data_transforms.make_atom14_masks(b)
b.update(make_atom14_masks(b)) b = data_transforms.make_atom14_positions(b)
return tensor_tree_map(lambda t: np.array(t), b) return tensor_tree_map(lambda t: np.array(t), b)
batch = _build_extra_feats_np() batch = _build_extra_feats_np()
...@@ -585,7 +589,7 @@ class TestLoss(unittest.TestCase): ...@@ -585,7 +589,7 @@ class TestLoss(unittest.TestCase):
) )
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda() atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
batch.update(feats.compute_residx(batch)) batch = data_transforms.make_atom14_masks(batch)
out_repro = violation_loss( out_repro = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol), find_structural_violations(batch, atom14_pred_pos, **c_viol),
...@@ -725,7 +729,7 @@ class TestLoss(unittest.TestCase): ...@@ -725,7 +729,7 @@ class TestLoss(unittest.TestCase):
batch = { batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32), "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"aatype": np.random.randint(0, 21, (n_res,)), "aatype": np.random.randint(0, 20, (n_res,)),
"atom14_gt_positions": "atom14_gt_positions":
np.random.rand(n_res, 14, 3).astype(np.float32), np.random.rand(n_res, 14, 3).astype(np.float32),
"atom14_gt_exists": "atom14_gt_exists":
...@@ -738,8 +742,8 @@ class TestLoss(unittest.TestCase): ...@@ -738,8 +742,8 @@ class TestLoss(unittest.TestCase):
def _build_extra_feats_np(): def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b.update(feats.build_ambiguity_feats(b)) b = data_transforms.make_atom14_masks(b)
b.update(feats.compute_residx(b)) b = data_transforms.make_atom14_positions(b)
return tensor_tree_map(lambda t: np.array(t), b) return tensor_tree_map(lambda t: np.array(t), b)
batch = _build_extra_feats_np() batch = _build_extra_feats_np()
...@@ -764,7 +768,7 @@ class TestLoss(unittest.TestCase): ...@@ -764,7 +768,7 @@ class TestLoss(unittest.TestCase):
value = tree_map(to_tensor, value, np.ndarray) value = tree_map(to_tensor, value, np.ndarray)
atom14_pred_pos = to_tensor(atom14_pred_pos) atom14_pred_pos = to_tensor(atom14_pred_pos)
batch.update(feats.atom37_to_frames(eps=1e-8, **batch)) batch = data_transforms.atom37_to_frames(batch)
batch.update(compute_renamed_ground_truth(batch, atom14_pred_pos)) batch.update(compute_renamed_ground_truth(batch, atom14_pred_pos))
out_repro = sidechain_loss( out_repro = sidechain_loss(
......
...@@ -37,7 +37,6 @@ if(compare_utils.alphafold_is_installed()): ...@@ -37,7 +37,6 @@ if(compare_utils.alphafold_is_installed()):
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
def test_dry_run(self): def test_dry_run(self):
batch_size = consts.batch_size
n_seq = consts.n_seq n_seq = consts.n_seq
n_templ = consts.n_templ n_templ = consts.n_templ
n_res = consts.n_res n_res = consts.n_res
...@@ -53,26 +52,26 @@ class TestModel(unittest.TestCase): ...@@ -53,26 +52,26 @@ class TestModel(unittest.TestCase):
batch = {} batch = {}
tf = torch.randint( tf = torch.randint(
c.input_embedder.tf_dim - 1, size=(batch_size, n_res) c.input_embedder.tf_dim - 1, size=(n_res,)
) )
batch["target_feat"] = nn.functional.one_hot( batch["target_feat"] = nn.functional.one_hot(
tf, c.input_embedder.tf_dim).float() tf, c.input_embedder.tf_dim).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res) batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand( batch["msa_feat"] = torch.rand(
(batch_size, n_seq, n_res, c.input_embedder.msa_dim) (n_seq, n_res, c.input_embedder.msa_dim)
) )
t_feats = random_template_feats(n_templ, n_res, batch_size=batch_size) t_feats = random_template_feats(n_templ, n_res)
batch.update({k:torch.tensor(v) for k, v in t_feats.items()}) batch.update({k:torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats( extra_feats = random_extra_msa_feats(
n_extra_seq, n_res, batch_size=batch_size n_extra_seq, n_res
) )
batch.update({k:torch.tensor(v) for k, v in extra_feats.items()}) batch.update({k:torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint( batch["msa_mask"] = torch.randint(
low=0, high=2, size=(batch_size, n_seq, n_res) low=0, high=2, size=(n_seq, n_res)
).float() ).float()
batch["seq_mask"] = torch.randint( batch["seq_mask"] = torch.randint(
low=0, high=2, size=(batch_size, n_res) low=0, high=2, size=(n_res,)
).float() ).float()
batch.update(make_atom14_masks(batch)) batch.update(make_atom14_masks(batch))
......
import argparse
from functools import partial
import json
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import time
from typing import Optional
import ml_collections as mlc
import pytorch_lightning as pl
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch
from torch.utils.data import RandomSampler
torch.manual_seed(42)
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.features import (
data_pipeline,
feature_pipeline,
mmcif_parsing,
)
from openfold.features import templates
from openfold.features.np.utils import to_date
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None,
mmcif_cache_dir: str = 'tmp/',
use_small_bfd: bool = True,
seed: int = 42,
mode: str = "train",
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing:
* bfd_uniclust_hits.a3m/small_bfd_hits.sto
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
config:
A dataset config object. See openfold.config
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement
"""
super(OpenFoldDataset, self).__init__()
self.data_dir = data_dir
self.alignment_dir = alignment_dir
self.config = config
self.seed = seed
self.mode = mode
valid_modes = ["train", "val", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(mapping_path is None):
self.mapping = {
str(i):os.path.splitext(name)[0]
for i, name in enumerate(os.listdir(alignment_dir))
}
else:
with open(mapping_path, 'r') as fp:
self.mapping = json.load(fp)
template_release_dates_path = os.path.join(
mmcif_cache_dir, "template_release_dates.json"
)
if(not os.path.exists(template_release_dates_path)):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_caches.py before running OpenFold"
)
template_release_dates_path = None
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=(20 if (mode == 'train') else 4),
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_path,
obsolete_pdbs_path=None,
)
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __getitem__(self, idx):
no_batch_modes = len(self.config.common.batch_modes)
batch_mode_idx = idx % no_batch_modes
batch_mode_str = self.config.common.batch_modes[batch_mode_idx][0]
idx = int(idx / no_batch_modes)
name = self.mapping[str(idx)]
if(self.mode == 'train' or self.mode == 'val'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id + '.cif')
with open(path, 'r') as f:
mmcif_string = f.read()
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
alignment_dir = os.path.join(self.alignment_dir, name)
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
)
else:
path = os.path.join(name, name + '.fasta')
data = self.data_pipeline.process_fasta(
fasta_path = feats,
alignment_dir = alignment_dir,
)
feats = self.feature_pipeline.process_features(
data, self.mode, batch_mode_str
)
return feats
def __len__(self):
return len(self.mapping.keys())
class OpenFoldBatchSampler(torch.utils.data.BatchSampler):
"""
A shameful hack.
In AlphaFold, certain batches are designated for loss clamping. The
exact method by residue cropping withing that batch is performed
depends on that designation.
In idiomatic PyTorch, such "batch-wide" properties generally do not
exist; samples are supposed to be generated independently and only
later batched. This class and OpenFoldDataset get around this design
limitation by encoding batch properties in the indices sent to the
Dataset.
While this works (and efficiently), it precludes the future use of an
IterableDataset (such as WebDataset), which doesn't use indices. In
that case, the same can be accomplished by delaying the feature
processing step to the collate_fn, an argument of the DataLoader. That
solution is avoided here because it requires loading an entire batch's
worth of uncropped features into memory at a time.
A third option would be to generate two separate Dataset objects, one
that generates "clamped" batches and another for "unclamped" ones.
However, this would require parsing the precomputed caches of most
proteins twice, once for each loader. Given how lopsided the chances of
drawing a "clamped" batch are, care would also have to be taken not
to allocate too many resources to the less used DataLoader.
"""
def __init__(self, config, **kwargs):
super(OpenFoldBatchSampler, self).__init__(**kwargs)
self.config = config
self.no_batch_modes = len(self.config.common.batch_modes)
def __iter__(self):
it = super().__iter__()
distr = torch.distributions.categorical.Categorical(
torch.tensor(
[prob for name, prob in self.config.common.batch_modes]
)
)
for sample in it:
mode_idx = distr.sample().item()
sample = [s * self.no_batch_modes + mode_idx for s in sample]
yield sample
class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self,
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
mmcif_cache_dir: str = 'tmp/',
**kwargs
):
super(OpenFoldDataModule, self).__init__()
self.config = config
self.template_mmcif_dir = template_mmcif_dir
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.mmcif_cache_dir = mmcif_cache_dir
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self.training_mode = self.train_data_dir is not None
if(self.training_mode and self.train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and self.predict_alingment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
def setup(self, stage):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
mmcif_cache_dir=self.mmcif_cache_dir,
use_small_bfd=self.config.data_module.use_small_bfd,
)
if(self.training_mode):
self.train_dataset = dataset_gen(
data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
mode='train',
)
if(self.val_data_dir is not None):
self.val_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mapping_path=None,
mode='val',
)
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
mapping_path=None,
mode='predict',
)
def train_dataloader(self):
stack_fn = partial(torch.stack, dim=0)
stack = lambda l: dict_multimap(stack_fn, l)
return torch.utils.data.DataLoader(
self.train_dataset,
batch_sampler=OpenFoldBatchSampler(
config=self.config,
sampler=RandomSampler(self.train_dataset),
batch_size=self.config.data_module.data_loaders.batch_size,
drop_last=False,
),
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=stack,
)
def val_dataloader(self):
stack_fn = partial(torch.stack, dim=0)
stack = lambda l: dict_multimap(stack_fn, l)
return torch.utils.data.DataLoader(
self.val_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=stack
)
def predict_dataloader(self):
stack_fn = partial(torch.stack, dim=0)
stack = lambda l: dict_multimap(stack_fn, l)
return torch.utils.data.DataLoader(
self.predict_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=stack
)
class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config):
super(OpenFoldWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config.model)
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(self.model, decay=config.ema.decay)
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
# Run the model
outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss
loss = self.loss(outputs, batch)
return {"loss": loss, "pred": outputs["sm"]["positions"][-1].detach()}
def training_epoch_end(self, outs):
out = outs[-1]["pred"].cpu()
with open("prediction/preds_" + str(time.strftime("%H:%M:%S")) + ".pickle", "wb") as f:
pickle.dump(out, f, protocol=pickle.HIGHEST_PROTOCOL)
def configure_optimizers(self,
learning_rate: float = 1e-3,
eps: float = 1e-8
) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured
return torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
eps=eps
)
def main(args):
config = model_config(
"model_1",
train=True,
low_prec=(args.precision == 16)
)
plugins = []
#plugins.append(DeepSpeedPlugin(config="deepspeed_config.json"))
trainer = pl.Trainer.from_argparse_args(
args,
plugins=plugins,
)
model_module = OpenFoldWrapper(config)
data_module = OpenFoldDataModule(config=config.data, **vars(args))
trainer.fit(model_module, data_module)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"train_data_dir", type=str,
help="Directory containing training mmCIF files"
)
parser.add_argument(
"train_alignment_dir", type=str,
help="Directory containing precomputed training alignments"
)
parser.add_argument(
"template_mmcif_dir", type=str,
help="Directory containing mmCIF files to search for templates"
)
parser.add_argument(
"max_template_date", type=str,
help="""Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target"""
)
parser.add_argument(
"--val_data_dir", type=str, default=None,
help="Directory containing validation mmCIF files"
)
parser.add_argument(
"--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments"
)
parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary"
)
parser.add_argument(
"--train_mapping_path", type=str, default=None,
help="""Optional path to a .json file containing a mapping from
consecutive numerical indices to sample names. Used to filter
the training set"""
)
parser.add_argument(
"--mmcif_cache_dir", type=str, default="tmp/",
help="Directory containing precomputed mmCIF metadata"
)
parser.add_argument(
"--use_small_bfd", type=bool, default=False,
help="Whether to use a reduced version of the BFD database"
)
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed for the DataModule"
)
parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults(
num_sanity_val_steps=0,
)
args = parser.parse_args()
# Seed torch
torch.manual_seed(args.seed)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment