Commit 42427a1c authored by Shenggan's avatar Shenggan
Browse files

add inference pipeline from openfold/alphafold

parent b3b3b445
......@@ -41,10 +41,10 @@ python setup.py install
## Usage
You can use `Evoformer` as `nn.Module` in your project after `from fastfold.model import Evoformer`:
You can use `Evoformer` as `nn.Module` in your project after `from fastfold.model.fastnn import Evoformer`:
```python
from fastfold.model import Evoformer
from fastfold.model.fastnn import Evoformer
evoformer_layer = Evoformer()
```
......@@ -58,15 +58,15 @@ init_dap(args.dap_size)
### Inference
You can use FastFold alongwith OpenFold with `inject_openfold`. This will replace the evoformer in OpenFold with the high performance evoformer from FastFold.
You can use FastFold alongwith OpenFold with `inject_fastnn`. This will replace the evoformer in OpenFold with the high performance evoformer from FastFold.
```python
from fastfold.utils import inject_openfold
from fastfold.utils import inject_fastnn
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_openfold(model)
model = inject_fastnn(model)
```
For Dynamic Axial Parallelism, you can refer to `./inference.py`. Here is an example of 2 GPUs parallel inference:
......
......@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from fastfold.distributed import init_dap
from fastfold.model import Evoformer
from fastfold.model.fastnn import Evoformer
def main():
......
name: fastfold
channels:
- conda-forge
- bioconda
- pytorch
dependencies:
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- numpy==1.21.2
- PyYAML==5.4.1
- requests==2.26.0
- scipy==1.7.1
- tqdm==4.62.2
- typing-extensions==3.10.0.2
- einops
- colossalai
- pytorch::pytorch=1.11.0
- conda-forge::python=3.8
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
import re
from fastfold.common import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(f"Only single model PDBs are supported. Found {len(models)} models.")
model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError("Only single chain PDBs are supported when chain_id not specified. "
f"Found {len(chains)} chains.")
else:
chain = chains[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
for res in chain:
if res.id[2] != " ":
raise ValueError(f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported.")
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(res_shortname,
residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
b_factors=np.array(b_factors),
)
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r'(\[[A-Z]+\]\n)'
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C']
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if ("[PRIMARY]" == g[0]):
seq = g[1][0].strip()
for i in range(len(seq)):
if (seq[i] not in residue_constants.restypes):
seq[i] = 'X'
aatype = np.array([
residue_constants.restype_order.get(res_symbol, residue_constants.restype_num)
for res_symbol in seq
])
elif ("[TERTIARY]" == g[0]):
tertiary = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros(
(len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = (np.transpose(
tertiary_np[:, i::3]))
atom_positions *= PICO_TO_ANGSTROM
elif ("[MASK]" == g[0]):
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
atom_mask = np.zeros((
len(mask),
residue_constants.atom_type_num,
)).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
pdb_lines.append("MODEL 1")
atom_index = 1
chain_id = "A"
# Add all atom sites.
for i in range(aatype.shape[0]):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i],
b_factors[i]):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ""
# PDB is a columnar format, every space matters here!
atom_line = (f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_id:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}")
pdb_lines.append(atom_line)
atom_index += 1
# Close the chain.
chain_end = "TER"
chain_termination_line = (f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} "
f"{chain_id:>1}{residue_index[-1]:>4}")
pdb_lines.append(chain_termination_line)
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines)
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
Returns:
A protein instance.
"""
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=features["aatype"],
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
b_factors=b_factors,
)
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used in AlphaFold."""
import os
import urllib.request
import collections
import functools
from typing import Mapping, List, Tuple
import numpy as np
import tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
"ALA": [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
"ARG": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "NE"],
["CG", "CD", "NE", "CZ"],
],
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"CYS": [["N", "CA", "CB", "SG"]],
"GLN": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLU": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLY": [],
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"LYS": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "CE"],
["CG", "CD", "CE", "NZ"],
],
"MET": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "SD"],
["CB", "CG", "SD", "CE"],
],
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
"SER": [["N", "CA", "CB", "OG"]],
"THR": [["N", "CA", "CB", "OG1"]],
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"VAL": [["N", "CA", "CB", "CG1"]],
}
# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
chi_angles_mask = [
[0.0, 0.0, 0.0, 0.0], # ALA
[1.0, 1.0, 1.0, 1.0], # ARG
[1.0, 1.0, 0.0, 0.0], # ASN
[1.0, 1.0, 0.0, 0.0], # ASP
[1.0, 0.0, 0.0, 0.0], # CYS
[1.0, 1.0, 1.0, 0.0], # GLN
[1.0, 1.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[1.0, 1.0, 0.0, 0.0], # HIS
[1.0, 1.0, 0.0, 0.0], # ILE
[1.0, 1.0, 0.0, 0.0], # LEU
[1.0, 1.0, 1.0, 1.0], # LYS
[1.0, 1.0, 1.0, 0.0], # MET
[1.0, 1.0, 0.0, 0.0], # PHE
[1.0, 1.0, 0.0, 0.0], # PRO
[1.0, 0.0, 0.0, 0.0], # SER
[1.0, 0.0, 0.0, 0.0], # THR
[1.0, 1.0, 0.0, 0.0], # TRP
[1.0, 1.0, 0.0, 0.0], # TYR
[1.0, 0.0, 0.0, 0.0], # VAL
]
# The following chi angles are pi periodic: they can be rotated by a multiple
# of pi without affecting the structure.
chi_pi_periodic = [
[0.0, 0.0, 0.0, 0.0], # ALA
[0.0, 0.0, 0.0, 0.0], # ARG
[0.0, 0.0, 0.0, 0.0], # ASN
[0.0, 1.0, 0.0, 0.0], # ASP
[0.0, 0.0, 0.0, 0.0], # CYS
[0.0, 0.0, 0.0, 0.0], # GLN
[0.0, 0.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[0.0, 0.0, 0.0, 0.0], # HIS
[0.0, 0.0, 0.0, 0.0], # ILE
[0.0, 0.0, 0.0, 0.0], # LEU
[0.0, 0.0, 0.0, 0.0], # LYS
[0.0, 0.0, 0.0, 0.0], # MET
[0.0, 1.0, 0.0, 0.0], # PHE
[0.0, 0.0, 0.0, 0.0], # PRO
[0.0, 0.0, 0.0, 0.0], # SER
[0.0, 0.0, 0.0, 0.0], # THR
[0.0, 0.0, 0.0, 0.0], # TRP
[0.0, 1.0, 0.0, 0.0], # TYR
[0.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
]
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
# psi and chi angles:
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
# is defined such that the dihedral-angle-definiting atom (the last entry in
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
"ALA": [
["N", 0, (-0.525, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.529, -0.774, -1.205)],
["O", 3, (0.627, 1.062, 0.000)],
],
"ARG": [
["N", 0, (-0.524, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.524, -0.778, -1.209)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.616, 1.390, -0.000)],
["CD", 5, (0.564, 1.414, 0.000)],
["NE", 6, (0.539, 1.357, -0.000)],
["NH1", 7, (0.206, 2.301, 0.000)],
["NH2", 7, (2.078, 0.978, -0.000)],
["CZ", 7, (0.758, 1.093, -0.000)],
],
"ASN": [
["N", 0, (-0.536, 1.357, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.531, -0.787, -1.200)],
["O", 3, (0.625, 1.062, 0.000)],
["CG", 4, (0.584, 1.399, 0.000)],
["ND2", 5, (0.593, -1.188, 0.001)],
["OD1", 5, (0.633, 1.059, 0.000)],
],
"ASP": [
["N", 0, (-0.525, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, 0.000, -0.000)],
["CB", 0, (-0.526, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.593, 1.398, -0.000)],
["OD1", 5, (0.610, 1.091, 0.000)],
["OD2", 5, (0.592, -1.101, -0.003)],
],
"CYS": [
["N", 0, (-0.522, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, 0.000)],
["CB", 0, (-0.519, -0.773, -1.212)],
["O", 3, (0.625, 1.062, -0.000)],
["SG", 4, (0.728, 1.653, 0.000)],
],
"GLN": [
["N", 0, (-0.526, 1.361, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.525, -0.779, -1.207)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.615, 1.393, 0.000)],
["CD", 5, (0.587, 1.399, -0.000)],
["NE2", 6, (0.593, -1.189, -0.001)],
["OE1", 6, (0.634, 1.060, 0.000)],
],
"GLU": [
["N", 0, (-0.528, 1.361, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.526, -0.781, -1.207)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.615, 1.392, 0.000)],
["CD", 5, (0.600, 1.397, 0.000)],
["OE1", 6, (0.607, 1.095, -0.000)],
["OE2", 6, (0.589, -1.104, -0.001)],
],
"GLY": [
["N", 0, (-0.572, 1.337, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.517, -0.000, -0.000)],
["O", 3, (0.626, 1.062, -0.000)],
],
"HIS": [
["N", 0, (-0.527, 1.360, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.525, -0.778, -1.208)],
["O", 3, (0.625, 1.063, 0.000)],
["CG", 4, (0.600, 1.370, -0.000)],
["CD2", 5, (0.889, -1.021, 0.003)],
["ND1", 5, (0.744, 1.160, -0.000)],
["CE1", 5, (2.030, 0.851, 0.002)],
["NE2", 5, (2.145, -0.466, 0.004)],
],
"ILE": [
["N", 0, (-0.493, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.536, -0.793, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.534, 1.437, -0.000)],
["CG2", 4, (0.540, -0.785, -1.199)],
["CD1", 5, (0.619, 1.391, 0.000)],
],
"LEU": [
["N", 0, (-0.520, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.522, -0.773, -1.214)],
["O", 3, (0.625, 1.063, -0.000)],
["CG", 4, (0.678, 1.371, 0.000)],
["CD1", 5, (0.530, 1.430, -0.000)],
["CD2", 5, (0.535, -0.774, 1.200)],
],
"LYS": [
["N", 0, (-0.526, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.524, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.619, 1.390, 0.000)],
["CD", 5, (0.559, 1.417, 0.000)],
["CE", 6, (0.560, 1.416, 0.000)],
["NZ", 7, (0.554, 1.387, 0.000)],
],
"MET": [
["N", 0, (-0.521, 1.364, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.210)],
["O", 3, (0.625, 1.062, -0.000)],
["CG", 4, (0.613, 1.391, -0.000)],
["SD", 5, (0.703, 1.695, 0.000)],
["CE", 6, (0.320, 1.786, -0.000)],
],
"PHE": [
["N", 0, (-0.518, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, -0.000)],
["CB", 0, (-0.525, -0.776, -1.212)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.607, 1.377, 0.000)],
["CD1", 5, (0.709, 1.195, -0.000)],
["CD2", 5, (0.706, -1.196, 0.000)],
["CE1", 5, (2.102, 1.198, -0.000)],
["CE2", 5, (2.098, -1.201, -0.000)],
["CZ", 5, (2.794, -0.003, -0.001)],
],
"PRO": [
["N", 0, (-0.566, 1.351, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, 0.000)],
["CB", 0, (-0.546, -0.611, -1.293)],
["O", 3, (0.621, 1.066, 0.000)],
["CG", 4, (0.382, 1.445, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
"SER": [
["N", 0, (-0.529, 1.360, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.518, -0.777, -1.211)],
["O", 3, (0.626, 1.062, -0.000)],
["OG", 4, (0.503, 1.325, 0.000)],
],
"THR": [
["N", 0, (-0.517, 1.364, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, -0.000)],
["CB", 0, (-0.516, -0.793, -1.215)],
["O", 3, (0.626, 1.062, 0.000)],
["CG2", 4, (0.550, -0.718, -1.228)],
["OG1", 4, (0.472, 1.353, 0.000)],
],
"TRP": [
["N", 0, (-0.521, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.212)],
["O", 3, (0.627, 1.062, 0.000)],
["CG", 4, (0.609, 1.370, -0.000)],
["CD1", 5, (0.824, 1.091, 0.000)],
["CD2", 5, (0.854, -1.148, -0.005)],
["CE2", 5, (2.186, -0.678, -0.007)],
["CE3", 5, (0.622, -2.530, -0.007)],
["NE1", 5, (2.140, 0.690, -0.004)],
["CH2", 5, (3.028, -2.890, -0.013)],
["CZ2", 5, (3.283, -1.543, -0.011)],
["CZ3", 5, (1.715, -3.389, -0.011)],
],
"TYR": [
["N", 0, (-0.522, 1.362, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, -0.000, -0.000)],
["CB", 0, (-0.522, -0.776, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG", 4, (0.607, 1.382, -0.000)],
["CD1", 5, (0.716, 1.195, -0.000)],
["CD2", 5, (0.713, -1.194, -0.001)],
["CE1", 5, (2.107, 1.200, -0.002)],
["CE2", 5, (2.104, -1.201, -0.003)],
["OH", 5, (4.168, -0.002, -0.005)],
["CZ", 5, (2.791, -0.001, -0.003)],
],
"VAL": [
["N", 0, (-0.494, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.533, -0.795, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.540, 1.429, -0.000)],
["CG2", 4, (0.533, -0.776, 1.203)],
],
}
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms = {
"ALA": ["C", "CA", "CB", "N", "O"],
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
"GLY": ["C", "CA", "N", "O"],
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
"SER": ["C", "CA", "CB", "N", "O", "OG"],
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
"TRP": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
"N",
"NE1",
"O",
],
"TYR": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"N",
"O",
"OH",
],
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
}
# Naming swaps for ambiguous atom names.
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
# 4 of the 20 amino acids.
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps = {
"ASP": {"OD1": "OD2"},
"GLU": {"OE1": "OE2"},
"PHE": {"CD1": "CD2", "CE1": "CE2"},
"TYR": {"CD1": "CD2", "CE1": "CE2"},
}
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius = {
"C": 1.7,
"N": 1.55,
"O": 1.52,
"S": 1.8,
}
Bond = collections.namedtuple(
"Bond", ["atom1_name", "atom2_name", "length", "stddev"]
)
BondAngle = collections.namedtuple(
"BondAngle",
["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
)
def get_cache_path():
cache_path = os.path.join(os.path.expanduser("~"), '.fastfold')
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
return cache_path
@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> Tuple[
Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]],
]:
"""Load stereo_chemical_props.txt into a nice structure.
Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").
Returns:
residue_bonds: dict that maps resname --> list of Bond tuples
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
"""
stereo_chemical_props_path = os.path.join(get_cache_path(), 'stereo_chemical_props.txt')
if not os.path.exists(stereo_chemical_props_path):
url = "https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt"
urllib.request.urlretrieve(url, stereo_chemical_props_path)
with open(stereo_chemical_props_path, 'rt') as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split("-")
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(
Bond(atom1, atom2, float(length), float(stddev))
)
residue_bonds["UNK"] = []
# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split("-")
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(
atom1,
atom2,
atom3,
float(angle_degree) / 180.0 * np.pi,
float(stddev_degree) / 180.0 * np.pi,
)
)
residue_bond_angles["UNK"] = []
def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return "-".join(sorted([atom1_name, atom2_name]))
# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(
bond1.length ** 2
+ bond2.length ** 2
- 2 * bond1.length * bond2.length * np.cos(gamma)
)
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (
2 * bond1.length * bond2.length * np.sin(gamma)
) * dl_outer
dl_db1 = (
2 * bond1.length - 2 * bond2.length * np.cos(gamma)
) * dl_outer
dl_db2 = (
2 * bond2.length - 2 * bond1.length * np.cos(gamma)
) * dl_outer
stddev = np.sqrt(
(dl_dgamma * ba.stddev) ** 2
+ (dl_db1 * bond1.stddev) ** 2
+ (dl_db2 * bond2.stddev) ** 2
)
residue_virtual_bonds[resname].append(
Bond(ba.atom1_name, ba.atom3name, length, stddev)
)
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
# Between-residue bond lengths for general bonds (first element) and for Proline
# (second element).
between_res_bond_length_c_n = [1.329, 1.341]
between_res_bond_length_stddev_c_n = [0.014, 0.016]
# Between-residue cos_angles.
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
"N",
"CA",
"C",
"CB",
"O",
"CG",
"CG1",
"CG2",
"OG",
"OG1",
"SG",
"CD",
"CD1",
"CD2",
"ND1",
"ND2",
"OD1",
"OD2",
"SD",
"CE",
"CE1",
"CE2",
"CE3",
"NE",
"NE1",
"NE2",
"OE1",
"OE2",
"CH2",
"NH1",
"NH2",
"OH",
"CZ",
"CZ2",
"CZ3",
"NZ",
"OXT",
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.
# A compact atom encoding with 14 columns
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"NE",
"CZ",
"NH1",
"NH2",
"",
"",
"",
],
"ASN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"ND2",
"",
"",
"",
"",
"",
"",
],
"ASP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"OD2",
"",
"",
"",
"",
"",
"",
],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"NE2",
"",
"",
"",
"",
"",
],
"GLU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"OE2",
"",
"",
"",
"",
"",
],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"ND1",
"CD2",
"CE1",
"NE2",
"",
"",
"",
"",
],
"ILE": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"CD1",
"",
"",
"",
"",
"",
"",
],
"LEU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"",
"",
"",
"",
"",
"",
],
"LYS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"CE",
"NZ",
"",
"",
"",
"",
"",
],
"MET": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"SD",
"CE",
"",
"",
"",
"",
"",
"",
],
"PHE": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"",
"",
"",
],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": [
"N",
"CA",
"C",
"O",
"CB",
"OG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"NE1",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
"",
"",
],
"VAL": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
}
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.
restypes_with_x = restypes + ["X"]
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
def sequence_to_onehot(
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError(
"The mapping must have values from 0 to num_unique_aas-1 "
"without any gaps. Got: %s" % sorted(mapping.values())
)
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping["X"])
else:
raise ValueError(
f"Invalid character in the sequence: {aa_type}"
)
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
}
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}
# Define a restype name for all unknown residues.
unk_restype = "UNK"
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
# remaining 20 amino acids are kept in alphabetical order.
# There are 2 non-amino acid codes, X (representing any amino acid) and
# "-" representing a missing amino acid in an alignment. The id for these
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
HHBLITS_AA_TO_ID = {
"A": 0,
"B": 2,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"J": 20,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"O": 20,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"U": 1,
"V": 17,
"W": 18,
"X": 20,
"Y": 19,
"Z": 3,
"-": 21,
}
# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
restypes_with_x_and_gap = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap))
)
def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
STANDARD_ATOM_MASK = _make_standard_atom_mask()
# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1] * (4 - len(indices)))
chi_angles_index[k] = indices
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
return one_hot
chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)
# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
chi_angles_atom_indices = tree.map_structure(
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
)
chi_angles_atom_indices = np.array(
[
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
for chi_atoms in chi_angles_atom_indices
]
)
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
chi_groups_for_atom = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)
def _make_rigid_transformation_4x4(ex, ey, translation):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)
# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack(
[ex_normalized, ey_normalized, eznorm, translation]
).transpose()
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
return m
# create an array with (restype, atomtype) --> rigid_group_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def _make_rigid_group_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[
resname
]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[
restype, atomtype, :
] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[
restype, atom14idx, :
] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {
name: np.array(pos)
for name, _, pos in rigid_group_atom_positions[resname]
}
# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["N"] - atom_positions["CA"],
ey=np.array([1.0, 0.0, 0.0]),
translation=atom_positions["N"],
)
restype_rigid_group_default_frame[restype, 2, :, :] = mat
# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["C"] - atom_positions["CA"],
ey=atom_positions["CA"] - atom_positions["N"],
translation=atom_positions["C"],
)
restype_rigid_group_default_frame[restype, 3, :, :] = mat
# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [
atom_positions[name] for name in base_atom_names
]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2],
)
restype_rigid_group_default_frame[restype, 4, :, :] = mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1.0, 0.0, 0.0]),
translation=axis_end_atom_position,
)
restype_rigid_group_default_frame[
restype, 4 + chi_idx, :, :
] = mat
_make_rigid_group_constants()
def make_atom14_dists_bounds(
overlap_tolerance=1.5, bond_length_tolerance_factor=15
):
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
# create lower and upper bounds for clashes
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[
restype, atom1_idx, atom2_idx
] = lower
restype_atom14_bond_lower_bound[
restype, atom2_idx, atom1_idx
] = lower
restype_atom14_bond_upper_bound[
restype, atom1_idx, atom2_idx
] = upper
restype_atom14_bond_upper_bound[
restype, atom2_idx, atom1_idx
] = upper
# overwrite lower and upper bounds for bonds and angles
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[
restype, atom1_idx, atom2_idx
] = lower
restype_atom14_bond_lower_bound[
restype, atom2_idx, atom1_idx
] = lower
restype_atom14_bond_upper_bound[
restype, atom1_idx, atom2_idx
] = upper
restype_atom14_bond_upper_bound[
restype, atom2_idx, atom1_idx
] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
}
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
np.arange(14, dtype=np.int), (21, 1)
)
def _make_atom14_ambiguity_feats():
for res, pairs in residue_atom_renaming_swaps.items():
res_idx = restype_order[restype_3to1[res]]
for atom1, atom2 in pairs.items():
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
restype_atom14_ambiguous_atoms_swap_idx[
res_idx, atom1_idx
] = atom2_idx
restype_atom14_ambiguous_atoms_swap_idx[
res_idx, atom2_idx
] = atom1_idx
_make_atom14_ambiguity_feats()
def aatype_to_str_sequence(aatype):
return ''.join([
restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import ml_collections as mlc
def set_inf(c, inf):
for k, v in c.items():
if isinstance(v, mlc.ConfigDict):
set_inf(v, inf)
elif k == "inf":
c[k] = inf
def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting
pass
elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_2":
# AF2 Suppl. Table 5, Model 1.1.2
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False
elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_2_ptm":
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_5_ptm":
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf(c, 1e4)
return c
c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
NUM_EXTRA_SEQ = "extra msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
config = mlc.ConfigDict(
{
"data": {
"common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
"alt_chi_angles": [NUM_RES, None],
"atom14_alt_gt_exists": [NUM_RES, None],
"atom14_alt_gt_positions": [NUM_RES, None, None],
"atom14_atom_exists": [NUM_RES, None],
"atom14_atom_is_ambiguous": [NUM_RES, None],
"atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None],
"backbone_rigid_mask": [NUM_RES],
"backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None],
"extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
"extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_row_mask": [NUM_EXTRA_SEQ],
"is_distillation": [],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_row_mask": [NUM_MSA_SEQ],
"no_recycling_iters": [],
"pseudo_beta": [NUM_RES, None],
"pseudo_beta_mask": [NUM_RES],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
"rigidgroups_group_exists": [NUM_RES, None],
"rigidgroups_group_is_ambiguous": [NUM_RES, None],
"rigidgroups_gt_exists": [NUM_RES, None],
"rigidgroups_gt_frames": [NUM_RES, None, None, None],
"seq_length": [],
"seq_mask": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_mask": [NUM_TEMPLATES],
"template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
"template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
"template_sum_probs": [NUM_TEMPLATES, None],
"template_torsion_angles_mask": [
NUM_TEMPLATES, NUM_RES, None,
],
"template_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [],
},
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True,
"template_features": [
"template_all_atom_positions",
"template_sum_probs",
"template_aatype",
"template_all_atom_mask",
],
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
},
"supervised": {
"clamp_prob": 0.9,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
"resolution",
"use_clamped_fape",
"is_distillation",
],
},
"predict": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"supervised": False,
"uniform_recycling": False,
},
"eval": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"supervised": True,
"uniform_recycling": False,
},
"train": {
"fixed_size": True,
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"supervised": True,
"clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
},
"data_module": {
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 16,
},
},
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"eps": eps,
},
"model": {
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
"recycling_embedder": {
"c_z": c_z,
"c_m": c_m,
"min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": 1e8,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_angle_embedder": {
# DISCREPANCY: c_in is supposed to be 51.
"c_in": 57,
"c_out": c_m,
},
"template_pair_embedder": {
"c_in": 88,
"c_out": c_t,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
"c_z": c_z,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
"c_hidden": 16,
"no_heads": 4,
"inf": 1e5, # 1e9,
},
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
"c_out": 23,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
},
"relax": {
"max_iterations": 0, # no max
"tolerance": 2.39,
"stiffness": 10.0,
"max_outer_iterations": 20,
"exclude_residues": [],
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": {
"backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"lddt": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
},
"masked_msa": {
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"eps": eps, # 1e-6,
"weight": 0.0,
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.0,
"enabled": tm_enabled,
},
"eps": eps,
},
"ema": {"decay": 0.999},
}
)
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from functools import partial
import json
import logging
import os
import pickle
from typing import Optional, Sequence, List, Any
import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
from fastfold.data import (
data_pipeline,
feature_pipeline,
mmcif_parsing,
templates,
)
from fastfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
mode: str = "train",
_output_raw: bool = False,
_alignment_index: Optional[Any] = None
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
kalign_binary_path:
Path to kalign binary.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
obsolete_pdbs_file_path:
Path to the file containing replacements for obsolete PDBs.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode:
"train", "val", or "predict"
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=max_template_hits,
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=obsolete_pdbs_file_path,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
_alignment_index=_alignment_index
)
return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def __getitem__(self, idx):
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")):
data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb",
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
_alignment_index=_alignment_index,
)
else:
raise ValueError("Invalid file type")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
)
if(self._output_raw):
return data
feats = self.feature_pipeline.process_features(
data, self.mode
)
return feats
def __len__(self):
return len(self._chain_ids)
def deterministic_train_filter(
chain_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = chain_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = chain_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True
def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
) -> List[float]:
# Stochastic filters
probabilities = []
cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out
class OpenFoldDataset(torch.utils.data.Dataset):
"""
Implements the stochastic filters applied during AlphaFold's training.
Because samples are selected from constituent datasets randomly, the
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization.
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(dataset_idx):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx]
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init):
self.reroll()
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self):
return self.epoch_len
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldBatchCollator:
def __init__(self, config, stage="train"):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator
self._prep_batch_properties_probs()
def _prep_batch_properties_probs(self):
keyed_probs = []
stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(stage_cfg.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs])
padding = [[0.] * (max_len - len(p)) for p in probs]
self.prop_keys = keys
self.prop_probs_tensor = torch.tensor(
[p + pad for p, pad in zip(probs, padding)],
dtype=torch.float32,
)
def _add_batch_properties(self, batch):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator
)
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
if(key == "no_recycling_iters"):
no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
return batch
def __iter__(self):
it = super().__iter__()
def _batch_prop_gen(iterator):
for batch in iterator:
yield self._add_batch_properties(batch)
return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self,
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
self.config = config
self.template_mmcif_dir = template_mmcif_dir
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_chain_data_cache_path = train_chain_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_chain_data_cache_path = (
distillation_chain_data_cache_path
)
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.distillation_mapping_path = distillation_mapping_path
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self.training_mode = self.train_data_dir is not None
if(self.training_mode and train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
)
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=self._alignment_index,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
_output_raw=True,
)
d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
_roll_at_init=False,
)
if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mapping_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
_output_raw=True,
)
else:
self.eval_dataset = None
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
mapping_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
def _gen_dataloader(self, stage):
generator = torch.Generator()
if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed)
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage)
dl = OpenFoldDataLoader(
dataset,
config=self.config,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=batch_collator,
)
return dl
def train_dataloader(self):
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
return self._gen_dataloader("eval")
return None
def predict_dataloader(self):
return self._gen_dataloader("predict")
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
with open(batch_path, "rb") as f:
self.batch = pickle.load(f)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
class DummyDataLoader(pl.LightningDataModule):
def __init__(self, batch_path):
super().__init__()
self.dataset = DummyDataset(batch_path)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datetime
from multiprocessing import cpu_count
from typing import Mapping, Optional, Sequence, Any
import numpy as np
from fastfold.data import templates, parsers, mmcif_parsing
from fastfold.data.tools import jackhmmer, hhblits, hhsearch
from fastfold.data.tools.utils import to_date
from fastfold.common import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_template_features(
input_sequence: str,
hits: Sequence[Any],
template_featurizer: Any,
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict:
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=query_pdb_code,
query_release_date=query_release_date,
hits=hits_cat,
)
template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
return template_features
def make_sequence_features(
sequence: str, description: str, num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features."""
features = {}
features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True,
)
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
)
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
)
return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(
make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
)
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
return mmcif_feats
def _aatype_to_str_sequence(aatype):
return ''.join([
residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
def make_protein_features(
protein_object: protein.Protein,
description: str,
_is_distillation: bool = False,
) -> FeatureDict:
pdb_feats = {}
aatype = protein_object.aatype
sequence = _aatype_to_str_sequence(aatype)
pdb_feats.update(
make_sequence_features(
sequence=sequence,
description=description,
num_res=len(protein_object.aatype),
)
)
all_atom_positions = protein_object.atom_positions
all_atom_mask = protein_object.atom_mask
pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
pdb_feats["is_distillation"] = np.array(
1. if _is_distillation else 0.
).astype(np.float32)
return pdb_feats
def make_pdb_features(
protein_object: protein.Protein,
description: str,
confidence_threshold: float = 0.5,
is_distillation: bool = True,
) -> FeatureDict:
pdb_feats = make_protein_features(
protein_object, description, _is_distillation=True
)
if(is_distillation):
high_confidence = protein_object.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1)
for i, confident in enumerate(high_confidence):
if(not confident):
pdb_feats["all_atom_mask"][i] = 0
return pdb_feats
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError("At least one MSA must be provided.")
int_msa = []
deletion_matrix = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
)
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0])
num_alignments = len(int_msa)
features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features["msa"] = np.array(int_msa, dtype=np.int32)
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
return features
class AlignmentRunner:
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hhsearch_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
hhsearch_binary_path:
Path to hhsearch binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
pdb70_database_path:
Path to pdb70 database.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hhsearch": {
"binary": hhsearch_binary_path,
"dbs": [
pdb70_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.use_small_bfd = use_small_bfd
if(no_cpus is None):
no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and
uniref90_database_path is not None
):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniclust_runner = None
if(bfd_database_path is not None):
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=bfd_database_path,
n_cpu=no_cpus,
)
else:
dbs = [bfd_database_path]
if(uniclust30_database_path is not None):
dbs.append(uniclust30_database_path)
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=dbs,
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = None
if(mgnify_database_path is not None):
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.hhsearch_pdb70_runner = None
if(pdb70_database_path is not None):
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path],
n_cpu=no_cpus,
)
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result["sto"],
max_sequences=self.uniref_max_hits
)
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m)
if(self.hhsearch_pdb70_runner is not None):
hhsearch_result = self.hhsearch_pdb70_runner.query(
uniref90_msa_as_a3m
)
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result)
if(self.jackhmmer_mgnify_runner is not None):
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result["sto"],
max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
fasta_path
)[0]
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
elif(self.hhblits_bfd_uniclust_runner is not None):
hhblits_bfd_uniclust_result = (
self.hhblits_bfd_uniclust_runner.query(fasta_path)
)
if output_dir is not None:
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline:
"""Assembles input features."""
def __init__(
self,
template_featurizer: Optional[templates.TemplateHitFeaturizer],
):
self.template_featurizer = template_featurizer
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[name] = data
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
return msa_data
def _parse_template_hits(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return all_hits
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
raise ValueError(
"""
If the alignment dir contains no MSAs, an input sequence
must be provided.
"""
)
msa_data["dummy"] = {
"msa": [input_sequence],
"deletion_matrix": [[0 for _ in input_sequence]],
}
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
return msa_features
def process_fasta(
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {
**sequence_features,
**msa_features,
**template_features
}
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if chain_id is None:
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if chain is None:
raise ValueError("No chains in mmCIF file")
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
def process_pdb(
self,
pdb_path: str,
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
"""
with open(pdb_path, 'r') as f:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
pdb_feats = make_pdb_features(
protein_object,
description,
is_distillation
)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**pdb_feats, **template_features, **msa_features}
def process_core(
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
"""
with open(core_path, 'r') as f:
core_str = f.read()
protein_object = protein.from_proteinnet_string(core_str)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**core_feats, **template_features, **msa_features}
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from functools import reduce, wraps
from operator import add
import numpy as np
import torch
from fastfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from fastfold.common import residue_constants as rc
from fastfold.utils.rigid_utils import Rotation, Rigid
from fastfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
batched_gather,
)
MSA_FEATURE_NAMES = [
"msa",
"deletion_matrix",
"msa_mask",
"msa_row_mask",
"bert_mask",
"true_msa",
]
def cast_to_64bit_ints(protein):
# We keep all ints as int64
for k, v in protein.items():
if v.dtype == torch.int32:
protein[k] = v.type(torch.int64)
return protein
def make_one_hot(x, num_classes):
x_one_hot = torch.zeros(*x.shape, num_classes)
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return x_one_hot
def make_seq_mask(protein):
protein["seq_mask"] = torch.ones(
protein["aatype"].shape, dtype=torch.float32
)
return protein
def make_template_mask(protein):
protein["template_mask"] = torch.ones(
protein["template_aatype"].shape[0], dtype=torch.float32
)
return protein
def curry1(f):
"""Supply all arguments but the first."""
@wraps(f)
def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)
return fc
def make_all_atom_aatype(protein):
protein["all_atom_aatype"] = protein["aatype"]
return protein
def fix_templates_aatype(protein):
# Map one-hot to indices
num_templates = protein["template_aatype"].shape[0]
if(num_templates > 0):
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
num_templates, -1
)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
return protein
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
[new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
).transpose(0, 1)
protein["msa"] = torch.gather(new_order, 0, protein["msa"])
perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
for k in protein:
if "profile" in k:
num_dim = protein[k].shape.as_list()[-1]
assert num_dim in [
20,
21,
22,
], "num_dim for %s out of expected range: %s" % (k, num_dim)
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
return protein
def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
for k in [
"domain_name",
"msa",
"num_alignments",
"seq_length",
"sequence",
"superfamily",
"deletion_matrix",
"resolution",
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
]:
if k in protein:
final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1:
if torch.is_tensor(protein[k]):
protein[k] = torch.squeeze(protein[k], dim=-1)
else:
protein[k] = np.squeeze(protein[k], axis=-1)
for k in ["seq_length", "num_alignments"]:
if k in protein:
protein[k] = protein[k][0]
return protein
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a portion of the MSA with 'X'."""
msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
x_idx = 20
gap_idx = 21
msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
protein["msa"] = torch.where(
msa_mask,
torch.ones_like(protein["msa"]) * x_idx,
protein["msa"]
)
aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
protein["aatype"] = torch.where(
aatype_mask,
torch.ones_like(protein["aatype"]) * x_idx,
protein["aatype"],
)
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device)
if seed is not None:
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split(
index_order, [num_sel, num_seq - num_sel]
)
for k in MSA_FEATURE_NAMES:
if k in protein:
if keep_extra:
protein["extra_" + k] = torch.index_select(
protein[k], 0, not_sel_seq
)
protein[k] = torch.index_select(protein[k], 0, sel_seq)
return protein
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = distillation
return protein
@curry1
def sample_msa_distillation(protein, max_seq):
if(protein["is_distillation"] == 1):
protein = sample_msa(max_seq, keep_extra=False)(protein)
return protein
@curry1
def crop_extra_msa(protein, max_extra_msa):
num_seq = protein["extra_msa"].shape[0]
num_sel = min(max_extra_msa, num_seq)
select_indices = torch.randperm(num_seq)[:num_sel]
for k in MSA_FEATURE_NAMES:
if "extra_" + k in protein:
protein["extra_" + k] = torch.index_select(
protein["extra_" + k], 0, select_indices
)
return protein
def delete_extra_msa(protein):
for k in MSA_FEATURE_NAMES:
if "extra_" + k in protein:
del protein["extra_" + k]
return protein
# Not used in inference
@curry1
def block_delete_msa(protein, config):
num_seq = protein["msa"].shape[0]
block_num_seq = torch.floor(
torch.tensor(num_seq, dtype=torch.float32)
* config.msa_fraction_per_block
).to(torch.int32)
if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(
0, config.num_blocks + 1
).sample()
else:
nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
# Make sure we keep the original sequence
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
keep_indices = torch.squeeze(difference, 0)
for k in MSA_FEATURE_NAMES:
if k in protein:
protein[k] = torch.gather(protein[k], keep_indices)
return protein
@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
weights = torch.cat(
[torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
0,
)
# Make agreement score as weighted Hamming distance
msa_one_hot = make_one_hot(protein["msa"], 23)
sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
num_seq, num_res, _ = sample_one_hot.shape
extra_num_seq, _, _ = extra_one_hot.shape
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement = torch.matmul(
torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
torch.reshape(
sample_one_hot * weights, [num_seq, num_res * 23]
).transpose(0, 1),
)
# Assign each sequence in the extra sequences to the closest MSA sample
protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
torch.int64
)
return protein
def unsorted_segment_sum(data, segment_ids, num_segments):
"""
Computes the sum along segments of a tensor. Similar to
tf.unsorted_segment_sum, but only supports 1-D indices.
:param data: A tensor whose segments are to be summed.
:param segment_ids: The 1-D segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
assert (
len(segment_ids.shape) == 1 and
segment_ids.shape[0] == data.shape[0]
)
segment_ids = segment_ids.view(
segment_ids.shape[0], *((1,) * len(data.shape[1:]))
)
segment_ids = segment_ids.expand(data.shape)
shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float())
tensor = tensor.type(data.dtype)
return tensor
@curry1
def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = protein["msa"].shape[0]
def csum(x):
return unsorted_segment_sum(
x, protein["extra_cluster_assignment"], num_seq
)
mask = protein["extra_msa_mask"]
mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
del msa_sum
del_sum = csum(mask * protein["extra_deletion_matrix"])
del_sum += protein["deletion_matrix"] # Original sequence
protein["cluster_deletion_mean"] = del_sum / mask_counts
del del_sum
return protein
def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
protein["msa_row_mask"] = torch.ones(
(protein["msa"].shape[0]), dtype=torch.float32
)
return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
"""Create pseudo beta features."""
is_gly = torch.eq(aatype, rc.restype_order["G"])
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
if all_atom_mask is not None:
pseudo_beta_mask = torch.where(
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
@curry1
def make_pseudo_beta(protein, prefix=""):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ["", "template_"]
(
protein[prefix + "pseudo_beta"],
protein[prefix + "pseudo_beta_mask"],
) = pseudo_beta_fn(
protein["template_aatype" if prefix else "aatype"],
protein[prefix + "all_atom_positions"],
protein["template_all_atom_mask" if prefix else "all_atom_mask"],
)
return protein
@curry1
def add_constant_field(protein, key, value):
protein[key] = torch.tensor(value)
return protein
def shaped_categorical(probs, epsilon=1e-10):
ds = probs.shape
num_classes = ds[-1]
distribution = torch.distributions.categorical.Categorical(
torch.reshape(probs + epsilon, [-1, num_classes])
)
counts = distribution.sample()
return torch.reshape(counts, ds[:-1])
def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present."""
if "hhblits_profile" in protein:
return protein
# Compute the profile for every residue (over all MSA sequences).
msa_one_hot = make_one_hot(protein["msa"], 22)
protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
return protein
@curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
categorical_probs = (
config.uniform_prob * random_aa
+ config.profile_prob * protein["hhblits_profile"]
+ config.same_prob * make_one_hot(protein["msa"], 22)
)
# Put all remaining probability on [MASK] which is a new column
pad_shapes = list(
reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
)
pad_shapes[1] = 1
mask_prob = (
1.0 - config.profile_prob - config.same_prob - config.uniform_prob
)
assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob
)
sh = protein["msa"].shape
mask_position = torch.rand(sh) < replace_fraction
bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
# Mix real and masked MSA
protein["bert_mask"] = mask_position.to(torch.float32)
protein["true_msa"] = protein["msa"]
protein["msa"] = bert_msa
return protein
@curry1
def make_fixed_size(
protein,
shape_schema,
msa_cluster_size,
extra_msa_size,
num_res=0,
num_templates=0,
):
"""Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
NUM_EXTRA_SEQ: extra_msa_size,
NUM_TEMPLATES: num_templates,
}
for k, v in protein.items():
# Don't transfer this to the accelerator.
if k == "extra_cluster_assignment":
continue
shape = list(v.shape)
schema = shape_schema[k]
msg = "Rank mismatch between shape and shape schema for"
assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
padding.reverse()
padding = list(itertools.chain(*padding))
if padding:
protein[k] = torch.nn.functional.pad(v, padding)
protein[k] = torch.reshape(protein[k], pad_size)
return protein
@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for
# compatibility with domain datasets.
has_break = torch.clip(
protein["between_segment_residues"].to(torch.float32), 0, 1
)
aatype_1hot = make_one_hot(protein["aatype"], 21)
target_feat = [
torch.unsqueeze(has_break, dim=-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = make_one_hot(protein["msa"], 23)
has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
2.0 / np.pi
)
msa_feat = [
msa_1hot,
torch.unsqueeze(has_deletion, dim=-1),
torch.unsqueeze(deletion_value, dim=-1),
]
if "cluster_profile" in protein:
deletion_mean_value = torch.atan(
protein["cluster_deletion_mean"] / 3.0
) * (2.0 / np.pi)
msa_feat.extend(
[
protein["cluster_profile"],
torch.unsqueeze(deletion_mean_value, dim=-1),
]
)
if "extra_deletion_matrix" in protein:
protein["extra_has_deletion"] = torch.clip(
protein["extra_deletion_matrix"], 0.0, 1.0
)
protein["extra_deletion_value"] = torch.atan(
protein["extra_deletion_matrix"] / 3.0
) * (2.0 / np.pi)
protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
protein["target_feat"] = torch.cat(target_feat, dim=-1)
return protein
@curry1
def select_feat(protein, feature_list):
return {k: v for k, v in protein.items() if k in feature_list}
@curry1
def crop_templates(protein, max_templates):
for k, v in protein.items():
if k.startswith("template_"):
protein[k] = v[:max_templates]
return protein
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append(
[(rc.atom_order[name] if name else 0) for name in atom_names]
)
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
[
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in rc.atom_types
]
)
restype_atom14_mask.append(
[(1.0 if name else 0.0) for name in atom_names]
)
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask,
dtype=torch.float32,
device=protein["aatype"].device,
)
protein_aatype = protein['aatype'].to(torch.long)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
residx_atom14_mask = restype_atom14_mask[protein_aatype]
protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask
restype_atom37_mask = torch.zeros(
[21, 37], dtype=torch.float32, device=protein["aatype"].device
)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein_aatype]
protein["atom37_atom_exists"] = residx_atom37_mask
return protein
def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
def make_atom14_positions(protein):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
residx_atom14_mask = protein["atom14_atom_exists"]
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
protein["all_atom_mask"],
residx_atom14_to_atom37,
dim=-1,
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
batched_gather(
protein["all_atom_positions"],
residx_atom14_to_atom37,
dim=-2,
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
)
)
protein["atom14_atom_exists"] = residx_atom14_mask
protein["atom14_gt_exists"] = residx_atom14_gt_mask
protein["atom14_gt_positions"] = residx_atom14_gt_positions
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {
res: torch.eye(
14,
dtype=protein["all_atom_mask"].dtype,
device=protein["all_atom_mask"].device,
)
for res in restype_3
}
for resname, swap in rc.residue_atom_renaming_swaps.items():
correspondences = torch.arange(
14, device=protein["all_atom_mask"].device
)
for source_atom_swap, target_atom_swap in swap.items():
source_index = rc.restype_name_to_atom14_names[resname].index(
source_atom_swap
)
target_index = rc.restype_name_to_atom14_names[resname].index(
target_atom_swap
)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3]
)
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[protein["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = torch.einsum(
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
)
protein["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = torch.einsum(
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
)
protein["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
for resname, swap in rc.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = rc.restype_order[rc.restype_3to1[resname]]
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
atom_name1
)
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
atom_name2
)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
protein["aatype"]
]
return protein
def atom37_to_frames(protein, eps=1e-8):
aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"]
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4):
if rc.chi_angles_mask[restype][chi_idx]:
names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8),
)
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
rc.chi_angles_mask
)
lookuptable = rc.atom_order.copy()
lookuptable[""] = 0
lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup(
restype_rigidgroup_base_atom_names,
)
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
restype_rigidgroup_base_atom37_idx,
)
restype_rigidgroup_base_atom37_idx = (
restype_rigidgroup_base_atom37_idx.view(
*((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
)
)
residx_rigidgroup_base_atom37_idx = batched_gather(
restype_rigidgroup_base_atom37_idx,
aatype,
dim=-3,
no_batch_dims=batch_dims,
)
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather(
restype_rigidgroup_mask,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
gt_atoms_exist = batched_gather(
all_atom_mask,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1]),
)
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
(*((1,) * batch_dims), 21, 8, 1, 1),
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[rc.restype_3to1[resname]]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
residx_rigidgroup_is_ambiguous = batched_gather(
restype_rigidgroup_is_ambiguous,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = batched_gather(
restype_rigidgroup_rots,
aatype,
dim=-4,
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein["rigidgroups_gt_exists"] = gt_exists
protein["rigidgroups_group_exists"] = group_exists
protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
return protein
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append(
[0, 0, 0, 0]
) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return chi_atom_indices
@curry1
def atom37_to_torsion_angles(
protein,
prefix="",
):
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
Dict containing:
* (prefix)aatype:
[*, N_res] residue indices
* (prefix)all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
* (prefix)all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
The same dictionary updated with the following features:
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"(prefix)torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype = protein[prefix + "aatype"]
all_atom_positions = protein[prefix + "all_atom_positions"]
all_atom_mask = protein[prefix + "all_atom_mask"]
aatype = torch.clamp(aatype, max=20)
pad = all_atom_positions.new_zeros(
[*all_atom_positions.shape[:-3], 1, 37, 3]
)
prev_all_atom_positions = torch.cat(
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
)
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat(
[prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
dim=-2,
)
phi_atom_pos = torch.cat(
[prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
dim=-2,
)
psi_atom_pos = torch.cat(
[all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
dim=-2,
)
pre_omega_mask = torch.prod(
prev_all_atom_mask[..., 1:3], dim=-1
) * torch.prod(all_atom_mask[..., :2], dim=-1)
phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
)
psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
* all_atom_mask[..., 4]
)
chi_atom_indices = torch.as_tensor(
get_chi_atom_indices(), device=aatype.device
)
atom_indices = chi_atom_indices[..., aatype, :, :]
chis_atom_pos = batched_gather(
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
)
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :]
chi_angle_atoms_mask = batched_gather(
all_atom_mask,
atom_indices,
dim=-1,
no_batch_dims=len(atom_indices.shape[:-2]),
)
chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
)
chis_mask = chis_mask * chi_angle_atoms_mask
torsions_atom_pos = torch.cat(
[
pre_omega_atom_pos[..., None, :, :],
phi_atom_pos[..., None, :, :],
psi_atom_pos[..., None, :, :],
chis_atom_pos,
],
dim=-3,
)
torsion_angles_mask = torch.cat(
[
pre_omega_mask[..., None],
phi_mask[..., None],
psi_mask[..., None],
chis_mask,
],
dim=-1,
)
torsion_frames = Rigid.from_3_points(
torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :],
eps=1e-8,
)
fourth_atom_rel_pos = torsion_frames.invert().apply(
torsions_atom_pos[..., 3, :]
)
torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
)
denom = torch.sqrt(
torch.sum(
torch.square(torsion_angles_sin_cos),
dim=-1,
dtype=torsion_angles_sin_cos.dtype,
keepdims=True,
)
+ 1e-8
)
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
rc.chi_pi_periodic,
)[aatype, ...]
mirror_torsion_angles = torch.cat(
[
all_atom_mask.new_ones(*aatype.shape, 3),
1.0 - 2.0 * chi_is_ambiguous,
],
dim=-1,
)
alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[..., None]
)
protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
return protein
def get_backbone_frames(protein):
# DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
..., 0, :, :
]
protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
return protein
def get_chi_angles(protein):
dtype = protein["all_atom_mask"].dtype
protein["chi_angles_sin_cos"] = (
protein["torsion_angles_sin_cos"][..., 3:, :]
).to(dtype)
protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
return protein
@curry1
def random_crop_to_size(
protein,
crop_size,
max_templates,
shape_schema,
subsample_templates=False,
seed=None,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
seq_length = protein["seq_length"]
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = 0
# No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
num_res_crop_size = min(int(seq_length), crop_size)
def _randint(lower, upper):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=protein["seq_length"].device,
generator=g,
)[0])
if subsample_templates:
templates_crop_start = _randint(0, num_templates)
templates_select_indices = torch.randperm(
num_templates, device=protein["seq_length"].device, generator=g
)
else:
templates_crop_start = 0
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
)
n = seq_length - num_res_crop_size
if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
right_anchor = n
else:
x = _randint(0, n)
right_anchor = n - x
num_res_crop_start = _randint(0, right_anchor)
for k, v in protein.items():
if k not in shape_schema or (
"template" not in k and NUM_RES not in shape_schema[k]
):
continue
# randomly permute the templates before cropping them.
if k.startswith("template") and subsample_templates:
v = v[templates_select_indices]
slices = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
else:
crop_start = num_res_crop_start if is_num_res else 0
crop_size = num_res_crop_size if is_num_res else dim
slices.append(slice(crop_start, crop_start + crop_size))
protein[k] = v[slices]
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General-purpose errors used throughout the data pipeline"""
class Error(Exception):
"""Base class for exceptions."""
class MultipleChainsError(Error):
"""An error indicating that multiple chains were found for a given ID."""
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Mapping, Tuple, List, Optional, Dict, Sequence
import ml_collections
import numpy as np
import torch
from fastfold.data import input_pipeline
FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, torch.Tensor]
def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray],
features: Sequence[str],
) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
np_example: A dict of NumPy feature arrays.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features
}
return tensor_dict
def make_data_config(
config: ml_collections.ConfigDict,
mode: str,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
if cfg[mode].supervised:
feature_names += cfg.supervised.supervised_features
return cfg, feature_names
def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict,
mode: str,
):
np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int"
).astype(np.float32)
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
return {k: v for k, v in features.items()}
class FeaturePipeline:
def __init__(
self,
config: ml_collections.ConfigDict,
):
self.config = config
def process_features(
self,
raw_features: FeatureDict,
mode: str = "train",
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
config=self.config,
mode=mode,
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
from fastfold.data import data_transforms
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms.correct_msa_restypes,
data_transforms.squeeze_features,
data_transforms.randomly_replace_msa_with_unknown(0.0),
data_transforms.make_seq_mask,
data_transforms.make_msa_mask,
data_transforms.make_hhblits_profile,
]
if common_cfg.use_templates:
transforms.extend(
[
data_transforms.fix_templates_aatype,
data_transforms.make_template_mask,
data_transforms.make_pseudo_beta("template_"),
]
)
if common_cfg.use_template_torsion_angles:
transforms.extend(
[
data_transforms.atom37_to_torsion_angles("template_"),
]
)
transforms.extend(
[
data_transforms.make_atom14_masks,
]
)
if mode_cfg.supervised:
transforms.extend(
[
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
)
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
if "max_distillation_msa_clusters" in mode_cfg:
transforms.append(
data_transforms.sample_msa_distillation(
mode_cfg.max_distillation_msa_clusters
)
)
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append(
data_transforms.sample_msa(
max_msa_clusters,
keep_extra=True,
seed=msa_seed,
)
)
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms.make_masked_msa(
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
)
)
if common_cfg.msa_cluster_features:
transforms.append(data_transforms.nearest_neighbor_clusters())
transforms.append(data_transforms.summarize_clusters())
# Crop after creating the cluster profiles.
if max_extra_msa:
transforms.append(data_transforms.crop_extra_msa(max_extra_msa))
else:
transforms.append(data_transforms.delete_extra_msa)
transforms.append(data_transforms.make_msa_feat())
crop_feats = dict(common_cfg.feat)
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(
data_transforms.random_crop_to_size(
mode_cfg.crop_size,
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
seed=ensemble_seed + 1,
)
)
transforms.append(
data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parses the mmCIF file format."""
import collections
import dataclasses
import io
import json
import logging
import os
from typing import Any, Mapping, Optional, Sequence, Tuple
from Bio import PDB
from Bio.Data import SCOPData
import numpy as np
from fastfold.data.errors import MultipleChainsError
import fastfold.common.residue_constants as residue_constants
# Type aliases:
ChainId = str
PdbHeader = Mapping[str, Any]
PdbStructure = PDB.Structure.Structure
SeqRes = str
MmCIFDict = Mapping[str, Sequence[str]]
@dataclasses.dataclass(frozen=True)
class Monomer:
id: str
num: int
# Note - mmCIF format provides no guarantees on the type of author-assigned
# sequence numbers. They need not be integers.
@dataclasses.dataclass(frozen=True)
class AtomSite:
residue_name: str
author_chain_id: str
mmcif_chain_id: str
author_seq_num: str
mmcif_seq_num: int
insertion_code: str
hetatm_atom: str
model_num: int
# Used to map SEQRES index to a residue in the structure.
@dataclasses.dataclass(frozen=True)
class ResiduePosition:
chain_id: str
residue_number: int
insertion_code: str
@dataclasses.dataclass(frozen=True)
class ResidueAtPosition:
position: Optional[ResiduePosition]
name: str
is_missing: bool
hetflag: str
@dataclasses.dataclass(frozen=True)
class MmcifObject:
"""Representation of a parsed mmCIF file.
Contains:
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
files being processed.
header: Biopython header.
structure: Biopython structure.
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
{'A': 'ABCDEFG'}
seqres_to_structure: Dict; for each chain_id contains a mapping between
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
1: ResidueAtPosition,
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id: str
header: PdbHeader
structure: PdbStructure
chain_to_seqres: Mapping[ChainId, SeqRes]
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
raw_string: Any
@dataclasses.dataclass(frozen=True)
class ParsingResult:
"""Returned by the parse function.
Contains:
mmcif_object: A MmcifObject, may be None if no chain could be successfully
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]
class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""
def mmcif_loop_to_list(
prefix: str, parsed_info: MmCIFDict
) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
"""
cols = []
data = []
for key, value in parsed_info.items():
if key.startswith(prefix):
cols.append(key)
data.append(value)
assert all([len(xs) == len(data[0]) for xs in data]), (
"mmCIF error: Not all loops are the same length: %s" % cols
)
return [dict(zip(cols, xs)) for xs in zip(*data)]
def mmcif_loop_to_dict(
prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
index: Which item of loop data should serve as the key.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
entries = mmcif_loop_to_list(prefix, parsed_info)
return {entry[index]: entry for entry in entries}
def parse(
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True
) -> ParsingResult:
"""Entry point, parses an mmcif_string.
Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.
Returns:
A ParsingResult.
"""
errors = {}
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure("", handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
# Ensure all values are lists, even if singletons.
for key, value in parsed_info.items():
if not isinstance(value, list):
parsed_info[key] = [value]
header = _get_header(parsed_info)
# Determine the protein chains, and their start numbers according to the
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ""): "No protein chains found in this file."}
)
seq_start_num = {
chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()
}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
# the authors / Biopython).
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != "1":
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
if atom.mmcif_chain_id in valid_chains:
hetflag = " "
if atom.hetatm_atom == "HETATM":
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ("HOH", "WAT"):
hetflag = "W"
else:
hetflag = "H_" + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = " "
position = ResiduePosition(
chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num),
insertion_code=insertion_code,
)
seq_idx = (
int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
)
current = seq_to_structure_mappings.get(
atom.author_chain_id, {}
)
current[seq_idx] = ResidueAtPosition(
position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag,
)
seq_to_structure_mappings[atom.author_chain_id] = current
# Add missing residue information to seq_to_structure_mappings.
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(
position=None,
name=monomer.id,
is_missing=True,
hetflag=" ",
)
author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
seq.append(code if len(code) == 1 else "X")
seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq
mmcif_object = MmcifObject(
file_id=file_id,
header=header,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info,
)
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, "")] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)
def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
return next(structure.get_models())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date."""
revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
return min(revision_dates)
def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header = {}
experiments = mmcif_loop_to_list("_exptl.", parsed_info)
header["structure_method"] = ",".join(
[experiment["_exptl.method"].lower() for experiment in experiments]
)
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if "_pdbx_audit_revision_history.revision_date" in parsed_info:
header["release_date"] = get_release_date(parsed_info)
else:
logging.warning(
"Could not determine release_date: %s", parsed_info["_entry.id"]
)
header["resolution"] = 0.00
for res_key in (
"_refine.ls_d_res_high",
"_em_3d_reconstruction.resolution",
"_reflns.d_resolution_high",
):
if res_key in parsed_info:
try:
raw_resolution = parsed_info[res_key][0]
header["resolution"] = float(raw_resolution)
except ValueError:
logging.info(
"Invalid resolution format: %s", parsed_info[res_key]
)
return header
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure."""
return [
AtomSite(*site)
for site in zip( # pylint:disable=g-complex-comprehension
parsed_info["_atom_site.label_comp_id"],
parsed_info["_atom_site.auth_asym_id"],
parsed_info["_atom_site.label_asym_id"],
parsed_info["_atom_site.auth_seq_id"],
parsed_info["_atom_site.label_seq_id"],
parsed_info["_atom_site.pdbx_PDB_ins_code"],
parsed_info["_atom_site.group_PDB"],
parsed_info["_atom_site.pdbx_PDB_model_num"],
)
]
def _get_protein_chains(
*, parsed_info: Mapping[str, Any]
) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.
Args:
parsed_info: _mmcif_dict produced by the Biopython parser.
Returns:
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
Monomer(
id=entity_poly_seq["_entity_poly_seq.mon_id"],
num=int(entity_poly_seq["_entity_poly_seq.num"]),
)
)
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
chain_id = struct_asym["_struct_asym.id"]
entity_id = struct_asym["_struct_asym.entity_id"]
entity_to_mmcif_chains[entity_id].append(chain_id)
# Identify and return the valid protein chains.
valid_chains = {}
for entity_id, seq_info in polymers.items():
chain_ids = entity_to_mmcif_chains[entity_id]
# Reject polymers without any peptide-like components, such as DNA/RNA.
if any(
[
"peptide" in chem_comps[monomer.id]["_chem_comp.type"]
for monomer in seq_info
]
):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains
def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in (".", "?")
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f"Expected exactly one chain in structure with id {chain_id}."
)
chain = relevant_chains[0]
# Extract the coordinates
num_res = len(mmcif_object.chain_to_seqres[chain_id])
all_atom_positions = np.zeros(
[num_res, residue_constants.atom_type_num, 3], dtype=np.float32
)
all_atom_mask = np.zeros(
[num_res, residue_constants.atom_type_num], dtype=np.float32
)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing:
res = chain[
(
res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code,
)
]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
# Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
if _zero_center_positions:
binary_mask = all_atom_mask.astype(bool)
translation_vec = all_atom_positions[binary_mask].mean(axis=0)
all_atom_positions[binary_mask] -= translation_vec
return all_atom_positions, all_atom_mask
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for parsing various file formats."""
import collections
import dataclasses
import re
import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class TemplateHit:
"""Class representing a template hit."""
index: int
name: str
aligned_cols: int
sum_probs: float
query: str
hit_sequence: str
indices_query: List[int]
indices_hit: List[int]
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.
Arguments:
fasta_string: The string contents of a FASTA file.
Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
sequences = []
descriptions = []
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith(">"):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append("")
continue
elif not line:
continue # Skip blank lines.
sequences[index] += line
return sequences, descriptions
def parse_stockholm(
stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
stockholm_string: The string contents of a stockholm file. The first
sequence in the file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* The names of the targets matched, including the jackhmmer subsequence
suffix.
"""
name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines():
line = line.strip()
if not line or line.startswith(("#", "//")):
continue
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence[name] = ""
name_to_sequence[name] += sequence
msa = []
deletion_matrix = []
query = ""
keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0:
# Gather the columns with gaps from the query
query = sequence
keep_columns = [i for i, res in enumerate(query) if res != "-"]
# Remove the columns with gaps in the query from all sequences.
aligned_sequence = "".join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence)
# Count the number of deletions w.r.t. query.
deletion_vec = []
deletion_count = 0
for seq_res, query_res in zip(sequence, query):
if seq_res != "-" or query_res != "-":
if query_res == "-":
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)
return msa, deletion_matrix, list(name_to_sequence.keys())
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
a3m_string: The string contents of a a3m file. The first sequence in the
file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
"""
sequences, _ = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
deletion_vec = []
deletion_count = 0
for j in msa_sequence:
if j.islower():
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix
def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str
) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap:
yield sequence_res
elif sequence_res != "-":
yield sequence_res.lower()
def convert_stockholm_to_a3m(
stockholm_format: str, max_sequences: Optional[int] = None
) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False
for line in stockholm_format.splitlines():
reached_max_sequences = (
max_sequences and len(sequences) >= max_sequences
)
if line.strip() and not line.startswith(("#", "//")):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ""
sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines():
if line[:4] == "#=GS":
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ""
if feature != "DE":
continue
if reached_max_sequences and seqname not in sequences:
continue
descriptions[seqname] = value
if len(descriptions) == len(sequences):
break
# Convert sto format to a3m line by line
a3m_sequences = {}
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = "".join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
)
fasta_chunks = (
f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences
)
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _get_hhr_line_regex_groups(
regex_pattern: str, line: str
) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
if match is None:
raise RuntimeError(f"Could not parse query line {line}")
return match.groups()
def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]
):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index
for symbol in sequence:
if symbol == "-":
indices_list.append(-1)
else:
indices_list.append(counter)
counter += 1
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"""Parses the detailed HMM HMM comparison section for a single Hit.
This works on .hhr files generated from both HHBlits and HHSearch.
Args:
detailed_lines: A list of lines from a single comparison section between 2
sequences (which each have their own HMM's)
Returns:
A dictionary with the information from that detailed comparison section
Raises:
RuntimeError: If a certain line cannot be processed
"""
# Parse first 2 lines.
number_of_hit = int(detailed_lines[0].split()[-1])
name_hit = detailed_lines[1][1:]
# Parse the summary line.
pattern = (
"Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
" ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
"]*Template_Neff=(.*)"
)
match = re.match(pattern, detailed_lines[2])
if match is None:
raise RuntimeError(
"Could not parse section: %s. Expected this: \n%s to contain summary."
% (detailed_lines, detailed_lines[2])
)
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
float(x) for x in match.groups()
]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query = ""
hit_sequence = ""
indices_query = []
indices_hit = []
length_block = None
for line in detailed_lines[3:]:
# Parse the query sequence line
if (
line.startswith("Q ")
and not line.startswith("Q ss_dssp")
and not line.startswith("Q ss_pred")
and not line.startswith("Q Consensus")
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices,
# and ensure it is the same as the actual block length.
start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1]
end = int(groups[2])
num_insertions = len([x for x in delta_query if x == "-"])
length_block = end - start + num_insertions
assert length_block == len(delta_query)
# Update the query sequence and indices list.
query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith("T "):
# Parse the hit sequence.
if (
not line.startswith("T ss_dssp")
and not line.startswith("T ss_pred")
and not line.startswith("T Consensus")
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1]
assert length_block == len(delta_hit_sequence)
# Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(
delta_hit_sequence, start, indices_hit
)
return TemplateHit(
index=number_of_hit,
name=name_hit,
aligned_cols=int(aligned_cols),
sum_probs=sum_probs,
query=query,
hit_sequence=hit_sequence,
indices_query=indices_query,
indices_hit=indices_hit,
)
def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
"""Parses the content of an entire HHR file."""
lines = hhr_string.splitlines()
# Each .hhr file starts with a results table, then has a sequence of hit
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
hits = []
if block_starts:
block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1):
hits.append(
_parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
)
return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {"query": 0}
lines = [line for line in tblout.splitlines() if line[0] != "#"]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
for line in lines:
fields = line.split()
e_value = fields[4]
target_name = fields[0]
e_values[target_name] = float(e_value)
return e_values
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for getting templates and calculating template features."""
import dataclasses
import datetime
import glob
import json
import logging
import os
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
import numpy as np
from fastfold.data import parsers, mmcif_parsing
from fastfold.data.errors import Error
from fastfold.data.tools import kalign
from fastfold.data.tools.utils import to_date
from fastfold.common import residue_constants
class NoChainsError(Error):
"""An error indicating that template mmCIF didn't have any chains."""
class SequenceNotInTemplateError(Error):
"""An error indicating that template mmCIF didn't contain the sequence."""
class NoAtomDataInTemplateError(Error):
"""An error indicating that template mmCIF didn't contain atom positions."""
class TemplateAtomMaskAllZerosError(Error):
"""An error indicating that template mmCIF had all atom positions masked."""
class QueryToTemplateAlignError(Error):
"""An error indicating that the query can't be aligned to the template."""
class CaDistanceError(Error):
"""An error indicating that a CA atom distance exceeds a threshold."""
# Prefilter exceptions.
class PrefilterError(Exception):
"""A base class for template prefilter exceptions."""
class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
class PdbIdError(PrefilterError):
"""An error indicating that the hit PDB ID was identical to the query."""
class AlignRatioError(PrefilterError):
"""An error indicating that the hit align ratio to the query was too small."""
class DuplicateError(PrefilterError):
"""An error indicating that the hit was an exact subsequence of the query."""
class LengthError(PrefilterError):
"""An error indicating that the hit was too short."""
TEMPLATE_FEATURES = {
"template_aatype": np.int64,
"template_all_atom_mask": np.float32,
"template_all_atom_positions": np.float32,
"template_domain_names": np.object,
"template_sequence": np.object,
"template_sum_probs": np.float32,
}
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
if not id_match:
raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
pdb_id, chain_id = id_match.group(0).split("_")
return pdb_id.lower(), chain_id
def _is_after_cutoff(
pdb_id: str,
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: Optional[datetime.datetime],
) -> bool:
"""Checks if the template date is after the release date cutoff.
Args:
pdb_id: 4 letter pdb code.
release_dates: Dictionary mapping PDB ids to their structure release dates.
release_date_cutoff: Max release date that is valid for this query.
Returns:
True if the template release date is after the cutoff, False otherwise.
"""
pdb_id_upper = pdb_id.upper()
if release_date_cutoff is None:
raise ValueError("The release_date_cutoff must not be None.")
if pdb_id_upper in release_dates:
return release_dates[pdb_id_upper] > release_date_cutoff
else:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
logging.info(
"Template structure not in release dates dict: %s", pdb_id
)
return False
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
"""Parses the data file from PDB that lists which PDB ids are obsolete."""
with open(obsolete_file_path) as f:
result = {}
for line in f:
line = line.strip()
# We skip obsolete entries that don't contain a mapping to a new entry.
if line.startswith("OBSLTE") and len(line) > 30:
# Format: Date From To
# 'OBSLTE 31-JUL-94 116L 216L'
from_id = line[20:24].lower()
to_id = line[29:33].lower()
result[from_id] = to_id
return result
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
dates = {}
for f in os.listdir(mmcif_dir):
if f.endswith(".cif"):
path = os.path.join(mmcif_dir, f)
with open(path, "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
if mmcif.mmcif_object is None:
logging.info(f"Failed to parse {f}. Skipping...")
continue
mmcif = mmcif.mmcif_object
release_date = mmcif.header["release_date"]
dates[file_id] = release_date
with open(out_path, "r") as fp:
fp.write(json.dumps(dates))
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
with open(path, "r") as fp:
data = json.load(fp)
return {
pdb.upper(): to_date(v)
for pdb, d in data.items()
for k, v in d.items()
if k == "release_date"
}
def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
hit_pdb_code: str,
query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95,
min_align_ratio: float = 0.1,
) -> bool:
"""Determines if template is valid (without parsing the template mmcif file).
Args:
hit: HhrHit for the template.
hit_pdb_code: The 4 letter pdb code of the template hit. This might be
different from the value in the actual hit since the original pdb might
have become obsolete.
query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release
dates.
release_date_cutoff: Max release date that is valid for this query.
max_subsequence_ratio: Exclude any exact matches with this much overlap.
min_align_ratio: Minimum overlap between the template and query.
Returns:
True if the hit passed the prefilter. Raises an exception otherwise.
Raises:
DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
"""
aligned_cols = hit.aligned_cols
align_ratio = aligned_cols / len(query_sequence)
template_sequence = hit.hit_sequence.replace("-", "")
length_ratio = float(len(template_sequence)) / len(query_sequence)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate = (
template_sequence in query_sequence
and length_ratio > max_subsequence_ratio
)
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
date = release_dates[hit_pdb_code.upper()]
raise DateError(
f"Date ({date}) > max template date "
f"({release_date_cutoff})."
)
if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError("PDB code identical to Query PDB code.")
if align_ratio <= min_align_ratio:
raise AlignRatioError(
"Proportion of residues aligned to query too small. "
f"Align ratio: {align_ratio}."
)
if duplicate:
raise DuplicateError(
"Template is an exact subsequence of query with large "
f"coverage. Length ratio: {length_ratio}."
)
if len(template_sequence) < 10:
raise LengthError(
f"Template too short. Length: {len(template_sequence)}."
)
return True
def _find_template_in_pdb(
template_chain_id: str,
template_sequence: str,
mmcif_object: mmcif_parsing.MmcifObject,
) -> Tuple[str, str, int]:
"""Tries to find the template chain in the given pdb file.
This method tries the three following things in order:
1. Tries if there is an exact match in both the chain ID and the sequence.
If yes, the chain sequence is returned. Otherwise:
2. Tries if there is an exact match only in the sequence.
If yes, the chain sequence is returned. Otherwise:
3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
If yes, the chain sequence is returned.
If none of these succeed, a SequenceNotInTemplateError is thrown.
Args:
template_chain_id: The template chain ID.
template_sequence: The template chain sequence.
mmcif_object: The PDB object to search for the template in.
Returns:
A tuple with:
* The chain sequence that was found to match the template in the PDB object.
* The ID of the chain that is being returned.
* The offset where the template sequence starts in the chain sequence.
Raises:
SequenceNotInTemplateError: If no match is found after the steps described
above.
"""
# Try if there is an exact match in both the chain ID and the (sub)sequence.
pdb_id = mmcif_object.file_id
chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
if chain_sequence and (template_sequence in chain_sequence):
logging.info(
"Found an exact template match %s_%s.", pdb_id, template_chain_id
)
mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, template_chain_id, mapping_offset
# Try if there is an exact match in the (sub)sequence only.
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
if chain_sequence and (template_sequence in chain_sequence):
logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, chain_id, mapping_offset
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
regex = re.compile("".join(regex))
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
match = re.search(regex, chain_sequence)
if match:
logging.info(
"Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
)
mapping_offset = match.start()
return chain_sequence, chain_id, mapping_offset
# No hits, raise an error.
raise SequenceNotInTemplateError(
"Could not find the template sequence in %s_%s. Template sequence: %s, "
"chain_to_seqres: %s"
% (
pdb_id,
template_chain_id,
template_sequence,
mmcif_object.chain_to_seqres,
)
)
def _realign_pdb_template_to_query(
old_template_sequence: str,
template_chain_id: str,
mmcif_object: mmcif_parsing.MmcifObject,
old_mapping: Mapping[int, int],
kalign_binary_path: str,
) -> Tuple[str, Mapping[int, int]]:
"""Aligns template from the mmcif_object to the query.
In case PDB70 contains a different version of the template sequence, we need
to perform a realignment to the actual sequence that is in the mmCIF file.
This method performs such realignment, but returns the new sequence and
mapping only if the sequence in the mmCIF file is 90% identical to the old
sequence.
Note that the old_template_sequence comes from the hit, and contains only that
part of the chain that matches with the query while the new_template_sequence
is the full chain.
Args:
old_template_sequence: The template sequence that was returned by the PDB
template search (typically done using HHSearch).
template_chain_id: The template chain id was returned by the PDB template
search (typically done using HHSearch). This is used to find the right
chain in the mmcif_object chain_to_seqres mapping.
mmcif_object: A mmcif_object which holds the actual template data.
old_mapping: A mapping from the query sequence to the template sequence.
This mapping will be used to compute the new mapping from the query
sequence to the actual mmcif_object template sequence by aligning the
old_template_sequence and the actual template sequence.
kalign_binary_path: The path to a kalign executable.
Returns:
A tuple (new_template_sequence, new_query_to_template_mapping) where:
* new_template_sequence is the actual template sequence that was found in
the mmcif_object.
* new_query_to_template_mapping is the new mapping from the query to the
actual template found in the mmcif_object.
Raises:
QueryToTemplateAlignError:
* If there was an error thrown by the alignment tool.
* Or if the actual template sequence differs by more than 10% from the
old_template_sequence.
"""
aligner = kalign.Kalign(binary_path=kalign_binary_path)
new_template_sequence = mmcif_object.chain_to_seqres.get(
template_chain_id, ""
)
# Sometimes the template chain id is unknown. But if there is only a single
# sequence within the mmcif_object, it is safe to assume it is that one.
if not new_template_sequence:
if len(mmcif_object.chain_to_seqres) == 1:
logging.info(
"Could not find %s in %s, but there is only 1 sequence, so "
"using that one.",
template_chain_id,
mmcif_object.file_id,
)
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
0
]
else:
raise QueryToTemplateAlignError(
f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
"If there are no mmCIF parsing errors, it is possible it was not a "
"protein chain."
)
try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence])
)
except Exception as e:
raise QueryToTemplateAlignError(
"Could not align old template %s to template %s (%s_%s). Error: %s"
% (
old_template_sequence,
new_template_sequence,
mmcif_object.file_id,
template_chain_id,
str(e),
)
)
logging.info(
"Old aligned template: %s\nNew aligned template: %s",
old_aligned_template,
new_aligned_template,
)
old_to_new_template_mapping = {}
old_template_index = -1
new_template_index = -1
num_same = 0
for old_template_aa, new_template_aa in zip(
old_aligned_template, new_aligned_template
):
if old_template_aa != "-":
old_template_index += 1
if new_template_aa != "-":
new_template_index += 1
if old_template_aa != "-" and new_template_aa != "-":
old_to_new_template_mapping[old_template_index] = new_template_index
if old_template_aa == new_template_aa:
num_same += 1
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
if (
float(num_same)
/ min(len(old_template_sequence), len(new_template_sequence))
< 0.9
):
raise QueryToTemplateAlignError(
"Insufficient similarity of the sequence in the database: %s to the "
"actual sequence in the mmCIF file %s_%s: %s. We require at least "
"90 %% similarity wrt to the shorter of the sequences. This is not a "
"problem unless you think this is a template that should be included."
% (
old_template_sequence,
mmcif_object.file_id,
template_chain_id,
new_template_sequence,
)
)
new_query_to_template_mapping = {}
for query_index, old_template_index in old_mapping.items():
new_query_to_template_mapping[
query_index
] = old_to_new_template_mapping.get(old_template_index, -1)
new_template_sequence = new_template_sequence.replace("-", "")
return new_template_sequence, new_query_to_template_mapping
def _check_residue_distances(
all_positions: np.ndarray,
all_positions_mask: np.ndarray,
max_ca_ca_distance: float,
):
"""Checks if the distance between unmasked neighbor residues is ok."""
ca_position = residue_constants.atom_order["CA"]
prev_is_unmasked = False
prev_calpha = None
for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
this_is_unmasked = bool(mask[ca_position])
if this_is_unmasked:
this_calpha = coords[ca_position]
if prev_is_unmasked:
distance = np.linalg.norm(this_calpha - prev_calpha)
if distance > max_ca_ca_distance:
raise CaDistanceError(
"The distance between residues %d and %d is %f > limit %f."
% (i, i + 1, distance, max_ca_ca_distance)
)
prev_calpha = this_calpha
prev_is_unmasked = this_is_unmasked
def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str,
max_ca_ca_distance: float,
_zero_center_positions: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object,
chain_id=auth_chain_id,
_zero_center_positions=_zero_center_positions,
)
all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances(
all_atom_positions, all_atom_mask, max_ca_ca_distance
)
return all_atom_positions, all_atom_mask
def _extract_template_features(
mmcif_object: mmcif_parsing.MmcifObject,
pdb_id: str,
mapping: Mapping[int, int],
template_sequence: str,
query_sequence: str,
template_chain_id: str,
kalign_binary_path: str,
_zero_center_positions: bool = True,
) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query.
Atoms for each residue in the template structure are indexed to coincide
with their corresponding residue in the query sequence, according to the
alignment mapping provided.
Args:
mmcif_object: mmcif_parsing.MmcifObject representing the template.
pdb_id: PDB code for the template.
mapping: Dictionary mapping indices in the query sequence to indices in
the template sequence.
template_sequence: String describing the amino acid sequence for the
template protein.
query_sequence: String describing the amino acid sequence for the query
protein.
template_chain_id: String ID describing which chain in the structure proto
should be used.
kalign_binary_path: The path to a kalign executable used for template
realignment.
Returns:
A tuple with:
* A dictionary containing the extra features derived from the template
protein structure.
* A warning message if the hit was realigned to the actual mmCIF sequence.
Otherwise None.
Raises:
NoChainsError: If the mmcif object doesn't contain any chains.
SequenceNotInTemplateError: If the given chain id / sequence can't
be found in the mmcif object.
QueryToTemplateAlignError: If the actual template in the mmCIF file
can't be aligned to the query.
NoAtomDataInTemplateError: If the mmcif object doesn't contain
atom positions.
TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
unmasked residues.
"""
if mmcif_object is None or not mmcif_object.chain_to_seqres:
raise NoChainsError(
"No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
)
warning = None
try:
seqres, chain_id, mapping_offset = _find_template_in_pdb(
template_chain_id=template_chain_id,
template_sequence=template_sequence,
mmcif_object=mmcif_object,
)
except SequenceNotInTemplateError:
# If PDB70 contains a different version of the template, we use the sequence
# from the mmcif_object.
chain_id = template_chain_id
warning = (
f"The exact sequence {template_sequence} was not found in "
f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
)
logging.warning(warning)
# This throws an exception if it fails to realign the hit.
seqres, mapping = _realign_pdb_template_to_query(
old_template_sequence=template_sequence,
template_chain_id=template_chain_id,
mmcif_object=mmcif_object,
old_mapping=mapping,
kalign_binary_path=kalign_binary_path,
)
logging.info(
"Sequence in %s_%s: %s successfully realigned to %s",
pdb_id,
chain_id,
template_sequence,
seqres,
)
# The template sequence changed.
template_sequence = seqres
# No mapping offset, the query is aligned to the actual sequence.
mapping_offset = 0
try:
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object,
chain_id,
max_ca_ca_distance=150.0,
_zero_center_positions=_zero_center_positions,
)
except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError(
"Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
) from ex
all_atom_positions = np.split(
all_atom_positions, all_atom_positions.shape[0]
)
all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
output_templates_sequence = []
templates_all_atom_positions = []
templates_all_atom_masks = []
for _ in query_sequence:
# Residues in the query_sequence that are not in the template_sequence:
templates_all_atom_positions.append(
np.zeros((residue_constants.atom_type_num, 3))
)
templates_all_atom_masks.append(
np.zeros(residue_constants.atom_type_num)
)
output_templates_sequence.append("-")
for k, v in mapping.items():
template_index = v + mapping_offset
templates_all_atom_positions[k] = all_atom_positions[template_index][0]
templates_all_atom_masks[k] = all_atom_masks[template_index][0]
output_templates_sequence[k] = template_sequence[v]
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
if np.sum(templates_all_atom_masks) < 5:
raise TemplateAtomMaskAllZerosError(
"Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
% (
pdb_id,
chain_id,
min(mapping.values()) + mapping_offset,
max(mapping.values()) + mapping_offset,
)
)
output_templates_sequence = "".join(output_templates_sequence)
templates_aatype = residue_constants.sequence_to_onehot(
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
)
return (
{
"template_all_atom_positions": np.array(
templates_all_atom_positions
),
"template_all_atom_mask": np.array(templates_all_atom_masks),
"template_sequence": output_templates_sequence.encode(),
"template_aatype": np.array(templates_aatype),
"template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
},
warning,
)
def _build_query_to_hit_index_mapping(
hit_query_sequence: str,
hit_sequence: str,
indices_hit: Sequence[int],
indices_query: Sequence[int],
original_query_sequence: str,
) -> Mapping[int, int]:
"""Gets mapping from indices in original query sequence to indices in the hit.
hit_query_sequence and hit_sequence are two aligned sequences containing gap
characters. hit_query_sequence contains only the part of the original query
sequence that matched the hit. When interpreting the indices from the .hhr, we
need to correct for this to recover a mapping from original query sequence to
the hit sequence.
Args:
hit_query_sequence: The portion of the query sequence that is in the .hhr
hit
hit_sequence: The portion of the hit sequence that is in the .hhr
indices_hit: The indices for each aminoacid relative to the hit sequence
indices_query: The indices for each aminoacid relative to the original query
sequence
original_query_sequence: String describing the original query sequence.
Returns:
Dictionary with indices in the original query sequence as keys and indices
in the hit sequence as values.
"""
# If the hit is empty (no aligned residues), return empty mapping
if not hit_query_sequence:
return {}
# Remove gaps and find the offset of hit.query relative to original query.
hhsearch_query_sequence = hit_query_sequence.replace("-", "")
hit_sequence = hit_sequence.replace("-", "")
hhsearch_query_offset = original_query_sequence.find(
hhsearch_query_sequence
)
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
min_idx = min(x for x in indices_hit if x > -1)
fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
min_idx = min(x for x in indices_query if x > -1)
fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
# Zip the corrected indices, ignore case where both seqs have gap characters.
mapping = {}
for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
if q_t != -1 and q_i != -1:
if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
original_query_sequence
):
continue
mapping[q_i + hhsearch_query_offset] = q_t
return mapping
@dataclasses.dataclass(frozen=True)
class PrefilterResult:
valid: bool
error: Optional[str]
warning: Optional[str]
@dataclasses.dataclass(frozen=True)
class SingleHitResult:
features: Optional[Mapping[str, Any]]
error: Optional[str]
warning: Optional[str]
def _prefilter_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str],
strict_error_check: bool = False,
):
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
if hit_pdb_code not in release_dates:
if hit_pdb_code in obsolete_pdbs:
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
# Pass hit_pdb_code since it might have changed due to the pdb being
# obsolete.
try:
_assess_hhsearch_hit(
hit=hit,
hit_pdb_code=hit_pdb_code,
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
release_dates=release_dates,
release_date_cutoff=max_template_date,
)
except PrefilterError as e:
hit_name = f"{hit_pdb_code}_{hit_chain_id}"
msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg)
if strict_error_check and isinstance(
e, (DateError, PdbIdError, DuplicateError)
):
# In strict mode we treat some prefilter cases as errors.
return PrefilterResult(valid=False, error=msg, warning=None)
return PrefilterResult(valid=False, error=None, warning=None)
return PrefilterResult(valid=True, error=None, warning=None)
def _process_single_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str,
strict_error_check: bool = False,
_zero_center_positions: bool = True,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
if hit_pdb_code not in release_dates:
if hit_pdb_code in obsolete_pdbs:
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
mapping = _build_query_to_hit_index_mapping(
hit.query,
hit.hit_sequence,
hit.indices_hit,
hit.indices_query,
query_sequence,
)
# The mapping is from the query to the actual hit sequence, so we need to
# remove gaps (which regardless have a missing confidence score).
template_sequence = hit.hit_sequence.replace("-", "")
cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
logging.info(
"Reading PDB entry from %s. Query: %s, template: %s",
cif_path,
query_sequence,
template_sequence,
)
# Fail if we can't find the mmCIF file.
with open(cif_path, "r") as cif_file:
cif_string = cif_file.read()
parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string
)
if parsing_result.mmcif_object is not None:
hit_release_date = datetime.datetime.strptime(
parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
)
if hit_release_date > max_template_date:
error = "Template %s date (%s) > max template date (%s)." % (
hit_pdb_code,
hit_release_date,
max_template_date,
)
if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None)
else:
logging.info(error)
return SingleHitResult(features=None, error=None, warning=None)
try:
features, realign_warning = _extract_template_features(
mmcif_object=parsing_result.mmcif_object,
pdb_id=hit_pdb_code,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions,
)
features["template_sum_probs"] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# computed. In such case the mmCIF parsing errors are not relevant.
return SingleHitResult(
features=features, error=None, warning=realign_warning
)
except (
NoChainsError,
NoAtomDataInTemplateError,
TemplateAtomMaskAllZerosError,
) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
warning = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
"%s, mmCIF parsing errors: %s"
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
if strict_error_check:
return SingleHitResult(features=None, error=warning, warning=None)
else:
return SingleHitResult(features=None, error=None, warning=warning)
except Error as e:
error = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
"%s, mmCIF parsing errors: %s"
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
return SingleHitResult(features=None, error=error, warning=None)
@dataclasses.dataclass(frozen=True)
class TemplateSearchResult:
features: Mapping[str, Any]
errors: Sequence[str]
warnings: Sequence[str]
class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features."""
def __init__(
self,
mmcif_dir: str,
max_template_date: str,
max_hits: int,
kalign_binary_path: str,
release_dates_path: Optional[str] = None,
obsolete_pdbs_path: Optional[str] = None,
strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None,
_zero_center_positions: bool = True,
):
"""Initializes the Template Search.
Args:
mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
is found by HHSearch, this directory is used to retrieve the template
data.
max_template_date: The maximum date permitted for template structures. No
template with date higher than this date will be returned. In ISO8601
date format, YYYY-MM-DD.
max_hits: The maximum number of templates that will be returned.
kalign_binary_path: The path to a kalign executable used for template
realignment.
release_dates_path: An optional path to a file with a mapping from PDB IDs
to their release dates. Thanks to this we don't have to redundantly
parse mmCIF files to get that information.
obsolete_pdbs_path: An optional path to a file containing a mapping from
obsolete PDB IDs to the PDB IDs of their replacements.
strict_error_check: If True, then the following will be treated as errors:
* If any template date is after the max_template_date.
* If any template has identical PDB ID to the query.
* If any template is a duplicate of the query.
* Any feature computation errors.
"""
self._mmcif_dir = mmcif_dir
if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
logging.error("Could not find CIFs in %s", self._mmcif_dir)
raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
try:
self._max_template_date = datetime.datetime.strptime(
max_template_date, "%Y-%m-%d"
)
except ValueError:
raise ValueError(
"max_template_date must be set and have format YYYY-MM-DD."
)
self.max_hits = max_hits
self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check
if release_dates_path:
logging.info(
"Using precomputed release dates %s.", release_dates_path
)
self._release_dates = _parse_release_dates(release_dates_path)
else:
self._release_dates = {}
if obsolete_pdbs_path:
logging.info(
"Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
)
self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
else:
self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions
def get_templates(
self,
query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Searching for template for: %s", query_pdb_code)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
# Always use a max_template_date. Set to query_release_date minus 60 days
# if that's earlier.
template_cutoff_date = self._max_template_date
if query_release_date:
delta = datetime.timedelta(days=60)
if query_release_date - delta < template_cutoff_date:
template_cutoff_date = query_release_date - delta
assert template_cutoff_date < query_release_date
assert template_cutoff_date <= self._max_template_date
num_hits = 0
errors = []
warnings = []
filtered = []
for hit in hits:
prefilter_result = _prefilter_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
max_template_date=template_cutoff_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
)
if prefilter_result.error:
errors.append(prefilter_result.error)
if prefilter_result.warning:
warnings.append(prefilter_result.warning)
if prefilter_result.valid:
filtered.append(hit)
filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
)
idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered
idx[:stk] = np.random.permutation(idx[:stk])
for i in idx:
# We got all the templates we wanted, stop processing hits.
if num_hits >= self.max_hits:
break
hit = filtered[i]
result = _process_single_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date=template_cutoff_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path,
_zero_center_positions=self._zero_center_positions,
)
if result.error:
errors.append(result.error)
# There could be an error even if there are some results, e.g. thrown by
# other unparsable chains in the same mmCIF file.
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.info(
"Skipped invalid hit %s, error: %s, warning: %s",
hit.name,
result.error,
result.warning,
)
else:
# Increment the hit counter, since we got features out of this hit.
num_hits += 1
for k in template_features:
template_features[k].append(result.features[k])
for name in template_features:
if num_hits > 0:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
else:
# Make sure the feature has correct dtype even if empty.
template_features[name] = np.array(
[], dtype=TEMPLATE_FEATURES[name]
)
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHblits from Python."""
import glob
import logging
import os
import subprocess
from typing import Any, Mapping, Optional, Sequence
from fastfold.data.tools import utils
_HHBLITS_DEFAULT_P = 20
_HHBLITS_DEFAULT_Z = 500
class HHBlits:
"""Python wrapper of the HHblits binary."""
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 4,
n_iter: int = 3,
e_value: float = 0.001,
maxseq: int = 1_000_000,
realign_max: int = 100_000,
maxfilt: int = 100_000,
min_prefilter_hits: int = 1000,
all_seqs: bool = False,
alt: Optional[int] = None,
p: int = _HHBLITS_DEFAULT_P,
z: int = _HHBLITS_DEFAULT_Z,
):
"""Initializes the Python HHblits wrapper.
Args:
binary_path: The path to the HHblits executable.
databases: A sequence of HHblits database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to give HHblits.
n_iter: The number of HHblits iterations.
e_value: The E-value, see HHblits docs for more details.
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
HHblits default: 20000.
min_prefilter_hits: Min number of hits to pass prefilter.
HHblits default: 100.
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
HHblits default: False.
alt: Show up to this many alternative alignments.
p: Minimum Prob for a hit to be included in the output hhr file.
HHblits default: 20.
z: Hard cap on number of hits reported in the hhr file.
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
Raises:
RuntimeError: If HHblits binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
for database_path in self.databases:
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHBlits database %s", database_path
)
raise ValueError(
f"Could not find HHBlits database {database_path}"
)
self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.maxseq = maxseq
self.realign_max = realign_max
self.maxfilt = maxfilt
self.min_prefilter_hits = min_prefilter_hits
self.all_seqs = all_seqs
self.alt = alt
self.p = p
self.z = z
def query(self, input_fasta_path: str) -> Mapping[str, Any]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, "output.a3m")
db_cmd = []
for db_path in self.databases:
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [
self.binary_path,
"-i",
input_fasta_path,
"-cpu",
str(self.n_cpu),
"-oa3m",
a3m_path,
"-o",
"/dev/null",
"-n",
str(self.n_iter),
"-e",
str(self.e_value),
"-maxseq",
str(self.maxseq),
"-realign_max",
str(self.realign_max),
"-maxfilt",
str(self.maxfilt),
"-min_prefilter_hits",
str(self.min_prefilter_hits),
]
if self.all_seqs:
cmd += ["-all"]
if self.alt:
cmd += ["-alt", str(self.alt)]
if self.p != _HHBLITS_DEFAULT_P:
cmd += ["-p", str(self.p)]
if self.z != _HHBLITS_DEFAULT_Z:
cmd += ["-Z", str(self.z)]
cmd += db_cmd
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("HHblits query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Logs have a 15k character limit, so log HHblits error line by line.
logging.error("HHblits failed. HHblits stderr begin:")
for error_line in stderr.decode("utf-8").splitlines():
if error_line.strip():
logging.error(error_line.strip())
logging.error("HHblits stderr end")
raise RuntimeError(
"HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
)
with open(a3m_path) as f:
a3m = f.read()
raw_output = dict(
a3m=a3m,
output=stdout,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value,
)
return raw_output
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHsearch from Python."""
import glob
import logging
import os
import subprocess
from typing import Sequence
from fastfold.data.tools import utils
class HHSearch:
"""Python wrapper of the HHsearch binary."""
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 2,
maxseq: int = 1_000_000,
):
"""Initializes the Python HHsearch wrapper.
Args:
binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
Raises:
RuntimeError: If HHsearch binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
self.n_cpu = n_cpu
self.maxseq = maxseq
for database_path in self.databases:
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHsearch database %s", database_path
)
raise ValueError(
f"Could not find HHsearch database {database_path}"
)
def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, "w") as f:
f.write(a3m)
db_cmd = []
for db_path in self.databases:
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [
self.binary_path,
"-i",
input_path,
"-o",
hhr_path,
"-maxseq",
str(self.maxseq),
"-cpu",
str(self.n_cpu),
] + db_cmd
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("HHsearch query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError(
"HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
)
with open(hhr_path) as f:
hhr = f.read()
return hhr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment