"...resnet50_tensorflow.git" did not exist on "ccf7da9deaf4fd37d7578d539595c543800d26f8"
Unverified Commit d3df8e69 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

add multimer inference (#59)

* add multimer inference

* add dada pipeline
parent 444c548a
......@@ -19,6 +19,7 @@ import contextlib
import dataclasses
import datetime
import json
import copy
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
......@@ -56,7 +57,7 @@ def empty_template_feats(n_res) -> FeatureDict:
def make_template_features(
input_sequence: str,
hits: Sequence[Any],
template_featurizer: Any,
template_featurizer: Union[hhsearch.HHSearch, hmmsearch.Hmmsearch],
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict:
......@@ -64,12 +65,18 @@ def make_template_features(
if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence))
else:
if type(template_featurizer) == hhsearch.HHSearch:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=query_pdb_code,
query_release_date=query_release_date,
hits=hits_cat,
)
else:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
hits=hits_cat,
)
template_features = templates_result.features
# The template featurizer doesn't format empty template features
......@@ -242,7 +249,7 @@ def run_msa_tool(
if(msa_format == "sto" and max_sto_sequences is not None):
result = msa_runner.query(fasta_path, max_sto_sequences)[0]
else:
result = msa_runner.query(fasta_path)[0]
result = msa_runner.query(fasta_path)
with open(msa_out_path, "w") as f:
f.write(result[msa_format])
......@@ -262,7 +269,6 @@ class AlignmentRunner:
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
......@@ -447,6 +453,225 @@ class AlignmentRunner:
f.write(hhblits_bfd_uniclust_result["a3m"])
class AlignmentRunnerMultimer(AlignmentRunner):
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
# super().__init__()
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
uniprot_database_path,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.uniprot_max_hits = uniprot_max_hits
self.use_small_bfd = use_small_bfd
if(no_cpus is None):
no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and
uniref90_database_path is not None
):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniclust_runner = None
if(bfd_database_path is not None):
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=bfd_database_path,
n_cpu=no_cpus,
)
else:
dbs = [bfd_database_path]
if(uniclust30_database_path is not None):
dbs.append(uniclust30_database_path)
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=dbs,
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = None
if(mgnify_database_path is not None):
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_uniprot_runner = None
if(uniprot_database_path is not None):
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path
)
if(template_searcher is not None and
self.jackhmmer_uniref90_runner is None
):
raise ValueError(
"Uniref90 runner must be specified to run template search"
)
self.template_searcher = template_searcher
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.sto")
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
fasta_path=fasta_path,
msa_out_path=uniref90_out_path,
msa_format='sto',
max_sto_sequences=self.uniref_max_hits,
)
template_msa = jackhmmer_uniref90_result["sto"]
template_msa = parsers.deduplicate_stockholm_msa(template_msa)
template_msa = parsers.remove_empty_columns_from_stockholm_msa(
template_msa
)
if(self.template_searcher is not None):
if(self.template_searcher.input_format == "sto"):
pdb_templates_result = self.template_searcher.query(
template_msa,
output_dir=output_dir
)
elif(self.template_searcher.input_format == "a3m"):
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
template_msa
)
pdb_templates_result = self.template_searcher.query(
uniref90_msa_as_a3m,
output_dir=output_dir
)
else:
fmt = self.template_searcher.input_format
raise ValueError(
f"Unrecognized template input format: {fmt}"
)
if(self.jackhmmer_mgnify_runner is not None):
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto")
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
fasta_path=fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
max_sto_sequences=self.mgnify_max_hits
)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="sto",
)
elif(self.hhblits_bfd_uniclust_runner is not None):
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="a3m",
)
if(self.jackhmmer_uniprot_runner is not None):
uniprot_out_path = os.path.join(output_dir, 'uniprot_hits.sto')
result = run_msa_tool(
self.jackhmmer_uniprot_runner,
fasta_path=fasta_path,
msa_out_path=uniprot_out_path,
msa_format='sto',
max_sto_sequences=self.uniprot_max_hits,
)
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
......@@ -722,7 +947,12 @@ class DataPipeline:
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index,
)
template_features = make_template_features(
input_sequence,
hits,
......@@ -893,8 +1123,8 @@ class DataPipelineMultimer:
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
msa, deletion_matrix, _ = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features(msa, deletion_matrix)
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
)
......
......@@ -188,9 +188,9 @@ def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
hit_pdb_code: str,
query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime,
query_pdb_code: Optional[str] = None,
max_subsequence_ratio: float = 0.95,
min_align_ratio: float = 0.1,
) -> bool:
......@@ -752,12 +752,12 @@ class SingleHitResult:
def _prefilter_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str],
strict_error_check: bool = False,
query_pdb_code: Optional[str] = None,
):
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
......@@ -794,7 +794,6 @@ def _prefilter_hit(
def _process_single_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
......@@ -803,6 +802,7 @@ def _process_single_hit(
kalign_binary_path: str,
strict_error_check: bool = False,
_zero_center_positions: bool = True,
query_pdb_code: Optional[str] = None,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
......@@ -996,9 +996,9 @@ class TemplateHitFeaturizer:
def get_templates(
self,
query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit],
query_pdb_code: Optional[str] = None,
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Searching for template for: %s", query_pdb_code)
......@@ -1155,7 +1155,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
idx[:stk] = np.random.permutation(idx[:stk])
for i in idx:
if(len(already_seen) >= self._max_hits):
if(len(already_seen) >= self.max_hits):
break
hit = filtered[i]
......
......@@ -23,6 +23,7 @@ import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request
from fastfold.data import parsers
from fastfold.data.tools import utils
......@@ -93,7 +94,10 @@ class Jackhmmer:
self.streaming_callback = streaming_callback
def _query_chunk(
self, input_fasta_path: str, database_path: str
self,
input_fasta_path: str,
database_path: str,
max_sequences: Optional[int] = None
) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
......@@ -167,8 +171,11 @@ class Jackhmmer:
with open(tblout_path) as f:
tbl = f.read()
if(max_sequences is None):
with open(sto_path) as f:
sto = f.read()
else:
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
raw_output = dict(
sto=sto,
......@@ -180,10 +187,16 @@ class Jackhmmer:
return raw_output
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None
) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)]
single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences,
)
return [single_chunk_result]
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
......@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i))
self._query_chunk(
input_fasta_path,
db_local_chunk(i),
max_sequences
)
)
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
future = next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if(i < self.num_streamed_chunks):
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
......@@ -19,6 +19,8 @@ import random
import sys
import time
from datetime import date
import tempfile
import contextlib
import numpy as np
import torch
......@@ -39,6 +41,12 @@ from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
......@@ -66,10 +74,22 @@ def add_data_args(parser: argparse.ArgumentParser):
type=str,
default=None,
)
parser.add_argument(
"--pdb_seqres_database_path",
type=str,
default=None,
)
parser.add_argument(
"--uniprot_database_path",
type=str,
default=None,
)
parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer')
parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits')
parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch')
parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign')
parser.add_argument("--hmmsearch_binary_path", type=str, default="hmmsearch")
parser.add_argument("--hmmbuild_binary_path", type=str, default="hmmbuild")
parser.add_argument(
'--max_template_date',
type=str,
......@@ -79,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not')
def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
......@@ -120,7 +141,7 @@ def main(args):
def inference_multimer_model(args):
print("running in multimer mode...")
config = model_config(args.model_name)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
predict_max_templates = 4
......@@ -143,6 +164,81 @@ def inference_multimer_model(args):
obsolete_pdbs_path=args.obsolete_pdbs_path,
)
if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunnerMultimer(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
else:
alignment_runner = None
monomer_data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=monomer_data_processor,
)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(
config.data
)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if(not args.use_precomputed_alignments):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
fasta_path = args.fasta_path
with open(fasta_path, "r") as fp:
data = fp.read()
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
for tag, seq in zip(tags, seqs):
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
chain_fasta_str = f'>chain_{tag}\n{seq}\n'
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
alignment_runner.run(
chain_fasta_path, local_alignment_dir
)
print(f"Finished running alignment for {tag}")
local_alignment_dir = alignment_dir
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=True,
)
def inference_monomer_model(args):
print("running in monomer mode...")
......@@ -282,6 +378,7 @@ def inference_monomer_model(args):
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment