Commit 6f3e0c0c authored by Dingquan Yu's avatar Dingquan Yu
Browse files

now used asynchronised version in parse_msa_data

parent aec12764
......@@ -21,7 +21,7 @@ import dataclasses
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import asyncio
import numpy as np
import torch
......@@ -737,30 +737,37 @@ class DataPipeline:
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)
if(ext == ".a3m"):
# Now will split the following steps into multiple processes
async def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file)
with open(path, "r") as infile:
msa = parsers.parse_stockholm(infile.read())
infile.close()
return {file_name: msa}
async def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file)
with open(path, "r") as infile:
msa = parsers.parse_a3m(infile.read())
infile.close()
return {file_name: msa}
async def run_parse_all_msa_files(stockholm_files: list, a3m_files: list, alignment_dir:str):
all_tasks = [asyncio.create_task(parse_stockholm_file(alignment_dir, sto)) for sto in stockholm_files]
all_tasks += [asyncio.create_task(parse_a3m_file(alignment_dir, a3m)) for a3m in a3m_files]
results = await asyncio.gather(*all_tasks)
return results
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))]
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
import time
start = time.time()
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
msa_results = asyncio.run(run_parse_all_msa_files(stockholm_files, a3m_files, alignment_dir))
end = time.time()
calculate_elapse(start, end, "parser.parse_a3m")
elif(ext == ".sto" and not "hmm_output" == filename):
import time
start = time.time()
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
end = time.time()
calculate_elapse(start, end, "parsers.parse_stockholm")
else:
continue
msa_data[f] = msa
calculate_elapse(start, end, "asynchronised version")
for i in msa_results:
msa_data.update({k:v for k,v in i.items()})
return msa_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment