Unverified Commit f861ff39 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #376 from dingquanyu/speedup-dataloader

Speed up data loading process 
parents 1606ac08 6f26b0ad
......@@ -21,16 +21,17 @@ import dataclasses
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np
import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
import concurrent
from concurrent.futures import ThreadPoolExecutor
FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
......@@ -735,22 +736,11 @@ class DataPipeline:
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)
if(ext == ".a3m"):
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]):
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue
msa_data[f] = msa
# Now will split the following steps into multiple processes
current_directory = os.path.dirname(os.path.abspath(__file__))
cmd = f"{current_directory}/tools/parse_msa_files.py"
msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data = pickle.load((open(msa_data.stdout.lstrip().rstrip(),'rb')))
return msa_data
......@@ -826,6 +816,7 @@ class DataPipeline:
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas = self._get_msas(
alignment_dir, input_sequence, alignment_index
)
......@@ -1216,8 +1207,10 @@ class DataPipelineMultimer:
with open(fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
......@@ -1228,6 +1221,7 @@ class DataPipelineMultimer:
)
continue
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
......@@ -1236,6 +1230,7 @@ class DataPipelineMultimer:
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
......@@ -1243,17 +1238,20 @@ class DataPipelineMultimer:
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
def get_mmcif_features(
self, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
......@@ -1284,18 +1282,21 @@ class DataPipelineMultimer:
alignment_index: Optional[str] = None,
) -> FeatureDict:
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
......@@ -1304,23 +1305,29 @@ class DataPipelineMultimer:
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
\ No newline at end of file
import os, argparse, pickle, tempfile, concurrent
from openfold.data import parsers
from concurrent.futures import ProcessPoolExecutor
def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file)
with open(path, "r") as infile:
msa = parsers.parse_stockholm(infile.read())
infile.close()
return {file_name: msa}
def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file)
with open(path, "r") as infile:
msa = parsers.parse_a3m(infile.read())
infile.close()
return {file_name: msa}
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
# Number of workers based on the tasks
msa_results={}
a3m_tasks = [(alignment_dir, f) for f in a3m_files]
sto_tasks = [(alignment_dir, f) for f in stockholm_files]
with ProcessPoolExecutor(max_workers = len(a3m_tasks) + len(sto_tasks)) as executor:
a3m_futures = {executor.submit(parse_a3m_file, *task): task for task in a3m_tasks}
sto_futures = {executor.submit(parse_stockholm_file, *task): task for task in sto_tasks}
for future in concurrent.futures.as_completed(a3m_futures | sto_futures):
try:
result = future.result()
msa_results.update(result)
except Exception as exc:
print(f'Task generated an exception: {exc}')
return msa_results
def main():
parser = argparse.ArgumentParser(description='Process msa files in parallel')
parser.add_argument('--alignment_dir', type=str, help='path to alignment dir')
args = parser.parse_args()
alignment_dir = args.alignment_dir
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))]
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir)
with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile:
pickle.dump(msa_data, outfile)
print(outfile.name)
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment