Unverified Commit 369f3e70 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

support ray2.0 (#52)

parent 56a32c3b
......@@ -126,7 +126,7 @@ or run the script `./inference.sh`, you can change the parameter in the script,
#### inference with data workflow
Alphafold's data pre-processing takes a lot of time, so we speed up the data pre-process by [ray](https://docs.ray.io/en/latest/workflows/concepts.html) workflow, which achieves a 3x times faster speed. To run the inference with ray workflow, you should install the package and add parameter `--enable_workflow` to cmdline or shell script `./inference.sh`
```shell
pip install ray==1.13.0 pyarrow
pip install ray==2.0.0 pyarrow
```
```shell
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
......
from ray import workflow
from typing import List
import ray
from ray.dag.function_node import FunctionNode
from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow
import fastfold.data.tools.hhblits as ffHHBlits
class HHBlitsFactory(TaskFactory):
keywords = ['binary_path', 'databases', 'n_cpu']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
# setup runner
runner = ffHHBlits.HHBlits(
binary_path=self.config['binary_path'],
databases=self.config['databases'],
n_cpu=self.config['n_cpu']
**self.config
)
# generate step function
@workflow.step
def hhblits_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:
# generate function node
@ray.remote
def hhblits_node_func(after: List[FunctionNode]) -> None:
result = runner.query(fasta_path)
with open(output_path, "w") as f:
f.write(result["a3m"])
with open(output_path, 'w') as f:
f.write(result['a3m'])
return hhblits_step.step(fasta_path, output_path, after)
return hhblits_node_func.bind(after)
import subprocess
import logging
from ray import workflow
from typing import List
import ray
from ray.dag.function_node import FunctionNode
from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow
class HHfilterFactory(TaskFactory):
keywords = ['binary_path']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
# generate step function
@workflow.step
def hhfilter_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:
# generate function node
@ray.remote
def hhfilter_node_func(after: List[FunctionNode]) -> None:
cmd = [
self.config.get('binary_path'),
......@@ -26,8 +28,6 @@ class HHfilterFactory(TaskFactory):
cmd += ['-cov', str(self.config.get('cov'))]
cmd += ['-i', fasta_path, '-o', output_path]
logging.info(f"HHfilter start: {' '.join(cmd)}")
subprocess.run(cmd)
subprocess.run(cmd, shell=True)
return hhfilter_step.step(fasta_path, output_path, after)
return hhfilter_node_func.bind(after)
\ No newline at end of file
from fastfold.workflow.factory import TaskFactory
from ray import workflow
from ray.workflow.common import Workflow
import fastfold.data.tools.hhsearch as ffHHSearch
from typing import List
import inspect
import ray
from ray.dag.function_node import FunctionNode
import fastfold.data.tools.hhsearch as ffHHSearch
from fastfold.workflow.factory import TaskFactory
class HHSearchFactory(TaskFactory):
keywords = ['binary_path', 'databases', 'n_cpu']
def gen_task(self, a3m_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:
def gen_node(self, a3m_path: str, output_path: str, atab_path: str = None, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
# setup runner
params = { k: self.config.get(k) for k in inspect.getfullargspec(ffHHSearch.HHSearch.__init__).kwonlyargs if self.config.get(k) }
# setup runner with a filtered config dict
runner = ffHHSearch.HHSearch(
binary_path=self.config['binary_path'],
databases=self.config['databases'],
n_cpu=self.config['n_cpu']
**params
)
# generate step function
@workflow.step
def hhsearch_step(a3m_path: str, output_path: str, after: List[Workflow], atab_path: str = None) -> None:
# generate function node
@ray.remote
def hhsearch_node_func(after: List[FunctionNode]) -> None:
with open(a3m_path, "r") as f:
a3m = f.read()
......@@ -35,4 +39,4 @@ class HHSearchFactory(TaskFactory):
with open(atab_path, "w") as f:
f.write(atab)
return hhsearch_step.step(a3m_path, output_path, after)
return hhsearch_node_func.bind(after)
from fastfold.workflow.factory import TaskFactory
from ray import workflow
from ray.workflow.common import Workflow
import inspect
from typing import List
import ray
from ray.dag.function_node import FunctionNode
import fastfold.data.tools.jackhmmer as ffJackHmmer
from fastfold.data import parsers
from typing import List
from fastfold.workflow.factory import TaskFactory
class JackHmmerFactory(TaskFactory):
keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits']
def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
params = { k: self.config.get(k) for k in inspect.getfullargspec(ffJackHmmer.Jackhmmer.__init__).kwonlyargs if self.config.get(k) }
# setup runner
runner = ffJackHmmer.Jackhmmer(
binary_path=self.config['binary_path'],
database_path=self.config['database_path'],
n_cpu=self.config['n_cpu']
**params
)
# generate step function
@workflow.step
def jackhmmer_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:
# generate function node
@ray.remote
def jackhmmer_node_func(after: List[FunctionNode]) -> None:
result = runner.query(fasta_path)[0]
uniref90_msa_a3m = parsers.convert_stockholm_to_a3m(
result['sto'],
......@@ -31,4 +35,4 @@ class JackHmmerFactory(TaskFactory):
with open(output_path, "w") as f:
f.write(uniref90_msa_a3m)
return jackhmmer_step.step(fasta_path, output_path, after)
return jackhmmer_node_func.bind(after)
from ast import keyword
import json
from ray.workflow.common import Workflow
from os import path
from typing import List
import ray
from ray.dag.function_node import FunctionNode
class TaskFactory:
......@@ -31,7 +32,7 @@ class TaskFactory:
def configure(self, keyword: str, value: any) -> None:
self.config[keyword] = value
def gen_task(self, after: List[Workflow]=None, *args, **kwargs) -> Workflow:
def gen_task(self, after: List[FunctionNode]=None, *args, **kwargs) -> FunctionNode:
raise NotImplementedError
def isReady(self):
......
......@@ -98,7 +98,7 @@ class FastFoldDataWorkFlow:
# set jackhmmer output path
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
# generate the workflow with i/o path
wf1 = jh_fac.gen_task(fasta_path, uniref90_out_path)
jh_node_1 = jh_fac.gen_node(fasta_path, uniref90_out_path)
#Run HHSearch on STEP1's result with PDB70"""
# create HHSearch workflow generator
......@@ -111,7 +111,7 @@ class FastFoldDataWorkFlow:
# set HHSearch output path
pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
# generate the workflow (STEP2 depend on STEP1)
wf2 = hhs_fac.gen_task(uniref90_out_path, pdb70_out_path, after=[wf1])
hhs_node = hhs_fac.gen_node(uniref90_out_path, pdb70_out_path, after=[jh_node_1])
# Run JackHmmer on MGNIFY
# reconfigure jackhmmer factory to use MGNIFY DB instead
......@@ -119,7 +119,7 @@ class FastFoldDataWorkFlow:
# set jackhmmer output path
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
# generate workflow for STEP3
wf3 = jh_fac.gen_task(fasta_path, mgnify_out_path)
jh_node_2 = jh_fac.gen_node(fasta_path, mgnify_out_path)
# Run HHBlits on BFD
# create HHBlits workflow generator
......@@ -132,9 +132,9 @@ class FastFoldDataWorkFlow:
# set HHBlits output path
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
# generate workflow for STEP4
wf4 = hhb_fac.gen_task(fasta_path, bfd_out_path)
hhb_node = hhb_fac.gen_node(fasta_path, bfd_out_path)
# run workflow
batch_run(wfs=[wf2, wf3, wf4], workflow_id=workflow_id)
batch_run(workflow_id=workflow_id, dags=[hhs_node, jh_node_2, hhb_node])
return
\ No newline at end of file
from ast import Call
from typing import Callable, List
from ray.workflow.common import Workflow
import ray
from ray.dag.function_node import FunctionNode
from ray import workflow
def batch_run(wfs: List[Workflow], workflow_id: str) -> None:
def batch_run(workflow_id: str, dags: List[FunctionNode]) -> None:
@workflow.step
def batch_step(wfs) -> None:
@ray.remote
def batch_dag_func(dags) -> None:
return
batch_wf = batch_step.step(wfs)
batch_wf.run(workflow_id=workflow_id)
def wf(after: List[Workflow]=None):
def decorator(f: Callable):
@workflow.step
def step_func(after: List[Workflow]) -> None:
f()
return step_func.step(after)
return decorator
batch = batch_dag_func.bind(dags)
workflow.run(batch, workflow_id=workflow_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment