"TensorFlow2x/vscode:/vscode.git/clone" did not exist on "2a379c65392140bfe4dc74ca34278949c453612d"
Unverified Commit 4693058b authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

support multimer (#63)

parent c80a4df5
...@@ -34,6 +34,7 @@ from fastfold.data import ( ...@@ -34,6 +34,7 @@ from fastfold.data import (
msa_pairing, msa_pairing,
feature_processing_multimer, feature_processing_multimer,
) )
from fastfold.data import templates
from fastfold.data.parsers import Msa from fastfold.data.parsers import Msa
from fastfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from fastfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from fastfold.data.tools.utils import to_date from fastfold.data.tools.utils import to_date
...@@ -57,7 +58,7 @@ def empty_template_feats(n_res) -> FeatureDict: ...@@ -57,7 +58,7 @@ def empty_template_feats(n_res) -> FeatureDict:
def make_template_features( def make_template_features(
input_sequence: str, input_sequence: str,
hits: Sequence[Any], hits: Sequence[Any],
template_featurizer: Union[hhsearch.HHSearch, hmmsearch.Hmmsearch], template_featurizer: Union[templates.TemplateHitFeaturizer, templates.HmmsearchHitFeaturizer],
query_pdb_code: Optional[str] = None, query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None, query_release_date: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
...@@ -65,7 +66,7 @@ def make_template_features( ...@@ -65,7 +66,7 @@ def make_template_features(
if(len(hits_cat) == 0 or template_featurizer is None): if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence)) template_features = empty_template_feats(len(input_sequence))
else: else:
if type(template_featurizer) == hhsearch.HHSearch: if type(template_featurizer) == templates.TemplateHitFeaturizer:
templates_result = template_featurizer.get_templates( templates_result = template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=query_pdb_code, query_pdb_code=query_pdb_code,
...@@ -202,32 +203,35 @@ def make_pdb_features( ...@@ -202,32 +203,35 @@ def make_pdb_features(
return pdb_feats return pdb_feats
def make_msa_features( def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features.""" """Constructs a feature dict of MSA features."""
if not msas: if not msas:
raise ValueError("At least one MSA must be provided.") raise ValueError("At least one MSA must be provided.")
int_msa = [] int_msa = []
deletion_matrix = [] deletion_matrix = []
species_ids = []
seen_sequences = set() seen_sequences = set()
for msa_index, msa in enumerate(msas): for msa_index, msa in enumerate(msas):
if not msa: if not msa:
raise ValueError( raise ValueError(
f"MSA {msa_index} must contain at least one sequence." f"MSA {msa_index} must contain at least one sequence."
) )
for sequence_index, sequence in enumerate(msa): for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences: if sequence in seen_sequences:
continue continue
seen_sequences.add(sequence) seen_sequences.add(sequence)
int_msa.append( int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence] [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
) )
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0]) deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index]
)
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa) num_alignments = len(int_msa)
features = {} features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32) features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
...@@ -235,9 +239,9 @@ def make_msa_features( ...@@ -235,9 +239,9 @@ def make_msa_features(
features["num_alignments"] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
return features return features
def run_msa_tool( def run_msa_tool(
msa_runner, msa_runner,
fasta_path: str, fasta_path: str,
...@@ -455,7 +459,7 @@ class AlignmentRunner: ...@@ -455,7 +459,7 @@ class AlignmentRunner:
class AlignmentRunnerMultimer(AlignmentRunner): class AlignmentRunnerMultimer:
"""Runs alignment tools and saves the results""" """Runs alignment tools and saves the results"""
def __init__( def __init__(
...@@ -504,7 +508,6 @@ class AlignmentRunnerMultimer(AlignmentRunner): ...@@ -504,7 +508,6 @@ class AlignmentRunnerMultimer(AlignmentRunner):
mgnify_max_hits: mgnify_max_hits:
Max number of mgnify hits Max number of mgnify hits
""" """
# super().__init__()
db_map = { db_map = {
"jackhmmer": { "jackhmmer": {
"binary": jackhmmer_binary_path, "binary": jackhmmer_binary_path,
...@@ -810,43 +813,41 @@ class DataPipeline: ...@@ -810,43 +813,41 @@ class DataPipeline:
return msa return msa
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1] filename, ext = os.path.splitext(name)
if(ext == ".a3m"): if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m( msa = parsers.parse_a3m(
read_msa(start, size) read_msa(start, size)
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix} # The "hmm_output" exception is a crude way to exclude
elif(ext == ".sto"): # multimer template hits.
msa, deletion_matrix, _ = parsers.parse_stockholm( elif(ext == ".sto" and not "hmm_output" == filename):
msa = parsers.parse_stockholm(
read_msa(start, size) read_msa(start, size)
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
msa_data[name] = data msa_data[name] =msa
fp.close() fp.close()
else: else:
for f in os.listdir(alignment_dir): for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f) path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1] filename, ext = os.path.splitext(f)
if(ext == ".a3m"): if(ext == ".a3m"):
with open(path, "r") as fp: with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read()) msa = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix} elif(ext == ".sto" and not "hmm_output" == filename):
elif(ext == ".sto"):
with open(path, "r") as fp: with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm( msa = parsers.parse_stockholm(
fp.read() fp.read()
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
msa_data[f] = data msa_data[f] = msa
return msa_data return msa_data
...@@ -913,19 +914,13 @@ class DataPipeline: ...@@ -913,19 +914,13 @@ class DataPipeline:
must be provided. must be provided.
""" """
) )
msa_data["dummy"] = { msa_data["dummy"] = Msa(
"msa": [input_sequence], [input_sequence],
"deletion_matrix": [[0 for _ in input_sequence]], [[0 for _ in input_sequence]],
} ["dummy"]
)
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values() msa_features = make_msa_features(list(msa_data.values()))
])
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
return msa_features return msa_features
...@@ -996,7 +991,10 @@ class DataPipeline: ...@@ -996,7 +991,10 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -1014,13 +1012,24 @@ class DataPipeline: ...@@ -1014,13 +1012,24 @@ class DataPipeline:
alignment_dir: str, alignment_dir: str,
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None, _alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
""" """
with open(pdb_path, 'r') as f: if(_structure_index is not None):
pdb_str = f.read() db_dir = os.path.dirname(pdb_path)
db = _structure_index["db"]
db_path = os.path.join(db_dir, db)
fp = open(db_path, "rb")
_, offset, length = _structure_index["files"][0]
fp.seek(offset)
pdb_str = fp.read(length).decode("utf-8")
fp.close()
else:
with open(pdb_path, 'r') as f:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id) protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype) input_sequence = _aatype_to_str_sequence(protein_object.aatype)
...@@ -1028,10 +1037,14 @@ class DataPipeline: ...@@ -1028,10 +1037,14 @@ class DataPipeline:
pdb_feats = make_pdb_features( pdb_feats = make_pdb_features(
protein_object, protein_object,
description, description,
is_distillation is_distillation=is_distillation
) )
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index
)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -1059,7 +1072,11 @@ class DataPipeline: ...@@ -1059,7 +1072,11 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper() description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description) core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index
)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -1123,8 +1140,8 @@ class DataPipelineMultimer: ...@@ -1123,8 +1140,8 @@ class DataPipelineMultimer:
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto") uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp: with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read() uniprot_msa_string = fp.read()
msa, deletion_matrix, _ = parsers.parse_stockholm(uniprot_msa_string) msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features(msa, deletion_matrix) all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + ( valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers', 'msa_species_identifiers',
) )
......
...@@ -76,8 +76,9 @@ def np_example_to_features( ...@@ -76,8 +76,9 @@ def np_example_to_features(
mode: str, mode: str,
): ):
np_example = dict(np_example) np_example = dict(np_example)
print("np_example seq_length", np_example["seq_length"])
if is_multimer: if is_multimer:
num_res = int(np_example["seq_length"]) num_res = int(np_example["seq_length"][0])
else: else:
num_res = int(np_example["seq_length"][0]) num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
......
...@@ -96,9 +96,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -96,9 +96,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions return sequences, descriptions
def parse_stockholm( def parse_stockholm(stockholm_string: str) -> Msa:
stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment. """Parses sequences and deletion matrix from stockholm format alignment.
Args: Args:
...@@ -153,10 +151,14 @@ def parse_stockholm( ...@@ -153,10 +151,14 @@ def parse_stockholm(
deletion_count = 0 deletion_count = 0
deletion_matrix.append(deletion_vec) deletion_matrix.append(deletion_vec)
return msa, deletion_matrix, list(name_to_sequence.keys()) return Msa(
sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys())
)
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment. """Parses sequences and deletion matrix from a3m format alignment.
Args: Args:
...@@ -171,7 +173,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -171,7 +173,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
at `deletion_matrix[i][j]` is the number of residues deleted from at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j. the aligned sequence i at residue position j.
""" """
sequences, _ = parse_fasta(a3m_string) sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = [] deletion_matrix = []
for msa_sequence in sequences: for msa_sequence in sequences:
deletion_vec = [] deletion_vec = []
...@@ -187,8 +189,12 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -187,8 +189,12 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences. # Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans("", "", string.ascii_lowercase) deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences] aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix return Msa(
sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions
)
def _convert_sto_seq_to_a3m( def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str query_non_gaps: Sequence[bool], sto_seq: str
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
import re
from fastfold.np import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)
assert(PDB_MAX_CHAINS == 62)
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} "
"chains because these cannot be written to PDB format"
)
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain is
parsed. Else, all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f"Only single model PDBs are supported. Found {len(models)} models."
)
model = models[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain:
if res.id[2] != " ":
raise ValueError(
f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported."
)
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[
residue_constants.atom_order[atom.name]
] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors),
)
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r'(\[[A-Z]+\]\n)'
tags = [
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C']
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if("[PRIMARY]" == g[0]):
seq = g[1][0].strip()
for i in range(len(seq)):
if(seq[i] not in residue_constants.restypes):
seq[i] = 'X'
aatype = np.array([
residue_constants.restype_order.get(
res_symbol, residue_constants.restype_num
) for res_symbol in seq
])
elif("[TERTIARY]" == g[0]):
tertiary = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros(
(len(tertiary[0])//3, residue_constants.atom_type_num, 3)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = (
np.transpose(tertiary_np[:, i::3])
)
atom_positions *= PICO_TO_ANGSTROM
elif("[MASK]" == g[0]):
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
atom_mask = np.zeros(
(len(mask), residue_constants.atom_type_num,)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}'
)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
)
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append("MODEL 1")
atom_index = 1
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1]
)
)
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[
0
] # Protein supports only C, N, O, S, this works.
charge = ""
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
# Close the final chain.
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1]
)
)
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
# Pad all lines to 80 characters
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
A protein instance.
"""
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"])
else:
chain_index = np.zeros_like(
_maybe_remove_leading_dim(features["aatype"])
)
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=_maybe_remove_leading_dim(features["aatype"]),
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
chain_index=chain_index,
b_factors=b_factors,
)
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Restrained Amber Minimization of a structure."""
import io
import time
from typing import Collection, Optional, Sequence
from absl import logging
from openfold.np import (
protein,
residue_constants,
)
import openfold.utils.loss as loss
from openfold.np.relax import cleanup, utils
import ml_collections
import numpy as np
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms
def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
"""Returns True if the atom will be restrained by the given restraint set."""
if rset == "non_hydrogen":
return atom.element.name != "hydrogen"
elif rset == "c_alpha":
return atom.name == "CA"
def _add_restraints(
system: openmm.System,
reference_pdb: openmm_app.PDBFile,
stiffness: unit.Unit,
rset: str,
exclude_residues: Sequence[int],
):
"""Adds a harmonic potential that restrains the system to a structure."""
assert rset in ["non_hydrogen", "c_alpha"]
force = openmm.CustomExternalForce(
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
)
force.addGlobalParameter("k", stiffness)
for p in ["x0", "y0", "z0"]:
force.addPerParticleParameter(p)
for i, atom in enumerate(reference_pdb.topology.atoms()):
if atom.residue.index in exclude_residues:
continue
if will_restrain(atom, rset):
force.addParticle(i, reference_pdb.positions[i])
logging.info(
"Restraining %d / %d particles.",
force.getNumParticles(),
system.getNumParticles(),
)
system.addForce(force)
def _openmm_minimize(
pdb_str: str,
max_iterations: int,
tolerance: unit.Unit,
stiffness: unit.Unit,
restraint_set: str,
exclude_residues: Sequence[int],
use_gpu: bool,
):
"""Minimize energy via openmm."""
pdb_file = io.StringIO(pdb_str)
pdb = openmm_app.PDBFile(pdb_file)
force_field = openmm_app.ForceField("amber99sb.xml")
constraints = openmm_app.HBonds
system = force_field.createSystem(pdb.topology, constraints=constraints)
if stiffness > 0 * ENERGY / (LENGTH ** 2):
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
simulation = openmm_app.Simulation(
pdb.topology, system, integrator, platform
)
simulation.context.setPositions(pdb.positions)
ret = {}
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
return ret
def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
"""Returns a pdb string provided OpenMM topology and positions."""
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, positions, f)
return f.getvalue()
def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
"""Checks that no atom positions have been altered by cleaning."""
cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
for ref_res, cl_res in zip(
reference.topology.residues(), cleaned.topology.residues()
):
assert ref_res.name == cl_res.name
for rat in ref_res.atoms():
for cat in cl_res.atoms():
if cat.name == rat.name:
if not np.array_equal(
cl_xyz[cat.index], ref_xyz[rat.index]
):
raise ValueError(
f"Coordinates of cleaned atom {cat} do not match "
f"coordinates of reference atom {rat}."
)
def _check_residues_are_well_defined(prot: protein.Protein):
"""Checks that all residues contain non-empty atom sets."""
if (prot.atom_mask.sum(axis=-1) == 0).any():
raise ValueError(
"Amber minimization can only be performed on proteins with"
" well-defined residues. This protein contains at least"
" one residue with no atoms."
)
def _check_atom_mask_is_ideal(prot):
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
atom_mask = prot.atom_mask
ideal_atom_mask = protein.ideal_atom_mask(prot)
utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
def clean_protein(prot: protein.Protein, checks: bool = True):
"""Adds missing atoms to Protein instance.
Args:
prot: A `protein.Protein` instance.
checks: A `bool` specifying whether to add additional checks to the cleaning
process.
Returns:
pdb_string: A string of the cleaned protein.
"""
_check_atom_mask_is_ideal(prot)
# Clean pdb.
prot_pdb_string = protein.to_pdb(prot)
pdb_file = io.StringIO(prot_pdb_string)
alterations_info = {}
fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
fixed_pdb_file = io.StringIO(fixed_pdb)
pdb_structure = PdbStructure(fixed_pdb_file)
cleanup.clean_structure(pdb_structure, alterations_info)
logging.info("alterations info: %s", alterations_info)
# Write pdb file of cleaned structure.
as_file = openmm_app.PDBFile(pdb_structure)
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string)
return pdb_string
def make_atom14_positions(prot):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]
]
restype_atom14_to_atom37.append(
[
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
]
)
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
[
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
]
)
restype_atom14_mask.append(
[(1.0 if name else 0.0) for name in atom_names]
)
# Add dummy mapping for restype 'UNK'.
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = np.array(
restype_atom14_to_atom37, dtype=np.int32
)
restype_atom37_to_atom14 = np.array(
restype_atom37_to_atom14, dtype=np.int32
)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein.
residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
prot["all_atom_mask"], residx_atom14_to_atom37, axis=1
).astype(np.float32)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
np.take_along_axis(
prot["all_atom_positions"],
residx_atom14_to_atom37[..., None],
axis=1,
)
)
prot["atom14_atom_exists"] = residx_atom14_mask
prot["atom14_gt_exists"] = residx_atom14_gt_mask
prot["atom14_gt_positions"] = residx_atom14_gt_positions
prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
# Create the gather indices for mapping back.
residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
# Create the corresponding mask.
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
prot["atom37_atom_exists"] = residx_atom37_mask
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [
residue_constants.restype_1to3[res]
for res in residue_constants.restypes
]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[
resname
].index(source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[
resname
].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack(
[all_matrices[restype] for restype in restype_3]
)
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[prot["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = np.einsum(
"rac,rab->rbc", residx_atom14_gt_positions, renaming_transform
)
prot["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = np.einsum(
"ra,rab->rb", residx_atom14_gt_mask, renaming_transform
)
prot["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]
]
atom_idx1 = residue_constants.restype_name_to_atom14_names[
resname
].index(atom_name1)
atom_idx2 = residue_constants.restype_name_to_atom14_names[
resname
].index(atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
prot["aatype"]
]
return prot
def find_violations(prot_np: protein.Protein):
"""Analyzes a protein and returns structural violation information.
Args:
prot_np: A protein.
Returns:
violations: A `dict` of structure components with structural violations.
violation_metrics: A `dict` of violation metrics.
"""
batch = {
"aatype": prot_np.aatype,
"all_atom_positions": prot_np.atom_positions.astype(np.float32),
"all_atom_mask": prot_np.atom_mask.astype(np.float32),
"residue_index": prot_np.residue_index,
}
batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
batch = make_atom14_positions(batch)
violations = loss.find_structural_violations_np(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
config=ml_collections.ConfigDict(
{
"violation_tolerance_factor": 12, # Taken from model config.
"clash_overlap_tolerance": 1.5, # Taken from model config.
}
),
)
violation_metrics = loss.compute_violation_metrics_np(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
violations=violations,
)
return violations, violation_metrics
def get_violation_metrics(prot: protein.Protein):
"""Computes violation and alignment metrics."""
structural_violations, struct_metrics = find_violations(prot)
violation_idx = np.flatnonzero(
structural_violations["total_per_residue_violations_mask"]
)
struct_metrics["residue_violations"] = violation_idx
struct_metrics["num_residue_violations"] = len(violation_idx)
struct_metrics["structural_violations"] = structural_violations
return struct_metrics
def _run_one_iteration(
*,
pdb_string: str,
max_iterations: int,
tolerance: float,
stiffness: float,
restraint_set: str,
max_attempts: int,
exclude_residues: Optional[Collection[int]] = None,
use_gpu: bool,
):
"""Runs the minimization pipeline.
Args:
pdb_string: A pdb string.
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
use_gpu: Whether to run relaxation on GPU
Returns:
A `dict` of minimization info.
"""
exclude_residues = exclude_residues or []
# Assign physical dimensions.
tolerance = tolerance * ENERGY
stiffness = stiffness * ENERGY / (LENGTH ** 2)
start = time.perf_counter()
minimized = False
attempts = 0
while not minimized and attempts < max_attempts:
attempts += 1
try:
logging.info(
"Minimizing protein, attempt %d of %d.", attempts, max_attempts
)
ret = _openmm_minimize(
pdb_string,
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
exclude_residues=exclude_residues,
use_gpu=use_gpu,
)
minimized = True
except Exception as e: # pylint: disable=broad-except
print(e)
logging.info(e)
if not minimized:
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
ret["opt_time"] = time.perf_counter() - start
ret["min_attempts"] = attempts
return ret
def run_pipeline(
prot: protein.Protein,
stiffness: float,
use_gpu: bool,
max_outer_iterations: int = 1,
place_hydrogens_every_iteration: bool = True,
max_iterations: int = 0,
tolerance: float = 2.39,
restraint_set: str = "non_hydrogen",
max_attempts: int = 100,
checks: bool = True,
exclude_residues: Optional[Sequence[int]] = None,
):
"""Run iterative amber relax.
Successive relax iterations are performed until all violations have been
resolved. Each iteration involves a restrained Amber minimization, with
restraint exclusions determined by violation-participating residues.
Args:
prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness.
use_gpu: Whether to run on GPU
max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization.
max_iterations: An `int` specifying the maximum number of L-BFGS steps
per relax iteration. A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
The default value is the OpenMM default.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts per iteration.
checks: Whether to perform cleaning checks.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
out: A dictionary of output values.
"""
# `protein.to_pdb` will strip any poorly-defined residues so we need to
# perform this check before `clean_protein`.
_check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks)
exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues)
violations = np.inf
iteration = 0
while violations > 0 and iteration < max_outer_iterations:
ret = _run_one_iteration(
pdb_string=pdb_string,
exclude_residues=exclude_residues,
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
max_attempts=max_attempts,
use_gpu=use_gpu,
)
prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True)
else:
pdb_string = ret["min_pdb"]
ret.update(get_violation_metrics(prot))
ret.update(
{
"num_exclusions": len(exclude_residues),
"iteration": iteration,
}
)
violations = ret["violations_per_residue"]
exclude_residues = exclude_residues.union(ret["residue_violations"])
logging.info(
"Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
"num residue violations %d num residue exclusions %d ",
ret["einit"],
ret["efinal"],
ret["opt_time"],
ret["num_residue_violations"],
ret["num_exclusions"],
)
iteration += 1
return ret
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
fix_pdb uses a third-party tool. We also support fixing some additional edge
cases like removing chains of length one (see clean_structure).
"""
import io
import pdbfixer
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
Args:
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
Returns:
A PDB string representing the fixed structure.
"""
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
fixer.findNonstandardResidues()
alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
fixer.replaceNonstandardResidues()
_remove_heterogens(fixer, alterations_info, keep_water=False)
fixer.findMissingResidues()
alterations_info["missing_residues"] = fixer.missingResidues
fixer.findMissingAtoms()
alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
alterations_info["missing_terminals"] = fixer.missingTerminals
fixer.addMissingAtoms(seed=0)
fixer.addMissingHydrogens()
out_handle = io.StringIO()
app.PDBFile.writeFile(
fixer.topology, fixer.positions, out_handle, keepIds=True
)
return out_handle.getvalue()
def clean_structure(pdb_structure, alterations_info):
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
Args:
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
_replace_met_se(pdb_structure, alterations_info)
_remove_chains_of_length_one(pdb_structure, alterations_info)
def _remove_heterogens(fixer, alterations_info, keep_water):
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
initial_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
initial_resnames.add(residue.name)
fixer.removeHeterogens(keepWater=keep_water)
final_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
final_resnames.add(residue.name)
alterations_info["removed_heterogens"] = initial_resnames.difference(
final_resnames
)
def _replace_met_se(pdb_structure, alterations_info):
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues = []
for res in pdb_structure.iter_residues():
name = res.get_name_with_spaces().strip()
if name == "MET":
s_atom = res.get_atom("SD")
if s_atom.element_symbol == "Se":
s_atom.element_symbol = "S"
s_atom.element = element.get_by_symbol("S")
modified_met_residues.append(s_atom.residue_number)
alterations_info["Se_in_MET"] = modified_met_residues
def _remove_chains_of_length_one(pdb_structure, alterations_info):
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
removed_chains = {}
for model in pdb_structure.iter_models():
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
invalid_chain_ids = [
c.chain_id for c in model.iter_chains() if len(c) <= 1
]
model.chains = valid_chains
for chain_id in invalid_chain_ids:
model.chains_by_id.pop(chain_id)
removed_chains[model.number] = invalid_chain_ids
alterations_info["removed_chains"] = removed_chains
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Amber relaxation."""
from typing import Any, Dict, Sequence, Tuple
from openfold.np import protein
from openfold.np.relax import amber_minimize, utils
import numpy as np
class AmberRelaxation(object):
"""Amber relaxation."""
def __init__(
self,
*,
max_iterations: int,
tolerance: float,
stiffness: float,
exclude_residues: Sequence[int],
max_outer_iterations: int,
use_gpu: bool,
):
"""Initialize Amber Relaxer.
Args:
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
exclude_residues: Residues to exclude from per-atom restraining.
Zero-indexed.
max_outer_iterations: Maximum number of violation-informed relax
iterations. A value of 1 will run the non-iterative procedure used in
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations.
use_gpu: Whether to run on GPU
"""
self._max_iterations = max_iterations
self._tolerance = tolerance
self._stiffness = stiffness
self._exclude_residues = exclude_residues
self._max_outer_iterations = max_outer_iterations
self._use_gpu = use_gpu
def process(
self, *, prot: protein.Protein
) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
prot=prot,
max_iterations=self._max_iterations,
tolerance=self._tolerance,
stiffness=self._stiffness,
exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations,
use_gpu=self._use_gpu,
)
min_pos = out["pos"]
start_pos = out["posinit"]
rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
debug_data = {
"initial_energy": out["einit"],
"final_energy": out["efinal"],
"attempts": out["min_attempts"],
"rmsd": rmsd,
}
pdb_str = amber_minimize.clean_protein(prot)
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
)
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
return min_pdb, debug_data, violations
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for minimization."""
import io
from openfold.np import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
Args:
pdb_str: An input PDB string.
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
B-factors are per residue; i.e. that the nonzero entries are identical in
[0, i, :].
Returns:
A new PDB string with the B-factors replaced.
"""
if bfactors.shape[-1] != residue_constants.atom_type_num:
raise ValueError(
f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
)
parser = PDB.PDBParser(QUIET=True)
handle = io.StringIO(pdb_str)
structure = parser.get_structure("", handle)
curr_resid = ("", "", "")
idx = -1
for atom in structure.get_atoms():
atom_resid = atom.parent.get_id()
if atom_resid != curr_resid:
idx += 1
if idx >= bfactors.shape[0]:
raise ValueError(
"Index into bfactors exceeds number of residues. "
"B-factors shape: {shape}, idx: {idx}."
)
curr_resid = atom_resid
atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
new_pdb = io.StringIO()
pdb_io = PDB.PDBIO()
pdb_io.set_structure(structure)
pdb_io.save(new_pdb)
return new_pdb.getvalue()
def assert_equal_nonterminal_atom_types(
atom_mask: np.ndarray, ref_atom_mask: np.ndarray
):
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order["OXT"]
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal(
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used in AlphaFold."""
import collections
import functools
import os
from typing import Mapping, List, Tuple
from importlib import resources
import numpy as np
import tree
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
"ALA": [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
"ARG": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "NE"],
["CG", "CD", "NE", "CZ"],
],
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"CYS": [["N", "CA", "CB", "SG"]],
"GLN": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLU": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLY": [],
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"LYS": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "CE"],
["CG", "CD", "CE", "NZ"],
],
"MET": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "SD"],
["CB", "CG", "SD", "CE"],
],
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
"SER": [["N", "CA", "CB", "OG"]],
"THR": [["N", "CA", "CB", "OG1"]],
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"VAL": [["N", "CA", "CB", "CG1"]],
}
# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
chi_angles_mask = [
[0.0, 0.0, 0.0, 0.0], # ALA
[1.0, 1.0, 1.0, 1.0], # ARG
[1.0, 1.0, 0.0, 0.0], # ASN
[1.0, 1.0, 0.0, 0.0], # ASP
[1.0, 0.0, 0.0, 0.0], # CYS
[1.0, 1.0, 1.0, 0.0], # GLN
[1.0, 1.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[1.0, 1.0, 0.0, 0.0], # HIS
[1.0, 1.0, 0.0, 0.0], # ILE
[1.0, 1.0, 0.0, 0.0], # LEU
[1.0, 1.0, 1.0, 1.0], # LYS
[1.0, 1.0, 1.0, 0.0], # MET
[1.0, 1.0, 0.0, 0.0], # PHE
[1.0, 1.0, 0.0, 0.0], # PRO
[1.0, 0.0, 0.0, 0.0], # SER
[1.0, 0.0, 0.0, 0.0], # THR
[1.0, 1.0, 0.0, 0.0], # TRP
[1.0, 1.0, 0.0, 0.0], # TYR
[1.0, 0.0, 0.0, 0.0], # VAL
]
# The following chi angles are pi periodic: they can be rotated by a multiple
# of pi without affecting the structure.
chi_pi_periodic = [
[0.0, 0.0, 0.0, 0.0], # ALA
[0.0, 0.0, 0.0, 0.0], # ARG
[0.0, 0.0, 0.0, 0.0], # ASN
[0.0, 1.0, 0.0, 0.0], # ASP
[0.0, 0.0, 0.0, 0.0], # CYS
[0.0, 0.0, 0.0, 0.0], # GLN
[0.0, 0.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[0.0, 0.0, 0.0, 0.0], # HIS
[0.0, 0.0, 0.0, 0.0], # ILE
[0.0, 0.0, 0.0, 0.0], # LEU
[0.0, 0.0, 0.0, 0.0], # LYS
[0.0, 0.0, 0.0, 0.0], # MET
[0.0, 1.0, 0.0, 0.0], # PHE
[0.0, 0.0, 0.0, 0.0], # PRO
[0.0, 0.0, 0.0, 0.0], # SER
[0.0, 0.0, 0.0, 0.0], # THR
[0.0, 0.0, 0.0, 0.0], # TRP
[0.0, 1.0, 0.0, 0.0], # TYR
[0.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
]
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
# psi and chi angles:
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
# is defined such that the dihedral-angle-definiting atom (the last entry in
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
"ALA": [
["N", 0, (-0.525, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.529, -0.774, -1.205)],
["O", 3, (0.627, 1.062, 0.000)],
],
"ARG": [
["N", 0, (-0.524, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.524, -0.778, -1.209)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.616, 1.390, -0.000)],
["CD", 5, (0.564, 1.414, 0.000)],
["NE", 6, (0.539, 1.357, -0.000)],
["NH1", 7, (0.206, 2.301, 0.000)],
["NH2", 7, (2.078, 0.978, -0.000)],
["CZ", 7, (0.758, 1.093, -0.000)],
],
"ASN": [
["N", 0, (-0.536, 1.357, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.531, -0.787, -1.200)],
["O", 3, (0.625, 1.062, 0.000)],
["CG", 4, (0.584, 1.399, 0.000)],
["ND2", 5, (0.593, -1.188, 0.001)],
["OD1", 5, (0.633, 1.059, 0.000)],
],
"ASP": [
["N", 0, (-0.525, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, 0.000, -0.000)],
["CB", 0, (-0.526, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.593, 1.398, -0.000)],
["OD1", 5, (0.610, 1.091, 0.000)],
["OD2", 5, (0.592, -1.101, -0.003)],
],
"CYS": [
["N", 0, (-0.522, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, 0.000)],
["CB", 0, (-0.519, -0.773, -1.212)],
["O", 3, (0.625, 1.062, -0.000)],
["SG", 4, (0.728, 1.653, 0.000)],
],
"GLN": [
["N", 0, (-0.526, 1.361, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.525, -0.779, -1.207)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.615, 1.393, 0.000)],
["CD", 5, (0.587, 1.399, -0.000)],
["NE2", 6, (0.593, -1.189, -0.001)],
["OE1", 6, (0.634, 1.060, 0.000)],
],
"GLU": [
["N", 0, (-0.528, 1.361, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.526, -0.781, -1.207)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.615, 1.392, 0.000)],
["CD", 5, (0.600, 1.397, 0.000)],
["OE1", 6, (0.607, 1.095, -0.000)],
["OE2", 6, (0.589, -1.104, -0.001)],
],
"GLY": [
["N", 0, (-0.572, 1.337, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.517, -0.000, -0.000)],
["O", 3, (0.626, 1.062, -0.000)],
],
"HIS": [
["N", 0, (-0.527, 1.360, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.525, -0.778, -1.208)],
["O", 3, (0.625, 1.063, 0.000)],
["CG", 4, (0.600, 1.370, -0.000)],
["CD2", 5, (0.889, -1.021, 0.003)],
["ND1", 5, (0.744, 1.160, -0.000)],
["CE1", 5, (2.030, 0.851, 0.002)],
["NE2", 5, (2.145, -0.466, 0.004)],
],
"ILE": [
["N", 0, (-0.493, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.536, -0.793, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.534, 1.437, -0.000)],
["CG2", 4, (0.540, -0.785, -1.199)],
["CD1", 5, (0.619, 1.391, 0.000)],
],
"LEU": [
["N", 0, (-0.520, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.522, -0.773, -1.214)],
["O", 3, (0.625, 1.063, -0.000)],
["CG", 4, (0.678, 1.371, 0.000)],
["CD1", 5, (0.530, 1.430, -0.000)],
["CD2", 5, (0.535, -0.774, 1.200)],
],
"LYS": [
["N", 0, (-0.526, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.524, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.619, 1.390, 0.000)],
["CD", 5, (0.559, 1.417, 0.000)],
["CE", 6, (0.560, 1.416, 0.000)],
["NZ", 7, (0.554, 1.387, 0.000)],
],
"MET": [
["N", 0, (-0.521, 1.364, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.210)],
["O", 3, (0.625, 1.062, -0.000)],
["CG", 4, (0.613, 1.391, -0.000)],
["SD", 5, (0.703, 1.695, 0.000)],
["CE", 6, (0.320, 1.786, -0.000)],
],
"PHE": [
["N", 0, (-0.518, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, -0.000)],
["CB", 0, (-0.525, -0.776, -1.212)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.607, 1.377, 0.000)],
["CD1", 5, (0.709, 1.195, -0.000)],
["CD2", 5, (0.706, -1.196, 0.000)],
["CE1", 5, (2.102, 1.198, -0.000)],
["CE2", 5, (2.098, -1.201, -0.000)],
["CZ", 5, (2.794, -0.003, -0.001)],
],
"PRO": [
["N", 0, (-0.566, 1.351, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, 0.000)],
["CB", 0, (-0.546, -0.611, -1.293)],
["O", 3, (0.621, 1.066, 0.000)],
["CG", 4, (0.382, 1.445, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
"SER": [
["N", 0, (-0.529, 1.360, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.518, -0.777, -1.211)],
["O", 3, (0.626, 1.062, -0.000)],
["OG", 4, (0.503, 1.325, 0.000)],
],
"THR": [
["N", 0, (-0.517, 1.364, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, -0.000)],
["CB", 0, (-0.516, -0.793, -1.215)],
["O", 3, (0.626, 1.062, 0.000)],
["CG2", 4, (0.550, -0.718, -1.228)],
["OG1", 4, (0.472, 1.353, 0.000)],
],
"TRP": [
["N", 0, (-0.521, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.212)],
["O", 3, (0.627, 1.062, 0.000)],
["CG", 4, (0.609, 1.370, -0.000)],
["CD1", 5, (0.824, 1.091, 0.000)],
["CD2", 5, (0.854, -1.148, -0.005)],
["CE2", 5, (2.186, -0.678, -0.007)],
["CE3", 5, (0.622, -2.530, -0.007)],
["NE1", 5, (2.140, 0.690, -0.004)],
["CH2", 5, (3.028, -2.890, -0.013)],
["CZ2", 5, (3.283, -1.543, -0.011)],
["CZ3", 5, (1.715, -3.389, -0.011)],
],
"TYR": [
["N", 0, (-0.522, 1.362, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, -0.000, -0.000)],
["CB", 0, (-0.522, -0.776, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG", 4, (0.607, 1.382, -0.000)],
["CD1", 5, (0.716, 1.195, -0.000)],
["CD2", 5, (0.713, -1.194, -0.001)],
["CE1", 5, (2.107, 1.200, -0.002)],
["CE2", 5, (2.104, -1.201, -0.003)],
["OH", 5, (4.168, -0.002, -0.005)],
["CZ", 5, (2.791, -0.001, -0.003)],
],
"VAL": [
["N", 0, (-0.494, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.533, -0.795, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.540, 1.429, -0.000)],
["CG2", 4, (0.533, -0.776, 1.203)],
],
}
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms = {
"ALA": ["C", "CA", "CB", "N", "O"],
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
"GLY": ["C", "CA", "N", "O"],
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
"SER": ["C", "CA", "CB", "N", "O", "OG"],
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
"TRP": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
"N",
"NE1",
"O",
],
"TYR": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"N",
"O",
"OH",
],
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
}
# Naming swaps for ambiguous atom names.
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
# 4 of the 20 amino acids.
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps = {
"ASP": {"OD1": "OD2"},
"GLU": {"OE1": "OE2"},
"PHE": {"CD1": "CD2", "CE1": "CE2"},
"TYR": {"CD1": "CD2", "CE1": "CE2"},
}
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius = {
"C": 1.7,
"N": 1.55,
"O": 1.52,
"S": 1.8,
}
Bond = collections.namedtuple(
"Bond", ["atom1_name", "atom2_name", "length", "stddev"]
)
BondAngle = collections.namedtuple(
"BondAngle",
["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
)
@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> Tuple[
Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]],
]:
"""Load stereo_chemical_props.txt into a nice structure.
Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").
Returns:
residue_bonds: Dict that maps resname -> list of Bond tuples
residue_virtual_bonds: Dict that maps resname -> list of Bond tuples
residue_bond_angles: Dict that maps resname -> list of BondAngle tuples
"""
# TODO: this file should be downloaded in a setup script
stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split("-")
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(
Bond(atom1, atom2, float(length), float(stddev))
)
residue_bonds["UNK"] = []
# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split("-")
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(
atom1,
atom2,
atom3,
float(angle_degree) / 180.0 * np.pi,
float(stddev_degree) / 180.0 * np.pi,
)
)
residue_bond_angles["UNK"] = []
def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return "-".join(sorted([atom1_name, atom2_name]))
# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(
bond1.length ** 2
+ bond2.length ** 2
- 2 * bond1.length * bond2.length * np.cos(gamma)
)
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (
2 * bond1.length * bond2.length * np.sin(gamma)
) * dl_outer
dl_db1 = (
2 * bond1.length - 2 * bond2.length * np.cos(gamma)
) * dl_outer
dl_db2 = (
2 * bond2.length - 2 * bond1.length * np.cos(gamma)
) * dl_outer
stddev = np.sqrt(
(dl_dgamma * ba.stddev) ** 2
+ (dl_db1 * bond1.stddev) ** 2
+ (dl_db2 * bond2.stddev) ** 2
)
residue_virtual_bonds[resname].append(
Bond(ba.atom1_name, ba.atom3name, length, stddev)
)
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
# Between-residue bond lengths for general bonds (first element) and for Proline
# (second element).
between_res_bond_length_c_n = [1.329, 1.341]
between_res_bond_length_stddev_c_n = [0.014, 0.016]
# Between-residue cos_angles.
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
"N",
"CA",
"C",
"CB",
"O",
"CG",
"CG1",
"CG2",
"OG",
"OG1",
"SG",
"CD",
"CD1",
"CD2",
"ND1",
"ND2",
"OD1",
"OD2",
"SD",
"CE",
"CE1",
"CE2",
"CE3",
"NE",
"NE1",
"NE2",
"OE1",
"OE2",
"CH2",
"NH1",
"NH2",
"OH",
"CZ",
"CZ2",
"CZ3",
"NZ",
"OXT",
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.
# A compact atom encoding with 14 columns
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"NE",
"CZ",
"NH1",
"NH2",
"",
"",
"",
],
"ASN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"ND2",
"",
"",
"",
"",
"",
"",
],
"ASP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"OD2",
"",
"",
"",
"",
"",
"",
],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"NE2",
"",
"",
"",
"",
"",
],
"GLU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"OE2",
"",
"",
"",
"",
"",
],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"ND1",
"CD2",
"CE1",
"NE2",
"",
"",
"",
"",
],
"ILE": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"CD1",
"",
"",
"",
"",
"",
"",
],
"LEU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"",
"",
"",
"",
"",
"",
],
"LYS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"CE",
"NZ",
"",
"",
"",
"",
"",
],
"MET": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"SD",
"CE",
"",
"",
"",
"",
"",
"",
],
"PHE": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"",
"",
"",
],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": [
"N",
"CA",
"C",
"O",
"CB",
"OG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"NE1",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
"",
"",
],
"VAL": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
}
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.
restypes_with_x = restypes + ["X"]
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
def sequence_to_onehot(
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError(
"The mapping must have values from 0 to num_unique_aas-1 "
"without any gaps. Got: %s" % sorted(mapping.values())
)
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping["X"])
else:
raise ValueError(
f"Invalid character in the sequence: {aa_type}"
)
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
}
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}
# Define a restype name for all unknown residues.
unk_restype = "UNK"
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
# remaining 20 amino acids are kept in alphabetical order.
# There are 2 non-amino acid codes, X (representing any amino acid) and
# "-" representing a missing amino acid in an alignment. The id for these
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
HHBLITS_AA_TO_ID = {
"A": 0,
"B": 2,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"J": 20,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"O": 20,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"U": 1,
"V": 17,
"W": 18,
"X": 20,
"Y": 19,
"Z": 3,
"-": 21,
}
# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
restypes_with_x_and_gap = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap))
)
def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
STANDARD_ATOM_MASK = _make_standard_atom_mask()
# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1] * (4 - len(indices)))
chi_angles_index[k] = indices
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
return one_hot
chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)
# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
chi_angles_atom_indices = tree.map_structure(
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
)
chi_angles_atom_indices = np.array(
[
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
for chi_atoms in chi_angles_atom_indices
]
)
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
chi_groups_for_atom = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)
def _make_rigid_transformation_4x4(ex, ey, translation):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)
# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack(
[ex_normalized, ey_normalized, eznorm, translation]
).transpose()
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
return m
# create an array with (restype, atomtype) --> rigid_group_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def _make_rigid_group_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[
resname
]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[
restype, atomtype, :
] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[
restype, atom14idx, :
] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {
name: np.array(pos)
for name, _, pos in rigid_group_atom_positions[resname]
}
# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["N"] - atom_positions["CA"],
ey=np.array([1.0, 0.0, 0.0]),
translation=atom_positions["N"],
)
restype_rigid_group_default_frame[restype, 2, :, :] = mat
# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["C"] - atom_positions["CA"],
ey=atom_positions["CA"] - atom_positions["N"],
translation=atom_positions["C"],
)
restype_rigid_group_default_frame[restype, 3, :, :] = mat
# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [
atom_positions[name] for name in base_atom_names
]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2],
)
restype_rigid_group_default_frame[restype, 4, :, :] = mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1.0, 0.0, 0.0]),
translation=axis_end_atom_position,
)
restype_rigid_group_default_frame[
restype, 4 + chi_idx, :, :
] = mat
_make_rigid_group_constants()
def make_atom14_dists_bounds(
overlap_tolerance=1.5, bond_length_tolerance_factor=15
):
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
# create lower and upper bounds for clashes
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[
restype, atom1_idx, atom2_idx
] = lower
restype_atom14_bond_lower_bound[
restype, atom2_idx, atom1_idx
] = lower
restype_atom14_bond_upper_bound[
restype, atom1_idx, atom2_idx
] = upper
restype_atom14_bond_upper_bound[
restype, atom2_idx, atom1_idx
] = upper
# overwrite lower and upper bounds for bonds and angles
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[
restype, atom1_idx, atom2_idx
] = lower
restype_atom14_bond_lower_bound[
restype, atom2_idx, atom1_idx
] = lower
restype_atom14_bond_upper_bound[
restype, atom1_idx, atom2_idx
] = upper
restype_atom14_bond_upper_bound[
restype, atom2_idx, atom1_idx
] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
}
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
np.arange(14, dtype=np.int), (21, 1)
)
def _make_atom14_ambiguity_feats():
for res, pairs in residue_atom_renaming_swaps.items():
res_idx = restype_order[restype_3to1[res]]
for atom1, atom2 in pairs.items():
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
restype_atom14_ambiguous_atoms_swap_idx[
res_idx, atom1_idx
] = atom2_idx
restype_atom14_ambiguous_atoms_swap_idx[
res_idx, atom2_idx
] = atom1_idx
_make_atom14_ambiguity_feats()
def aatype_to_str_sequence(aatype):
return ''.join([
restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
### ALPHAFOLD MULTIMER STUFF ###
def _make_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in restypes:
residue_name = restype_1to3[residue_name]
residue_chi_angles = chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return np.array(chi_atom_indices)
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
restype_1to3[res] for res in restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
def _make_restype_atom37_mask():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
return restype_atom37_mask
def _make_restype_atom14_mask():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask = []
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
restype_atom14_mask.append([0.] * 14)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
return restype_atom14_mask
def _make_restype_atom37_to_atom14():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in atom_types
])
restype_atom37_to_atom14.append([0] * 37)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
return restype_atom37_to_atom14
def _make_restype_atom14_to_atom37():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_to_atom37.append([
(atom_order[name] if name else 0)
for name in atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
return restype_atom14_to_atom37
def _make_restype_atom14_is_ambiguous():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = restype_order[
restype_3to1[resname]]
atom_idx1 = restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
return restype_atom14_is_ambiguous
def _make_restype_rigidgroup_base_atom37_idx():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for chi_idx in range(4):
if chi_angles_mask[restype][chi_idx]:
atom_names = chi_angles_atoms[resname][chi_idx]
base_atom_names[restype, chi_idx + 4, :] = atom_names[1:]
# Translate atom names into atom37 indices.
lookuptable = atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
base_atom_names)
return restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES = _make_chi_atom_indices()
RENAMING_MATRICES = _make_renaming_matrices()
RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
RESTYPE_ATOM37_MASK = _make_restype_atom37_mask()
RESTYPE_ATOM14_MASK = _make_restype_atom14_mask()
RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32)
RESTYPE_RIGIDGROUP_MASK[:, 0] = 1
RESTYPE_RIGIDGROUP_MASK[:, 3] = 1
RESTYPE_RIGIDGROUP_MASK[:20, 4:] = chi_angles_mask
...@@ -18,7 +18,12 @@ from typing import Dict, Text, Tuple ...@@ -18,7 +18,12 @@ from typing import Dict, Text, Tuple
import torch import torch
from fastfold.np import residue_constants as rc from fastfold.common import residue_const
ants as rc
from fastfold.utils import geometry, tensor_utils from fastfold.utils import geometry, tensor_utils
import numpy as np import numpy as np
......
...@@ -142,7 +142,6 @@ def main(args): ...@@ -142,7 +142,6 @@ def main(args):
def inference_multimer_model(args): def inference_multimer_model(args):
print("running in multimer mode...") print("running in multimer mode...")
config = model_config(args.model_name) config = model_config(args.model_name)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
predict_max_templates = 4 predict_max_templates = 4
...@@ -235,11 +234,55 @@ def inference_multimer_model(args): ...@@ -235,11 +234,55 @@ def inference_multimer_model(args):
feature_dict = data_processor.process_fasta( feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir fasta_path=fasta_path, alignment_dir=local_alignment_dir
) )
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=True, feature_dict, mode='predict', is_multimer=True,
) )
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
out = result_q.get()
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
def inference_monomer_model(args): def inference_monomer_model(args):
print("running in monomer mode...") print("running in monomer mode...")
config = model_config(args.model_name) config = model_config(args.model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment