"git@developer.sourcefind.cn:OpenDAS/sparseconvnet.git" did not exist on "15fd91a0a1a34376105d85ac9f7e5f24dc266394"
Commit 4e58a6a0 authored by Dingquan Yu's avatar Dingquan Yu
Browse files

now use ThreadPoolExecutor

parent 2204bbb2
...@@ -22,9 +22,9 @@ from openfold.utils.tensor_utils import ( ...@@ -22,9 +22,9 @@ from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
def calculate_elapse(start, end): def calculate_elapse(start, end, name):
elapse = end - start elapse = end - start
print(f"this function runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes") print(f"{name} runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes")
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
...@@ -451,7 +451,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -451,7 +451,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
import time
start = time.time()
if self.mode == 'train' or self.mode == 'eval': if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None ext = None
...@@ -477,15 +478,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -477,15 +478,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_path=path, fasta_path=path,
alignment_dir=self.alignment_dir alignment_dir=self.alignment_dir
) )
end = time.time()
calculate_elapse(start, end, "process_fasta in data_modules")
if self._output_raw: if self._output_raw:
return data return data
process_start = time.time()
# process all_chain_features # process all_chain_features
data = self.feature_pipeline.process_features(data, data = self.feature_pipeline.process_features(data,
mode=self.mode, mode=self.mode,
is_multimer=True) is_multimer=True)
end = time.time()
calculate_elapse(process_start, end, "process_features in data_modules")
calculate_elapse(start, end, "dataset get_item in data_modules")
# if it's inference mode, only need all_chain_features # if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor( data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])], [idx for _ in range(data["aatype"].shape[-1])],
......
...@@ -738,10 +738,14 @@ class DataPipeline: ...@@ -738,10 +738,14 @@ class DataPipeline:
fp.close() fp.close()
else: else:
# Now will split the following steps into multiple processes # Now will split the following steps into multiple processes
import time
current_directory = os.path.dirname(os.path.abspath(__file__)) current_directory = os.path.dirname(os.path.abspath(__file__))
cmd = f"{current_directory}/parse_msa_files.py" cmd = f"{current_directory}/parse_msa_files.py"
start = time.time()
msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True) msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data = pickle.load((open(msa_data.stdout.rstrip(),'rb'))) msa_data = pickle.load((open(msa_data.stdout.lstrip().rstrip(),'rb')))
end = time.time()
calculate_elapse(start, end, "parse_msa_files in data_pipeline")
return msa_data return msa_data
def _parse_template_hit_files( def _parse_template_hit_files(
...@@ -823,12 +827,12 @@ class DataPipeline: ...@@ -823,12 +827,12 @@ class DataPipeline:
alignment_dir, input_sequence, alignment_index alignment_dir, input_sequence, alignment_index
) )
end = time.time() end = time.time()
calculate_elapse(start,end,"get_msas") calculate_elapse(start,end,"get_msas in data_pipeline")
msa_features = make_msa_features( msa_features = make_msa_features(
msas=msas msas=msas
) )
end_main = time.time() end_main = time.time()
calculate_elapse(start_main, end_main,"process_msa_feats") calculate_elapse(start_main, end_main,"process_msa_feats in data_pipeline")
return msa_features return msa_features
# Load and process sequence embedding features # Load and process sequence embedding features
......
import os, multiprocessing, argparse, pickle, tempfile import os, multiprocessing, argparse, pickle, tempfile, concurrent
import multiprocessing.pool # Need to import multiprocessing.pool first otherwise multiprocessing.pool.Pool cannot be called import multiprocessing.pool # Need to import multiprocessing.pool first otherwise multiprocessing.pool.Pool cannot be called
from openfold.data import parsers from openfold.data import parsers
import contextlib from concurrent.futures import ThreadPoolExecutor
def parse_stockholm_file(alignment_dir: str, stockholm_file: str): def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file) path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file) file_name,_ = os.path.splitext(stockholm_file)
...@@ -21,15 +22,20 @@ def parse_a3m_file(alignment_dir: str, a3m_file: str): ...@@ -21,15 +22,20 @@ def parse_a3m_file(alignment_dir: str, a3m_file: str):
return {file_name: msa} return {file_name: msa}
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str): def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
# Number of workers based on the tasks
msa_results={} msa_results={}
processes = []
a3m_tasks = [(alignment_dir, f) for f in a3m_files] a3m_tasks = [(alignment_dir, f) for f in a3m_files]
sto_tasks = [(alignment_dir, f) for f in stockholm_files] sto_tasks = [(alignment_dir, f) for f in stockholm_files]
with multiprocessing.pool.Pool(len(a3m_tasks) + len(sto_tasks)) as pool: with ThreadPoolExecutor() as executor:
a3m_results = pool.starmap_async(parse_a3m_file, a3m_tasks).get() a3m_futures = {executor.submit(parse_a3m_file, *task): task for task in a3m_tasks}
sto_results = pool.starmap_async(parse_stockholm_file, sto_tasks).get() sto_futures = {executor.submit(parse_stockholm_file, *task): task for task in sto_tasks}
for res in [*a3m_results, *sto_results]:
msa_results.update(res) for future in concurrent.futures.as_completed(a3m_futures | sto_futures):
try:
result = future.result()
msa_results.update(result)
except Exception as exc:
print(f'Task generated an exception: {exc}')
return msa_results return msa_results
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment