Commit e72e4e62 authored by Dingquan Yu's avatar Dingquan Yu
Browse files

remove debugging statement

parent 53c03a6a
...@@ -22,9 +22,6 @@ from openfold.utils.tensor_utils import ( ...@@ -22,9 +22,6 @@ from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
def calculate_elapse(start, end, name):
elapse = end - start
print(f"{name} runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes")
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
...@@ -451,9 +448,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -451,9 +448,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
if self.mode == 'train' or self.mode == 'eval':
import time import time
start = time.time() start = time.time()
if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None ext = None
for e in self.supported_exts: for e in self.supported_exts:
...@@ -478,18 +476,14 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -478,18 +476,14 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_path=path, fasta_path=path,
alignment_dir=self.alignment_dir alignment_dir=self.alignment_dir
) )
end = time.time()
calculate_elapse(start, end, "process_fasta in data_modules")
if self._output_raw: if self._output_raw:
return data return data
process_start = time.time()
# process all_chain_features # process all_chain_features
data = self.feature_pipeline.process_features(data, data = self.feature_pipeline.process_features(data,
mode=self.mode, mode=self.mode,
is_multimer=True) is_multimer=True)
end = time.time()
calculate_elapse(process_start, end, "process_features in data_modules")
calculate_elapse(start, end, "dataset get_item in data_modules")
# if it's inference mode, only need all_chain_features # if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor( data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])], [idx for _ in range(data["aatype"].shape[-1])],
......
...@@ -28,16 +28,14 @@ import pickle ...@@ -28,16 +28,14 @@ import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date, NonDaemonicProcess, NonDaemonicProcessPool from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import concurrent
from concurrent.futures import ThreadPoolExecutor
FeatureDict = MutableMapping[str, np.ndarray] FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def calculate_elapse(start, end, name):
elapse = end - start
print(f"{name} runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes")
def make_template_features( def make_template_features(
input_sequence: str, input_sequence: str,
...@@ -738,14 +736,11 @@ class DataPipeline: ...@@ -738,14 +736,11 @@ class DataPipeline:
fp.close() fp.close()
else: else:
# Now will split the following steps into multiple processes # Now will split the following steps into multiple processes
import time
current_directory = os.path.dirname(os.path.abspath(__file__)) current_directory = os.path.dirname(os.path.abspath(__file__))
cmd = f"{current_directory}/parse_msa_files.py" cmd = f"{current_directory}/tools/parse_msa_files.py"
start = time.time()
msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True) msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data = pickle.load((open(msa_data.stdout.lstrip().rstrip(),'rb'))) msa_data = pickle.load((open(msa_data.stdout.lstrip().rstrip(),'rb')))
end = time.time()
calculate_elapse(start, end, "parse_msa_files in data_pipeline")
return msa_data return msa_data
def _parse_template_hit_files( def _parse_template_hit_files(
...@@ -820,19 +815,14 @@ class DataPipeline: ...@@ -820,19 +815,14 @@ class DataPipeline:
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
import time
start_main = time.time()
start = time.time()
msas = self._get_msas( msas = self._get_msas(
alignment_dir, input_sequence, alignment_index alignment_dir, input_sequence, alignment_index
) )
end = time.time()
calculate_elapse(start,end,"get_msas in data_pipeline")
msa_features = make_msa_features( msa_features = make_msa_features(
msas=msas msas=msas
) )
end_main = time.time()
calculate_elapse(start_main, end_main,"process_msa_feats in data_pipeline")
return msa_features return msa_features
# Load and process sequence embedding features # Load and process sequence embedding features
...@@ -1216,8 +1206,10 @@ class DataPipelineMultimer: ...@@ -1216,8 +1206,10 @@ class DataPipelineMultimer:
with open(fasta_path) as f: with open(fasta_path) as f:
input_fasta_str = f.read() input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {} all_chain_features = {}
sequence_features = {} sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1 is_homomer_or_monomer = len(set(input_seqs)) == 1
...@@ -1228,6 +1220,7 @@ class DataPipelineMultimer: ...@@ -1228,6 +1220,7 @@ class DataPipelineMultimer:
) )
continue continue
chain_features = self._process_single_chain( chain_features = self._process_single_chain(
chain_id=desc, chain_id=desc,
sequence=seq, sequence=seq,
...@@ -1236,6 +1229,7 @@ class DataPipelineMultimer: ...@@ -1236,6 +1229,7 @@ class DataPipelineMultimer:
is_homomer_or_monomer=is_homomer_or_monomer is_homomer_or_monomer=is_homomer_or_monomer
) )
chain_features = convert_monomer_features( chain_features = convert_monomer_features(
chain_features, chain_features,
chain_id=desc chain_id=desc
...@@ -1243,12 +1237,15 @@ class DataPipelineMultimer: ...@@ -1243,12 +1237,15 @@ class DataPipelineMultimer:
all_chain_features[desc] = chain_features all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge( np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features,
) )
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
...@@ -1284,18 +1281,21 @@ class DataPipelineMultimer: ...@@ -1284,18 +1281,21 @@ class DataPipelineMultimer:
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
all_chain_features = {} all_chain_features = {}
sequence_features = {} sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1 is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items(): for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id]) desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features: if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy( all_chain_features[desc] = copy.deepcopy(
sequence_features[seq] sequence_features[seq]
) )
continue continue
chain_features = self._process_single_chain( chain_features = self._process_single_chain(
chain_id=desc, chain_id=desc,
sequence=seq, sequence=seq,
...@@ -1304,23 +1304,29 @@ class DataPipelineMultimer: ...@@ -1304,23 +1304,29 @@ class DataPipelineMultimer:
is_homomer_or_monomer=is_homomer_or_monomer is_homomer_or_monomer=is_homomer_or_monomer
) )
chain_features = convert_monomer_features( chain_features = convert_monomer_features(
chain_features, chain_features,
chain_id=desc chain_id=desc
) )
mmcif_feats = self.get_mmcif_features(mmcif, chain_id) mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats) chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge( np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features,
) )
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
return np_example return np_example
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment