Commit a59ae7c1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor data pipeline; add distillation parsing

parent 07e64267
from . import model
from . import utils
from . import np
__all__ = ["model", "utils", "np"]
......@@ -183,6 +183,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_template_hits": 4,
"max_templates": 4,
"num_ensemble": 1,
"crop": False,
......@@ -194,6 +195,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_template_hits": 4,
"max_templates": 4,
"num_ensemble": 1,
"crop": False,
......@@ -205,6 +207,7 @@ config = mlc.ConfigDict(
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_template_hits": 20,
"max_templates": 4,
"num_ensemble": 1,
"crop": True,
......
from functools import partial
import json
import logging
import os
from typing import Optional, Sequence
import ml_collections as mlc
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
from openfold.data import (
data_pipeline,
feature_pipeline,
mmcif_parsing,
templates,
)
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None,
max_template_hits: int = 4,
template_release_dates_cache_path: Optional[str] = None,
use_small_bfd: bool = True,
output_raw: bool = False,
mode: str = "train",
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing:
* bfd_uniclust_hits.a3m/small_bfd_hits.sto
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
config:
A dataset config object. See openfold.config
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.alignment_dir = alignment_dir
self.config = config
self.output_raw = output_raw
self.mode = mode
valid_modes = ["train", "val", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(mapping_path is None):
self.mapping = {
str(i):os.path.splitext(name)[0]
for i, name in enumerate(os.listdir(alignment_dir))
}
else:
with open(mapping_path, 'r') as fp:
self.mapping = json.load(fp)
if(template_release_dates_cache_path is None):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_caches.py before running OpenFold"
)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=max_template_hits,
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=None,
)
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
)
if(not self.output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir):
with open(path, 'r') as f:
mmcif_string = f.read()
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
)
return data
def __getitem__(self, idx):
name = self.mapping[str(idx)]
alignment_dir = os.path.join(self.alignment_dir, name)
if(self.mode == 'train' or self.mode == 'val'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id + '.cif')
if(os.path.exists(path)):
data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir
)
else:
# Try to search for a distillation PDB file instead
path = os.path.join(self.data_dir, file_id + '.pdb')
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir
)
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=feats,
alignment_dir=alignment_dir,
)
if(self.output_raw):
return data
feats = self.feature_pipeline.process_features(
data, self.mode, "unclamped"
)
return feats
def __len__(self):
return len(self.mapping.keys())
def looped_sequence(sequence):
while True:
for x in sequence:
yield x
class OpenFoldDataset(torch.utils.data.IterableDataset):
"""
The Dataset is written to accommodate the requirement that proteins are
sampled from the distillation set with some probability p
and from the PDB set with probability (1 - p). Proteins are sampled
from both sets without replacement, and as soon as either set is
emptied, it is refilled. The Dataset therefore has an arbitrary length.
Nevertheless, for compatibility with various PyTorch Lightning
functionalities, it is possible to specify an epoch length. This length
has no effect on the output of the Dataset.
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
):
self.datasets = datasets
self.samplers = [
looped_sequence(RandomSampler(d)) for d in datasets
]
self.batch_size = batch_size
self.epoch_len = epoch_len
self.distr = torch.distributions.categorical.Categorical(
probs=torch.tensor(probabilities),
)
def __iter__(self):
return self
def __next__(self):
dataset_idx = self.distr.sample()
sampler = self.samplers[dataset_idx]
element_idx = next(sampler)
return self.datasets[dataset_idx][element_idx]
def __len__(self):
return self.epoch_len
class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"):
self.config = config
batch_modes = config.common.batch_modes
batch_mode_names, batch_mode_probs = list(zip(*batch_modes))
self.batch_mode_names = batch_mode_names
self.batch_mode_probs = batch_mode_probs
self.generator = generator
self.stage = stage
self.batch_mode_probs_tensor = torch.tensor(self.batch_mode_probs)
self.feature_pipeline = feature_pipeline.FeaturePipeline(self.config)
def __call__(self, raw_prots):
# We use torch.multinomial here rather than Categorical because the
# latter doesn't accept a generator for some reason
batch_mode_idx = torch.multinomial(
self.batch_mode_probs_tensor,
1,
generator=self.generator
).item()
batch_mode_name = self.batch_mode_names[batch_mode_idx]
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage, batch_mode_name
)
processed_prots.append(features)
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots)
class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self,
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
self.config = config
self.template_mmcif_dir = template_mmcif_dir
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.distillation_mapping_path = distillation_mapping_path
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self.training_mode = self.train_data_dir is not None
if(self.training_mode and self.train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and self.predict_alingment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
def setup(self, stage):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
use_small_bfd=self.config.data_module.use_small_bfd,
)
if(self.training_mode):
self.train_dataset = dataset_gen(
data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
max_template_hits=self.config.train.max_template_hits,
output_raw=True,
mode="train",
)
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits,
output_raw=True,
mode="train",
)
d_prob = self.config.train.distillation_prob
self.train_dataset = OpenFoldDataset(
datasets=[self.train_dataset, distillation_dataset],
probabilities=[1 - d_prob, d_prob],
epoch_len=(
self.train_dataset.len() + distillation_dataset.len()
),
)
if(self.val_data_dir is not None):
self.val_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mapping_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
)
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
mapping_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
self.batch_collation_seed = torch.Generator().seed()
def _gen_batch_collator(self, stage):
""" We want each process to use the same batch collation seed """
generator = torch.Generator()
generator = generator.manual_seed(self.batch_collation_seed)
collate_fn = OpenFoldBatchCollator(
self.config, generator, stage
)
return collate_fn
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("train"),
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("eval")
)
def predict_dataloader(self):
return torch.utils.data.DataLoader(
self.predict_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("eval")
)
......@@ -22,7 +22,7 @@ import numpy as np
from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants
from openfold.np import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray]
......@@ -81,9 +81,43 @@ def make_mmcif_features(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
return mmcif_feats
def make_pdb_features(
protein_object: protein.Protein,
description: str,
confidence_threshold: float = 0.5,
) -> FeatureDict:
pdb_feats = {}
pdb_feats.update(
make_sequence_features(
sequence=protein_object.aatype,
description=description,
num_res=len(protein_object.aatype),
)
)
all_atom_positions = protein_object.atom_positions
all_atom_mask = protein_object.atom_mask
high_confidence = protein.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1)
for i, confident in enumerate(high_confidence):
if(not confident):
all_atom_mask[i] = 0
pdb_feats["all_atom_positions"] = all_atom_positions
pdb_feats["all_atom_mask"] = all_atom_mask
pdb_feats["is_distillation"] = np.array(1.).astype(np.float32)
return pdb_feats
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix],
......@@ -311,7 +345,11 @@ class DataPipeline:
alignments["mgnify_deletion_matrix"],
),
)
return {**sequence_features, **msa_features, **templates_result.data}
return {
**sequence_features,
**msa_features,
**templates_result.features
}
def process_mmcif(
self,
......@@ -357,4 +395,47 @@ class DataPipeline:
),
)
return {**mmcif_feats, **templates_result.data, **msa_features}
return {**mmcif_feats, **templates_result.features, **msa_features}
def process_pdb(
self,
pdb_path: str,
alignment_dir: str,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
"""
with open(pdb_path, 'r') as f:
pdb_str = pdb_path
protein_object = protein.from_pdb_string(pdb_str)
pdb_feats = make_pdb_features(protein_object)
mmcif_feats = make_mmcif_features(mmcif, chain_id)
alignments = self._parse_alignment_output(alignment_dir)
input_sequence = mmcif.chain_to_seqres[chain_id]
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments["hhsearch_hits"],
)
msa_features = make_msa_features(
msas=(
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
return {**mmcif_feats, **templates_result.features, **msa_features}
......@@ -21,7 +21,7 @@ import numpy as np
import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.tools import residue_constants as rc
from openfold.np import residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import (
tree_map,
......@@ -1104,7 +1104,7 @@ def random_crop_to_size(
else:
num_templates = protein["aatype"].new_zeros((1,))
num_res_crop_size = min(seq_length, crop_size)
num_res_crop_size = min(seq_length.item(), crop_size)
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
......@@ -1112,18 +1112,16 @@ def random_crop_to_size(
g.manual_seed(seed)
def _randint(lower, upper):
return int(
torch.randint(
return torch.randint(
lower,
upper,
upper + 1,
(1,),
device=protein["seq_length"].device,
generator=g,
)[0]
)
)[0].item()
if subsample_templates:
templates_crop_start = _randint(0, num_templates + 1)
templates_crop_start = _randint(0, num_templates)
templates_select_indices = torch.randperm(
num_templates, device=protein["seq_length"].device, generator=g
)
......
......@@ -130,7 +130,7 @@ def _is_after_cutoff(
else:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
logging.warning(
logging.info(
"Template structure not in release dates dict: %s", pdb_id
)
return False
......
......@@ -72,8 +72,8 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
# args = checkpoint(chunker(s, e), *args)
args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = checkpoint(chunker(s, e), *args)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args)
return args
......@@ -1464,10 +1464,7 @@ class AlphaFoldLoss(nn.Module):
for k, loss_fn in loss_fns.items():
weight = self.config[k].weight
if weight:
# print(k)
loss = loss_fn()
# print(weight * loss)
cum_loss = cum_loss + weight * loss
# print(cum_loss)
return cum_loss
......@@ -87,7 +87,7 @@ def main(args):
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
config.data.predict.num_ensemble = num_ensemble
feature_processor = feature_pipeline.FeaturePipeline(config)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
alignment_dir = os.path.join(output_dir_base, "alignments")
......
import argparse
from functools import partial
import json
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import random
import time
from typing import Optional
import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch
from torch.utils.data import RandomSampler
torch.manual_seed(42)
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.features import (
data_pipeline,
feature_pipeline,
mmcif_parsing,
from openfold.data.data_modules import (
OpenFoldDataModule,
)
from openfold.features import templates
from openfold.features.np.utils import to_date
from openfold.model.model import AlphaFold
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None,
mmcif_cache_dir: str = 'tmp/',
use_small_bfd: bool = True,
seed: int = 42,
mode: str = "train",
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing:
* bfd_uniclust_hits.a3m/small_bfd_hits.sto
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
config:
A dataset config object. See openfold.config
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement
"""
super(OpenFoldDataset, self).__init__()
self.data_dir = data_dir
self.alignment_dir = alignment_dir
self.config = config
self.seed = seed
self.mode = mode
valid_modes = ["train", "val", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(mapping_path is None):
self.mapping = {
str(i):os.path.splitext(name)[0]
for i, name in enumerate(os.listdir(alignment_dir))
}
else:
with open(mapping_path, 'r') as fp:
self.mapping = json.load(fp)
template_release_dates_path = os.path.join(
mmcif_cache_dir, "template_release_dates.json"
)
if(not os.path.exists(template_release_dates_path)):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_caches.py before running OpenFold"
)
template_release_dates_path = None
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=(20 if (mode == 'train') else 4),
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_path,
obsolete_pdbs_path=None,
)
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __getitem__(self, idx):
no_batch_modes = len(self.config.common.batch_modes)
batch_mode_idx = idx % no_batch_modes
batch_mode_str = self.config.common.batch_modes[batch_mode_idx][0]
idx = int(idx / no_batch_modes)
name = self.mapping[str(idx)]
if(self.mode == 'train' or self.mode == 'val'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id + '.cif')
with open(path, 'r') as f:
mmcif_string = f.read()
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
alignment_dir = os.path.join(self.alignment_dir, name)
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
)
else:
path = os.path.join(name, name + '.fasta')
data = self.data_pipeline.process_fasta(
fasta_path = feats,
alignment_dir = alignment_dir,
)
feats = self.feature_pipeline.process_features(
data, self.mode, batch_mode_str
)
return feats
def __len__(self):
return len(self.mapping.keys())
class OpenFoldBatchSampler(torch.utils.data.BatchSampler):
"""
A shameful hack.
In AlphaFold, certain batches are designated for loss clamping. The
exact method by residue cropping withing that batch is performed
depends on that designation.
In idiomatic PyTorch, such "batch-wide" properties generally do not
exist; samples are supposed to be generated independently and only
later batched. This class and OpenFoldDataset get around this design
limitation by encoding batch properties in the indices sent to the
Dataset.
While this works (and efficiently), it precludes the future use of an
IterableDataset (such as WebDataset), which doesn't use indices. In
that case, the same can be accomplished by delaying the feature
processing step to the collate_fn, an argument of the DataLoader. That
solution is avoided here because it requires loading an entire batch's
worth of uncropped features into memory at a time.
A third option would be to generate two separate Dataset objects, one
that generates "clamped" batches and another for "unclamped" ones.
However, this would require parsing the precomputed caches of most
proteins twice, once for each loader. Given how lopsided the chances of
drawing a "clamped" batch are, care would also have to be taken not
to allocate too many resources to the less used DataLoader.
"""
def __init__(self, config, **kwargs):
super(OpenFoldBatchSampler, self).__init__(**kwargs)
self.config = config
self.no_batch_modes = len(self.config.common.batch_modes)
def __iter__(self):
it = super().__iter__()
distr = torch.distributions.categorical.Categorical(
torch.tensor(
[prob for name, prob in self.config.common.batch_modes]
)
)
for sample in it:
mode_idx = distr.sample().item()
sample = [s * self.no_batch_modes + mode_idx for s in sample]
yield sample
class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self,
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
mmcif_cache_dir: str = 'tmp/',
**kwargs
):
super(OpenFoldDataModule, self).__init__()
self.config = config
self.template_mmcif_dir = template_mmcif_dir
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.mmcif_cache_dir = mmcif_cache_dir
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self.training_mode = self.train_data_dir is not None
if(self.training_mode and self.train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and self.predict_alingment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
def setup(self, stage):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
mmcif_cache_dir=self.mmcif_cache_dir,
use_small_bfd=self.config.data_module.use_small_bfd,
)
if(self.training_mode):
self.train_dataset = dataset_gen(
data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
mode='train',
)
if(self.val_data_dir is not None):
self.val_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mapping_path=None,
mode='val',
)
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
mapping_path=None,
mode='predict',
)
def train_dataloader(self):
stack_fn = partial(torch.stack, dim=0)
stack = lambda l: dict_multimap(stack_fn, l)
return torch.utils.data.DataLoader(
self.train_dataset,
batch_sampler=OpenFoldBatchSampler(
config=self.config,
sampler=RandomSampler(self.train_dataset),
batch_size=self.config.data_module.data_loaders.batch_size,
drop_last=False,
),
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=stack,
)
def val_dataloader(self):
stack_fn = partial(torch.stack, dim=0)
stack = lambda l: dict_multimap(stack_fn, l)
return torch.utils.data.DataLoader(
self.val_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=stack
)
def predict_dataloader(self):
stack_fn = partial(torch.stack, dim=0)
stack = lambda l: dict_multimap(stack_fn, l)
return torch.utils.data.DataLoader(
self.predict_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=stack
)
from openfold.utils.tensor_utils import tensor_tree_map
class OpenFoldWrapper(pl.LightningModule):
......@@ -380,6 +61,8 @@ class OpenFoldWrapper(pl.LightningModule):
eps=eps
)
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def main(args):
config = model_config(
......@@ -421,6 +104,14 @@ if __name__ == "__main__":
help="""Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target"""
)
parser.add_argument(
"--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files"
)
parser.add_argument(
"--distillation_alignment_dir", type=str, default=None,
help="Directory containing precomputed distillation alignments"
)
parser.add_argument(
"--val_data_dir", type=str, default=None,
help="Directory containing validation mmCIF files"
......@@ -440,16 +131,20 @@ if __name__ == "__main__":
the training set"""
)
parser.add_argument(
"--mmcif_cache_dir", type=str, default="tmp/",
help="Directory containing precomputed mmCIF metadata"
"--distillation_mapping_path", type=str, default=None,
help="""See --train_mapping_path"""
)
parser.add_argument(
"--template_release_dates_cache_path", type=str, default=None,
help="Output of templates.generate_mmcif_cache"
)
parser.add_argument(
"--use_small_bfd", type=bool, default=False,
help="Whether to use a reduced version of the BFD database"
)
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed for the DataModule"
"--seed", type=int, default=None,
help="Random seed"
)
parser = pl.Trainer.add_argparse_args(parser)
......@@ -459,7 +154,9 @@ if __name__ == "__main__":
args = parser.parse_args()
# Seed torch
if(args.seed is not None):
torch.manual_seed(args.seed)
random.seed(args.seed + 1)
np.random.seed(args.seed + 2)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment