Commit 581411fa authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished constructing multimer dataset object and now return both...

finished constructing multimer dataset object and now return both all_seq_features and ground truth structure
parent b55ad675
...@@ -473,34 +473,16 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -473,34 +473,16 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
seqs = self.mmcif_data_cache[mmcif_id]['seqs'] seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = "" fasta_str = ""
for c,s in zip(chains,seqs): for c,s in zip(chains,seqs):
fasta_str+f">{mmcif_id}_{c}\n{s}" fasta_str+=f">{mmcif_id}_{c}\n{s}\n"
with temp_fasta_file(fasta_str) as fasta_file:
print(fasta_str) all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir)
import sys
sys.exit()
alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None alignment_index = None
if(self.alignment_index is not None): ground_truth=[]
alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1) for chain in chains:
if(len(spl) == 2): path = os.path.join(self.alignment_dir, f"{mmcif_id}_{chain.upper()}")
file_id, chain_id = spl print(f"path is {path}")
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id)
structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None ext = None
for e in self.supported_exts: for e in self.supported_exts:
if(os.path.exists(path + e)): if(os.path.exists(path + e)):
...@@ -510,50 +492,46 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -510,50 +492,46 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
if(ext is None): if(ext is None):
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
path += ext path += ext
if(ext == ".cif"): if(ext == ".cif"):
data = self._parse_mmcif( data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index, path, mmcif_id, chain, self.alignment_dir, alignment_index,
) )
elif(ext == ".core"): ground_truth.append(data)
data = self.data_pipeline.process_core( elif(ext == ".core"):
path, alignment_dir, alignment_index, data = self.data_pipeline.process_core(
) path, self.alignment_dir, alignment_index,
elif(ext == ".pdb"): )
structure_index = None ground_truth.append(data)
if(self._structure_index is not None): elif(ext == ".pdb"):
structure_index = self._structure_index[name] structure_index = None
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path, pdb_path=path,
alignment_dir=alignment_dir, alignment_dir=self.alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain,
alignment_index=alignment_index, alignment_index=alignment_index,
_structure_index=structure_index, _structure_index=structure_index,
) )
else: ground_truth.append(data)
raise ValueError("Extension branch missing") else:
raise ValueError("Extension branch missing")
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
# if it's training now, then return both all_chain_features and ground_truth
return all_chain_features,ground_truth
else: else:
path = os.path.join(name, name + ".fasta") # if it's inference mode, only need all_chain_features
data = self.data_pipeline.process_fasta( all_chain_features["batch_idx"] = torch.tensor(
fasta_path=path, [idx for _ in range(all_chain_features["aatype"].shape[-1])],
alignment_dir=alignment_dir, dtype=torch.int64,
alignment_index=alignment_index, device=all_chain_features["aatype"].device)
) return all_chain_features
if(self._output_raw):
return data
feats = self.feature_pipeline.process_features(
data, self.mode
)
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._chain_ids)
...@@ -582,7 +560,7 @@ def deterministic_train_filter( ...@@ -582,7 +560,7 @@ def deterministic_train_filter(
def deterministic_multimer_train_filter( def deterministic_multimer_train_filter(
mmcif_data_cache_entry, mmcif_data_cache_entry,
max_resolution: 9., max_resolution:float= 9.,
max_single_aa_prop:float=0.8, max_single_aa_prop:float=0.8,
minimum_number_of_residues:int=200, minimum_number_of_residues:int=200,
) -> bool: ) -> bool:
...@@ -593,7 +571,7 @@ def deterministic_multimer_train_filter( ...@@ -593,7 +571,7 @@ def deterministic_multimer_train_filter(
""" """
# First check resolution # First check resolution
resolution = mmcif_data_cache_entry.get("resolution", None) resolution = mmcif_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution): if(resolution is not None and resolution > max_resolution) or (resolution is None):
return False return False
# Then check if any single amino acid accounts for more than 80% of the complex sequences # Then check if any single amino acid accounts for more than 80% of the complex sequences
seqs = mmcif_data_cache_entry["seqs"] seqs = mmcif_data_cache_entry["seqs"]
...@@ -752,11 +730,11 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -752,11 +730,11 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx = [] selected_idx = []
for i in range(len(mmcif_data_cache)): for i in range(len(mmcif_data_cache)):
mmcif_id = dataset.idx_to_mmcif_id(i) mmcif_id = dataset.idx_to_mmcif_id(i)
print(f"mmcif_id is {mmcif_id} and candidate_idx: {i}")
chains = mmcif_data_cache[mmcif_id]['chain_ids'] chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id] mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(len(chains)>1) and (not deterministic_multimer_train_filter(mmcif_data_cache_entry, if(len(chains)>1) and deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9)): max_resolution=9,minimum_number_of_residues=5):
print(f"{mmcif_id} passed the filter now added: {i}")
selected_idx.append(i) selected_idx.append(i)
return selected_idx return selected_idx
...@@ -781,8 +759,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -781,8 +759,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx = self.filter_samples(dataset_idx) selected_idx = self.filter_samples(dataset_idx)
if len(selected_idx)<self.epoch_len: if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx) self.epoch_len = len(selected_idx)
print(f"self.epoch_len is {self.epoch_len}")
self.datapoints = [(dataset_idx, datapoint_idx) for datapoint_idx in range(self.epoch_len) ] self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
print(f"datapoints is {self.datapoints}") print(f"datapoints is {self.datapoints}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment