Commit 2204bbb2 authored by Dingquan Yu's avatar Dingquan Yu
Browse files

fixed errors when running in subprocess

parent c3c627e7
...@@ -435,15 +435,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -435,15 +435,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
raise list(mmcif_object.errors.values())[0] raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object mmcif_object = mmcif_object.mmcif_object
# print(f" ###### line 442 started mmcif_processing")
# start = time.time()
data = self.data_pipeline.process_mmcif( data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
alignment_index=alignment_index alignment_index=alignment_index
) )
# end = time.time()
# calculate_elapse(start , end)s
return data return data
def mmcif_id_to_idx(self, mmcif_id): def mmcif_id_to_idx(self, mmcif_id):
...@@ -453,8 +449,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -453,8 +449,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
return self._mmcifs[idx] return self._mmcifs[idx]
def __getitem__(self, idx): def __getitem__(self, idx):
print(f"####### line 456 idx is {idx}")
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
......
...@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union ...@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess import subprocess
import numpy as np import numpy as np
import torch import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
...@@ -738,22 +738,10 @@ class DataPipeline: ...@@ -738,22 +738,10 @@ class DataPipeline:
fp.close() fp.close()
else: else:
# Now will split the following steps into multiple processes # Now will split the following steps into multiple processes
# a3m_tasks = [(alignment_dir, a3m) for a3m in a3m_files]
# sto_tasks = [(alignment_dir, sto) for sto in stockholm_files]
# with NonDaemonicProcessPool(len(a3m_tasks) + len(sto_tasks)) as pool:
# a3m_results = pool.starmap(parse_a3m_file, a3m_tasks)
# sto_results = pool.starmap(parse_stockholm_file, sto_tasks)
# msa_results = {**a3m_results, **sto_results}
import time, json
current_directory = os.path.dirname(os.path.abspath(__file__)) current_directory = os.path.dirname(os.path.abspath(__file__))
start = time.time() cmd = f"{current_directory}/parse_msa_files.py"
cmd = f"{current_directory}/parse_msa_files.py {alignment_dir}" msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data = subprocess.check_output(['python', cmd], capture_output=True, text= True) msa_data = pickle.load((open(msa_data.stdout.rstrip(),'rb')))
msa_data = json.load(msa_data)
end = time.time()
calculate_elapse(start, end, "multiprocessing version")
return msa_data return msa_data
def _parse_template_hit_files( def _parse_template_hit_files(
...@@ -836,12 +824,9 @@ class DataPipeline: ...@@ -836,12 +824,9 @@ class DataPipeline:
) )
end = time.time() end = time.time()
calculate_elapse(start,end,"get_msas") calculate_elapse(start,end,"get_msas")
start = time.time()
msa_features = make_msa_features( msa_features = make_msa_features(
msas=msas msas=msas
) )
end = time.time()
calculate_elapse(start, end, "make_msa_features")
end_main = time.time() end_main = time.time()
calculate_elapse(start_main, end_main,"process_msa_feats") calculate_elapse(start_main, end_main,"process_msa_feats")
return msa_features return msa_features
......
import os, multiprocessing, argparse, json import os, multiprocessing, argparse, pickle, tempfile
import multiprocessing.pool # Need to import multiprocessing.pool first otherwise multiprocessing.pool.Pool cannot be called
from openfold.data import parsers from openfold.data import parsers
def parse_stockholm_file(alignment_dir: str, stockholm_file: str, queue: multiprocessing.Queue): import contextlib
path = os.path.join(alignment_dir, stockholm_file) def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
file_name,_ = os.path.splitext(stockholm_file) path = os.path.join(alignment_dir, stockholm_file)
with open(path, "r") as infile: file_name,_ = os.path.splitext(stockholm_file)
msa = parsers.parse_stockholm(infile.read()) with open(path, "r") as infile:
infile.close() msa = parsers.parse_stockholm(infile.read())
queue.put({file_name: msa}) infile.close()
# queue.put()
return {file_name: msa}
def parse_a3m_file(alignment_dir: str, a3m_file: str,queue: multiprocessing.Queue): def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file) path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file) file_name,_ = os.path.splitext(a3m_file)
with open(path, "r") as infile: with open(path, "r") as infile:
msa = parsers.parse_a3m(infile.read()) msa = parsers.parse_a3m(infile.read())
infile.close() infile.close()
queue.put({file_name: msa}) # queue.put({file_name: msa})
return {file_name: msa}
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str): def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
print(f"#### line 764 start running in multiprocessing way")
msa_results={} msa_results={}
processes = [] processes = []
queue = multiprocessing.Queue() a3m_tasks = [(alignment_dir, f) for f in a3m_files]
for f in stockholm_files: sto_tasks = [(alignment_dir, f) for f in stockholm_files]
process = multiprocessing.Process(target = parse_stockholm_file, args=(alignment_dir, f, queue)) with multiprocessing.pool.Pool(len(a3m_tasks) + len(sto_tasks)) as pool:
process.deamon = False a3m_results = pool.starmap_async(parse_a3m_file, a3m_tasks).get()
processes.append(process) sto_results = pool.starmap_async(parse_stockholm_file, sto_tasks).get()
process.start() for res in [*a3m_results, *sto_results]:
for f in a3m_files: msa_results.update(res)
process = multiprocessing.Process(target = parse_a3m_file, args=(alignment_dir, f, queue)) return msa_results
process.daemon = False
processes.append(process)
process.start()
for p in processes:
res = queue.get()
msa_results.update(res)
p.join()
def main(alignment_dir): def main():
parser = argparse.ArgumentParser(description='Process msa files in parallel') parser = argparse.ArgumentParser(description='Process msa files in parallel')
parser.add_argument('alignment_dir', metavar='N', type=int, nargs='+', parser.add_argument('--alignment_dir', type=str, help='path to alignment dir')
help='an integer for the accumulator')
args = parser.parse_args() args = parser.parse_args()
alignment_dir = args.alignment_dir
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))] stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))]
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')] a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir) msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir)
return json.dumps(msa_data) with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile:
pickle.dump(msa_data, outfile)
print(outfile.name)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment