"segmentation/configs/vscode:/vscode.git/clone" did not exist on "88dbd1ae88ff3417a05cff8717077f0da1abec7f"
Commit 2204bbb2 authored by Dingquan Yu's avatar Dingquan Yu
Browse files

fixed errors when running in subprocess

parent c3c627e7
......@@ -435,15 +435,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
# print(f" ###### line 442 started mmcif_processing")
# start = time.time()
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
alignment_index=alignment_index
)
# end = time.time()
# calculate_elapse(start , end)s
return data
def mmcif_id_to_idx(self, mmcif_id):
......@@ -453,8 +449,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
return self._mmcifs[idx]
def __getitem__(self, idx):
print(f"####### line 456 idx is {idx}")
mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None
......
......@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np
import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
......@@ -738,22 +738,10 @@ class DataPipeline:
fp.close()
else:
# Now will split the following steps into multiple processes
# a3m_tasks = [(alignment_dir, a3m) for a3m in a3m_files]
# sto_tasks = [(alignment_dir, sto) for sto in stockholm_files]
# with NonDaemonicProcessPool(len(a3m_tasks) + len(sto_tasks)) as pool:
# a3m_results = pool.starmap(parse_a3m_file, a3m_tasks)
# sto_results = pool.starmap(parse_stockholm_file, sto_tasks)
# msa_results = {**a3m_results, **sto_results}
import time, json
current_directory = os.path.dirname(os.path.abspath(__file__))
start = time.time()
cmd = f"{current_directory}/parse_msa_files.py {alignment_dir}"
msa_data = subprocess.check_output(['python', cmd], capture_output=True, text= True)
msa_data = json.load(msa_data)
end = time.time()
calculate_elapse(start, end, "multiprocessing version")
cmd = f"{current_directory}/parse_msa_files.py"
msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data = pickle.load((open(msa_data.stdout.rstrip(),'rb')))
return msa_data
def _parse_template_hit_files(
......@@ -836,12 +824,9 @@ class DataPipeline:
)
end = time.time()
calculate_elapse(start,end,"get_msas")
start = time.time()
msa_features = make_msa_features(
msas=msas
)
end = time.time()
calculate_elapse(start, end, "make_msa_features")
end_main = time.time()
calculate_elapse(start_main, end_main,"process_msa_feats")
return msa_features
......
import os, multiprocessing, argparse, json
import os, multiprocessing, argparse, pickle, tempfile
import multiprocessing.pool # Need to import multiprocessing.pool first otherwise multiprocessing.pool.Pool cannot be called
from openfold.data import parsers
def parse_stockholm_file(alignment_dir: str, stockholm_file: str, queue: multiprocessing.Queue):
import contextlib
def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file)
with open(path, "r") as infile:
msa = parsers.parse_stockholm(infile.read())
infile.close()
queue.put({file_name: msa})
# queue.put()
return {file_name: msa}
def parse_a3m_file(alignment_dir: str, a3m_file: str,queue: multiprocessing.Queue):
def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file)
with open(path, "r") as infile:
msa = parsers.parse_a3m(infile.read())
infile.close()
queue.put({file_name: msa})
# queue.put({file_name: msa})
return {file_name: msa}
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
print(f"#### line 764 start running in multiprocessing way")
msa_results={}
processes = []
queue = multiprocessing.Queue()
for f in stockholm_files:
process = multiprocessing.Process(target = parse_stockholm_file, args=(alignment_dir, f, queue))
process.deamon = False
processes.append(process)
process.start()
for f in a3m_files:
process = multiprocessing.Process(target = parse_a3m_file, args=(alignment_dir, f, queue))
process.daemon = False
processes.append(process)
process.start()
for p in processes:
res = queue.get()
a3m_tasks = [(alignment_dir, f) for f in a3m_files]
sto_tasks = [(alignment_dir, f) for f in stockholm_files]
with multiprocessing.pool.Pool(len(a3m_tasks) + len(sto_tasks)) as pool:
a3m_results = pool.starmap_async(parse_a3m_file, a3m_tasks).get()
sto_results = pool.starmap_async(parse_stockholm_file, sto_tasks).get()
for res in [*a3m_results, *sto_results]:
msa_results.update(res)
p.join()
return msa_results
def main(alignment_dir):
def main():
parser = argparse.ArgumentParser(description='Process msa files in parallel')
parser.add_argument('alignment_dir', metavar='N', type=int, nargs='+',
help='an integer for the accumulator')
parser.add_argument('--alignment_dir', type=str, help='path to alignment dir')
args = parser.parse_args()
alignment_dir = args.alignment_dir
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))]
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir)
return json.dumps(msa_data)
with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile:
pickle.dump(msa_data, outfile)
print(outfile.name)
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment