Commit dc5d43d7 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update train_openfold to handle val_mmcif_cache;update datamodule to handle if...

update train_openfold to handle val_mmcif_cache;update datamodule to handle if validation mmcif cache is not provided
parent e097da95
......@@ -23,6 +23,7 @@ import tempfile
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
import random
@contextlib.contextmanager
......@@ -368,15 +369,15 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
"""
super(OpenFoldSingleMultimerDataset, self).__init__()
self.data_dir = data_dir
self.mmcif_data_cache_path=mmcif_data_cache_path
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
if mmcif_data_cache_path is not None:
with open(mmcif_data_cache_path,"r") as infile:
if self.mmcif_data_cache_path is not None:
with open(self.mmcif_data_cache_path,"r") as infile:
self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache,dict)
......@@ -413,13 +414,16 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
c for c in self._chain_ids if c in chains_to_include
]
if self.mmcif_data_cache is not None:
if self.mmcif_data_cache_path is not None:
self._mmcifs = list(self.mmcif_data_cache.keys())
self._mmcif_id_to_idx_dict = {
elif self.mmcif_data_cache_path is None and self.alignment_dir is not None:
self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
else:
raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")
self._mmcif_id_to_idx_dict = {
mmcif: i for i, mmcif in enumerate(self._mmcifs)
}
# changed template_featurizer to hmmsearch for now just to run the test
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir,
......@@ -470,9 +474,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is :{mmcif_id} idx:{idx} and has {len(chains)}chains")
alignment_index = None
if(self.mode == 'train' or self.mode == 'eval'):
path = os.path.join(self.data_dir, f"{mmcif_id}")
......@@ -715,17 +716,18 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
def filter_samples(self,dataset_idx):
dataset = self.datasets[dataset_idx]
mmcif_data_cache = dataset.mmcif_data_cache
mmcif_data_cache = dataset.mmcif_data_cache if hasattr(dataset,"mmcif_data_cache") else None
selected_idx = []
for i in range(len(mmcif_data_cache)):
mmcif_id = dataset.idx_to_mmcif_id(i)
chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,
minimum_number_of_residues=5):
selected_idx.append(i)
if mmcif_data_cache is not None:
for i in range(len(mmcif_data_cache)):
mmcif_id = dataset.idx_to_mmcif_id(i)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,
minimum_number_of_residues=5):
selected_idx.append(i)
else:
selected_idx = list(range(len(dataset._mmcif_id_to_idx_dict)))
return selected_idx
def __getitem__(self, idx):
......@@ -746,6 +748,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self.datapoints = []
for dataset_idx in dataset_choices:
selected_idx = self.filter_samples(dataset_idx)
random.shuffle(selected_idx)
if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx)
logging.info(f"self.epoch_len is {self.epoch_len}")
......@@ -849,7 +852,6 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
self.stage = stage
self.generator = generator
print('initialised a multimer dataloader')
def __iter__(self):
it = super().__iter__()
......@@ -1092,6 +1094,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
template_mmcif_dir: str, max_template_date: str,
train_data_dir: Optional[str] = None,
train_mmcif_data_cache_path:Optional[str] = None,
val_mmcif_data_cache_path:Optional[str] = None,
**kwargs):
super(OpenFoldMultimerDataModule,self).__init__(config,
template_mmcif_dir,
......@@ -1099,6 +1102,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
train_data_dir,**kwargs)
self.train_mmcif_data_cache_path = train_mmcif_data_cache_path
self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self):
# Most of the arguments are the same for the three datasets
......@@ -1167,6 +1171,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mmcif_data_cache_path=self.val_mmcif_data_cache_path,
filter_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
......@@ -1206,7 +1211,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
batch_size=1,
num_workers=self.config.data_module.data_loaders.num_workers,
)
print(f"generated training dataloader")
return dl
class DummyDataset(torch.utils.data.Dataset):
......
......@@ -509,6 +509,10 @@ if __name__ == "__main__":
"--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments"
)
parser.add_argument(
"--val_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment