"lib/llm/vscode:/vscode.git/clone" did not exist on "411f07e038dfe748f17b8b3b4892d29cbbf5ba1b"
Commit e9794a62 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

remove number of chains filter; no longer hard code is_multimer in multime...

remove number of chains filter; no longer hard code is_multimer in multime classes; remove recycling dimesion in ground truth features
parent 19fe90b2
...@@ -470,6 +470,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -470,6 +470,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids'] chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
is_multimer = (len(chains)>1)
seqs = self.mmcif_data_cache[mmcif_id]['seqs'] seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = "" fasta_str = ""
for c,s in zip(chains,seqs): for c,s in zip(chains,seqs):
...@@ -480,7 +481,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -480,7 +481,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
# process all_chain_features # process all_chain_features
all_chain_features = self.feature_pipeline.process_features(all_chain_features, all_chain_features = self.feature_pipeline.process_features(all_chain_features,
mode=self.mode, mode=self.mode,
is_multimer=True) is_multimer=is_multimer)
alignment_index = None alignment_index = None
ground_truth=[] ground_truth=[]
...@@ -502,13 +503,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -502,13 +503,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self._parse_mmcif( data = self._parse_mmcif(
path, mmcif_id, chain, alignment_dir, alignment_index, path, mmcif_id, chain, alignment_dir, alignment_index,
) )
ground_truth_feats = self.feature_pipeline.process_features(data, self.mode) # since it's ground truth features, change the mode to eval in order to avoid padding
ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
is_multimer=False)
#remove recycling dimension
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats) ground_truth.append(ground_truth_feats)
elif(ext == ".core"): elif(ext == ".core"):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index, path, alignment_dir, alignment_index,
) )
ground_truth_feats = self.feature_pipeline.process_features(data, self.mode) ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats) ground_truth.append(ground_truth_feats)
elif(ext == ".pdb"): elif(ext == ".pdb"):
structure_index = None structure_index = None
...@@ -520,7 +527,9 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -520,7 +527,9 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_index=alignment_index, alignment_index=alignment_index,
_structure_index=structure_index, _structure_index=structure_index,
) )
ground_truth_feats = self.feature_pipeline.process_features(data, self.mode) ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats) ground_truth.append(ground_truth_feats)
else: else:
raise ValueError("Extension branch missing") raise ValueError("Extension branch missing")
...@@ -740,9 +749,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -740,9 +749,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
mmcif_id = dataset.idx_to_mmcif_id(i) mmcif_id = dataset.idx_to_mmcif_id(i)
chains = mmcif_data_cache[mmcif_id]['chain_ids'] chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id] mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(len(chains)>1) and deterministic_multimer_train_filter(mmcif_data_cache_entry, if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,minimum_number_of_residues=5): max_resolution=9,minimum_number_of_residues=5):
print(f"{mmcif_id} passed the filter now added: {i}")
selected_idx.append(i) selected_idx.append(i)
return selected_idx return selected_idx
...@@ -769,8 +777,6 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -769,8 +777,6 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self.epoch_len = len(selected_idx) self.epoch_len = len(selected_idx)
print(f"self.epoch_len is {self.epoch_len}") print(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ] self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
print(f"datapoints is {self.datapoints}")
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __call__(self, prots): def __call__(self, prots):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment