Commit aec12764 authored by Dingquan Yu's avatar Dingquan Yu
Browse files

added timing steps

parent cdeb8d1b
...@@ -22,6 +22,9 @@ from openfold.utils.tensor_utils import ( ...@@ -22,6 +22,9 @@ from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
def calculate_elapse(start, end):
elapse = end - start
print(f"this function runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes")
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
...@@ -195,7 +198,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -195,7 +198,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_index=alignment_index, alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled seqemb_mode=self.config.seqemb_mode.enabled
) )
return data return data
def chain_id_to_idx(self, chain_id): def chain_id_to_idx(self, chain_id):
...@@ -423,24 +425,25 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -423,24 +425,25 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index): def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
import time
mmcif_object = mmcif_parsing.parse( mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string file_id=file_id, mmcif_string=mmcif_string
) )
# Crash if an error is encountered. Any parsing errors should have # Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage. # been dealt with at the alignment stage.
if mmcif_object.mmcif_object is None: if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0] raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object mmcif_object = mmcif_object.mmcif_object
# print(f" ###### line 442 started mmcif_processing")
# start = time.time()
data = self.data_pipeline.process_mmcif( data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
alignment_index=alignment_index alignment_index=alignment_index
) )
# end = time.time()
# calculate_elapse(start , end)s
return data return data
def mmcif_id_to_idx(self, mmcif_id): def mmcif_id_to_idx(self, mmcif_id):
...@@ -450,6 +453,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -450,6 +453,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
return self._mmcifs[idx] return self._mmcifs[idx]
def __getitem__(self, idx): def __getitem__(self, idx):
print(f"####### line 456 idx is {idx}")
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
...@@ -741,9 +746,7 @@ class OpenFoldMultimerDataset(OpenFoldDataset): ...@@ -741,9 +746,7 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
generator=self.generator, generator=self.generator,
) )
samples = samples.squeeze() samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s] cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache: for datapoint_idx in cache:
yield datapoint_idx yield datapoint_idx
......
...@@ -35,6 +35,9 @@ from openfold.np import residue_constants, protein ...@@ -35,6 +35,9 @@ from openfold.np import residue_constants, protein
FeatureDict = MutableMapping[str, np.ndarray] FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def calculate_elapse(start, end, name):
elapse = end - start
print(f"{name} runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes")
def make_template_features( def make_template_features(
input_sequence: str, input_sequence: str,
...@@ -739,13 +742,21 @@ class DataPipeline: ...@@ -739,13 +742,21 @@ class DataPipeline:
filename, ext = os.path.splitext(f) filename, ext = os.path.splitext(f)
if(ext == ".a3m"): if(ext == ".a3m"):
import time
start = time.time()
with open(path, "r") as fp: with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read()) msa = parsers.parse_a3m(fp.read())
end = time.time()
calculate_elapse(start, end, "parser.parse_a3m")
elif(ext == ".sto" and not "hmm_output" == filename): elif(ext == ".sto" and not "hmm_output" == filename):
import time
start = time.time()
with open(path, "r") as fp: with open(path, "r") as fp:
msa = parsers.parse_stockholm( msa = parsers.parse_stockholm(
fp.read() fp.read()
) )
end = time.time()
calculate_elapse(start, end, "parsers.parse_stockholm")
else: else:
continue continue
...@@ -825,13 +836,22 @@ class DataPipeline: ...@@ -825,13 +836,22 @@ class DataPipeline:
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
import time
start_main = time.time()
start = time.time()
msas = self._get_msas( msas = self._get_msas(
alignment_dir, input_sequence, alignment_index alignment_dir, input_sequence, alignment_index
) )
end = time.time()
calculate_elapse(start,end,"get_msas")
start = time.time()
msa_features = make_msa_features( msa_features = make_msa_features(
msas=msas msas=msas
) )
end = time.time()
calculate_elapse(start, end, "make_msa_features")
end_main = time.time()
calculate_elapse(start_main, end_main,"process_msa_feats")
return msa_features return msa_features
# Load and process sequence embedding features # Load and process sequence embedding features
......
...@@ -20,7 +20,7 @@ import itertools ...@@ -20,7 +20,7 @@ import itertools
import re import re
import string import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
import asyncio
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
...@@ -120,9 +120,9 @@ def parse_stockholm(stockholm_string: str) -> Msa: ...@@ -120,9 +120,9 @@ def parse_stockholm(stockholm_string: str) -> Msa:
line = line.strip() line = line.strip()
if not line or line.startswith(("#", "//")): if not line or line.startswith(("#", "//")):
continue continue
name, sequence = line.split() name, sequence = line.split(maxsplit=1)
if name not in name_to_sequence: if name not in name_to_sequence:
name_to_sequence[name] = "" name_to_sequence.setdefault(name,"")
name_to_sequence[name] += sequence name_to_sequence[name] += sequence
msa = [] msa = []
......
...@@ -42,6 +42,9 @@ from scripts.zero_to_fp32 import ( ...@@ -42,6 +42,9 @@ from scripts.zero_to_fp32 import (
from openfold.utils.logger import PerformanceLoggingCallback from openfold.utils.logger import PerformanceLoggingCallback
def calculate_elapse(start, end):
elapse = end - start
print(f"this function runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes")
class OpenFoldWrapper(pl.LightningModule): class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config): def __init__(self, config):
...@@ -316,7 +319,7 @@ def main(args): ...@@ -316,7 +319,7 @@ def main(args):
batch_seed=args.seed, batch_seed=args.seed,
**vars(args) **vars(args)
) )
data_module.prepare_data() data_module.prepare_data()
data_module.setup() data_module.setup()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment