Unverified Commit b254d465 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

merge data workflow to main (#48)

Added data workflow for fastfold
parent a37c8b4c
...@@ -18,14 +18,16 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo ...@@ -18,14 +18,16 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
3. Ease of use 3. Ease of use
* Huge performance gains with a few lines changes * Huge performance gains with a few lines changes
* You don't need to care about how the parallel part is implemented * You don't need to care about how the parallel part is implemented
4. Faster data processing, about 3x times faster than the original way
## Installation ## Installation
To install and use FastFold, you will need: To install and use FastFold, you will need:
+ Python 3.8 or later + Python 3.8 or 3.9.
+ [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.1 or above + [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.1 or above
+ PyTorch 1.10 or above + PyTorch 1.10 or above
For now, You can install FastFold: For now, You can install FastFold:
### Using Conda (Recommended) ### Using Conda (Recommended)
...@@ -116,6 +118,32 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ ...@@ -116,6 +118,32 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--hhsearch_binary_path `which hhsearch` \ --hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign` --kalign_binary_path `which kalign`
``` ```
or run the script `./inference.sh`, you can change the parameter in the script, especisally those data path.
```shell
./inference.sh
```
#### inference with data workflow
Alphafold's data pre-processing takes a lot of time, so we speed up the data pre-process by [ray](https://docs.ray.io/en/latest/workflows/concepts.html) workflow, which achieves a 3x times faster speed. To run the intference with ray workflow, you should install the package and add parameter `--enable_workflow` to cmdline or shell script `./inference.sh`
```shell
pip install ray==1.13.0 pyarrow
```
```shell
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--output_dir ./ \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign` \
--enable_workflow
```
## Performance Benchmark ## Performance Benchmark
......
...@@ -9,7 +9,7 @@ RUN conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pyt ...@@ -9,7 +9,7 @@ RUN conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pyt
&& conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda && conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda
RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 numpy==1.21.2 \ RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 numpy==1.21.2 \
PyYAML==5.4.1 requests==2.26.0 scipy==1.7.1 tqdm==4.62.2 typing-extensions==3.10.0.2 einops PyYAML==5.4.1 requests==2.26.0 scipy==1.7.1 tqdm==4.62.2 typing-extensions==3.10.0.2 einops ray==1.13.0 pyarrow
RUN pip install colossalai==0.1.8+torch1.10cu11.3 -f https://release.colossalai.org RUN pip install colossalai==0.1.8+torch1.10cu11.3 -f https://release.colossalai.org
......
from .workflow_run import batch_run
\ No newline at end of file
from .task_factory import TaskFactory
from .hhblits import HHBlitsFactory
from .hhsearch import HHSearchFactory
from .jackhmmer import JackHmmerFactory
from .hhfilter import HHfilterFactory
\ No newline at end of file
from ray import workflow
from typing import List
from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow
import fastfold.data.tools.hhblits as ffHHBlits
class HHBlitsFactory(TaskFactory):
keywords = ['binary_path', 'databases', 'n_cpu']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:
self.isReady()
# setup runner
runner = ffHHBlits.HHBlits(
binary_path=self.config['binary_path'],
databases=self.config['databases'],
n_cpu=self.config['n_cpu']
)
# generate step function
@workflow.step
def hhblits_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:
result = runner.query(fasta_path)
with open(output_path, "w") as f:
f.write(result["a3m"])
return hhblits_step.step(fasta_path, output_path, after)
import subprocess
import logging
from ray import workflow
from typing import List
from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow
class HHfilterFactory(TaskFactory):
keywords = ['binary_path']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:
self.isReady()
# generate step function
@workflow.step
def hhfilter_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:
cmd = [
self.config.get('binary_path'),
]
if 'id' in self.config:
cmd += ['-id', str(self.config.get('id'))]
if 'cov' in self.config:
cmd += ['-cov', str(self.config.get('cov'))]
cmd += ['-i', fasta_path, '-o', output_path]
logging.info(f"HHfilter start: {' '.join(cmd)}")
subprocess.run(cmd)
return hhfilter_step.step(fasta_path, output_path, after)
from fastfold.workflow.factory import TaskFactory
from ray import workflow
from ray.workflow.common import Workflow
import fastfold.data.tools.hhsearch as ffHHSearch
from typing import List
class HHSearchFactory(TaskFactory):
keywords = ['binary_path', 'databases', 'n_cpu']
def gen_task(self, a3m_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:
self.isReady()
# setup runner
runner = ffHHSearch.HHSearch(
binary_path=self.config['binary_path'],
databases=self.config['databases'],
n_cpu=self.config['n_cpu']
)
# generate step function
@workflow.step
def hhsearch_step(a3m_path: str, output_path: str, after: List[Workflow], atab_path: str = None) -> None:
with open(a3m_path, "r") as f:
a3m = f.read()
if atab_path:
hhsearch_result, atab = runner.query(a3m, gen_atab=True)
else:
hhsearch_result = runner.query(a3m)
with open(output_path, "w") as f:
f.write(hhsearch_result)
if atab_path:
with open(atab_path, "w") as f:
f.write(atab)
return hhsearch_step.step(a3m_path, output_path, after)
from fastfold.workflow.factory import TaskFactory
from ray import workflow
from ray.workflow.common import Workflow
import fastfold.data.tools.jackhmmer as ffJackHmmer
from fastfold.data import parsers
from typing import List
class JackHmmerFactory(TaskFactory):
keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:
self.isReady()
# setup runner
runner = ffJackHmmer.Jackhmmer(
binary_path=self.config['binary_path'],
database_path=self.config['database_path'],
n_cpu=self.config['n_cpu']
)
# generate step function
@workflow.step
def jackhmmer_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:
result = runner.query(fasta_path)[0]
uniref90_msa_a3m = parsers.convert_stockholm_to_a3m(
result['sto'],
max_sequences=self.config['uniref_max_hits']
)
with open(output_path, "w") as f:
f.write(uniref90_msa_a3m)
return jackhmmer_step.step(fasta_path, output_path, after)
from ast import keyword
import json
from ray.workflow.common import Workflow
from os import path
from typing import List
class TaskFactory:
keywords = []
def __init__(self, config: dict = None, config_path: str = None) -> None:
# skip if no keyword required from config file
if not self.__class__.keywords:
return
# setting config for factory
if config is not None:
self.config = config
elif config_path is not None:
self.loadConfig(config_path)
else:
self.loadConfig()
def configure(self, config: dict, purge=False) -> None:
if purge:
self.config = config
else:
self.config.update(config)
def configure(self, keyword: str, value: any) -> None:
self.config[keyword] = value
def gen_task(self, after: List[Workflow]=None, *args, **kwargs) -> Workflow:
raise NotImplementedError
def isReady(self):
for key in self.__class__.keywords:
if key not in self.config:
raise KeyError(f"{self.__class__.__name__} not ready: \"{key}\" not specified")
def loadConfig(self, config_path='./config.json'):
with open(config_path) as configFile:
globalConfig = json.load(configFile)
if 'tools' not in globalConfig:
raise KeyError("\"tools\" not found in global config file")
factoryName = self.__class__.__name__[:-7]
if factoryName not in globalConfig['tools']:
raise KeyError(f"\"{factoryName}\" not found in the \"tools\" section in config")
self.config = globalConfig['tools'][factoryName]
\ No newline at end of file
from .fastfold_data_workflow import FastFoldDataWorkFlow
\ No newline at end of file
import os
import time
from multiprocessing import cpu_count
from ray import workflow
from fastfold.workflow.factory import JackHmmerFactory, HHSearchFactory, HHBlitsFactory
from fastfold.workflow import batch_run
from typing import Optional
class FastFoldDataWorkFlow:
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hhsearch_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
):
self.db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hhsearch": {
"binary": hhsearch_binary_path,
"dbs": [
pdb70_database_path,
],
},
}
for name, dic in self.db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in self.db_map["hhsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.use_small_bfd = use_small_bfd
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
if(no_cpus is None):
self.no_cpus = cpu_count()
else:
self.no_cpus = no_cpus
def run(self, fasta_path: str, output_dir: str, alignment_dir: str=None) -> None:
localtime = time.asctime(time.localtime(time.time()))
workflow_id = 'fastfold_data_workflow ' + str(localtime)
# clearing remaining ray workflow data
try:
workflow.cancel(workflow_id)
workflow.delete(workflow_id)
except:
print("Workflow not found. Clean. Skipping")
pass
# prepare alignment directory for alignment outputs
if alignment_dir is None:
alignment_dir = os.path.join(output_dir, "alignment")
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
# Run JackHmmer on UNIREF90
# create JackHmmer workflow generator
jh_config = {
"binary_path": self.db_map["jackhmmer"]["binary"],
"database_path": self.db_map["jackhmmer"]["dbs"][0],
"n_cpu": self.no_cpus,
"uniref_max_hits": self.uniref_max_hits,
}
jh_fac = JackHmmerFactory(config = jh_config)
# set jackhmmer output path
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
# generate the workflow with i/o path
wf1 = jh_fac.gen_task(fasta_path, uniref90_out_path)
#Run HHSearch on STEP1's result with PDB70"""
# create HHSearch workflow generator
hhs_config = {
"binary_path": self.db_map["hhsearch"]["binary"],
"databases": self.db_map["hhsearch"]["dbs"],
"n_cpu": self.no_cpus,
}
hhs_fac = HHSearchFactory(config=hhs_config)
# set HHSearch output path
pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
# generate the workflow (STEP2 depend on STEP1)
wf2 = hhs_fac.gen_task(uniref90_out_path, pdb70_out_path, after=[wf1])
# Run JackHmmer on MGNIFY
# reconfigure jackhmmer factory to use MGNIFY DB instead
jh_fac.configure('database_path', self.db_map["jackhmmer"]["dbs"][1])
# set jackhmmer output path
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
# generate workflow for STEP3
wf3 = jh_fac.gen_task(fasta_path, mgnify_out_path)
# Run HHBlits on BFD
# create HHBlits workflow generator
hhb_config = {
"binary_path": self.db_map["hhblits"]["binary"],
"databases": self.db_map["hhblits"]["dbs"],
"n_cpu": self.no_cpus,
}
hhb_fac = HHBlitsFactory(config=hhb_config)
# set HHBlits output path
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
# generate workflow for STEP4
wf4 = hhb_fac.gen_task(fasta_path, bfd_out_path)
# run workflow
batch_run(wfs=[wf2, wf3, wf4], workflow_id=workflow_id)
return
\ No newline at end of file
from ast import Call
from typing import Callable, List
from ray.workflow.common import Workflow
from ray import workflow
def batch_run(wfs: List[Workflow], workflow_id: str) -> None:
@workflow.step
def batch_step(wfs) -> None:
return
batch_wf = batch_step.step(wfs)
batch_wf.run(workflow_id=workflow_id)
def wf(after: List[Workflow]=None):
def decorator(f: Callable):
@workflow.step
def step_func(after: List[Workflow]) -> None:
f()
return step_func.step(after)
return decorator
...@@ -31,6 +31,7 @@ from fastfold.common import protein, residue_constants ...@@ -31,6 +31,7 @@ from fastfold.common import protein, residue_constants
from fastfold.config import model_config from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size from fastfold.model.fastnn import set_chunk_size
from fastfold.data import data_pipeline, feature_pipeline, templates from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.workflow.template import FastFoldDataWorkFlow
from fastfold.utils import inject_fastnn from fastfold.utils import inject_fastnn
from fastfold.data.parsers import parse_fasta from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_ from fastfold.utils.import_weights import import_jax_weights_
...@@ -74,7 +75,7 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -74,7 +75,7 @@ def add_data_args(parser: argparse.ArgumentParser):
) )
parser.add_argument('--obsolete_pdbs_path', type=str, default=None) parser.add_argument('--obsolete_pdbs_path', type=str, default=None)
parser.add_argument('--release_dates_path', type=str, default=None) parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not')
def inference_model(rank, world_size, result_q, batch, args): def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank) os.environ['RANK'] = str(rank)
...@@ -157,20 +158,37 @@ def main(args): ...@@ -157,20 +158,37 @@ def main(args):
if (args.use_precomputed_alignments is None): if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir): if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
if args.enable_workflow:
alignment_runner = data_pipeline.AlignmentRunner( print("Running alignment with ray workflow...")
jackhmmer_binary_path=args.jackhmmer_binary_path, alignment_data_workflow_runner = FastFoldDataWorkFlow(
hhblits_binary_path=args.hhblits_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path, hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path, hhsearch_binary_path=args.hhsearch_binary_path,
mgnify_database_path=args.mgnify_database_path, uniref90_database_path=args.uniref90_database_path,
bfd_database_path=args.bfd_database_path, mgnify_database_path=args.mgnify_database_path,
uniclust30_database_path=args.uniclust30_database_path, bfd_database_path=args.bfd_database_path,
pdb70_database_path=args.pdb70_database_path, uniclust30_database_path=args.uniclust30_database_path,
use_small_bfd=use_small_bfd, pdb70_database_path=args.pdb70_database_path,
no_cpus=args.cpus, use_small_bfd=use_small_bfd,
) no_cpus=args.cpus,
alignment_runner.run(fasta_path, local_alignment_dir) )
t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path, feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir) alignment_dir=local_alignment_dir)
......
python inference.py target.fasta /data/pdb_mmcif/mmcif_files \
--output_dir ./ \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign` \
# --enable_workflow
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment