Unverified Commit 6835c248 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

add multimer workflow (#70)

* add multimer workflow

* support multimer dataworkflow
parent 9ab281fe
...@@ -16,6 +16,8 @@ dependencies: ...@@ -16,6 +16,8 @@ dependencies:
- typing-extensions==3.10.0.2 - typing-extensions==3.10.0.2
- einops - einops
- colossalai - colossalai
- ray==2.0.0
- pyarrow
- pandas - pandas
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113 - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113 - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
......
...@@ -42,7 +42,6 @@ from fastfold.common import residue_constants, protein ...@@ -42,7 +42,6 @@ from fastfold.common import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def empty_template_feats(n_res) -> FeatureDict: def empty_template_feats(n_res) -> FeatureDict:
...@@ -466,12 +465,14 @@ class AlignmentRunnerMultimer: ...@@ -466,12 +465,14 @@ class AlignmentRunnerMultimer:
self, self,
jackhmmer_binary_path: Optional[str] = None, jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None, hhblits_binary_path: Optional[str] = None,
hmmsearch_binary_path: Optional[str] = None,
hmmbuild_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None, uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None, mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None, bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None, uniclust30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None, uniprot_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None, pdb_seqres_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None, use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None, no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000, uniref_max_hits: int = 10000,
...@@ -524,6 +525,12 @@ class AlignmentRunnerMultimer: ...@@ -524,6 +525,12 @@ class AlignmentRunnerMultimer:
bfd_database_path if not use_small_bfd else None, bfd_database_path if not use_small_bfd else None,
], ],
}, },
"hmmsearch": {
"binary": hmmsearch_binary_path,
"dbs": [
pdb_seqres_database_path,
],
},
} }
for name, dic in db_map.items(): for name, dic in db_map.items():
...@@ -585,15 +592,14 @@ class AlignmentRunnerMultimer: ...@@ -585,15 +592,14 @@ class AlignmentRunnerMultimer:
database_path=uniprot_database_path database_path=uniprot_database_path
) )
if(template_searcher is not None and self.hmmsearch_pdb_runner = None
self.jackhmmer_uniref90_runner is None if(pdb_seqres_database_path is not None):
): self.hmmsearch_pdb_runner = hmmsearch.Hmmsearch(
raise ValueError( binary_path=hmmsearch_binary_path,
"Uniref90 runner must be specified to run template search" hmmbuild_binary_path=hmmbuild_binary_path,
database_path=pdb_seqres_database_path,
) )
self.template_searcher = template_searcher
def run( def run(
self, self,
fasta_path: str, fasta_path: str,
...@@ -617,25 +623,11 @@ class AlignmentRunnerMultimer: ...@@ -617,25 +623,11 @@ class AlignmentRunnerMultimer:
template_msa template_msa
) )
if(self.template_searcher is not None): if(self.hmmsearch_pdb_runner is not None):
if(self.template_searcher.input_format == "sto"): pdb_templates_result = self.hmmsearch_pdb_runner.query(
pdb_templates_result = self.template_searcher.query(
template_msa, template_msa,
output_dir=output_dir output_dir=output_dir
) )
elif(self.template_searcher.input_format == "a3m"):
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
template_msa
)
pdb_templates_result = self.template_searcher.query(
uniref90_msa_as_a3m,
output_dir=output_dir
)
else:
fmt = self.template_searcher.input_format
raise ValueError(
f"Unrecognized template input format: {fmt}"
)
if(self.jackhmmer_mgnify_runner is not None): if(self.jackhmmer_mgnify_runner is not None):
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto") mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto")
...@@ -835,7 +827,6 @@ class DataPipeline: ...@@ -835,7 +827,6 @@ class DataPipeline:
for f in os.listdir(alignment_dir): for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f) path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f) filename, ext = os.path.splitext(f)
if(ext == ".a3m"): if(ext == ".a3m"):
with open(path, "r") as fp: with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read()) msa = parsers.parse_a3m(fp.read())
......
...@@ -87,6 +87,14 @@ class HHSearch: ...@@ -87,6 +87,14 @@ class HHSearch:
f"Could not find HHsearch database {database_path}" f"Could not find HHsearch database {database_path}"
) )
@property
def output_format(self) -> str:
return 'hhr'
@property
def input_format(self) -> str:
return 'a3m'
def query(self, a3m: str, gen_atab: bool = False) -> Union[str, tuple]: def query(self, a3m: str, gen_atab: bool = False) -> Union[str, tuple]:
"""Queries the database using HHsearch using a given a3m.""" """Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
......
...@@ -32,6 +32,7 @@ class Hmmsearch(object): ...@@ -32,6 +32,7 @@ class Hmmsearch(object):
binary_path: str, binary_path: str,
hmmbuild_binary_path: str, hmmbuild_binary_path: str,
database_path: str, database_path: str,
n_cpu: int=8,
flags: Optional[Sequence[str]] = None flags: Optional[Sequence[str]] = None
): ):
"""Initializes the Python hmmsearch wrapper. """Initializes the Python hmmsearch wrapper.
...@@ -49,6 +50,7 @@ class Hmmsearch(object): ...@@ -49,6 +50,7 @@ class Hmmsearch(object):
self.binary_path = binary_path self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path) self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path self.database_path = database_path
self.n_cpu = n_cpu
if flags is None: if flags is None:
# Default hmmsearch run settings. # Default hmmsearch run settings.
flags = ['--F1', '0.1', flags = ['--F1', '0.1',
...@@ -95,7 +97,7 @@ class Hmmsearch(object): ...@@ -95,7 +97,7 @@ class Hmmsearch(object):
cmd = [ cmd = [
self.binary_path, self.binary_path,
'--noali', # Don't include the alignment in stdout. '--noali', # Don't include the alignment in stdout.
'--cpu', '8' '--cpu', str(self.n_cpu)
] ]
# If adding flags, we have to do so before the output and input: # If adding flags, we have to do so before the output and input:
if self.flags: if self.flags:
......
...@@ -3,3 +3,4 @@ from .hhblits import HHBlitsFactory ...@@ -3,3 +3,4 @@ from .hhblits import HHBlitsFactory
from .hhsearch import HHSearchFactory from .hhsearch import HHSearchFactory
from .jackhmmer import JackHmmerFactory from .jackhmmer import JackHmmerFactory
from .hhfilter import HHfilterFactory from .hhfilter import HHfilterFactory
from .hmmsearch import HmmSearchFactory
\ No newline at end of file
from typing import List
import inspect
import ray
from ray.dag.function_node import FunctionNode
from fastfold.data.tools import hmmsearch, hmmbuild
from fastfold.data import parsers
from fastfold.workflow.factory import TaskFactory
from typing import Optional
class HmmSearchFactory(TaskFactory):
keywords = ['binary_path', 'hmmbuild_binary_path', 'database_path', 'n_cpu']
def gen_node(self, msa_sto_path: str, output_dir: Optional[str] = None, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
params = { k: self.config.get(k) for k in inspect.getfullargspec(hmmsearch.Hmmsearch.__init__).kwonlyargs if self.config.get(k) }
# setup runner with a filtered config dict
runner = hmmsearch.Hmmsearch(
**params
)
# generate function node
@ray.remote
def hmmsearch_node_func(after: List[FunctionNode]) -> None:
with open(msa_sto_path, "r") as f:
msa_sto = f.read()
msa_sto = parsers.deduplicate_stockholm_msa(msa_sto)
msa_sto = parsers.remove_empty_columns_from_stockholm_msa(
msa_sto
)
hmmsearch_result = runner.query(msa_sto, output_dir=output_dir)
return hmmsearch_node_func.bind(after)
...@@ -13,7 +13,7 @@ class JackHmmerFactory(TaskFactory): ...@@ -13,7 +13,7 @@ class JackHmmerFactory(TaskFactory):
keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits'] keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits']
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode: def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None, output_format: str="a3m") -> FunctionNode:
self.isReady() self.isReady()
...@@ -28,11 +28,17 @@ class JackHmmerFactory(TaskFactory): ...@@ -28,11 +28,17 @@ class JackHmmerFactory(TaskFactory):
@ray.remote @ray.remote
def jackhmmer_node_func(after: List[FunctionNode]) -> None: def jackhmmer_node_func(after: List[FunctionNode]) -> None:
result = runner.query(fasta_path)[0] result = runner.query(fasta_path)[0]
if output_format == "a3m":
uniref90_msa_a3m = parsers.convert_stockholm_to_a3m( uniref90_msa_a3m = parsers.convert_stockholm_to_a3m(
result['sto'], result['sto'],
max_sequences=self.config['uniref_max_hits'] max_sequences=self.config['uniref_max_hits']
) )
with open(output_path, "w") as f: with open(output_path, "w") as f:
f.write(uniref90_msa_a3m) f.write(uniref90_msa_a3m)
elif output_format == "sto":
template_msa = result['sto']
with open(output_path, "w") as f:
f.write(template_msa)
return jackhmmer_node_func.bind(after) return jackhmmer_node_func.bind(after)
from .fastfold_data_workflow import FastFoldDataWorkFlow from .fastfold_data_workflow import FastFoldDataWorkFlow
from .fastfold_multimer_data_workflow import FastFoldMultimerDataWorkFlow
\ No newline at end of file
...@@ -118,11 +118,12 @@ class FastFoldDataWorkFlow: ...@@ -118,11 +118,12 @@ class FastFoldDataWorkFlow:
self.jackhmmer_small_bfd_factory = JackHmmerFactory(config=jh_config) self.jackhmmer_small_bfd_factory = JackHmmerFactory(config=jh_config)
def run(self, fasta_path: str, output_dir: str, alignment_dir: str=None, storage_dir: str=None) -> None: def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
# storage_dir = "file:///tmp/ray/lcmql/workflow_data" storage_dir = "file:///tmp/ray/lcmql/workflow_data"
if storage_dir is not None: if storage_dir is not None:
if not os.path.exists(storage_dir): if not os.path.exists(storage_dir):
os.makedirs(storage_dir) os.makedirs(storage_dir)
if not ray.is_initialized():
ray.init(storage=storage_dir) ray.init(storage=storage_dir)
localtime = time.asctime(time.localtime(time.time())) localtime = time.asctime(time.localtime(time.time()))
...@@ -135,13 +136,6 @@ class FastFoldDataWorkFlow: ...@@ -135,13 +136,6 @@ class FastFoldDataWorkFlow:
print("Workflow not found. Clean. Skipping") print("Workflow not found. Clean. Skipping")
pass pass
# prepare alignment directory for alignment outputs
if alignment_dir is None:
alignment_dir = os.path.join(output_dir, "alignment")
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
# Run JackHmmer on UNIREF90 # Run JackHmmer on UNIREF90
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m") uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
# generate the workflow with i/o path # generate the workflow with i/o path
...@@ -167,7 +161,7 @@ class FastFoldDataWorkFlow: ...@@ -167,7 +161,7 @@ class FastFoldDataWorkFlow:
# Run Jackhmmer on small_bfd # Run Jackhmmer on small_bfd
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m") bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
# generate workflow for STEP4_2 # generate workflow for STEP4_2
bfd_node = self.jackhmmer_small_bfd_factory.gen_node(fasta_path, bfd_out_path) bfd_node = self.jackhmmer_small_bfd_factory.gen_node(fasta_path, bfd_out_path, output_format="sto")
# run workflow # run workflow
batch_run(workflow_id=workflow_id, dags=[hhs_node, mgnify_node, bfd_node]) batch_run(workflow_id=workflow_id, dags=[hhs_node, mgnify_node, bfd_node])
......
import os
import time
from multiprocessing import cpu_count
import ray
from ray import workflow
from fastfold.data.tools import hmmsearch
from fastfold.workflow.factory import JackHmmerFactory, HHBlitsFactory, HmmSearchFactory
from fastfold.workflow import batch_run
from typing import Optional, Union
class FastFoldMultimerDataWorkFlow:
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hmmsearch_binary_path: Optional[str] = None,
hmmbuild_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
pdb_seqres_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
):
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
uniprot_database_path,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hmmsearch": {
"binary": hmmsearch_binary_path,
"dbs": [
pdb_seqres_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hmmsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.use_small_bfd = use_small_bfd
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
if(no_cpus is None):
self.no_cpus = cpu_count()
else:
self.no_cpus = no_cpus
# create JackHmmer workflow generator
self.jackhmmer_uniref90_factory = None
if jackhmmer_binary_path is not None and uniref90_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": uniref90_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": uniref_max_hits,
}
self.jackhmmer_uniref90_factory = JackHmmerFactory(config = jh_config)
# create HMMSearch workflow generator
self.hmmsearch_pdb_factory = None
if pdb_seqres_database_path is not None:
hmm_config = {
"binary_path": db_map["hmmsearch"]["binary"],
"hmmbuild_binary_path": hmmbuild_binary_path,
"database_path": pdb_seqres_database_path,
"n_cpu": self.no_cpus,
}
self.hmmsearch_pdb_factory = HmmSearchFactory(config=hmm_config)
self.jackhmmer_mgnify_factory = None
if jackhmmer_binary_path is not None and mgnify_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": mgnify_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": mgnify_max_hits,
}
self.jackhmmer_mgnify_factory = JackHmmerFactory(config=jh_config)
if bfd_database_path is not None:
if not use_small_bfd:
hhb_config = {
"binary_path": db_map["hhblits"]["binary"],
"databases": db_map["hhblits"]["dbs"],
"n_cpu": self.no_cpus,
}
self.hhblits_bfd_factory = HHBlitsFactory(config=hhb_config)
else:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": bfd_database_path,
"n_cpu": no_cpus,
}
self.jackhmmer_small_bfd_factory = JackHmmerFactory(config=jh_config)
self.jackhmmer_uniprot_factory = None
if jackhmmer_binary_path is not None and uniprot_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": uniprot_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": uniprot_max_hits,
}
self.jackhmmer_uniprot_factory = JackHmmerFactory(config=jh_config)
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/lcmql/workflow_data"
if storage_dir is not None:
if not os.path.exists(storage_dir):
os.makedirs(storage_dir)
if not ray.is_initialized():
ray.init(storage=storage_dir)
localtime = time.asctime(time.localtime(time.time()))
workflow_id = 'fastfold_data_workflow ' + str(localtime)
# clearing remaining ray workflow data
try:
workflow.cancel(workflow_id)
workflow.delete(workflow_id)
except:
print("Workflow not found. Clean. Skipping")
pass
# Run JackHmmer on UNIREF90
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.sto")
# generate the workflow with i/o path
uniref90_node = self.jackhmmer_uniref90_factory.gen_node(fasta_path, uniref90_out_path, output_format="sto")
#Run HmmSearch on STEP1's result with PDB"""
# generate the workflow (STEP2 depend on STEP1)
hmm_node = self.hmmsearch_pdb_factory.gen_node(uniref90_out_path, output_dir=alignment_dir,after=[uniref90_node])
# Run JackHmmer on MGNIFY
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.sto")
# generate workflow for STEP3
mgnify_node = self.jackhmmer_mgnify_factory.gen_node(fasta_path, mgnify_out_path, output_format="sto")
if not self.use_small_bfd:
# Run HHBlits on BFD
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
# generate workflow for STEP4
bfd_node = self.hhblits_bfd_factory.gen_node(fasta_path, bfd_out_path)
else:
# Run Jackhmmer on small_bfd
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.sto")
# generate workflow for STEP4_2
bfd_node = self.jackhmmer_small_bfd_factory.gen_node(fasta_path, bfd_out_path, output_format="sto")
# Run JackHmmer on UNIPROT
uniprot_out_path = os.path.join(alignment_dir, "uniprot_hits.sto")
# generate workflow for STEP5
uniprot_node = self.jackhmmer_uniprot_factory.gen_node(fasta_path, uniprot_out_path, output_format="sto")
# run workflow
batch_run(workflow_id=workflow_id, dags=[hmm_node, mgnify_node, bfd_node, uniprot_node])
return
\ No newline at end of file
...@@ -26,6 +26,7 @@ import numpy as np ...@@ -26,6 +26,7 @@ import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import pickle import pickle
import shutil
from fastfold.model.hub import AlphaFold from fastfold.model.hub import AlphaFold
import fastfold import fastfold
...@@ -35,7 +36,8 @@ from fastfold.config import model_config ...@@ -35,7 +36,8 @@ from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size from fastfold.model.fastnn import set_chunk_size
from fastfold.data import data_pipeline, feature_pipeline, templates from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.tools import hhsearch, hmmsearch from fastfold.data.tools import hhsearch, hmmsearch
from fastfold.workflow.template import FastFoldDataWorkFlow from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow
from fastfold.utils import inject_fastnn from fastfold.utils import inject_fastnn
from fastfold.data.parsers import parse_fasta from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_ from fastfold.utils.import_weights import import_jax_weights_
...@@ -145,15 +147,6 @@ def inference_multimer_model(args): ...@@ -145,15 +147,6 @@ def inference_multimer_model(args):
predict_max_templates = 4 predict_max_templates = 4
if not args.use_precomputed_alignments:
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = None
template_featurizer = templates.HmmsearchHitFeaturizer( template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -164,17 +157,36 @@ def inference_multimer_model(args): ...@@ -164,17 +157,36 @@ def inference_multimer_model(args):
) )
if(not args.use_precomputed_alignments): if(not args.use_precomputed_alignments):
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_runner = FastFoldMultimerDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus
)
else:
alignment_runner = data_pipeline.AlignmentRunnerMultimer( alignment_runner = data_pipeline.AlignmentRunnerMultimer(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path, uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path, mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path, bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path, uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path, uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher, pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None), use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus, no_cpus=args.cpus
) )
else: else:
alignment_runner = None alignment_runner = None
...@@ -221,12 +233,20 @@ def inference_multimer_model(args): ...@@ -221,12 +233,20 @@ def inference_multimer_model(args):
if(args.use_precomputed_alignments is None): if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir): if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
else:
shutil.rmtree(local_alignment_dir)
os.makedirs(local_alignment_dir)
chain_fasta_str = f'>chain_{tag}\n{seq}\n' chain_fasta_str = f'>chain_{tag}\n{seq}\n'
with temp_fasta_file(chain_fasta_str) as chain_fasta_path: with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
alignment_runner.run( if args.enable_workflow:
chain_fasta_path, local_alignment_dir print("Running alignment with ray workflow...")
) t = time.perf_counter()
alignment_runner.run(chain_fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner.run(chain_fasta_path, local_alignment_dir)
print(f"Finished running alignment for {tag}") print(f"Finished running alignment for {tag}")
local_alignment_dir = alignment_dir local_alignment_dir = alignment_dir
...@@ -351,7 +371,7 @@ def inference_monomer_model(args): ...@@ -351,7 +371,7 @@ def inference_monomer_model(args):
no_cpus=args.cpus, no_cpus=args.cpus,
) )
t = time.perf_counter() t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir) alignment_data_workflow_runner.run(fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}") print(f"Alignment data workflow time: {time.perf_counter() - t}")
else: else:
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment