Commit 57f869d6 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Continue work on AlphaFold-Multimer

parent 100485dd
......@@ -74,6 +74,8 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif "multimer" in name:
c.model.update(multimer_model_config_update)
else:
raise ValueError("Invalid model name")
......@@ -493,3 +495,11 @@ config = mlc.ConfigDict(
"ema": {"decay": 0.999},
}
)
multimer_model_config_update = mlc.ConfigDict(
"relative_encoding": {
"enabled": True,
"max_relative_chain": 2,
"max_relative_idx": 32,
}
)
......@@ -25,6 +25,7 @@ from openfold.data import (
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
)
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
......@@ -277,11 +278,13 @@ class AlignmentRunner:
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
):
"""
Args:
......@@ -320,6 +323,7 @@ class AlignmentRunner:
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
uniprot_database_path,
],
},
"hhblits": {
......@@ -339,6 +343,7 @@ class AlignmentRunner:
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.uniprot_max_hits = uniprot_max_hits
self.use_small_bfd = use_small_bfd
if(no_cpus is None):
......@@ -381,6 +386,13 @@ class AlignmentRunner:
n_cpu=no_cpus,
)
self._uniprot_msa_runner = None
if(uniprot_database_path is not None):
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path
)
if(template_searcher is not None and
self.jackhmmer_uniref90_runner is None
):
......@@ -456,6 +468,148 @@ class AlignmentRunner:
msa_format="a3m",
)
if(self.jackhmmer_uniprot_runner is not None):
uniprot_out_path = os.path.join(output_dir, 'uniprot_hits.sto')
result = run_msa_tool(
self.jackhmmer_uniprot_runner,
input_fasta_path=input_fasta_path,
msa_out_path=uniprot_out_path,
msa_format='sto',
max_sto_sequences=self.uniprot_max_hits,
)
@dataclasses.dataclass(frozen=True)
class _FastaChain:
sequence: str
description: str
def _make_chain_id_map(*,
sequences: Sequence[str],
descriptions: Sequence[str],
) -> Mapping[str, _FastaChain]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if len(sequences) != len(descriptions):
raise ValueError('sequences and descriptions must have equal length. '
f'Got {len(sequences)} != {len(descriptions)}.')
if len(sequences) > protein.PDB_MAX_CHAINS:
raise ValueError('Cannot process more chains than the PDB format supports. '
f'Got {len(sequences)} chains.')
chain_id_map = {}
for chain_id, sequence, description in zip(
protein.PDB_CHAIN_IDS, sequences, descriptions):
chain_id_map[chain_id] = _FastaChain(
sequence=sequence, description=description)
return chain_id_map
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def convert_monomer_features(
monomer_features: FeatureDict,
chain_id: str
) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == 'aatype':
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == 'template_aatype':
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
elif feature_name == 'template_all_atom_masks':
feature_name = 'template_all_atom_mask'
converted[feature_name] = feature
return converted
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)
def add_assembly_features(
all_chain_features: MutableMapping[str, FeatureDict],
) -> MutableMapping[str, FeatureDict]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_id, chain_features in all_chain_features.items():
seq = str(chain_features['sequence'])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = {}
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
for sym_id, chain_features in enumerate(group_chain_features, start=1):
new_all_chain_features[
f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
seq_length = chain_features['seq_length']
chain_features['asym_id'] = chain_id * np.ones(seq_length)
chain_features['sym_id'] = sym_id * np.ones(seq_length)
chain_features['entity_id'] = entity_id * np.ones(seq_length)
chain_id += 1
return new_all_chain_features
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
return np_example
class DataPipeline:
"""Assembles input features."""
......@@ -579,10 +733,9 @@ class DataPipeline:
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
msa_objects = [Msa(m, d) for m, d in zip(msas, deletion_matrices)]
msa_features = make_msa_features(msa_objects)
return msa_features
......@@ -722,3 +875,126 @@ class DataPipeline:
return {**core_feats, **template_features, **msa_features}
class DataPipelineMultimer:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
monomer_data_pipeline: DataPipeline,
jackhmmer_binary_path: str,
uniprot_database_path: str,
max_uniprot_hits: int = 50000,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self._monomer_data_pipeline = monomer_data_pipeline
def _process_single_chain(
self,
chain_id: str,
sequence: str,
description: str,
msa_output_dir: str,
is_homomer_or_monomer: bool
) -> FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n'
chain_msa_output_dir = os.path.join(msa_output_dir, chain_id)
if not os.path.exists(chain_msa_output_dir):
raise ValueError(f"Alignments for {chain_id} not found...")
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
chain_features = self._monomer_data_pipeline.process_fasta(
input_fasta_path=chain_fasta_path,
alignment_dir=chain_msa_output_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,
chain_msa_output_dir)
chain_features.update(all_seq_msa_features)
return chain_features
def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path = os.path.join(msa_output_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_uniprot_accession_identifiers',
'msa_species_identifiers',
)
feats = {
f'{k}_all_seq': v for k, v in all_seq_features.items()
if k in valid_feats
}
return feats
def process(self,
input_fasta_path: str,
msa_output_dir: str,
is_prokaryote: bool = False
) -> FeatureDict:
"""Runs alignment tools on the input sequences and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
chain_id_map = _make_chain_id_map(
sequences=input_seqs,
descriptions=input_descs
)
chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json')
with open(chain_id_map_path, 'w') as f:
chain_id_map_dict = {
chain_id: dataclasses.asdict(fasta_chain)
for chain_id, fasta_chain in chain_id_map.items()
}
json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for chain_id, fasta_chain in chain_id_map.items():
if fasta_chain.sequence in sequence_features:
all_chain_features[chain_id] = copy.deepcopy(
sequence_features[fasta_chain.sequence])
continue
chain_features = self._process_single_chain(
chain_id=chain_id,
sequence=fasta_chain.sequence,
description=fasta_chain.description,
msa_output_dir=msa_output_dir,
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=chain_id
)
all_chain_features[chain_id] = chain_features
sequence_features[fasta_chain.sequence] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing.pair_and_merge(
all_chain_features=all_chain_features,
is_prokaryote=is_prokaryote,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
......@@ -476,6 +476,20 @@ def get_atom_coords(
pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd = residue_constants.atom_order['CD']
nh1 = residue_constants.atom_order['NH1']
nh2 = residue_constants.atom_order['NH2']
if(
res.get_resname() == 'ARG' and
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
(np.linalg.norm(pos[nh1] - pos[cd]) >
np.linalg.norm(pos[nh2] - pos[cd]))
):
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
......
......@@ -14,8 +14,10 @@
# limitations under the License.
"""Functions for getting templates and calculating template features."""
import abc
import dataclasses
import datetime
import functools
import glob
import json
import logging
......@@ -65,10 +67,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
class PdbIdError(PrefilterError):
"""An error indicating that the hit PDB ID was identical to the query."""
class AlignRatioError(PrefilterError):
"""An error indicating that the hit align ratio to the query was too small."""
......@@ -188,7 +186,6 @@ def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
hit_pdb_code: str,
query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95,
......@@ -202,7 +199,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might
have become obsolete.
query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release
dates.
release_date_cutoff: Max release date that is valid for this query.
......@@ -214,7 +210,6 @@ def _assess_hhsearch_hit(
Raises:
DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
......@@ -239,10 +234,6 @@ def _assess_hhsearch_hit(
f"({release_date_cutoff})."
)
if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError("PDB code identical to Query PDB code.")
if align_ratio <= min_align_ratio:
raise AlignRatioError(
"Proportion of residues aligned to query too small. "
......@@ -408,9 +399,10 @@ def _realign_pdb_template_to_query(
)
try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
parsed_a3m = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence])
)
old_aligned_template, new_aligned_template = parsed_a3m.sequences
except Exception as e:
raise QueryToTemplateAlignError(
"Could not align old template %s to template %s (%s_%s). Error: %s"
......@@ -752,7 +744,6 @@ class SingleHitResult:
def _prefilter_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime],
......@@ -773,7 +764,6 @@ def _prefilter_hit(
hit=hit,
hit_pdb_code=hit_pdb_code,
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
release_dates=release_dates,
release_date_cutoff=max_template_date,
)
......@@ -781,9 +771,7 @@ def _prefilter_hit(
hit_name = f"{hit_pdb_code}_{hit_chain_id}"
msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg)
if strict_error_check and isinstance(
e, (DateError, PdbIdError, DuplicateError)
):
if strict_error_check and isinstance(e, (DateError, DuplicateError)):
# In strict mode we treat some prefilter cases as errors.
return PrefilterResult(valid=False, error=msg, warning=None)
......@@ -792,9 +780,16 @@ def _prefilter_hit(
return PrefilterResult(valid=True, error=None, warning=None)
@functools.lru_cache(16, typed=False)
def _read_file(path):
with open(path, 'r') as f:
file_data = f.read()
return file_data
def _process_single_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
......@@ -832,8 +827,7 @@ def _process_single_hit(
template_sequence,
)
# Fail if we can't find the mmCIF file.
with open(cif_path, "r") as cif_file:
cif_string = cif_file.read()
cif_string = _read_file(cif_path)
parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string
......@@ -866,6 +860,10 @@ def _process_single_hit(
kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions,
)
if hit.sum_probs is None:
features["template_sum_probs"] = [0]
else:
features["template_sum_probs"] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the
......@@ -920,8 +918,8 @@ class TemplateSearchResult:
warnings: Sequence[str]
class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features."""
class TemplateHitFeaturizer(abc.ABC):
"""An abstract base class for turning template hits to features."""
def __init__(
self,
mmcif_dir: str,
......@@ -993,10 +991,18 @@ class TemplateHitFeaturizer:
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions
@abc.abstractmethod
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]
) -> TemplateSearchResult:
class HhsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates(
self,
query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
......@@ -1025,7 +1031,6 @@ class TemplateHitFeaturizer:
for hit in hits:
prefilter_result = _prefilter_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
max_template_date=template_cutoff_date,
release_dates=self._release_dates,
......@@ -1105,3 +1110,88 @@ class TemplateHitFeaturizer:
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings
)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]
) -> TemplateSearchResult:
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
already_seen = set()
errors = []
warnings = []
if not hits or hits[0].sum_probs is None:
sorted_hits = hits
else:
sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)
for hit in sorted_hits:
if(len(already_seen) >= self._max_hits):
break
result = _process_single_hit(
query_sequence=query_sequence,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date = self._max_template_date,
release_dates = self._release_dates,
obsolete_pdbs = self._obsolete_pdbs,
strict_error_check = self._strict_error_check,
kalign_binary_path = self._kalign_binary_path
)
if result.error:
errors.append(result.error)
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.debug(
"Skipped invalid hit %s, error: %s, warning: %s",
hit.name, result.error, result.warning,
)
else:
already_seen_key = result.features["template_sequence"]
if(already_seen_key in already_seen):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen.add(already_seen_key)
for k in template_features:
template_features[k].append(result.features[k])
if already_seen:
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = {
"template_aatype": np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_masks": np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.array([0], dtype=np.float32),
}
return TemplateSearchResult(
features=template_features,
errors=errors,
warnings=warnings,
)
......@@ -18,7 +18,7 @@ import glob
import logging
import os
import subprocess
from typing import Any, Mapping, Optional, Sequence
from typing import Any, List, Mapping, Optional, Sequence
from openfold.data.tools import utils
......@@ -99,9 +99,9 @@ class HHBlits:
self.p = p
self.z = z
def query(self, input_fasta_path: str) -> Mapping[str, Any]:
def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, "output.a3m")
db_cmd = []
......@@ -172,4 +172,4 @@ class HHBlits:
n_iter=self.n_iter,
e_value=self.e_value,
)
return raw_output
return [raw_output]
......@@ -20,6 +20,7 @@ import os
import subprocess
from typing import Sequence
from openfold.data import parsers
from openfold.data.tools import utils
......@@ -62,9 +63,17 @@ class HHSearch:
f"Could not find HHsearch database {database_path}"
)
@property
def output_format(self) -> str:
return 'hhr'
@property
def input_format(self) -> str:
return 'a3m'
def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, "w") as f:
......@@ -104,3 +113,11 @@ class HHSearch:
with open(hhr_path) as f:
hhr = f.read()
return hhr
def get_template_hits(self,
output_string: str,
input_sequence: str
) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool"""
del input_sequence # Used by hmmsearch but not needed for hhsearch
return parsers.parse_hhr(output_string)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import os
import re
import subprocess
from absl import logging
from openfold.data.tools import utils
class Hmmbuild(object):
"""Python wrapper of the hmmbuild binary."""
def __init__(self,
*,
binary_path: str,
singlemx: bool = False):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self.binary_path = binary_path
self.singlemx = singlemx
def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return self._build_profile(sto, model_construction=model_construction)
def build_profile_from_a3m(self, a3m: str) -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines = []
for line in a3m.splitlines():
if not line.startswith('>'):
line = re.sub('[a-z]+', '', line) # Remove inserted residues.
lines.append(line + '\n')
msa = ''.join(lines)
return self._build_profile(msa, model_construction='fast')
def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if model_construction not in {'hand', 'fast'}:
raise ValueError(f'Invalid model_construction {model_construction} - only'
'hand and fast supported.')
with utils.tmpdir_manager() as query_tmp_dir:
input_query = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
with open(input_query, 'w') as f:
f.write(msa)
cmd = [self.binary_path]
# If adding flags, we have to do so before the output and input:
if model_construction == 'hand':
cmd.append(f'--{model_construction}')
if self.singlemx:
cmd.append('--singlemx')
cmd.extend([
'--amino',
output_hmm_path,
input_query,
])
logging.info('Launching subprocess %s', cmd)
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
with utils.timing('hmmbuild query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'), stderr.decode('utf-8'))
if retcode:
raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(output_hmm_path, encoding='utf-8') as f:
hmm = f.read()
return hmm
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import os
import subprocess
from typing import Optional, Sequence
from absl import logging
from openfold.data import parsers
from openfold.data.tools import hmmbuild
from openfold.data.tools import utils
class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
def __init__(self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = ['--F1', '0.1',
'--F2', '0.1',
'--F3', '0.1',
'--incE', '100',
'-E', '100',
'--domE', '100',
'--incdomE', '100']
self.flags = flags
if not os.path.exists(self.database_path):
logging.error('Could not find hmmsearch database %s', database_path)
raise ValueError(f'Could not find hmmsearch database {database_path}')
@property
def output_format(self) -> str:
return 'sto'
@property
def input_format(self) -> str:
return 'sto'
def query(self, msa_sto: str) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto,
model_construction='hand')
return self.query_with_hmm(hmm)
def query_with_hmm(self, hmm: str) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
out_path = os.path.join(query_tmp_dir, 'output.sto')
with open(hmm_input_path, 'w') as f:
f.write(hmm)
cmd = [
self.binary_path,
'--noali', # Don't include the alignment in stdout.
'--cpu', '8'
]
# If adding flags, we have to do so before the output and input:
if self.flags:
cmd.extend(self.flags)
cmd.extend([
'-A', out_path,
hmm_input_path,
self.database_path,
])
logging.info('Launching sub-process %s', cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing(
f'hmmsearch ({os.path.basename(self.database_path)}) query'):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(out_path) as f:
out_msa = f.read()
return out_msa
def get_template_hits(self,
output_string: str,
input_sequence: str
) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = parsers.convert_stockholm_to_a3m(
output_string,
remove_first_row_gaps=False
)
template_hits = parsers.parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False
)
return template_hits
......@@ -23,6 +23,7 @@ import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request
from openfold.data import parsers
from openfold.data.tools import utils
......@@ -93,10 +94,13 @@ class Jackhmmer:
self.streaming_callback = streaming_callback
def _query_chunk(
self, input_fasta_path: str, database_path: str
self,
input_fasta_path: str,
database_path: str,
max_sequences: Optional[int] = None
) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, "output.sto")
# The F1/F2/F3 are the expected proportion to pass each of the filtering
......@@ -167,8 +171,11 @@ class Jackhmmer:
with open(tblout_path) as f:
tbl = f.read()
if(max_sequences is None):
with open(sto_path) as f:
sto = f.read()
else:
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
raw_output = dict(
sto=sto,
......@@ -180,10 +187,16 @@ class Jackhmmer:
return raw_output
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None
) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)]
single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences,
)
return [single_chunk_result]
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
......@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i))
self._query_chunk(
input_fasta_path,
db_local_chunk(i),
max_sequences
)
)
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
future = next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if(i < self.num_streamed_chunks):
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
......@@ -72,7 +72,7 @@ class Kalign:
"residues long. Got %s (%d residues)." % (s, len(s))
)
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
......
......@@ -16,7 +16,7 @@
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import (
......@@ -151,6 +151,40 @@ class AngleResnet(nn.Module):
return unnormalized_s, s
class PointProjection(nn.Module):
def __init__(self,
c_hidden: int,
num_points: int,
no_heads: int
return_local_points: bool = False,
):
super().__init__()
self.return_local_points = return_local_points
self.no_heads = no_heads
self.linear = Linear(c_hidden, 3 * num_points)
def forward(self,
activations: torch.Tensor,
rigids: Rigid3Array,
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
points_local = points_local.reshape(
points_local.shape[:-1],
self.no_heads,
-1,
)
points_local = torch.split(points_local, 3, dim=-1)
points_local = Vec3Array(*points_local)
points_global = rigids[..., None, None].apply_to_point(points_local)
if(self.return_local_points):
return points_global, points_local
return points_global
class InvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
......@@ -200,13 +234,23 @@ class InvariantPointAttention(nn.Module):
self.linear_q = Linear(self.c_s, hc)
self.linear_kv = Linear(self.c_s, 2 * hc)
hpq = self.no_heads * self.no_qk_points * 3
self.linear_q_points = Linear(self.c_s, hpq)
self.linear_q_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads
)
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
self.linear_kv_points = Linear(self.c_s, hpkv)
self.linear_k_points = PointProjection(
self.c_s,
self.no_qk_points
self.no_heads,
)
hpv = self.no_heads * self.no_v_points * 3
self.linear_v_points = PointProjection(
self.c_s,
self.no_v_points
self.no_heads,
)
self.linear_b = Linear(self.c_z, self.no_heads)
......@@ -257,35 +301,14 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# [*, N_res, H * P_q * 3]
q_pts = self.linear_q_points(s)
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, P_qk]
q_pts = self.linear_q_points(s, r)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_qk, 3]
k_pts = self.linear_k_points(s, r)
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
)
# [*, N_res, H, P_v, 3]
v_pts = self.linear_v_points(s, r)
##########################
# Compute attention scores
......@@ -302,8 +325,8 @@ class InvariantPointAttention(nn.Module):
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :]
pt_att = pt_att * pt_att + self.eps
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
......@@ -340,26 +363,20 @@ class InvariantPointAttention(nn.Module):
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
# [*, N_res, H, P_v]
o_pt = v_pts.tensor_dot(
permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
)
o_pt = o_pt.sum(dim=-3)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H, P_v]
o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
......@@ -370,7 +387,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
(o, *o_pt, o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
)
......
......@@ -24,8 +24,6 @@ from importlib import resources
import numpy as np
import tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096
......@@ -1309,3 +1307,179 @@ def aatype_to_str_sequence(aatype):
restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
### ALPHAFOLD MULTIMER STUFF ###
def _make_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in restypes:
residue_name = restype_1to3[residue_name]
residue_chi_angles = chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return np.array(chi_atom_indices)
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
restype_1to3[res] for res in restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
def _make_restype_atom37_mask():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
return restype_atom37_mask
def _make_restype_atom14_mask():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask = []
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
restype_atom14_mask.append([0.] * 14)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
return restype_atom14_mask
def _make_restype_atom37_to_atom14():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in atom_types
])
restype_atom37_to_atom14.append([0] * 37)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
return restype_atom37_to_atom14
def _make_restype_atom14_to_atom37():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_to_atom37.append([
(atom_order[name] if name else 0)
for name in atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
return restype_atom14_to_atom37
def _make_restype_atom14_is_ambiguous():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = restype_order[
restype_3to1[resname]]
atom_idx1 = restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
return restype_atom14_is_ambiguous
def _make_restype_rigidgroup_base_atom37_idx():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for chi_idx in range(4):
if chi_angles_mask[restype][chi_idx]:
atom_names = chi_angles_atoms[resname][chi_idx]
base_atom_names[restype, chi_idx + 4, :] = atom_names[1:]
# Translate atom names into atom37 indices.
lookuptable = atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
base_atom_names)
return restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES = _make_chi_atom_indices()
RENAMING_MATRICES = _make_renaming_matrices()
RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
RESTYPE_ATOM37_MASK = _make_restype_atom37_mask()
RESTYPE_ATOM14_MASK = _make_restype_atom14_mask()
RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32)
RESTYPE_RIGIDGROUP_MASK[:, 0] = 1
RESTYPE_RIGIDGROUP_MASK[:, 3] = 1
RESTYPE_RIGIDGROUP_MASK[:20, 4:] = chi_angles_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment