Unverified Commit f861ff39 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #376 from dingquanyu/speedup-dataloader

Speed up data loading process 
parents 1606ac08 6f26b0ad
...@@ -21,16 +21,17 @@ import dataclasses ...@@ -21,16 +21,17 @@ import dataclasses
from multiprocessing import cpu_count from multiprocessing import cpu_count
import tempfile import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np import numpy as np
import torch import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import concurrent
from concurrent.futures import ThreadPoolExecutor
FeatureDict = MutableMapping[str, np.ndarray] FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
...@@ -735,22 +736,11 @@ class DataPipeline: ...@@ -735,22 +736,11 @@ class DataPipeline:
fp.close() fp.close()
else: else:
for f in os.listdir(alignment_dir): # Now will split the following steps into multiple processes
path = os.path.join(alignment_dir, f) current_directory = os.path.dirname(os.path.abspath(__file__))
filename, ext = os.path.splitext(f) cmd = f"{current_directory}/tools/parse_msa_files.py"
msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
if(ext == ".a3m"): msa_data = pickle.load((open(msa_data.stdout.lstrip().rstrip(),'rb')))
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]):
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue
msa_data[f] = msa
return msa_data return msa_data
...@@ -826,6 +816,7 @@ class DataPipeline: ...@@ -826,6 +816,7 @@ class DataPipeline:
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas = self._get_msas( msas = self._get_msas(
alignment_dir, input_sequence, alignment_index alignment_dir, input_sequence, alignment_index
) )
...@@ -1216,8 +1207,10 @@ class DataPipelineMultimer: ...@@ -1216,8 +1207,10 @@ class DataPipelineMultimer:
with open(fasta_path) as f: with open(fasta_path) as f:
input_fasta_str = f.read() input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {} all_chain_features = {}
sequence_features = {} sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1 is_homomer_or_monomer = len(set(input_seqs)) == 1
...@@ -1228,6 +1221,7 @@ class DataPipelineMultimer: ...@@ -1228,6 +1221,7 @@ class DataPipelineMultimer:
) )
continue continue
chain_features = self._process_single_chain( chain_features = self._process_single_chain(
chain_id=desc, chain_id=desc,
sequence=seq, sequence=seq,
...@@ -1236,6 +1230,7 @@ class DataPipelineMultimer: ...@@ -1236,6 +1230,7 @@ class DataPipelineMultimer:
is_homomer_or_monomer=is_homomer_or_monomer is_homomer_or_monomer=is_homomer_or_monomer
) )
chain_features = convert_monomer_features( chain_features = convert_monomer_features(
chain_features, chain_features,
chain_id=desc chain_id=desc
...@@ -1243,17 +1238,20 @@ class DataPipelineMultimer: ...@@ -1243,17 +1238,20 @@ class DataPipelineMultimer:
all_chain_features[desc] = chain_features all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge( np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features,
) )
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
return np_example return np_example
def get_mmcif_features( def get_mmcif_features(
self, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str self, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict: ) -> FeatureDict:
...@@ -1284,18 +1282,21 @@ class DataPipelineMultimer: ...@@ -1284,18 +1282,21 @@ class DataPipelineMultimer:
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
all_chain_features = {} all_chain_features = {}
sequence_features = {} sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1 is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items(): for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id]) desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features: if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy( all_chain_features[desc] = copy.deepcopy(
sequence_features[seq] sequence_features[seq]
) )
continue continue
chain_features = self._process_single_chain( chain_features = self._process_single_chain(
chain_id=desc, chain_id=desc,
sequence=seq, sequence=seq,
...@@ -1304,23 +1305,29 @@ class DataPipelineMultimer: ...@@ -1304,23 +1305,29 @@ class DataPipelineMultimer:
is_homomer_or_monomer=is_homomer_or_monomer is_homomer_or_monomer=is_homomer_or_monomer
) )
chain_features = convert_monomer_features( chain_features = convert_monomer_features(
chain_features, chain_features,
chain_id=desc chain_id=desc
) )
mmcif_feats = self.get_mmcif_features(mmcif, chain_id) mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats) chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge( np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features,
) )
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
return np_example return np_example
\ No newline at end of file
import os, argparse, pickle, tempfile, concurrent
from openfold.data import parsers
from concurrent.futures import ProcessPoolExecutor
def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file)
with open(path, "r") as infile:
msa = parsers.parse_stockholm(infile.read())
infile.close()
return {file_name: msa}
def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file)
with open(path, "r") as infile:
msa = parsers.parse_a3m(infile.read())
infile.close()
return {file_name: msa}
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
# Number of workers based on the tasks
msa_results={}
a3m_tasks = [(alignment_dir, f) for f in a3m_files]
sto_tasks = [(alignment_dir, f) for f in stockholm_files]
with ProcessPoolExecutor(max_workers = len(a3m_tasks) + len(sto_tasks)) as executor:
a3m_futures = {executor.submit(parse_a3m_file, *task): task for task in a3m_tasks}
sto_futures = {executor.submit(parse_stockholm_file, *task): task for task in sto_tasks}
for future in concurrent.futures.as_completed(a3m_futures | sto_futures):
try:
result = future.result()
msa_results.update(result)
except Exception as exc:
print(f'Task generated an exception: {exc}')
return msa_results
def main():
parser = argparse.ArgumentParser(description='Process msa files in parallel')
parser.add_argument('--alignment_dir', type=str, help='path to alignment dir')
args = parser.parse_args()
alignment_dir = args.alignment_dir
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))]
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir)
with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile:
pickle.dump(msa_data, outfile)
print(outfile.name)
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment