Commit 57f869d6 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Continue work on AlphaFold-Multimer

parent 100485dd
...@@ -74,6 +74,8 @@ def model_config(name, train=False, low_prec=False): ...@@ -74,6 +74,8 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name:
c.model.update(multimer_model_config_update)
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
...@@ -493,3 +495,11 @@ config = mlc.ConfigDict( ...@@ -493,3 +495,11 @@ config = mlc.ConfigDict(
"ema": {"decay": 0.999}, "ema": {"decay": 0.999},
} }
) )
multimer_model_config_update = mlc.ConfigDict(
"relative_encoding": {
"enabled": True,
"max_relative_chain": 2,
"max_relative_idx": 32,
}
)
...@@ -25,6 +25,7 @@ from openfold.data import ( ...@@ -25,6 +25,7 @@ from openfold.data import (
parsers, parsers,
mmcif_parsing, mmcif_parsing,
msa_identifiers, msa_identifiers,
msa_pairing,
) )
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date from openfold.data.tools.utils import to_date
...@@ -277,11 +278,13 @@ class AlignmentRunner: ...@@ -277,11 +278,13 @@ class AlignmentRunner:
mgnify_database_path: Optional[str] = None, mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None, bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None, uniclust30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None, template_searcher: Optional[TemplateSearcher] = None,
use_small_bfd: Optional[bool] = None, use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None, no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000, uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000, mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
): ):
""" """
Args: Args:
...@@ -320,6 +323,7 @@ class AlignmentRunner: ...@@ -320,6 +323,7 @@ class AlignmentRunner:
uniref90_database_path, uniref90_database_path,
mgnify_database_path, mgnify_database_path,
bfd_database_path if use_small_bfd else None, bfd_database_path if use_small_bfd else None,
uniprot_database_path,
], ],
}, },
"hhblits": { "hhblits": {
...@@ -339,6 +343,7 @@ class AlignmentRunner: ...@@ -339,6 +343,7 @@ class AlignmentRunner:
self.uniref_max_hits = uniref_max_hits self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits self.mgnify_max_hits = mgnify_max_hits
self.uniprot_max_hits = uniprot_max_hits
self.use_small_bfd = use_small_bfd self.use_small_bfd = use_small_bfd
if(no_cpus is None): if(no_cpus is None):
...@@ -381,6 +386,13 @@ class AlignmentRunner: ...@@ -381,6 +386,13 @@ class AlignmentRunner:
n_cpu=no_cpus, n_cpu=no_cpus,
) )
self._uniprot_msa_runner = None
if(uniprot_database_path is not None):
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path
)
if(template_searcher is not None and if(template_searcher is not None and
self.jackhmmer_uniref90_runner is None self.jackhmmer_uniref90_runner is None
): ):
...@@ -456,6 +468,148 @@ class AlignmentRunner: ...@@ -456,6 +468,148 @@ class AlignmentRunner:
msa_format="a3m", msa_format="a3m",
) )
if(self.jackhmmer_uniprot_runner is not None):
uniprot_out_path = os.path.join(output_dir, 'uniprot_hits.sto')
result = run_msa_tool(
self.jackhmmer_uniprot_runner,
input_fasta_path=input_fasta_path,
msa_out_path=uniprot_out_path,
msa_format='sto',
max_sto_sequences=self.uniprot_max_hits,
)
@dataclasses.dataclass(frozen=True)
class _FastaChain:
sequence: str
description: str
def _make_chain_id_map(*,
sequences: Sequence[str],
descriptions: Sequence[str],
) -> Mapping[str, _FastaChain]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if len(sequences) != len(descriptions):
raise ValueError('sequences and descriptions must have equal length. '
f'Got {len(sequences)} != {len(descriptions)}.')
if len(sequences) > protein.PDB_MAX_CHAINS:
raise ValueError('Cannot process more chains than the PDB format supports. '
f'Got {len(sequences)} chains.')
chain_id_map = {}
for chain_id, sequence, description in zip(
protein.PDB_CHAIN_IDS, sequences, descriptions):
chain_id_map[chain_id] = _FastaChain(
sequence=sequence, description=description)
return chain_id_map
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def convert_monomer_features(
monomer_features: FeatureDict,
chain_id: str
) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == 'aatype':
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == 'template_aatype':
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
elif feature_name == 'template_all_atom_masks':
feature_name = 'template_all_atom_mask'
converted[feature_name] = feature
return converted
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)
def add_assembly_features(
all_chain_features: MutableMapping[str, FeatureDict],
) -> MutableMapping[str, FeatureDict]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_id, chain_features in all_chain_features.items():
seq = str(chain_features['sequence'])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = {}
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
for sym_id, chain_features in enumerate(group_chain_features, start=1):
new_all_chain_features[
f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
seq_length = chain_features['seq_length']
chain_features['asym_id'] = chain_id * np.ones(seq_length)
chain_features['sym_id'] = sym_id * np.ones(seq_length)
chain_features['entity_id'] = entity_id * np.ones(seq_length)
chain_id += 1
return new_all_chain_features
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
return np_example
class DataPipeline: class DataPipeline:
"""Assembles input features.""" """Assembles input features."""
...@@ -579,10 +733,9 @@ class DataPipeline: ...@@ -579,10 +733,9 @@ class DataPipeline:
(v["msa"], v["deletion_matrix"]) for v in msa_data.values() (v["msa"], v["deletion_matrix"]) for v in msa_data.values()
]) ])
msa_features = make_msa_features( msa_objects = [Msa(m, d) for m, d in zip(msas, deletion_matrices)]
msas=msas,
deletion_matrices=deletion_matrices, msa_features = make_msa_features(msa_objects)
)
return msa_features return msa_features
...@@ -722,3 +875,126 @@ class DataPipeline: ...@@ -722,3 +875,126 @@ class DataPipeline:
return {**core_feats, **template_features, **msa_features} return {**core_feats, **template_features, **msa_features}
class DataPipelineMultimer:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
monomer_data_pipeline: DataPipeline,
jackhmmer_binary_path: str,
uniprot_database_path: str,
max_uniprot_hits: int = 50000,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self._monomer_data_pipeline = monomer_data_pipeline
def _process_single_chain(
self,
chain_id: str,
sequence: str,
description: str,
msa_output_dir: str,
is_homomer_or_monomer: bool
) -> FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n'
chain_msa_output_dir = os.path.join(msa_output_dir, chain_id)
if not os.path.exists(chain_msa_output_dir):
raise ValueError(f"Alignments for {chain_id} not found...")
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
chain_features = self._monomer_data_pipeline.process_fasta(
input_fasta_path=chain_fasta_path,
alignment_dir=chain_msa_output_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,
chain_msa_output_dir)
chain_features.update(all_seq_msa_features)
return chain_features
def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path = os.path.join(msa_output_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_uniprot_accession_identifiers',
'msa_species_identifiers',
)
feats = {
f'{k}_all_seq': v for k, v in all_seq_features.items()
if k in valid_feats
}
return feats
def process(self,
input_fasta_path: str,
msa_output_dir: str,
is_prokaryote: bool = False
) -> FeatureDict:
"""Runs alignment tools on the input sequences and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
chain_id_map = _make_chain_id_map(
sequences=input_seqs,
descriptions=input_descs
)
chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json')
with open(chain_id_map_path, 'w') as f:
chain_id_map_dict = {
chain_id: dataclasses.asdict(fasta_chain)
for chain_id, fasta_chain in chain_id_map.items()
}
json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for chain_id, fasta_chain in chain_id_map.items():
if fasta_chain.sequence in sequence_features:
all_chain_features[chain_id] = copy.deepcopy(
sequence_features[fasta_chain.sequence])
continue
chain_features = self._process_single_chain(
chain_id=chain_id,
sequence=fasta_chain.sequence,
description=fasta_chain.description,
msa_output_dir=msa_output_dir,
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=chain_id
)
all_chain_features[chain_id] = chain_features
sequence_features[fasta_chain.sequence] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing.pair_and_merge(
all_chain_features=all_chain_features,
is_prokaryote=is_prokaryote,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
...@@ -476,6 +476,20 @@ def get_atom_coords( ...@@ -476,6 +476,20 @@ def get_atom_coords(
pos[residue_constants.atom_order["SD"]] = [x, y, z] pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0 mask[residue_constants.atom_order["SD"]] = 1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd = residue_constants.atom_order['CD']
nh1 = residue_constants.atom_order['NH1']
nh2 = residue_constants.atom_order['NH2']
if(
res.get_resname() == 'ARG' and
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
(np.linalg.norm(pos[nh1] - pos[cd]) >
np.linalg.norm(pos[nh2] - pos[cd]))
):
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
"""Functions for getting templates and calculating template features.""" """Functions for getting templates and calculating template features."""
import abc
import dataclasses import dataclasses
import datetime import datetime
import functools
import glob import glob
import json import json
import logging import logging
...@@ -65,10 +67,6 @@ class DateError(PrefilterError): ...@@ -65,10 +67,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date.""" """An error indicating that the hit date was after the max allowed date."""
class PdbIdError(PrefilterError):
"""An error indicating that the hit PDB ID was identical to the query."""
class AlignRatioError(PrefilterError): class AlignRatioError(PrefilterError):
"""An error indicating that the hit align ratio to the query was too small.""" """An error indicating that the hit align ratio to the query was too small."""
...@@ -188,7 +186,6 @@ def _assess_hhsearch_hit( ...@@ -188,7 +186,6 @@ def _assess_hhsearch_hit(
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
hit_pdb_code: str, hit_pdb_code: str,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime, release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95, max_subsequence_ratio: float = 0.95,
...@@ -202,7 +199,6 @@ def _assess_hhsearch_hit( ...@@ -202,7 +199,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might different from the value in the actual hit since the original pdb might
have become obsolete. have become obsolete.
query_sequence: Amino acid sequence of the query. query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release release_dates: Dictionary mapping pdb codes to their structure release
dates. dates.
release_date_cutoff: Max release date that is valid for this query. release_date_cutoff: Max release date that is valid for this query.
...@@ -214,7 +210,6 @@ def _assess_hhsearch_hit( ...@@ -214,7 +210,6 @@ def _assess_hhsearch_hit(
Raises: Raises:
DateError: If the hit date was after the max allowed date. DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small. AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query. DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short. LengthError: If the hit was too short.
...@@ -239,10 +234,6 @@ def _assess_hhsearch_hit( ...@@ -239,10 +234,6 @@ def _assess_hhsearch_hit(
f"({release_date_cutoff})." f"({release_date_cutoff})."
) )
if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError("PDB code identical to Query PDB code.")
if align_ratio <= min_align_ratio: if align_ratio <= min_align_ratio:
raise AlignRatioError( raise AlignRatioError(
"Proportion of residues aligned to query too small. " "Proportion of residues aligned to query too small. "
...@@ -408,9 +399,10 @@ def _realign_pdb_template_to_query( ...@@ -408,9 +399,10 @@ def _realign_pdb_template_to_query(
) )
try: try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m( parsed_a3m = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence]) aligner.align([old_template_sequence, new_template_sequence])
) )
old_aligned_template, new_aligned_template = parsed_a3m.sequences
except Exception as e: except Exception as e:
raise QueryToTemplateAlignError( raise QueryToTemplateAlignError(
"Could not align old template %s to template %s (%s_%s). Error: %s" "Could not align old template %s to template %s (%s_%s). Error: %s"
...@@ -752,7 +744,6 @@ class SingleHitResult: ...@@ -752,7 +744,6 @@ class SingleHitResult:
def _prefilter_hit( def _prefilter_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
...@@ -773,7 +764,6 @@ def _prefilter_hit( ...@@ -773,7 +764,6 @@ def _prefilter_hit(
hit=hit, hit=hit,
hit_pdb_code=hit_pdb_code, hit_pdb_code=hit_pdb_code,
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
release_dates=release_dates, release_dates=release_dates,
release_date_cutoff=max_template_date, release_date_cutoff=max_template_date,
) )
...@@ -781,9 +771,7 @@ def _prefilter_hit( ...@@ -781,9 +771,7 @@ def _prefilter_hit(
hit_name = f"{hit_pdb_code}_{hit_chain_id}" hit_name = f"{hit_pdb_code}_{hit_chain_id}"
msg = f"hit {hit_name} did not pass prefilter: {str(e)}" msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg) logging.info("%s: %s", query_pdb_code, msg)
if strict_error_check and isinstance( if strict_error_check and isinstance(e, (DateError, DuplicateError)):
e, (DateError, PdbIdError, DuplicateError)
):
# In strict mode we treat some prefilter cases as errors. # In strict mode we treat some prefilter cases as errors.
return PrefilterResult(valid=False, error=msg, warning=None) return PrefilterResult(valid=False, error=msg, warning=None)
...@@ -792,9 +780,16 @@ def _prefilter_hit( ...@@ -792,9 +780,16 @@ def _prefilter_hit(
return PrefilterResult(valid=True, error=None, warning=None) return PrefilterResult(valid=True, error=None, warning=None)
@functools.lru_cache(16, typed=False)
def _read_file(path):
with open(path, 'r') as f:
file_data = f.read()
return file_data
def _process_single_hit( def _process_single_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
mmcif_dir: str, mmcif_dir: str,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
...@@ -832,8 +827,7 @@ def _process_single_hit( ...@@ -832,8 +827,7 @@ def _process_single_hit(
template_sequence, template_sequence,
) )
# Fail if we can't find the mmCIF file. # Fail if we can't find the mmCIF file.
with open(cif_path, "r") as cif_file: cif_string = _read_file(cif_path)
cif_string = cif_file.read()
parsing_result = mmcif_parsing.parse( parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string file_id=hit_pdb_code, mmcif_string=cif_string
...@@ -866,7 +860,11 @@ def _process_single_hit( ...@@ -866,7 +860,11 @@ def _process_single_hit(
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions, _zero_center_positions=_zero_center_positions,
) )
features["template_sum_probs"] = [hit.sum_probs]
if hit.sum_probs is None:
features["template_sum_probs"] = [0]
else:
features["template_sum_probs"] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the # It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still # mmCIF file, but the template features for the chain we want were still
...@@ -920,8 +918,8 @@ class TemplateSearchResult: ...@@ -920,8 +918,8 @@ class TemplateSearchResult:
warnings: Sequence[str] warnings: Sequence[str]
class TemplateHitFeaturizer: class TemplateHitFeaturizer(abc.ABC):
"""A class for turning hhr hits to template features.""" """An abstract base class for turning template hits to features."""
def __init__( def __init__(
self, self,
mmcif_dir: str, mmcif_dir: str,
...@@ -993,10 +991,18 @@ class TemplateHitFeaturizer: ...@@ -993,10 +991,18 @@ class TemplateHitFeaturizer:
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions self._zero_center_positions = _zero_center_positions
@abc.abstractmethod
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]
) -> TemplateSearchResult:
class HhsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates( def get_templates(
self, self,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime], query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit], hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult: ) -> TemplateSearchResult:
...@@ -1025,7 +1031,6 @@ class TemplateHitFeaturizer: ...@@ -1025,7 +1031,6 @@ class TemplateHitFeaturizer:
for hit in hits: for hit in hits:
prefilter_result = _prefilter_hit( prefilter_result = _prefilter_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit, hit=hit,
max_template_date=template_cutoff_date, max_template_date=template_cutoff_date,
release_dates=self._release_dates, release_dates=self._release_dates,
...@@ -1105,3 +1110,88 @@ class TemplateHitFeaturizer: ...@@ -1105,3 +1110,88 @@ class TemplateHitFeaturizer:
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings features=template_features, errors=errors, warnings=warnings
) )
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]
) -> TemplateSearchResult:
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
already_seen = set()
errors = []
warnings = []
if not hits or hits[0].sum_probs is None:
sorted_hits = hits
else:
sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)
for hit in sorted_hits:
if(len(already_seen) >= self._max_hits):
break
result = _process_single_hit(
query_sequence=query_sequence,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date = self._max_template_date,
release_dates = self._release_dates,
obsolete_pdbs = self._obsolete_pdbs,
strict_error_check = self._strict_error_check,
kalign_binary_path = self._kalign_binary_path
)
if result.error:
errors.append(result.error)
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.debug(
"Skipped invalid hit %s, error: %s, warning: %s",
hit.name, result.error, result.warning,
)
else:
already_seen_key = result.features["template_sequence"]
if(already_seen_key in already_seen):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen.add(already_seen_key)
for k in template_features:
template_features[k].append(result.features[k])
if already_seen:
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = {
"template_aatype": np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_masks": np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.array([0], dtype=np.float32),
}
return TemplateSearchResult(
features=template_features,
errors=errors,
warnings=warnings,
)
...@@ -18,7 +18,7 @@ import glob ...@@ -18,7 +18,7 @@ import glob
import logging import logging
import os import os
import subprocess import subprocess
from typing import Any, Mapping, Optional, Sequence from typing import Any, List, Mapping, Optional, Sequence
from openfold.data.tools import utils from openfold.data.tools import utils
...@@ -99,9 +99,9 @@ class HHBlits: ...@@ -99,9 +99,9 @@ class HHBlits:
self.p = p self.p = p
self.z = z self.z = z
def query(self, input_fasta_path: str) -> Mapping[str, Any]: def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
"""Queries the database using HHblits.""" """Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, "output.a3m") a3m_path = os.path.join(query_tmp_dir, "output.a3m")
db_cmd = [] db_cmd = []
...@@ -172,4 +172,4 @@ class HHBlits: ...@@ -172,4 +172,4 @@ class HHBlits:
n_iter=self.n_iter, n_iter=self.n_iter,
e_value=self.e_value, e_value=self.e_value,
) )
return raw_output return [raw_output]
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import subprocess import subprocess
from typing import Sequence from typing import Sequence
from openfold.data import parsers
from openfold.data.tools import utils from openfold.data.tools import utils
...@@ -62,9 +63,17 @@ class HHSearch: ...@@ -62,9 +63,17 @@ class HHSearch:
f"Could not find HHsearch database {database_path}" f"Could not find HHsearch database {database_path}"
) )
@property
def output_format(self) -> str:
return 'hhr'
@property
def input_format(self) -> str:
return 'a3m'
def query(self, a3m: str) -> str: def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m.""" """Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m") input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr") hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, "w") as f: with open(input_path, "w") as f:
...@@ -104,3 +113,11 @@ class HHSearch: ...@@ -104,3 +113,11 @@ class HHSearch:
with open(hhr_path) as f: with open(hhr_path) as f:
hhr = f.read() hhr = f.read()
return hhr return hhr
def get_template_hits(self,
output_string: str,
input_sequence: str
) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool"""
del input_sequence # Used by hmmsearch but not needed for hhsearch
return parsers.parse_hhr(output_string)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import os
import re
import subprocess
from absl import logging
from openfold.data.tools import utils
class Hmmbuild(object):
"""Python wrapper of the hmmbuild binary."""
def __init__(self,
*,
binary_path: str,
singlemx: bool = False):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self.binary_path = binary_path
self.singlemx = singlemx
def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return self._build_profile(sto, model_construction=model_construction)
def build_profile_from_a3m(self, a3m: str) -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines = []
for line in a3m.splitlines():
if not line.startswith('>'):
line = re.sub('[a-z]+', '', line) # Remove inserted residues.
lines.append(line + '\n')
msa = ''.join(lines)
return self._build_profile(msa, model_construction='fast')
def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if model_construction not in {'hand', 'fast'}:
raise ValueError(f'Invalid model_construction {model_construction} - only'
'hand and fast supported.')
with utils.tmpdir_manager() as query_tmp_dir:
input_query = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
with open(input_query, 'w') as f:
f.write(msa)
cmd = [self.binary_path]
# If adding flags, we have to do so before the output and input:
if model_construction == 'hand':
cmd.append(f'--{model_construction}')
if self.singlemx:
cmd.append('--singlemx')
cmd.extend([
'--amino',
output_hmm_path,
input_query,
])
logging.info('Launching subprocess %s', cmd)
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
with utils.timing('hmmbuild query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'), stderr.decode('utf-8'))
if retcode:
raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(output_hmm_path, encoding='utf-8') as f:
hmm = f.read()
return hmm
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import os
import subprocess
from typing import Optional, Sequence
from absl import logging
from openfold.data import parsers
from openfold.data.tools import hmmbuild
from openfold.data.tools import utils
class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
def __init__(self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = ['--F1', '0.1',
'--F2', '0.1',
'--F3', '0.1',
'--incE', '100',
'-E', '100',
'--domE', '100',
'--incdomE', '100']
self.flags = flags
if not os.path.exists(self.database_path):
logging.error('Could not find hmmsearch database %s', database_path)
raise ValueError(f'Could not find hmmsearch database {database_path}')
@property
def output_format(self) -> str:
return 'sto'
@property
def input_format(self) -> str:
return 'sto'
def query(self, msa_sto: str) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto,
model_construction='hand')
return self.query_with_hmm(hmm)
def query_with_hmm(self, hmm: str) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
out_path = os.path.join(query_tmp_dir, 'output.sto')
with open(hmm_input_path, 'w') as f:
f.write(hmm)
cmd = [
self.binary_path,
'--noali', # Don't include the alignment in stdout.
'--cpu', '8'
]
# If adding flags, we have to do so before the output and input:
if self.flags:
cmd.extend(self.flags)
cmd.extend([
'-A', out_path,
hmm_input_path,
self.database_path,
])
logging.info('Launching sub-process %s', cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing(
f'hmmsearch ({os.path.basename(self.database_path)}) query'):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(out_path) as f:
out_msa = f.read()
return out_msa
def get_template_hits(self,
output_string: str,
input_sequence: str
) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = parsers.convert_stockholm_to_a3m(
output_string,
remove_first_row_gaps=False
)
template_hits = parsers.parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False
)
return template_hits
...@@ -23,6 +23,7 @@ import subprocess ...@@ -23,6 +23,7 @@ import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request from urllib import request
from openfold.data import parsers
from openfold.data.tools import utils from openfold.data.tools import utils
...@@ -93,10 +94,13 @@ class Jackhmmer: ...@@ -93,10 +94,13 @@ class Jackhmmer:
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
def _query_chunk( def _query_chunk(
self, input_fasta_path: str, database_path: str self,
input_fasta_path: str,
database_path: str,
max_sequences: Optional[int] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer.""" """Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, "output.sto") sto_path = os.path.join(query_tmp_dir, "output.sto")
# The F1/F2/F3 are the expected proportion to pass each of the filtering # The F1/F2/F3 are the expected proportion to pass each of the filtering
...@@ -167,8 +171,11 @@ class Jackhmmer: ...@@ -167,8 +171,11 @@ class Jackhmmer:
with open(tblout_path) as f: with open(tblout_path) as f:
tbl = f.read() tbl = f.read()
with open(sto_path) as f: if(max_sequences is None):
sto = f.read() with open(sto_path) as f:
sto = f.read()
else:
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
raw_output = dict( raw_output = dict(
sto=sto, sto=sto,
...@@ -180,10 +187,16 @@ class Jackhmmer: ...@@ -180,10 +187,16 @@ class Jackhmmer:
return raw_output return raw_output
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None
) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer.""" """Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None: if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)] single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences,
)
return [single_chunk_result]
db_basename = os.path.basename(self.database_path) db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}" db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
...@@ -217,12 +230,20 @@ class Jackhmmer: ...@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk # Run Jackhmmer with the chunk
future.result() future.result()
chunked_output.append( chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i)) self._query_chunk(
input_fasta_path,
db_local_chunk(i),
max_sequences
)
) )
# Remove the local copy of the chunk # Remove the local copy of the chunk
os.remove(db_local_chunk(i)) os.remove(db_local_chunk(i))
future = next_future future = next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if(i < self.num_streamed_chunks):
future = next_future
if self.streaming_callback: if self.streaming_callback:
self.streaming_callback(i) self.streaming_callback(i)
return chunked_output return chunked_output
...@@ -72,7 +72,7 @@ class Kalign: ...@@ -72,7 +72,7 @@ class Kalign:
"residues long. Got %s (%d residues)." % (s, len(s)) "residues long. Got %s (%d residues)." % (s, len(s))
) )
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta") input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m") output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple, Union
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_ from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
...@@ -151,6 +151,40 @@ class AngleResnet(nn.Module): ...@@ -151,6 +151,40 @@ class AngleResnet(nn.Module):
return unnormalized_s, s return unnormalized_s, s
class PointProjection(nn.Module):
def __init__(self,
c_hidden: int,
num_points: int,
no_heads: int
return_local_points: bool = False,
):
super().__init__()
self.return_local_points = return_local_points
self.no_heads = no_heads
self.linear = Linear(c_hidden, 3 * num_points)
def forward(self,
activations: torch.Tensor,
rigids: Rigid3Array,
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
points_local = points_local.reshape(
points_local.shape[:-1],
self.no_heads,
-1,
)
points_local = torch.split(points_local, 3, dim=-1)
points_local = Vec3Array(*points_local)
points_global = rigids[..., None, None].apply_to_point(points_local)
if(self.return_local_points):
return points_global, points_local
return points_global
class InvariantPointAttention(nn.Module): class InvariantPointAttention(nn.Module):
""" """
Implements Algorithm 22. Implements Algorithm 22.
...@@ -200,13 +234,23 @@ class InvariantPointAttention(nn.Module): ...@@ -200,13 +234,23 @@ class InvariantPointAttention(nn.Module):
self.linear_q = Linear(self.c_s, hc) self.linear_q = Linear(self.c_s, hc)
self.linear_kv = Linear(self.c_s, 2 * hc) self.linear_kv = Linear(self.c_s, 2 * hc)
hpq = self.no_heads * self.no_qk_points * 3 self.linear_q_points = PointProjection(
self.linear_q_points = Linear(self.c_s, hpq) self.c_s,
self.no_qk_points,
self.no_heads
)
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 self.linear_k_points = PointProjection(
self.linear_kv_points = Linear(self.c_s, hpkv) self.c_s,
self.no_qk_points
self.no_heads,
)
hpv = self.no_heads * self.no_v_points * 3 self.linear_v_points = PointProjection(
self.c_s,
self.no_v_points
self.no_heads,
)
self.linear_b = Linear(self.c_z, self.no_heads) self.linear_b = Linear(self.c_z, self.no_heads)
...@@ -257,35 +301,14 @@ class InvariantPointAttention(nn.Module): ...@@ -257,35 +301,14 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1) k, v = torch.split(kv, self.c_hidden, dim=-1)
# [*, N_res, H * P_q * 3] # [*, N_res, H, P_qk]
q_pts = self.linear_q_points(s) q_pts = self.linear_q_points(s, r)
# This is kind of clunky, but it's how the original does it # [*, N_res, H, P_qk, 3]
# [*, N_res, H * P_q, 3] k_pts = self.linear_k_points(s, r)
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3] # [*, N_res, H, P_v, 3]
q_pts = q_pts.view( v_pts = self.linear_v_points(s, r)
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
)
########################## ##########################
# Compute attention scores # Compute attention scores
...@@ -302,8 +325,8 @@ class InvariantPointAttention(nn.Module): ...@@ -302,8 +325,8 @@ class InvariantPointAttention(nn.Module):
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :]
pt_att = pt_att ** 2 pt_att = pt_att * pt_att + self.eps
# [*, N_res, N_res, H, P_q] # [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1)) pt_att = sum(torch.unbind(pt_att, dim=-1))
...@@ -340,26 +363,20 @@ class InvariantPointAttention(nn.Module): ...@@ -340,26 +363,20 @@ class InvariantPointAttention(nn.Module):
# As DeepMind explains, this manual matmul ensures that the operation # As DeepMind explains, this manual matmul ensures that the operation
# happens in float32. # happens in float32.
# [*, H, 3, N_res, P_v] # [*, N_res, H, P_v]
o_pt = torch.sum( o_pt = v_pts.tensor_dot(
( permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
) )
o_pt = o_pt.sum(dim=-3)
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v, 3] # [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
# [*, N_res, H, C_z] # [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype)) o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
...@@ -370,7 +387,7 @@ class InvariantPointAttention(nn.Module): ...@@ -370,7 +387,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s] # [*, N_res, C_s]
s = self.linear_out( s = self.linear_out(
torch.cat( torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 (o, *o_pt, o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype) ).to(dtype=z.dtype)
) )
......
...@@ -24,8 +24,6 @@ from importlib import resources ...@@ -24,8 +24,6 @@ from importlib import resources
import numpy as np import numpy as np
import tree import tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180]. # Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096 ca_ca = 3.80209737096
...@@ -1309,3 +1307,179 @@ def aatype_to_str_sequence(aatype): ...@@ -1309,3 +1307,179 @@ def aatype_to_str_sequence(aatype):
restypes_with_x[aatype[i]] restypes_with_x[aatype[i]]
for i in range(len(aatype)) for i in range(len(aatype))
]) ])
### ALPHAFOLD MULTIMER STUFF ###
def _make_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in restypes:
residue_name = restype_1to3[residue_name]
residue_chi_angles = chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return np.array(chi_atom_indices)
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
restype_1to3[res] for res in restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
def _make_restype_atom37_mask():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
return restype_atom37_mask
def _make_restype_atom14_mask():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask = []
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
restype_atom14_mask.append([0.] * 14)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
return restype_atom14_mask
def _make_restype_atom37_to_atom14():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in atom_types
])
restype_atom37_to_atom14.append([0] * 37)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
return restype_atom37_to_atom14
def _make_restype_atom14_to_atom37():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_to_atom37.append([
(atom_order[name] if name else 0)
for name in atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
return restype_atom14_to_atom37
def _make_restype_atom14_is_ambiguous():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = restype_order[
restype_3to1[resname]]
atom_idx1 = restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
return restype_atom14_is_ambiguous
def _make_restype_rigidgroup_base_atom37_idx():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for chi_idx in range(4):
if chi_angles_mask[restype][chi_idx]:
atom_names = chi_angles_atoms[resname][chi_idx]
base_atom_names[restype, chi_idx + 4, :] = atom_names[1:]
# Translate atom names into atom37 indices.
lookuptable = atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
base_atom_names)
return restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES = _make_chi_atom_indices()
RENAMING_MATRICES = _make_renaming_matrices()
RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
RESTYPE_ATOM37_MASK = _make_restype_atom37_mask()
RESTYPE_ATOM14_MASK = _make_restype_atom14_mask()
RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32)
RESTYPE_RIGIDGROUP_MASK[:, 0] = 1
RESTYPE_RIGIDGROUP_MASK[:, 3] = 1
RESTYPE_RIGIDGROUP_MASK[:20, 4:] = chi_angles_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment