"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "238866194bc5d909dc935ed2612f3248cd916854"
Commit dc5d43d7 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update train_openfold to handle val_mmcif_cache;update datamodule to handle if...

update train_openfold to handle val_mmcif_cache;update datamodule to handle if validation mmcif cache is not provided
parent e097da95
...@@ -23,6 +23,7 @@ import tempfile ...@@ -23,6 +23,7 @@ import tempfile
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
import random
@contextlib.contextmanager @contextlib.contextmanager
...@@ -368,15 +369,15 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -368,15 +369,15 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
""" """
super(OpenFoldSingleMultimerDataset, self).__init__() super(OpenFoldSingleMultimerDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.mmcif_data_cache_path=mmcif_data_cache_path
self.chain_data_cache = None self.chain_data_cache = None
if chain_data_cache_path is not None: if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp: with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp) self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict) assert isinstance(self.chain_data_cache, dict)
if mmcif_data_cache_path is not None: if self.mmcif_data_cache_path is not None:
with open(mmcif_data_cache_path,"r") as infile: with open(self.mmcif_data_cache_path,"r") as infile:
self.mmcif_data_cache = json.load(infile) self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache,dict) assert isinstance(self.mmcif_data_cache,dict)
...@@ -413,13 +414,16 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -413,13 +414,16 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
c for c in self._chain_ids if c in chains_to_include c for c in self._chain_ids if c in chains_to_include
] ]
if self.mmcif_data_cache is not None: if self.mmcif_data_cache_path is not None:
self._mmcifs = list(self.mmcif_data_cache.keys()) self._mmcifs = list(self.mmcif_data_cache.keys())
self._mmcif_id_to_idx_dict = { elif self.mmcif_data_cache_path is None and self.alignment_dir is not None:
self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
else:
raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")
self._mmcif_id_to_idx_dict = {
mmcif: i for i, mmcif in enumerate(self._mmcifs) mmcif: i for i, mmcif in enumerate(self._mmcifs)
} }
# changed template_featurizer to hmmsearch for now just to run the test # changed template_featurizer to hmmsearch for now just to run the test
template_featurizer = templates.HmmsearchHitFeaturizer( template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir, mmcif_dir=template_mmcif_dir,
...@@ -470,9 +474,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -470,9 +474,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is :{mmcif_id} idx:{idx} and has {len(chains)}chains")
alignment_index = None alignment_index = None
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
...@@ -715,17 +716,18 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -715,17 +716,18 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
def filter_samples(self,dataset_idx): def filter_samples(self,dataset_idx):
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
mmcif_data_cache = dataset.mmcif_data_cache mmcif_data_cache = dataset.mmcif_data_cache if hasattr(dataset,"mmcif_data_cache") else None
selected_idx = [] selected_idx = []
for i in range(len(mmcif_data_cache)): if mmcif_data_cache is not None:
mmcif_id = dataset.idx_to_mmcif_id(i) for i in range(len(mmcif_data_cache)):
chains = mmcif_data_cache[mmcif_id]['chain_ids'] mmcif_id = dataset.idx_to_mmcif_id(i)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id] mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry, if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9, max_resolution=9,
minimum_number_of_residues=5): minimum_number_of_residues=5):
selected_idx.append(i) selected_idx.append(i)
else:
selected_idx = list(range(len(dataset._mmcif_id_to_idx_dict)))
return selected_idx return selected_idx
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -746,6 +748,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -746,6 +748,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self.datapoints = [] self.datapoints = []
for dataset_idx in dataset_choices: for dataset_idx in dataset_choices:
selected_idx = self.filter_samples(dataset_idx) selected_idx = self.filter_samples(dataset_idx)
random.shuffle(selected_idx)
if len(selected_idx)<self.epoch_len: if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx) self.epoch_len = len(selected_idx)
logging.info(f"self.epoch_len is {self.epoch_len}") logging.info(f"self.epoch_len is {self.epoch_len}")
...@@ -849,7 +852,6 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader): ...@@ -849,7 +852,6 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
self.stage = stage self.stage = stage
self.generator = generator self.generator = generator
print('initialised a multimer dataloader')
def __iter__(self): def __iter__(self):
it = super().__iter__() it = super().__iter__()
...@@ -1092,6 +1094,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1092,6 +1094,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
template_mmcif_dir: str, max_template_date: str, template_mmcif_dir: str, max_template_date: str,
train_data_dir: Optional[str] = None, train_data_dir: Optional[str] = None,
train_mmcif_data_cache_path:Optional[str] = None, train_mmcif_data_cache_path:Optional[str] = None,
val_mmcif_data_cache_path:Optional[str] = None,
**kwargs): **kwargs):
super(OpenFoldMultimerDataModule,self).__init__(config, super(OpenFoldMultimerDataModule,self).__init__(config,
template_mmcif_dir, template_mmcif_dir,
...@@ -1099,6 +1102,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1099,6 +1102,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
train_data_dir,**kwargs) train_data_dir,**kwargs)
self.train_mmcif_data_cache_path = train_mmcif_data_cache_path self.train_mmcif_data_cache_path = train_mmcif_data_cache_path
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self): def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
...@@ -1167,6 +1171,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1167,6 +1171,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
mmcif_data_cache_path=self.val_mmcif_data_cache_path,
filter_path=None, filter_path=None,
max_template_hits=self.config.eval.max_template_hits, max_template_hits=self.config.eval.max_template_hits,
mode="eval", mode="eval",
...@@ -1206,7 +1211,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1206,7 +1211,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
batch_size=1, batch_size=1,
num_workers=self.config.data_module.data_loaders.num_workers, num_workers=self.config.data_module.data_loaders.num_workers,
) )
print(f"generated training dataloader")
return dl return dl
class DummyDataset(torch.utils.data.Dataset): class DummyDataset(torch.utils.data.Dataset):
......
...@@ -509,6 +509,10 @@ if __name__ == "__main__": ...@@ -509,6 +509,10 @@ if __name__ == "__main__":
"--val_alignment_dir", type=str, default=None, "--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments" help="Directory containing precomputed validation alignments"
) )
parser.add_argument(
"--val_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument( parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign', "--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary" help="Path to the kalign binary"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment