Commit 581411fa authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished constructing multimer dataset object and now return both...

finished constructing multimer dataset object and now return both all_seq_features and ground truth structure
parent b55ad675
......@@ -473,34 +473,16 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
for c,s in zip(chains,seqs):
fasta_str+f">{mmcif_id}_{c}\n{s}"
print(fasta_str)
import sys
sys.exit()
alignment_dir = os.path.join(self.alignment_dir, name)
fasta_str+=f">{mmcif_id}_{c}\n{s}\n"
with temp_fasta_file(fasta_str) as fasta_file:
all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir)
alignment_index = None
if(self.alignment_index is not None):
alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name]
ground_truth=[]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id)
structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
for chain in chains:
path = os.path.join(self.alignment_dir, f"{mmcif_id}_{chain.upper()}")
print(f"path is {path}")
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
......@@ -513,47 +495,43 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index,
path, mmcif_id, chain, self.alignment_dir, alignment_index,
)
ground_truth.append(data)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
path, self.alignment_dir, alignment_index,
)
ground_truth.append(data)
elif(ext == ".pdb"):
structure_index = None
if(self._structure_index is not None):
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir,
alignment_dir=self.alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
chain_id=chain,
alignment_index=alignment_index,
_structure_index=structure_index,
)
ground_truth.append(data)
else:
raise ValueError("Extension branch missing")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
alignment_index=alignment_index,
)
if(self._output_raw):
return data
feats = self.feature_pipeline.process_features(
data, self.mode
)
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
# if it's training now, then return both all_chain_features and ground_truth
return all_chain_features,ground_truth
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
else:
# if it's inference mode, only need all_chain_features
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
device=all_chain_features["aatype"].device)
return all_chain_features
return feats
def __len__(self):
return len(self._chain_ids)
......@@ -582,7 +560,7 @@ def deterministic_train_filter(
def deterministic_multimer_train_filter(
mmcif_data_cache_entry,
max_resolution: 9.,
max_resolution:float= 9.,
max_single_aa_prop:float=0.8,
minimum_number_of_residues:int=200,
) -> bool:
......@@ -593,7 +571,7 @@ def deterministic_multimer_train_filter(
"""
# First check resolution
resolution = mmcif_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
if(resolution is not None and resolution > max_resolution) or (resolution is None):
return False
# Then check if any single amino acid accounts for more than 80% of the complex sequences
seqs = mmcif_data_cache_entry["seqs"]
......@@ -752,11 +730,11 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx = []
for i in range(len(mmcif_data_cache)):
mmcif_id = dataset.idx_to_mmcif_id(i)
print(f"mmcif_id is {mmcif_id} and candidate_idx: {i}")
chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(len(chains)>1) and (not deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9)):
if(len(chains)>1) and deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,minimum_number_of_residues=5):
print(f"{mmcif_id} passed the filter now added: {i}")
selected_idx.append(i)
return selected_idx
......@@ -781,8 +759,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx = self.filter_samples(dataset_idx)
if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx)
self.datapoints = [(dataset_idx, datapoint_idx) for datapoint_idx in range(self.epoch_len) ]
print(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
print(f"datapoints is {self.datapoints}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment