Commit b14e47f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/hpcaitech/FastFold

parents 490cb6f5 05681304
Pipeline #234 failed with stages
in 0 seconds
import inspect
from typing import List
import ray
from ray.dag.function_node import FunctionNode
import fastfold.data.tools.jackhmmer as ffJackHmmer
from fastfold.data import parsers
from fastfold.workflow.factory import TaskFactory
class JackHmmerFactory(TaskFactory):
keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits']
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None, output_format: str="a3m") -> FunctionNode:
self.isReady()
params = { k: self.config.get(k) for k in inspect.getfullargspec(ffJackHmmer.Jackhmmer.__init__).kwonlyargs if self.config.get(k) }
# setup runner
runner = ffJackHmmer.Jackhmmer(
**params
)
# generate function node
@ray.remote
def jackhmmer_node_func(after: List[FunctionNode]) -> None:
result = runner.query(fasta_path)[0]
if output_format == "a3m":
uniref90_msa_a3m = parsers.convert_stockholm_to_a3m(
result['sto'],
max_sequences=self.config['uniref_max_hits']
)
with open(output_path, "w") as f:
f.write(uniref90_msa_a3m)
elif output_format == "sto":
template_msa = result['sto']
with open(output_path, "w") as f:
f.write(template_msa)
return jackhmmer_node_func.bind(after)
from ast import keyword
import json
from os import path
from typing import List
import ray
from ray.dag.function_node import FunctionNode
class TaskFactory:
keywords = []
def __init__(self, config: dict = None, config_path: str = None) -> None:
# skip if no keyword required from config file
if not self.__class__.keywords:
return
# setting config for factory
if config is not None:
self.config = config
elif config_path is not None:
self.loadConfig(config_path)
else:
self.loadConfig()
def configure(self, config: dict, purge=False) -> None:
if purge:
self.config = config
else:
self.config.update(config)
def configure(self, keyword: str, value: any) -> None:
self.config[keyword] = value
def gen_task(self, after: List[FunctionNode]=None, *args, **kwargs) -> FunctionNode:
raise NotImplementedError
def isReady(self):
for key in self.__class__.keywords:
if key not in self.config:
raise KeyError(f"{self.__class__.__name__} not ready: \"{key}\" not specified")
def loadConfig(self, config_path='./config.json'):
with open(config_path) as configFile:
globalConfig = json.load(configFile)
if 'tools' not in globalConfig:
raise KeyError("\"tools\" not found in global config file")
factoryName = self.__class__.__name__[:-7]
if factoryName not in globalConfig['tools']:
raise KeyError(f"\"{factoryName}\" not found in the \"tools\" section in config")
self.config = globalConfig['tools'][factoryName]
\ No newline at end of file
from .fastfold_data_workflow import FastFoldDataWorkFlow
from .fastfold_multimer_data_workflow import FastFoldMultimerDataWorkFlow
\ No newline at end of file
import os
import time
from multiprocessing import cpu_count
import ray
from ray import workflow
from fastfold.workflow.factory import JackHmmerFactory, HHSearchFactory, HHBlitsFactory
from fastfold.workflow import batch_run
from typing import Optional
class FastFoldDataWorkFlow:
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hhsearch_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniref30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
):
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hhsearch": {
"binary": hhsearch_binary_path,
"dbs": [
pdb70_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.use_small_bfd = use_small_bfd
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
if(no_cpus is None):
self.no_cpus = cpu_count()
else:
self.no_cpus = no_cpus
# create JackHmmer workflow generator
self.jackhmmer_uniref90_factory = None
if jackhmmer_binary_path is not None and uniref90_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": uniref90_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": uniref_max_hits,
}
self.jackhmmer_uniref90_factory = JackHmmerFactory(config = jh_config)
# create HHSearch workflow generator
self.hhsearch_pdb_factory = None
if pdb70_database_path is not None:
hhs_config = {
"binary_path": db_map["hhsearch"]["binary"],
"databases": db_map["hhsearch"]["dbs"],
"n_cpu": self.no_cpus,
}
self.hhsearch_pdb_factory = HHSearchFactory(config=hhs_config)
self.jackhmmer_mgnify_factory = None
if jackhmmer_binary_path is not None and mgnify_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": mgnify_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": mgnify_max_hits,
}
self.jackhmmer_mgnify_factory = JackHmmerFactory(config=jh_config)
if bfd_database_path is not None:
if not use_small_bfd:
hhb_config = {
"binary_path": db_map["hhblits"]["binary"],
"databases": db_map["hhblits"]["dbs"],
"n_cpu": self.no_cpus,
}
self.hhblits_bfd_factory = HHBlitsFactory(config=hhb_config)
else:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": bfd_database_path,
"n_cpu": no_cpus,
}
self.jackhmmer_small_bfd_factory = JackHmmerFactory(config=jh_config)
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
storage_dir = "file:///tmp/ray/" + str(timestamp) + "/workflow_data"
if storage_dir is not None:
if not os.path.exists(storage_dir):
os.makedirs(storage_dir[7:], exist_ok=True)
if not ray.is_initialized():
ray.init(storage=storage_dir)
localtime = time.asctime(time.localtime(time.time()))
workflow_id = 'fastfold_data_workflow ' + str(localtime)
# clearing remaining ray workflow data
try:
workflow.cancel(workflow_id)
workflow.delete(workflow_id)
except:
print("Workflow not found. Clean. Skipping")
pass
# Run JackHmmer on UNIREF90
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
# generate the workflow with i/o path
uniref90_node = self.jackhmmer_uniref90_factory.gen_node(fasta_path, uniref90_out_path)
#Run HHSearch on STEP1's result with PDB70"""
pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
# generate the workflow (STEP2 depend on STEP1)
hhs_node = self.hhsearch_pdb_factory.gen_node(uniref90_out_path, pdb70_out_path, after=[uniref90_node])
# Run JackHmmer on MGNIFY
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
# generate workflow for STEP3
mgnify_node = self.jackhmmer_mgnify_factory.gen_node(fasta_path, mgnify_out_path)
if not self.use_small_bfd:
# Run HHBlits on BFD
bfd_out_path = os.path.join(alignment_dir, "bfd_uniref_hits.a3m")
# generate workflow for STEP4
bfd_node = self.hhblits_bfd_factory.gen_node(fasta_path, bfd_out_path)
else:
# Run Jackhmmer on small_bfd
bfd_out_path = os.path.join(alignment_dir, "bfd_uniref_hits.a3m")
# generate workflow for STEP4_2
bfd_node = self.jackhmmer_small_bfd_factory.gen_node(fasta_path, bfd_out_path, output_format="sto")
# run workflow
batch_run(workflow_id=workflow_id, dags=[hhs_node, mgnify_node, bfd_node])
return
\ No newline at end of file
import os
import time
from multiprocessing import cpu_count
import ray
from ray import workflow
from fastfold.data.tools import hmmsearch
from fastfold.workflow.factory import JackHmmerFactory, HHBlitsFactory, HmmSearchFactory
from fastfold.workflow import batch_run
from typing import Optional, Union
class FastFoldMultimerDataWorkFlow:
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hmmsearch_binary_path: Optional[str] = None,
hmmbuild_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniref30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
pdb_seqres_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
):
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
uniprot_database_path,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hmmsearch": {
"binary": hmmsearch_binary_path,
"dbs": [
pdb_seqres_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hmmsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.use_small_bfd = use_small_bfd
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
if(no_cpus is None):
self.no_cpus = cpu_count()
else:
self.no_cpus = no_cpus
# create JackHmmer workflow generator
self.jackhmmer_uniref90_factory = None
if jackhmmer_binary_path is not None and uniref90_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": uniref90_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": uniref_max_hits,
}
self.jackhmmer_uniref90_factory = JackHmmerFactory(config = jh_config)
# create HMMSearch workflow generator
self.hmmsearch_pdb_factory = None
if pdb_seqres_database_path is not None:
hmm_config = {
"binary_path": db_map["hmmsearch"]["binary"],
"hmmbuild_binary_path": hmmbuild_binary_path,
"database_path": pdb_seqres_database_path,
"n_cpu": self.no_cpus,
}
self.hmmsearch_pdb_factory = HmmSearchFactory(config=hmm_config)
self.jackhmmer_mgnify_factory = None
if jackhmmer_binary_path is not None and mgnify_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": mgnify_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": mgnify_max_hits,
}
self.jackhmmer_mgnify_factory = JackHmmerFactory(config=jh_config)
if bfd_database_path is not None:
if not use_small_bfd:
hhb_config = {
"binary_path": db_map["hhblits"]["binary"],
"databases": db_map["hhblits"]["dbs"],
"n_cpu": self.no_cpus,
}
self.hhblits_bfd_factory = HHBlitsFactory(config=hhb_config)
else:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": bfd_database_path,
"n_cpu": no_cpus,
}
self.jackhmmer_small_bfd_factory = JackHmmerFactory(config=jh_config)
self.jackhmmer_uniprot_factory = None
if jackhmmer_binary_path is not None and uniprot_database_path is not None:
jh_config = {
"binary_path": db_map["jackhmmer"]["binary"],
"database_path": uniprot_database_path,
"n_cpu": no_cpus,
"uniref_max_hits": uniprot_max_hits,
}
self.jackhmmer_uniprot_factory = JackHmmerFactory(config=jh_config)
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
storage_dir = "file:///tmp/ray/" + str(timestamp) + "/workflow_data"
if storage_dir is not None:
if not os.path.exists(storage_dir):
os.makedirs(storage_dir[7:], exist_ok=True)
if not ray.is_initialized():
ray.init(storage=storage_dir)
localtime = time.asctime(time.localtime(time.time()))
workflow_id = 'fastfold_data_workflow ' + str(localtime)
# clearing remaining ray workflow data
try:
workflow.cancel(workflow_id)
workflow.delete(workflow_id)
except:
print("Workflow not found. Clean. Skipping")
pass
# Run JackHmmer on UNIREF90
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.sto")
# generate the workflow with i/o path
uniref90_node = self.jackhmmer_uniref90_factory.gen_node(fasta_path, uniref90_out_path, output_format="sto")
#Run HmmSearch on STEP1's result with PDB"""
# generate the workflow (STEP2 depend on STEP1)
hmm_node = self.hmmsearch_pdb_factory.gen_node(uniref90_out_path, output_dir=alignment_dir,after=[uniref90_node])
# Run JackHmmer on MGNIFY
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.sto")
# generate workflow for STEP3
mgnify_node = self.jackhmmer_mgnify_factory.gen_node(fasta_path, mgnify_out_path, output_format="sto")
if not self.use_small_bfd:
# Run HHBlits on BFD
bfd_out_path = os.path.join(alignment_dir, "bfd_uniref_hits.a3m")
# generate workflow for STEP4
bfd_node = self.hhblits_bfd_factory.gen_node(fasta_path, bfd_out_path)
else:
# Run Jackhmmer on small_bfd
bfd_out_path = os.path.join(alignment_dir, "bfd_uniref_hits.sto")
# generate workflow for STEP4_2
bfd_node = self.jackhmmer_small_bfd_factory.gen_node(fasta_path, bfd_out_path, output_format="sto")
# Run JackHmmer on UNIPROT
uniprot_out_path = os.path.join(alignment_dir, "uniprot_hits.sto")
# generate workflow for STEP5
uniprot_node = self.jackhmmer_uniprot_factory.gen_node(fasta_path, uniprot_out_path, output_format="sto")
# run workflow
batch_run(workflow_id=workflow_id, dags=[hmm_node, mgnify_node, bfd_node, uniprot_node])
return
\ No newline at end of file
from ast import Call
from typing import Callable, List
import ray
from ray.dag.function_node import FunctionNode
from ray import workflow
def batch_run(workflow_id: str, dags: List[FunctionNode]) -> None:
@ray.remote
def batch_dag_func(dags) -> None:
return
batch = batch_dag_func.bind(dags)
workflow.run(batch, workflow_id=workflow_id)
import time
import habana_frameworks.torch as ht
class hpu_perf:
def __init__(self, module, log=True, mark_step=True, memoryinfo=False, sync=False):
if log:
print(f" {module}: start")
self.module = module
self.stime = time.perf_counter()
self.mark = mark_step
self.mem = memoryinfo
self.sync = sync
self.log = log
if self.mem:
ht.hpu.reset_peak_memory_stats()
self.prelog = None
def checknow(self, log):
if self.mark:
ht.core.mark_step()
if self.sync:
ht.core.hpu.default_stream().synchronize()
if self.mem:
print(ht.hpu.memory_summary())
tmp = time.perf_counter()
if self.log:
print(" {}: {} takes {:.2f} ms".format(self.module, log, (tmp - self.stime)*1000))
self.stime = tmp
def checkahead(self, log):
if self.mark:
ht.core.mark_step()
if self.sync:
ht.core.hpu.default_stream().synchronize()
if self.mem:
print(ht.hpu.memory_summary())
tmp = time.perf_counter()
if self.prelog is not None and self.log:
print(" {}: {} takes {:.2f} ms".format(self.module, self.prelog, (tmp - self.stime)*1000))
self.stime = tmp
self.prelog = log
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import os
import pickle
import random
import shutil
import sys
import tempfile
import time
from datetime import date
import numpy as np
import torch
import torch.multiprocessing as mp
import habana_frameworks.torch.core as htcore
import fastfold.habana as habana
import fastfold.relax.relax as relax
from fastfold.common import protein, residue_constants
from fastfold.config import model_config
from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.parsers import parse_fasta
from fastfold.habana.distributed import init_dist
from fastfold.habana.fastnn.ops import set_chunk_size
from fastfold.habana.inject_habana import inject_habana
from fastfold.model.hub import AlphaFold
from fastfold.model.nn.triangular_multiplicative_update import \
set_fused_triangle_multiplication
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.workflow.template import (FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow)
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'--uniref90_database_path',
type=str,
default=None,
)
parser.add_argument(
'--mgnify_database_path',
type=str,
default=None,
)
parser.add_argument(
'--pdb70_database_path',
type=str,
default=None,
)
parser.add_argument(
'--uniclust30_database_path',
type=str,
default=None,
)
parser.add_argument(
'--bfd_database_path',
type=str,
default=None,
)
parser.add_argument(
"--pdb_seqres_database_path",
type=str,
default=None,
)
parser.add_argument(
"--uniprot_database_path",
type=str,
default=None,
)
parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer')
parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits')
parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch')
parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign')
parser.add_argument("--hmmsearch_binary_path", type=str, default="hmmsearch")
parser.add_argument("--hmmbuild_binary_path", type=str, default="hmmbuild")
parser.add_argument(
'--max_template_date',
type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument('--obsolete_pdbs_path', type=str, default=None)
parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--chunk_size', type=int, default=None)
parser.add_argument('--enable_workflow',
default=False,
action='store_true',
help='run inference with ray workflow or not')
parser.add_argument('--inplace', default=False, action='store_true')
def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
habana.enable_habana()
init_dist()
device = torch.device("hpu")
config = model_config(args.model_name)
if args.chunk_size:
config.globals.chunk_size = args.chunk_size
if "v3" in args.param_path:
set_fused_triangle_multiplication()
config.globals.inplace = False
config.globals.is_multimer = args.model_preset == 'multimer'
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_habana(model)
model = model.eval()
model = model.to(device=device)
set_chunk_size(model.globals.chunk_size)
with torch.no_grad():
batch = {k: torch.as_tensor(v).to(device=device) for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
htcore.mark_step()
print(f"Inference time: {time.perf_counter() - t}")
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
result_q.put(out)
torch.distributed.barrier()
def main(args):
if args.model_preset == "multimer":
inference_multimer_model(args)
else:
inference_monomer_model(args)
def inference_multimer_model(args):
print("running in multimer mode...")
config = model_config(args.model_name)
predict_max_templates = 4
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=predict_max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path,
)
if (not args.use_precomputed_alignments):
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_runner = FastFoldMultimerDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus)
else:
alignment_runner = data_pipeline.AlignmentRunnerMultimer(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus)
else:
alignment_runner = None
monomer_data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=monomer_data_processor,)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if (not args.use_precomputed_alignments):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
fasta_path = args.fasta_path
with open(fasta_path, "r") as fp:
data = fp.read()
lines = [l.replace('\n', '') for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
for tag, seq in zip(tags, seqs):
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
else:
shutil.rmtree(local_alignment_dir)
os.makedirs(local_alignment_dir)
chain_fasta_str = f'>chain_{tag}\n{seq}\n'
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
if args.enable_workflow:
print("Running alignment with ray workflow...")
t = time.perf_counter()
alignment_runner.run(chain_fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner.run(chain_fasta_path, local_alignment_dir)
print(f"Finished running alignment for {tag}")
local_alignment_dir = alignment_dir
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
is_multimer=True,
)
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model,
nprocs=args.hpus,
args=(args.hpus, result_q, batch, args))
out = result_q.get()
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=False,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
def inference_monomer_model(args):
print("running in monomer mode...")
config = model_config(args.model_name)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path)
use_small_bfd = args.preset == 'reduced_dbs' # (args.bfd_database_path is None)
if use_small_bfd:
assert args.bfd_database_path is not None
else:
assert args.bfd_database_path is not None
assert args.uniclust30_database_path is not None
data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if (args.use_precomputed_alignments is None):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
with open(args.fasta_path, "r") as fp:
fasta = fp.read()
seqs, tags = parse_fasta(fasta)
seq, tag = seqs[0], tags[0]
print(f"tag:{tag}\nseq[{len(seq)}]:{seq}")
batch = [None]
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_data_workflow_runner = FastFoldDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
)
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model,
nprocs=args.hpus,
args=(args.hpus, result_q, batch, args))
out = result_q.get()
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=False,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_path",
type=str,
)
parser.add_argument(
"template_mmcif_dir",
type=str,
)
parser.add_argument("--use_precomputed_alignments",
type=str,
default=None,
help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""")
parser.add_argument(
"--output_dir",
type=str,
default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument("--model_name",
type=str,
default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm or model_{1-5}_multimer, as defined on the AlphaFold GitHub.""")
parser.add_argument("--param_path",
type=str,
default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
./data/params""")
parser.add_argument("--cpus",
type=int,
default=12,
help="""Number of CPUs with which to run alignment tools""")
parser.add_argument("--hpus",
type=int,
default=1,
help="""Number of GPUs with which to run inference""")
parser.add_argument('--preset',
type=str,
default='full_dbs',
choices=('reduced_dbs', 'full_dbs'))
parser.add_argument('--data_random_seed', type=str, default=None)
parser.add_argument(
"--model_preset",
type=str,
default="monomer",
choices=["monomer", "multimer"],
help="Choose preset model configuration - the monomer model, the monomer model with "
"extra ensembling, monomer model with pTM head, or multimer model",
)
add_data_args(parser)
args = parser.parse_args()
if (args.param_path is None):
args.param_path = os.path.join("data", "params", "params_" + args.model_name + ".npz")
main(args)
export GC_KERNEL_PATH=./fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export PYTHONPATH=./:$PYTHONPATH
# add '--gpus [N]' to use N gpus for inference
# add '--enable_workflow' to use parallel workflow for data processing
# add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa
# add '--chunk_size [N]' to use chunk to reduce peak memory
# add '--inplace' to use inplace to save memory
python habana/inference.py target.fasta data/pdb_mmcif/mmcif_files \
--output_dir ./ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign`
import pickle
import time
import habana_frameworks.torch.core as htcore
import torch
import fastfold.habana as habana
from fastfold.config import model_config
from fastfold.habana.distributed import init_dist
from fastfold.habana.fastnn.ops import set_chunk_size
from fastfold.habana.inject_habana import inject_habana
from fastfold.model.hub import AlphaFold
def main():
habana.enable_habana()
init_dist()
batch = pickle.load(open('./test_batch.pkl', 'rb'))
model_name = "model_1"
device = torch.device("hpu")
config = model_config(model_name)
config.globals.inplace = False
config.globals.chunk_size = 512
# habana.enable_hmp()
model = AlphaFold(config)
model = inject_habana(model)
model = model.eval()
model = model.to(device=device)
if config.globals.chunk_size is not None:
set_chunk_size(model.globals.chunk_size + 1)
if habana.is_hmp():
from habana_frameworks.torch.hpex import hmp
hmp.convert(opt_level='O1',
bf16_file_path='./habana/ops_bf16.txt',
fp32_file_path='./habana/ops_fp32.txt',
isVerbose=False)
print("========= AMP ENABLED!!")
with torch.no_grad():
batch = {k: torch.as_tensor(v).to(device=device) for k, v in batch.items()}
for _ in range(5):
t = time.perf_counter()
out = model(batch)
htcore.mark_step()
htcore.hpu.default_stream().synchronize()
print(f"Inference time: {time.perf_counter() - t}")
if __name__ == '__main__':
main()
addmm
conv2d
max_pool2d
sum
relu
mm
bmm
mv
linear
t
mul
sub
add
truediv
layer_norm
cross_entropy
log_softmax
nll_loss
softmax
import argparse
import logging
import random
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
import fastfold.habana as habana
from fastfold.config import model_config
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
from fastfold.habana.distributed import init_dist, get_data_parallel_world_size
from fastfold.habana.inject_habana import inject_habana
from fastfold.model.hub import AlphaFold, AlphaFoldLoss, AlphaFoldLRScheduler
from fastfold.utils.tensor_utils import tensor_tree_map
import habana_frameworks.torch.core as htcore
from habana_frameworks.torch.hpex import hmp
logging.disable(logging.WARNING)
torch.multiprocessing.set_sharing_strategy('file_system')
from habana.hpuhelper import *
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--from_torch', default=False, action='store_true')
parser.add_argument("--template_mmcif_dir",
type=str,
help="Directory containing mmCIF files to search for templates")
parser.add_argument("--max_template_date",
type=str,
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target''')
parser.add_argument("--train_data_dir",
type=str,
help="Directory containing training mmCIF files")
parser.add_argument("--train_alignment_dir",
type=str,
help="Directory containing precomputed training alignments")
parser.add_argument(
"--train_chain_data_cache_path",
type=str,
default=None,
)
parser.add_argument("--distillation_data_dir",
type=str,
default=None,
help="Directory containing training PDB files")
parser.add_argument("--distillation_alignment_dir",
type=str,
default=None,
help="Directory containing precomputed distillation alignments")
parser.add_argument(
"--distillation_chain_data_cache_path",
type=str,
default=None,
)
parser.add_argument("--val_data_dir",
type=str,
default=None,
help="Directory containing validation mmCIF files")
parser.add_argument("--val_alignment_dir",
type=str,
default=None,
help="Directory containing precomputed validation alignments")
parser.add_argument("--kalign_binary_path",
type=str,
default='/usr/bin/kalign',
help="Path to the kalign binary")
parser.add_argument("--train_filter_path",
type=str,
default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set''')
parser.add_argument("--distillation_filter_path",
type=str,
default=None,
help="""See --train_filter_path""")
parser.add_argument("--obsolete_pdbs_file_path",
type=str,
default=None,
help="""Path to obsolete.dat file containing list of obsolete PDBs and
their replacements.""")
parser.add_argument("--template_release_dates_cache_path",
type=str,
default=None,
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files.""")
parser.add_argument("--train_epoch_len",
type=int,
default=10000,
help=("The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."))
parser.add_argument("--_alignment_index_path",
type=str,
default=None,
help="Training alignment index. See the README for instructions.")
parser.add_argument("--config_preset",
type=str,
default="initial_training",
help=('Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'))
parser.add_argument(
"--_distillation_structure_index_path",
type=str,
default=None,
)
parser.add_argument("--distillation_alignment_index_path",
type=str,
default=None,
help="Distillation alignment index. See the README for instructions.")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# habana arguments
parser.add_argument("--hmp",
action='store_true',
default=False,
help="Whether to use habana mixed precision")
parser.add_argument("--hmp-bf16",
type=str,
default="./habana/ops_bf16.txt",
help="Path to bf16 ops list in hmp O1 mode")
parser.add_argument("--hmp-fp32",
type=str,
default="./habana/ops_fp32.txt",
help="Path to fp32 ops list in hmp O1 mode")
parser.add_argument("--hmp-opt-level",
type=str,
default='O1',
help="Choose optimization level for hmp")
parser.add_argument("--hmp-verbose",
action='store_true',
default=False,
help='Enable verbose mode for hmp')
args = parser.parse_args()
habana.enable_habana()
init_dist()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
config = model_config(args.config_preset, train=True)
config.globals.inplace = False
model = AlphaFold(config)
model = inject_habana(model)
model = model.to(device="hpu")
if get_data_parallel_world_size() > 1:
model = DDP(model, gradient_as_bucket_view=True, bucket_cap_mb=400)
train_dataset, test_dataset = SetupTrainDataset(
config=config.data,
template_mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
train_data_dir=args.train_data_dir,
train_alignment_dir=args.train_alignment_dir,
train_chain_data_cache_path=args.train_chain_data_cache_path,
distillation_data_dir=args.distillation_data_dir,
distillation_alignment_dir=args.distillation_alignment_dir,
distillation_chain_data_cache_path=args.distillation_chain_data_cache_path,
val_data_dir=args.val_data_dir,
val_alignment_dir=args.val_alignment_dir,
kalign_binary_path=args.kalign_binary_path,
# train_mapping_path=args.train_mapping_path,
# distillation_mapping_path=args.distillation_mapping_path,
obsolete_pdbs_file_path=args.obsolete_pdbs_file_path,
template_release_dates_cache_path=args.template_release_dates_cache_path,
train_epoch_len=args.train_epoch_len,
_alignment_index_path=args._alignment_index_path,
)
train_dataloader, test_dataloader = TrainDataLoader(
config=config.data,
train_dataset=train_dataset,
test_dataset=test_dataset,
batch_seed=args.seed,
)
criterion = AlphaFoldLoss(config.loss)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8)
from habana_frameworks.torch.hpex.optimizers import FusedAdamW
optimizer = FusedAdamW(model.parameters(), lr=1e-3, eps=1e-8)
lr_scheduler = AlphaFoldLRScheduler(optimizer)
if args.hmp:
hmp.convert(opt_level='O1',
bf16_file_path=args.hmp_bf16,
fp32_file_path=args.hmp_fp32,
isVerbose=args.hmp_verbose)
print("========= HMP ENABLED!!")
idx = 0
for epoch in range(200):
model.train()
train_dataloader = tqdm(train_dataloader)
for batch in train_dataloader:
perf = hpu_perf("train step")
batch = {k: torch.as_tensor(v).to(device="hpu", non_blocking=True) for k, v in batch.items()}
optimizer.zero_grad()
perf.checknow("prepare input and zero grad")
output = model(batch)
perf.checknow("forward")
batch = tensor_tree_map(lambda t: t[..., -1], batch)
perf.checknow("prepare loss input")
loss, loss_breakdown = criterion(output, batch, _return_breakdown=True)
perf.checknow("loss")
loss.backward()
if idx % 10 == 0:
train_dataloader.set_postfix(loss=float(loss))
perf.checknow("backward")
with hmp.disable_casts():
optimizer.step()
perf.checknow("optimizer")
idx += 1
lr_scheduler.step()
if test_dataloader is not None:
model.eval()
train_dataloader = tqdm(train_dataloader)
for batch in test_dataloader:
batch = {k: torch.as_tensor(v).to(device="hpu") for k, v in batch.items()}
with torch.no_grad():
output = model(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
_, loss_breakdown = criterion(output, batch, _return_breakdown=True)
htcore.mark_step()
train_dataloader.set_postfix(loss=float(loss))
if __name__ == "__main__":
main()
export GC_KERNEL_PATH=./fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export PYTHONPATH=./:$PYTHONPATH
DATA_DIR=../FastFold-dataset/train
hpus_per_node=1
max_template_date=2021-10-10
train_data_dir=${DATA_DIR}/mmcif_dir # specify the dir contains *.cif or *.pdb
train_alignment_dir=${DATA_DIR}/alignment_dir # a dir to save template and features.pkl of training sequence
mkdir -p ${train_alignment_dir}
# val_data_dir=${PROJECT_DIR}/dataset/val_pdb
# val_alignment_dir=${PROJECT_DIR}/dataset/alignment_val_pdb # a dir to save template and features.pkl of vld sequence
template_mmcif_dir=${DATA_DIR}/data/pdb_mmcif/mmcif_files
template_release_dates_cache_path=${DATA_DIR}/mmcif_cache.json # a cache used to pre-filter templates
train_chain_data_cache_path=${DATA_DIR}/chain_data_cache.json # a separate chain-level cache with data used for training-time data filtering
train_epoch_len=10000 # virtual length of each training epoch, which affects frequency of validation & checkpointing
mpirun --allow-run-as-root --bind-to none -np ${hpus_per_node} python habana/train.py \
--from_torch \
--template_mmcif_dir=${template_mmcif_dir} \
--max_template_date=${max_template_date} \
--train_data_dir=${train_data_dir} \
--train_alignment_dir=${train_alignment_dir} \
--train_chain_data_cache_path=${train_chain_data_cache_path} \
--template_release_dates_cache_path=${template_release_dates_cache_path} \
--train_epoch_len=${train_epoch_len} \
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import random
import sys
import time
from datetime import date
import tempfile
import contextlib
import numpy as np
import torch
import torch.multiprocessing as mp
import pickle
import shutil
from fastfold.model.hub import AlphaFold
import fastfold
import fastfold.relax.relax as relax
from fastfold.common import protein, residue_constants
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.model.nn.triangular_multiplicative_update import set_fused_triangle_multiplication
from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.tools import hhsearch, hmmsearch
from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11:
torch.backends.cuda.matmul.allow_tf32 = True
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'--uniref90_database_path',
type=str,
default=None,
)
parser.add_argument(
'--mgnify_database_path',
type=str,
default=None,
)
parser.add_argument(
'--pdb70_database_path',
type=str,
default=None,
)
parser.add_argument(
'--uniref30_database_path',
type=str,
default=None,
)
parser.add_argument(
'--bfd_database_path',
type=str,
default=None,
)
parser.add_argument(
"--pdb_seqres_database_path",
type=str,
default=None,
)
parser.add_argument(
"--uniprot_database_path",
type=str,
default=None,
)
parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer')
parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits')
parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch')
parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign')
parser.add_argument("--hmmsearch_binary_path", type=str, default="hmmsearch")
parser.add_argument("--hmmbuild_binary_path", type=str, default="hmmbuild")
parser.add_argument(
'--max_template_date',
type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument('--obsolete_pdbs_path', type=str, default=None)
parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--chunk_size', type=int, default=None)
parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not')
parser.add_argument('--inplace', default=False, action='store_true')
def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
torch.cuda.set_device(rank)
config = model_config(args.model_name)
if args.chunk_size:
config.globals.chunk_size = args.chunk_size
if "v3" in args.param_path:
set_fused_triangle_multiplication()
config.globals.inplace = args.inplace
config.globals.is_multimer = args.model_preset == 'multimer'
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_fastnn(model)
model = model.eval()
model = model.cuda()
set_chunk_size(model.globals.chunk_size)
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
result_q.put(out)
torch.distributed.barrier()
torch.cuda.synchronize()
def main(args):
if args.model_preset == "multimer":
inference_multimer_model(args)
else:
inference_monomer_model(args)
def inference_multimer_model(args):
print("running in multimer mode...")
config = model_config(args.model_name)
predict_max_templates = 4
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=predict_max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path,
)
if(not args.use_precomputed_alignments):
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_runner = FastFoldMultimerDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus
)
else:
alignment_runner = data_pipeline.AlignmentRunnerMultimer(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus
)
else:
alignment_runner = None
monomer_data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=monomer_data_processor,
)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
# seed_torch(seed=1029)
feature_processor = feature_pipeline.FeaturePipeline(
config.data
)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if(not args.use_precomputed_alignments):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
fasta_path = args.fasta_path
with open(fasta_path, "r") as fp:
data = fp.read()
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
output_prefix = "_and_".join(tags)
for tag, seq in zip(tags, seqs):
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
else:
shutil.rmtree(local_alignment_dir)
os.makedirs(local_alignment_dir)
chain_fasta_str = f'>chain_{tag}\n{seq}\n'
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
if args.enable_workflow:
print("Running alignment with ray workflow...")
t = time.perf_counter()
alignment_runner.run(chain_fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner.run(chain_fasta_path, local_alignment_dir)
print(f"Finished running alignment for {tag}")
local_alignment_dir = alignment_dir
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=True,
)
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
out = result_q.get()
if args.save_prediction_result:
# Save the prediction result .pkl
prediction_result_path = os.path.join(args.output_dir,
f'{output_prefix}_{args.model_name}.pkl')
with open(prediction_result_path, 'wb') as f:
pickle.dump(out, f)
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{output_prefix}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
if(args.relaxation):
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{output_prefix}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
def inference_monomer_model(args):
print("running in monomer mode...")
config = model_config(args.model_name)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
use_small_bfd = args.preset == 'reduced_dbs' # (args.bfd_database_path is None)
if use_small_bfd:
assert args.bfd_database_path is not None
else:
assert args.bfd_database_path is not None
assert args.uniref30_database_path is not None
data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
# seed_torch(seed=1029)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if (args.use_precomputed_alignments is None):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
with open(args.fasta_path, "r") as fp:
fasta = fp.read()
seqs, tags = parse_fasta(fasta)
seq, tag = seqs[0], tags[0]
print(f"tag:{tag}\nseq[{len(seq)}]:{seq}")
batch = [None]
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_data_workflow_runner = FastFoldDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
)
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
out = result_q.get()
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
if(args.relaxation):
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if(args.save_prediction_result):
# Save the prediction result .pkl
prediction_result_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}.pkl'
)
with open(prediction_result_path, "wb") as fp:
pickle.dump(out, fp)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_path",
type=str,
)
parser.add_argument(
"template_mmcif_dir",
type=str,
)
parser.add_argument("--use_precomputed_alignments",
type=str,
default=None,
help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""")
parser.add_argument(
"--output_dir",
type=str,
default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument("--model_name",
type=str,
default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm or model_{1-5}_multimer, as defined on the AlphaFold GitHub.""")
parser.add_argument("--param_path",
type=str,
default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
./data/params""")
parser.add_argument(
"--relaxation", action="store_false", default=False,
)
parser.add_argument("--cpus",
type=int,
default=12,
help="""Number of CPUs with which to run alignment tools""")
parser.add_argument("--gpus",
type=int,
default=1,
help="""Number of GPUs with which to run inference""")
parser.add_argument('--preset',
type=str,
default='full_dbs',
choices=('reduced_dbs', 'full_dbs'))
parser.add_argument('--save_prediction_result',
type=bool,
default=True)
parser.add_argument('--data_random_seed', type=str, default=None)
parser.add_argument(
"--model_preset",
type=str,
default="monomer",
choices=["monomer", "multimer"],
help="Choose preset model configuration - the monomer model, the monomer model with "
"extra ensembling, monomer model with pTM head, or multimer model",
)
add_data_args(parser)
args = parser.parse_args()
if (args.param_path is None):
args.param_path = os.path.join("data", "params", "params_" + args.model_name + ".npz")
main(args)
# add '--gpus [N]' to use N gpus for inference
# add '--enable_workflow' to use parallel workflow for data processing
# add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa
# add '--chunk_size [N]' to use chunk to reduce peak memory
# add '--inplace' to use inplace to save memory
python inference.py target.fasta data/pdb_mmcif/mmcif_files \
--output_dir ./outputs \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign` \
--enable_workflow \
--inplace
# add '--gpus [N]' to use N gpus for inference
# add '--enable_workflow' to use parallel workflow for data processing
# add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa
# add '--chunk_size [N]' to use chunk to reduce peak memory
# add '--inplace' to use inplace to save memory
python inference.py target.fasta data/pdb_mmcif/mmcif_files \
--output_dir ./ \
--gpus 1 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb70_database_path data/pdb70/pdb70 \
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
--uniprot_database_path data/uniprot/uniprot.fasta \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign` \
--model_preset multimer \
--param_path data/params/params_model_1_multimer_v3.npz \
--model_name model_1_multimer \
biopython==1.79
dm-tree==0.1.6
ml-collections==0.1.0
scipy==1.7.1
pandas
pytest
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips all required data for AlphaFold.
#
# Usage: bash download_all_data.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs.
if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]]
then
echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized."
exit 1
fi
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
echo "Downloading AlphaFold parameters..."
bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}"
if [[ "${DOWNLOAD_MODE}" = reduced_dbs ]] ; then
echo "Downloading Small BFD..."
bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}"
else
echo "Downloading BFD..."
bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}"
fi
echo "Downloading MGnify..."
bash "${SCRIPT_DIR}/download_mgnify.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB70..."
bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref30..."
bash "${SCRIPT_DIR}/download_uniref30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
# UniProt and PDB SeqRes for multimer version
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment