Commit e9794a62 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

remove number of chains filter; no longer hard code is_multimer in multime...

remove number of chains filter; no longer hard code is_multimer in multime classes; remove recycling dimesion in ground truth features
parent 19fe90b2
......@@ -470,6 +470,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
is_multimer = (len(chains)>1)
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
for c,s in zip(chains,seqs):
......@@ -480,7 +481,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
# process all_chain_features
all_chain_features = self.feature_pipeline.process_features(all_chain_features,
mode=self.mode,
is_multimer=True)
is_multimer=is_multimer)
alignment_index = None
ground_truth=[]
......@@ -502,13 +503,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self._parse_mmcif(
path, mmcif_id, chain, alignment_dir, alignment_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, self.mode)
# since it's ground truth features, change the mode to eval in order to avoid padding
ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
is_multimer=False)
#remove recycling dimension
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, self.mode)
ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
elif(ext == ".pdb"):
structure_index = None
......@@ -520,7 +527,9 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_index=alignment_index,
_structure_index=structure_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, self.mode)
ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
else:
raise ValueError("Extension branch missing")
......@@ -740,9 +749,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
mmcif_id = dataset.idx_to_mmcif_id(i)
chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(len(chains)>1) and deterministic_multimer_train_filter(mmcif_data_cache_entry,
if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,minimum_number_of_residues=5):
print(f"{mmcif_id} passed the filter now added: {i}")
selected_idx.append(i)
return selected_idx
......@@ -769,8 +777,6 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self.epoch_len = len(selected_idx)
print(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
print(f"datapoints is {self.datapoints}")
class OpenFoldBatchCollator:
def __call__(self, prots):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment