"components/vscode:/vscode.git/clone" did not exist on "e1a95dabb2a0c6da332c3f2c06ebc5fe4b3f70e7"
Commit 581411fa authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished constructing multimer dataset object and now return both...

finished constructing multimer dataset object and now return both all_seq_features and ground truth structure
parent b55ad675
......@@ -473,34 +473,16 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
for c,s in zip(chains,seqs):
fasta_str+f">{mmcif_id}_{c}\n{s}"
print(fasta_str)
import sys
sys.exit()
alignment_dir = os.path.join(self.alignment_dir, name)
fasta_str+=f">{mmcif_id}_{c}\n{s}\n"
with temp_fasta_file(fasta_str) as fasta_file:
all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir)
alignment_index = None
if(self.alignment_index is not None):
alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name]
ground_truth=[]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id)
structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
for chain in chains:
path = os.path.join(self.alignment_dir, f"{mmcif_id}_{chain.upper()}")
print(f"path is {path}")
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
......@@ -510,50 +492,46 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index,
)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
elif(ext == ".pdb"):
structure_index = None
if(self._structure_index is not None):
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
alignment_index=alignment_index,
_structure_index=structure_index,
)
else:
raise ValueError("Extension branch missing")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path, mmcif_id, chain, self.alignment_dir, alignment_index,
)
ground_truth.append(data)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, self.alignment_dir, alignment_index,
)
ground_truth.append(data)
elif(ext == ".pdb"):
structure_index = None
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=self.alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain,
alignment_index=alignment_index,
_structure_index=structure_index,
)
ground_truth.append(data)
else:
raise ValueError("Extension branch missing")
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
# if it's training now, then return both all_chain_features and ground_truth
return all_chain_features,ground_truth
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
alignment_index=alignment_index,
)
# if it's inference mode, only need all_chain_features
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
return all_chain_features
if(self._output_raw):
return data
feats = self.feature_pipeline.process_features(
data, self.mode
)
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats
def __len__(self):
return len(self._chain_ids)
......@@ -582,7 +560,7 @@ def deterministic_train_filter(
def deterministic_multimer_train_filter(
mmcif_data_cache_entry,
max_resolution: 9.,
max_resolution:float= 9.,
max_single_aa_prop:float=0.8,
minimum_number_of_residues:int=200,
) -> bool:
......@@ -593,7 +571,7 @@ def deterministic_multimer_train_filter(
"""
# First check resolution
resolution = mmcif_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
if(resolution is not None and resolution > max_resolution) or (resolution is None):
return False
# Then check if any single amino acid accounts for more than 80% of the complex sequences
seqs = mmcif_data_cache_entry["seqs"]
......@@ -752,11 +730,11 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx = []
for i in range(len(mmcif_data_cache)):
mmcif_id = dataset.idx_to_mmcif_id(i)
print(f"mmcif_id is {mmcif_id} and candidate_idx: {i}")
chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(len(chains)>1) and (not deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9)):
if(len(chains)>1) and deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,minimum_number_of_residues=5):
print(f"{mmcif_id} passed the filter now added: {i}")
selected_idx.append(i)
return selected_idx
......@@ -781,8 +759,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx = self.filter_samples(dataset_idx)
if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx)
self.datapoints = [(dataset_idx, datapoint_idx) for datapoint_idx in range(self.epoch_len) ]
print(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
print(f"datapoints is {self.datapoints}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment