Commit 39d4e5c7 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

modified OpenFoldSingleDataset class so that it takes mmcif_cache information;...

modified OpenFoldSingleDataset class so that it takes mmcif_cache information; created a new OpenFoldMultimerDataset class that will create suitable data input for multimer training
parent 226d6011
......@@ -29,6 +29,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
mmcif_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
......@@ -60,6 +61,12 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
mmcif_data_cache_path:
Path to cache of all mmcifs files generated by
scripts/generate_mmcif_cache.py It should be a json file which records
what PDB ID contains which chain(s)
kalign_binary_path:
Path to kalign binary.
max_template_hits:
......@@ -91,6 +98,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
if mmcif_data_cache_path is not None:
with open(mmcif_data_cache_path,"r") as infile:
self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache,dict)
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
......@@ -359,6 +371,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
yield idx
def looped_samples(dataset_idx):
print(f"dataset_idx is {dataset_idx} and start looping samples")
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
......@@ -369,6 +382,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
print(f"candidate_idx: {candidate_idx} and chain_id: {chain_id}")
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
continue
......@@ -417,6 +431,11 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldMultimerDataset(torch.utils.data.Dataset):
"""
Implement the filtering criteria used in AlphaFold Multimer training
"""
class OpenFoldBatchCollator:
def __call__(self, prots):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment