Unverified Commit 6835c248 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

add multimer workflow (#70)

* add multimer workflow

* support multimer dataworkflow
parent 9ab281fe
......@@ -16,6 +16,8 @@ dependencies:
- typing-extensions==3.10.0.2
- einops
- colossalai
- ray==2.0.0
- pyarrow
- pandas
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
......
......@@ -42,7 +42,6 @@ from fastfold.common import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def empty_template_feats(n_res) -> FeatureDict:
......@@ -466,12 +465,14 @@ class AlignmentRunnerMultimer:
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hmmsearch_binary_path: Optional[str] = None,
hmmbuild_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None,
pdb_seqres_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
......@@ -524,6 +525,12 @@ class AlignmentRunnerMultimer:
bfd_database_path if not use_small_bfd else None,
],
},
"hmmsearch": {
"binary": hmmsearch_binary_path,
"dbs": [
pdb_seqres_database_path,
],
},
}
for name, dic in db_map.items():
......@@ -585,14 +592,13 @@ class AlignmentRunnerMultimer:
database_path=uniprot_database_path
)
if(template_searcher is not None and
self.jackhmmer_uniref90_runner is None
):
raise ValueError(
"Uniref90 runner must be specified to run template search"
self.hmmsearch_pdb_runner = None
if(pdb_seqres_database_path is not None):
self.hmmsearch_pdb_runner = hmmsearch.Hmmsearch(
binary_path=hmmsearch_binary_path,
hmmbuild_binary_path=hmmbuild_binary_path,
database_path=pdb_seqres_database_path,
)
self.template_searcher = template_searcher
def run(
self,
......@@ -617,25 +623,11 @@ class AlignmentRunnerMultimer:
template_msa
)
if(self.template_searcher is not None):
if(self.template_searcher.input_format == "sto"):
pdb_templates_result = self.template_searcher.query(
template_msa,
output_dir=output_dir
)
elif(self.template_searcher.input_format == "a3m"):
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
template_msa
)
pdb_templates_result = self.template_searcher.query(
uniref90_msa_as_a3m,
output_dir=output_dir
)
else:
fmt = self.template_searcher.input_format
raise ValueError(
f"Unrecognized template input format: {fmt}"
)
if(self.hmmsearch_pdb_runner is not None):
pdb_templates_result = self.hmmsearch_pdb_runner.query(
template_msa,
output_dir=output_dir
)
if(self.jackhmmer_mgnify_runner is not None):
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto")
......@@ -835,7 +827,6 @@ class DataPipeline:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)
if(ext == ".a3m"):
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
......
......@@ -86,6 +86,14 @@ class HHSearch:
raise ValueError(
f"Could not find HHsearch database {database_path}"
)
@property
def output_format(self) -> str:
return 'hhr'
@property
def input_format(self) -> str:
return 'a3m'
def query(self, a3m: str, gen_atab: bool = False) -> Union[str, tuple]:
"""Queries the database using HHsearch using a given a3m."""
......
......@@ -32,6 +32,7 @@ class Hmmsearch(object):
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
n_cpu: int=8,
flags: Optional[Sequence[str]] = None
):
"""Initializes the Python hmmsearch wrapper.
......@@ -49,6 +50,7 @@ class Hmmsearch(object):
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path
self.n_cpu = n_cpu
if flags is None:
# Default hmmsearch run settings.
flags = ['--F1', '0.1',
......@@ -95,7 +97,7 @@ class Hmmsearch(object):
cmd = [
self.binary_path,
'--noali', # Don't include the alignment in stdout.
'--cpu', '8'
'--cpu', str(self.n_cpu)
]
# If adding flags, we have to do so before the output and input:
if self.flags:
......
......@@ -2,4 +2,5 @@ from .task_factory import TaskFactory
from .hhblits import HHBlitsFactory
from .hhsearch import HHSearchFactory
from .jackhmmer import JackHmmerFactory
from .hhfilter import HHfilterFactory
\ No newline at end of file
from .hhfilter import HHfilterFactory
from .hmmsearch import HmmSearchFactory
\ No newline at end of file
from typing import List
import inspect
import ray
from ray.dag.function_node import FunctionNode
from fastfold.data.tools import hmmsearch, hmmbuild
from fastfold.data import parsers
from fastfold.workflow.factory import TaskFactory
from typing import Optional
class HmmSearchFactory(TaskFactory):
keywords = ['binary_path', 'hmmbuild_binary_path', 'database_path', 'n_cpu']
def gen_node(self, msa_sto_path: str, output_dir: Optional[str] = None, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
params = { k: self.config.get(k) for k in inspect.getfullargspec(hmmsearch.Hmmsearch.__init__).kwonlyargs if self.config.get(k) }
# setup runner with a filtered config dict
runner = hmmsearch.Hmmsearch(
**params
)
# generate function node
@ray.remote
def hmmsearch_node_func(after: List[FunctionNode]) -> None:
with open(msa_sto_path, "r") as f:
msa_sto = f.read()
msa_sto = parsers.deduplicate_stockholm_msa(msa_sto)
msa_sto = parsers.remove_empty_columns_from_stockholm_msa(
msa_sto
)
hmmsearch_result = runner.query(msa_sto, output_dir=output_dir)
return hmmsearch_node_func.bind(after)
......@@ -13,7 +13,7 @@ class JackHmmerFactory(TaskFactory):
keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits']
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None, output_format: str="a3m") -> FunctionNode:
self.isReady()
......@@ -28,11 +28,17 @@ class JackHmmerFactory(TaskFactory):
@ray.remote
def jackhmmer_node_func(after: List[FunctionNode]) -> None:
result = runner.query(fasta_path)[0]
uniref90_msa_a3m = parsers.convert_stockholm_to_a3m(
result['sto'],
max_sequences=self.config['uniref_max_hits']
)
with open(output_path, "w") as f:
f.write(uniref90_msa_a3m)
if output_format == "a3m":
uniref90_msa_a3m = parsers.convert_stockholm_to_a3m(
result['sto'],
max_sequences=self.config['uniref_max_hits']
)
with open(output_path, "w") as f:
f.write(uniref90_msa_a3m)
elif output_format == "sto":
template_msa = result['sto']
with open(output_path, "w") as f:
f.write(template_msa)
return jackhmmer_node_func.bind(after)
from .fastfold_data_workflow import FastFoldDataWorkFlow
\ No newline at end of file
from .fastfold_data_workflow import FastFoldDataWorkFlow
from .fastfold_multimer_data_workflow import FastFoldMultimerDataWorkFlow
\ No newline at end of file
......@@ -118,12 +118,13 @@ class FastFoldDataWorkFlow:
self.jackhmmer_small_bfd_factory = JackHmmerFactory(config=jh_config)
def run(self, fasta_path: str, output_dir: str, alignment_dir: str=None, storage_dir: str=None) -> None:
# storage_dir = "file:///tmp/ray/lcmql/workflow_data"
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/lcmql/workflow_data"
if storage_dir is not None:
if not os.path.exists(storage_dir):
os.makedirs(storage_dir)
ray.init(storage=storage_dir)
if not ray.is_initialized():
ray.init(storage=storage_dir)
localtime = time.asctime(time.localtime(time.time()))
workflow_id = 'fastfold_data_workflow ' + str(localtime)
......@@ -135,13 +136,6 @@ class FastFoldDataWorkFlow:
print("Workflow not found. Clean. Skipping")
pass
# prepare alignment directory for alignment outputs
if alignment_dir is None:
alignment_dir = os.path.join(output_dir, "alignment")
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
# Run JackHmmer on UNIREF90
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
# generate the workflow with i/o path
......@@ -167,7 +161,7 @@ class FastFoldDataWorkFlow:
# Run Jackhmmer on small_bfd
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
# generate workflow for STEP4_2
bfd_node = self.jackhmmer_small_bfd_factory.gen_node(fasta_path, bfd_out_path)
bfd_node = self.jackhmmer_small_bfd_factory.gen_node(fasta_path, bfd_out_path, output_format="sto")
# run workflow
batch_run(workflow_id=workflow_id, dags=[hhs_node, mgnify_node, bfd_node])
......
import os
import time
from multiprocessing import cpu_count
import ray
from ray import workflow
from fastfold.data.tools import hmmsearch
from fastfold.workflow.factory import JackHmmerFactory, HHBlitsFactory, HmmSearchFactory
from fastfold.workflow import batch_run
from typing import Optional, Union
class FastFoldMultimerDataWorkFlow:
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hmmsearch_binary_path: Optional[str] = None,
hmmbuild_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
pdb_seqres_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
):
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
uniprot_database_path,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hmmsearch": {
"binary": hmmsearch_binary_path,
"dbs": [
pdb_seqres_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hmmsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.use_small_bfd = use_small_bfd
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
if(no_cpus is None):
self.no_cpus = cpu_count()
else:
self.no_cpus = no_cpus
# create JackHmmer workflow generator
self.jackhmmer_uniref90_factory = None
if jackhmmer_binary_path is not None and uniref90_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": uniref90_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": uniref_max_hits,
}
self.jackhmmer_uniref90_factory = JackHmmerFactory(config = jh_config)
# create HMMSearch workflow generator
self.hmmsearch_pdb_factory = None
if pdb_seqres_database_path is not None:
hmm_config = {
"binary_path": db_map["hmmsearch"]["binary"],
"hmmbuild_binary_path": hmmbuild_binary_path,
"database_path": pdb_seqres_database_path,
"n_cpu": self.no_cpus,
}
self.hmmsearch_pdb_factory = HmmSearchFactory(config=hmm_config)
self.jackhmmer_mgnify_factory = None
if jackhmmer_binary_path is not None and mgnify_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": mgnify_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": mgnify_max_hits,
}
self.jackhmmer_mgnify_factory = JackHmmerFactory(config=jh_config)
if bfd_database_path is not None:
if not use_small_bfd:
hhb_config = {
"binary_path": db_map["hhblits"]["binary"],
"databases": db_map["hhblits"]["dbs"],
"n_cpu": self.no_cpus,
}
self.hhblits_bfd_factory = HHBlitsFactory(config=hhb_config)
else:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": bfd_database_path,
"n_cpu": no_cpus,
}
self.jackhmmer_small_bfd_factory = JackHmmerFactory(config=jh_config)
self.jackhmmer_uniprot_factory = None
if jackhmmer_binary_path is not None and uniprot_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": uniprot_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": uniprot_max_hits,
}
self.jackhmmer_uniprot_factory = JackHmmerFactory(config=jh_config)
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/lcmql/workflow_data"
if storage_dir is not None:
if not os.path.exists(storage_dir):
os.makedirs(storage_dir)
if not ray.is_initialized():
ray.init(storage=storage_dir)
localtime = time.asctime(time.localtime(time.time()))
workflow_id = 'fastfold_data_workflow ' + str(localtime)
# clearing remaining ray workflow data
try:
workflow.cancel(workflow_id)
workflow.delete(workflow_id)
except:
print("Workflow not found. Clean. Skipping")
pass
# Run JackHmmer on UNIREF90
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.sto")
# generate the workflow with i/o path
uniref90_node = self.jackhmmer_uniref90_factory.gen_node(fasta_path, uniref90_out_path, output_format="sto")
#Run HmmSearch on STEP1's result with PDB"""
# generate the workflow (STEP2 depend on STEP1)
hmm_node = self.hmmsearch_pdb_factory.gen_node(uniref90_out_path, output_dir=alignment_dir,after=[uniref90_node])
# Run JackHmmer on MGNIFY
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.sto")
# generate workflow for STEP3
mgnify_node = self.jackhmmer_mgnify_factory.gen_node(fasta_path, mgnify_out_path, output_format="sto")
if not self.use_small_bfd:
# Run HHBlits on BFD
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
# generate workflow for STEP4
bfd_node = self.hhblits_bfd_factory.gen_node(fasta_path, bfd_out_path)
else:
# Run Jackhmmer on small_bfd
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.sto")
# generate workflow for STEP4_2
bfd_node = self.jackhmmer_small_bfd_factory.gen_node(fasta_path, bfd_out_path, output_format="sto")
# Run JackHmmer on UNIPROT
uniprot_out_path = os.path.join(alignment_dir, "uniprot_hits.sto")
# generate workflow for STEP5
uniprot_node = self.jackhmmer_uniprot_factory.gen_node(fasta_path, uniprot_out_path, output_format="sto")
# run workflow
batch_run(workflow_id=workflow_id, dags=[hmm_node, mgnify_node, bfd_node, uniprot_node])
return
\ No newline at end of file
......@@ -26,6 +26,7 @@ import numpy as np
import torch
import torch.multiprocessing as mp
import pickle
import shutil
from fastfold.model.hub import AlphaFold
import fastfold
......@@ -35,7 +36,8 @@ from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.tools import hhsearch, hmmsearch
from fastfold.workflow.template import FastFoldDataWorkFlow
from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow
from fastfold.utils import inject_fastnn
from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_
......@@ -145,15 +147,6 @@ def inference_multimer_model(args):
predict_max_templates = 4
if not args.use_precomputed_alignments:
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = None
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
......@@ -164,18 +157,37 @@ def inference_multimer_model(args):
)
if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunnerMultimer(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_runner = FastFoldMultimerDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus
)
else:
alignment_runner = data_pipeline.AlignmentRunnerMultimer(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus
)
else:
alignment_runner = None
......@@ -221,12 +233,20 @@ def inference_multimer_model(args):
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
else:
shutil.rmtree(local_alignment_dir)
os.makedirs(local_alignment_dir)
chain_fasta_str = f'>chain_{tag}\n{seq}\n'
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
alignment_runner.run(
chain_fasta_path, local_alignment_dir
)
if args.enable_workflow:
print("Running alignment with ray workflow...")
t = time.perf_counter()
alignment_runner.run(chain_fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner.run(chain_fasta_path, local_alignment_dir)
print(f"Finished running alignment for {tag}")
local_alignment_dir = alignment_dir
......@@ -351,7 +371,7 @@ def inference_monomer_model(args):
no_cpus=args.cpus,
)
t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir)
alignment_data_workflow_runner.run(fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner = data_pipeline.AlignmentRunner(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment