Unverified Commit 369f3e70 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

support ray2.0 (#52)

parent 56a32c3b
...@@ -126,7 +126,7 @@ or run the script `./inference.sh`, you can change the parameter in the script, ...@@ -126,7 +126,7 @@ or run the script `./inference.sh`, you can change the parameter in the script,
#### inference with data workflow #### inference with data workflow
Alphafold's data pre-processing takes a lot of time, so we speed up the data pre-process by [ray](https://docs.ray.io/en/latest/workflows/concepts.html) workflow, which achieves a 3x times faster speed. To run the inference with ray workflow, you should install the package and add parameter `--enable_workflow` to cmdline or shell script `./inference.sh` Alphafold's data pre-processing takes a lot of time, so we speed up the data pre-process by [ray](https://docs.ray.io/en/latest/workflows/concepts.html) workflow, which achieves a 3x times faster speed. To run the inference with ray workflow, you should install the package and add parameter `--enable_workflow` to cmdline or shell script `./inference.sh`
```shell ```shell
pip install ray==1.13.0 pyarrow pip install ray==2.0.0 pyarrow
``` ```
```shell ```shell
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
......
from ray import workflow
from typing import List from typing import List
import ray
from ray.dag.function_node import FunctionNode
from fastfold.workflow.factory import TaskFactory from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow
import fastfold.data.tools.hhblits as ffHHBlits import fastfold.data.tools.hhblits as ffHHBlits
class HHBlitsFactory(TaskFactory): class HHBlitsFactory(TaskFactory):
keywords = ['binary_path', 'databases', 'n_cpu'] keywords = ['binary_path', 'databases', 'n_cpu']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady() self.isReady()
# setup runner # setup runner
runner = ffHHBlits.HHBlits( runner = ffHHBlits.HHBlits(
binary_path=self.config['binary_path'], **self.config
databases=self.config['databases'],
n_cpu=self.config['n_cpu']
) )
# generate step function # generate function node
@workflow.step @ray.remote
def hhblits_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: def hhblits_node_func(after: List[FunctionNode]) -> None:
result = runner.query(fasta_path) result = runner.query(fasta_path)
with open(output_path, "w") as f: with open(output_path, 'w') as f:
f.write(result["a3m"]) f.write(result['a3m'])
return hhblits_step.step(fasta_path, output_path, after) return hhblits_node_func.bind(after)
import subprocess import subprocess
import logging import logging
from ray import workflow
from typing import List from typing import List
import ray
from ray.dag.function_node import FunctionNode
from fastfold.workflow.factory import TaskFactory from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow
class HHfilterFactory(TaskFactory): class HHfilterFactory(TaskFactory):
keywords = ['binary_path'] keywords = ['binary_path']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady() self.isReady()
# generate step function # generate function node
@workflow.step @ray.remote
def hhfilter_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: def hhfilter_node_func(after: List[FunctionNode]) -> None:
cmd = [ cmd = [
self.config.get('binary_path'), self.config.get('binary_path'),
...@@ -26,8 +28,6 @@ class HHfilterFactory(TaskFactory): ...@@ -26,8 +28,6 @@ class HHfilterFactory(TaskFactory):
cmd += ['-cov', str(self.config.get('cov'))] cmd += ['-cov', str(self.config.get('cov'))]
cmd += ['-i', fasta_path, '-o', output_path] cmd += ['-i', fasta_path, '-o', output_path]
logging.info(f"HHfilter start: {' '.join(cmd)}") subprocess.run(cmd, shell=True)
subprocess.run(cmd)
return hhfilter_step.step(fasta_path, output_path, after) return hhfilter_node_func.bind(after)
\ No newline at end of file
from fastfold.workflow.factory import TaskFactory
from ray import workflow
from ray.workflow.common import Workflow
import fastfold.data.tools.hhsearch as ffHHSearch
from typing import List from typing import List
import inspect
import ray
from ray.dag.function_node import FunctionNode
import fastfold.data.tools.hhsearch as ffHHSearch
from fastfold.workflow.factory import TaskFactory
class HHSearchFactory(TaskFactory): class HHSearchFactory(TaskFactory):
keywords = ['binary_path', 'databases', 'n_cpu'] keywords = ['binary_path', 'databases', 'n_cpu']
def gen_task(self, a3m_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: def gen_node(self, a3m_path: str, output_path: str, atab_path: str = None, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady() self.isReady()
params = { k: self.config.get(k) for k in inspect.getfullargspec(ffHHSearch.HHSearch.__init__).kwonlyargs if self.config.get(k) }
# setup runner # setup runner with a filtered config dict
runner = ffHHSearch.HHSearch( runner = ffHHSearch.HHSearch(
binary_path=self.config['binary_path'], **params
databases=self.config['databases'],
n_cpu=self.config['n_cpu']
) )
# generate step function # generate function node
@workflow.step @ray.remote
def hhsearch_step(a3m_path: str, output_path: str, after: List[Workflow], atab_path: str = None) -> None: def hhsearch_node_func(after: List[FunctionNode]) -> None:
with open(a3m_path, "r") as f: with open(a3m_path, "r") as f:
a3m = f.read() a3m = f.read()
...@@ -35,4 +39,4 @@ class HHSearchFactory(TaskFactory): ...@@ -35,4 +39,4 @@ class HHSearchFactory(TaskFactory):
with open(atab_path, "w") as f: with open(atab_path, "w") as f:
f.write(atab) f.write(atab)
return hhsearch_step.step(a3m_path, output_path, after) return hhsearch_node_func.bind(after)
from fastfold.workflow.factory import TaskFactory import inspect
from ray import workflow from typing import List
from ray.workflow.common import Workflow
import ray
from ray.dag.function_node import FunctionNode
import fastfold.data.tools.jackhmmer as ffJackHmmer import fastfold.data.tools.jackhmmer as ffJackHmmer
from fastfold.data import parsers from fastfold.data import parsers
from typing import List from fastfold.workflow.factory import TaskFactory
class JackHmmerFactory(TaskFactory): class JackHmmerFactory(TaskFactory):
keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits'] keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady() self.isReady()
params = { k: self.config.get(k) for k in inspect.getfullargspec(ffJackHmmer.Jackhmmer.__init__).kwonlyargs if self.config.get(k) }
# setup runner # setup runner
runner = ffJackHmmer.Jackhmmer( runner = ffJackHmmer.Jackhmmer(
binary_path=self.config['binary_path'], **params
database_path=self.config['database_path'],
n_cpu=self.config['n_cpu']
) )
# generate step function # generate function node
@workflow.step @ray.remote
def jackhmmer_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: def jackhmmer_node_func(after: List[FunctionNode]) -> None:
result = runner.query(fasta_path)[0] result = runner.query(fasta_path)[0]
uniref90_msa_a3m = parsers.convert_stockholm_to_a3m( uniref90_msa_a3m = parsers.convert_stockholm_to_a3m(
result['sto'], result['sto'],
...@@ -31,4 +35,4 @@ class JackHmmerFactory(TaskFactory): ...@@ -31,4 +35,4 @@ class JackHmmerFactory(TaskFactory):
with open(output_path, "w") as f: with open(output_path, "w") as f:
f.write(uniref90_msa_a3m) f.write(uniref90_msa_a3m)
return jackhmmer_step.step(fasta_path, output_path, after) return jackhmmer_node_func.bind(after)
from ast import keyword from ast import keyword
import json import json
from ray.workflow.common import Workflow
from os import path from os import path
from typing import List from typing import List
import ray
from ray.dag.function_node import FunctionNode
class TaskFactory: class TaskFactory:
...@@ -31,7 +32,7 @@ class TaskFactory: ...@@ -31,7 +32,7 @@ class TaskFactory:
def configure(self, keyword: str, value: any) -> None: def configure(self, keyword: str, value: any) -> None:
self.config[keyword] = value self.config[keyword] = value
def gen_task(self, after: List[Workflow]=None, *args, **kwargs) -> Workflow: def gen_task(self, after: List[FunctionNode]=None, *args, **kwargs) -> FunctionNode:
raise NotImplementedError raise NotImplementedError
def isReady(self): def isReady(self):
......
...@@ -98,7 +98,7 @@ class FastFoldDataWorkFlow: ...@@ -98,7 +98,7 @@ class FastFoldDataWorkFlow:
# set jackhmmer output path # set jackhmmer output path
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m") uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
# generate the workflow with i/o path # generate the workflow with i/o path
wf1 = jh_fac.gen_task(fasta_path, uniref90_out_path) jh_node_1 = jh_fac.gen_node(fasta_path, uniref90_out_path)
#Run HHSearch on STEP1's result with PDB70""" #Run HHSearch on STEP1's result with PDB70"""
# create HHSearch workflow generator # create HHSearch workflow generator
...@@ -111,7 +111,7 @@ class FastFoldDataWorkFlow: ...@@ -111,7 +111,7 @@ class FastFoldDataWorkFlow:
# set HHSearch output path # set HHSearch output path
pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr") pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
# generate the workflow (STEP2 depend on STEP1) # generate the workflow (STEP2 depend on STEP1)
wf2 = hhs_fac.gen_task(uniref90_out_path, pdb70_out_path, after=[wf1]) hhs_node = hhs_fac.gen_node(uniref90_out_path, pdb70_out_path, after=[jh_node_1])
# Run JackHmmer on MGNIFY # Run JackHmmer on MGNIFY
# reconfigure jackhmmer factory to use MGNIFY DB instead # reconfigure jackhmmer factory to use MGNIFY DB instead
...@@ -119,7 +119,7 @@ class FastFoldDataWorkFlow: ...@@ -119,7 +119,7 @@ class FastFoldDataWorkFlow:
# set jackhmmer output path # set jackhmmer output path
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m") mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
# generate workflow for STEP3 # generate workflow for STEP3
wf3 = jh_fac.gen_task(fasta_path, mgnify_out_path) jh_node_2 = jh_fac.gen_node(fasta_path, mgnify_out_path)
# Run HHBlits on BFD # Run HHBlits on BFD
# create HHBlits workflow generator # create HHBlits workflow generator
...@@ -132,9 +132,9 @@ class FastFoldDataWorkFlow: ...@@ -132,9 +132,9 @@ class FastFoldDataWorkFlow:
# set HHBlits output path # set HHBlits output path
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m") bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
# generate workflow for STEP4 # generate workflow for STEP4
wf4 = hhb_fac.gen_task(fasta_path, bfd_out_path) hhb_node = hhb_fac.gen_node(fasta_path, bfd_out_path)
# run workflow # run workflow
batch_run(wfs=[wf2, wf3, wf4], workflow_id=workflow_id) batch_run(workflow_id=workflow_id, dags=[hhs_node, jh_node_2, hhb_node])
return return
\ No newline at end of file
from ast import Call from ast import Call
from typing import Callable, List from typing import Callable, List
from ray.workflow.common import Workflow import ray
from ray.dag.function_node import FunctionNode
from ray import workflow from ray import workflow
def batch_run(wfs: List[Workflow], workflow_id: str) -> None: def batch_run(workflow_id: str, dags: List[FunctionNode]) -> None:
@workflow.step @ray.remote
def batch_step(wfs) -> None: def batch_dag_func(dags) -> None:
return return
batch_wf = batch_step.step(wfs) batch = batch_dag_func.bind(dags)
workflow.run(batch, workflow_id=workflow_id)
batch_wf.run(workflow_id=workflow_id)
def wf(after: List[Workflow]=None):
def decorator(f: Callable):
@workflow.step
def step_func(after: List[Workflow]) -> None:
f()
return step_func.step(after)
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment