Commit 100485dd authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Continue multimer implementation

parent e1b69c13
......@@ -16,17 +16,24 @@
import os
import datetime
from multiprocessing import cpu_count
from typing import Mapping, Optional, Sequence, Any
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np
from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data import (
templates,
parsers,
mmcif_parsing,
msa_identifiers,
)
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray]
FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def empty_template_feats(n_res) -> FeatureDict:
return {
......@@ -180,32 +187,39 @@ def make_pdb_features(
return pdb_feats
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError("At least one MSA must be provided.")
int_msa = []
deletion_matrix = []
uniprot_accession_ids = []
species_ids = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa):
for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
)
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0])
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index]
)
uniprot_accession_ids.append(
identifiers.uniprot_accession_id.encode('utf-8')
)
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa)
features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
......@@ -213,8 +227,45 @@ def make_msa_features(
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
features["msa_uniprot_accession_identifiers"] = np.array(
uniprot_accession_ids, dtype=np.object_
)
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
return features
def run_msa_tool(
msa_runner,
input_fasta_path: str,
msa_out_path: str,
msa_format: str,
use_precomputed_msas: bool,
max_sto_sequences: Optional[int] = None,
) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path):
if(msa_format == "sto" and max_sto_sequences is not None):
result = msa_runner.query(input_fasta_path, max_sto_sequences)[0]
else:
result = msa_runner.query(input_fasta_path)[0]
result_a3m = parsers.convert_stockholm_to_a3m(result["sto"])
with open(msa_out_path, "w") as f:
f.write(result_a3m)
else:
logging.warning("Reading MSA from file %s", msa_out_path)
if(msa_format == "sto" and max_sto_sequences is not None):
precomputed_msa = parsers.truncate_stockholm_msa(
msa_out_path,
max_sto_sequences,
)
result = {"sto": precomputed_msa}
else:
with open(msa_out_path, "r") as f:
result = {msa_format: f.read()}
return result
class AlignmentRunner:
"""Runs alignment tools and saves the results"""
......@@ -222,12 +273,11 @@ class AlignmentRunner:
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hhsearch_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
......@@ -239,8 +289,6 @@ class AlignmentRunner:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
hhsearch_binary_path:
Path to hhsearch binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
......@@ -254,8 +302,6 @@ class AlignmentRunner:
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
pdb70_database_path:
Path to pdb70 database.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
......@@ -282,12 +328,6 @@ class AlignmentRunner:
bfd_database_path if not use_small_bfd else None,
],
},
"hhsearch": {
"binary": hhsearch_binary_path,
"dbs": [
pdb70_database_path,
],
},
}
for name, dic in db_map.items():
......@@ -297,13 +337,6 @@ class AlignmentRunner:
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.use_small_bfd = use_small_bfd
......@@ -348,14 +381,15 @@ class AlignmentRunner:
n_cpu=no_cpus,
)
self.hhsearch_pdb70_runner = None
if(pdb70_database_path is not None):
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path],
n_cpu=no_cpus,
if(template_searcher is not None and
self.jackhmmer_uniref90_runner is None
):
raise ValueError(
"Uniref90 runner must be specified to run template search"
)
self.template_searcher = template_searcher
def run(
self,
fasta_path: str,
......@@ -363,52 +397,64 @@ class AlignmentRunner:
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result["sto"],
max_sequences=self.uniref_max_hits
)
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m)
if(self.hhsearch_pdb70_runner is not None):
hhsearch_result = self.hhsearch_pdb70_runner.query(
uniref90_msa_as_a3m
)
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result)
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
input_fasta_path=fasta_path,
msa_out_path=uniref90_out_path,
msa_format='sto',
max_sto_sequences=self.uniref_max_hits,
)
if(self.jackhmmer_mgnify_runner is not None):
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result["sto"],
max_sequences=self.mgnify_max_hits
template_msa = jackhmmer_uniref90_result["sto"]
template_msa = parsers.deduplicate_stockholm_msa(template_msa)
template_msa = parsers.remove_empty_columns_from_stockholm_msa(
template_msa
)
if(self.template_searcher is not None):
if(self.template_searcher.input_format == "sto"):
pdb_templates_result = self.template_searcher.query(template_msa)
elif(self.template_searcher.input_format == "a3m"):
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
template_msa
)
pdb_templates_result = self.template_searcher.query(
uniref90_msa_as_a3m
)
else:
fmt = self.template_searcher.input_format
raise ValueError(
f"Unrecognized template input format: {fmt}"
)
if(self.jackhmmer_mgnify_runner is not None):
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m)
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
input_fasta_path=fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
max_sto_sequences=self.mgnify_max_hits
)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
fasta_path
)[0]
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
input_fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="sto",
)
elif(self.hhblits_bfd_uniclust_runner is not None):
hhblits_bfd_uniclust_result = (
self.hhblits_bfd_uniclust_runner.query(fasta_path)
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
input_fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="a3m",
)
if output_dir is not None:
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment