Commit d35816e3 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

add OpenFoldMultimerDataModule

parent 4d9a4bc2
......@@ -19,7 +19,16 @@ from openfold.data import (
templates,
)
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
import contextlib
import tempfile
@contextlib.contextmanager
def temp_fasta_file(sequence_str):
"""function that create temparory fasta file used in multimer datapipeline"""
with tempfile.NamedTemporaryFile("w", suffix=".fasta") as fasta_file:
fasta_file.write(sequence_str)
fasta_file.seek(0)
yield fasta_file.name
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
......@@ -278,7 +287,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
dtype=torch.int64,
device=feats["aatype"].device)
return feats,data
return feats
def __len__(self):
return len(self._chain_ids)
......@@ -399,32 +408,12 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
if self.mmcif_data_cache is not None:
self._mmcifs = list(self.mmcif_data_cache.keys())
self._mmcif_id_to_idx_dict = {
mmcif: i for i, mmcif in enumerate(self._mmcifs)
}
# changed template_featurizer to hmmsearch for now just to run the test
template_featurizer = templates.HmmsearchHitFeaturizer(
......@@ -440,6 +429,9 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
self.multimer_data_pipeline = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=self.data_pipeline
)
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
......@@ -468,14 +460,23 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def mmcif_id_to_idx(self, chain_id):
return self._mmcif_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def idx_to_mmcif_id(self, idx):
return self._mmcifs[idx]
def __getitem__(self, idx):
name = self.idx_to_chain_id(idx)
mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
for c,s in zip(chains,seqs):
fasta_str+f">{mmcif_id}_{c}\n{s}"
print(fasta_str)
import sys
sys.exit()
alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None
......@@ -642,6 +643,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
## TO DO: add filtering cretieria for multimer
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
......@@ -690,12 +692,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldMultimerDataset(torch.utils.data.Dataset):
"""
Implement the filtering criteria used in AlphaFold Multimer training
"""
print(f"datapoints is {self.datapoints}")
class OpenFoldBatchCollator:
def __call__(self, prots):
......@@ -799,6 +796,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
train_mmcif_data_cache_path:Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
......@@ -826,6 +824,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_chain_data_cache_path = train_chain_data_cache_path
self.train_mmcif_data_cache_path=train_mmcif_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_chain_data_cache_path = (
......@@ -1007,7 +1006,110 @@ class OpenFoldDataModule(pl.LightningDataModule):
def predict_dataloader(self):
return self._gen_dataloader("predict")
class OpenFoldMultimerDataModule(OpenFoldDataModule):
"""
Create a datamodule specifically for multimer training
Compared to OpenFoldDataModule, OpenFoldMultimerDataModule
requires mmcif_data_cache_path which is the product of
scripts/generate_mmcif_cache.py mmcif_data_cache_path should be
a file that record what chain(s) each mmcif file has
"""
def __init__(self, config: mlc.ConfigDict,
template_mmcif_dir: str, max_template_date: str,
train_data_dir: Optional[str] = None,
train_mmcif_data_cache_path:Optional[str] = None,
**kwargs):
super(OpenFoldMultimerDataModule,self).__init__(config,
template_mmcif_dir,
max_template_date,
train_data_dir,**kwargs)
self.train_mmcif_data_cache_path = train_mmcif_data_cache_path
self.training_mode = self.train_data_dir is not None
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
)
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
mmcif_data_cache_path=self.train_mmcif_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
alignment_index=self.alignment_index,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index,
)
d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob]
else:
datasets = [train_dataset]
probabilities = [1.]
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
generator=generator,
_roll_at_init=False,
)
if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
filter_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
)
else:
self.eval_dataset = None
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
filter_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment