Commit 1109480e authored by Augustin-Zidek's avatar Augustin-Zidek
Browse files

Initial release of AlphaFold.

PiperOrigin-RevId: 384954738
parents
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for data pipeline tools."""
import contextlib
import shutil
import tempfile
import time
from typing import Optional
from absl import logging
@contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir)
try:
yield tmpdir
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
@contextlib.contextmanager
def timing(msg: str):
logging.info('Started %s', msg)
tic = time.time()
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alphafold model."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations.
Generally we employ two different representations for all atom coordinates,
one is atom37 where each heavy atom corresponds to a given position in a 37
dimensional array, This mapping is non amino acid specific, but each slot
corresponds to an atom of a given name, for example slot 12 always corresponds
to 'C delta 1', positions that are not present for a given amino acid are
zeroed out and denoted by a mask.
The other representation we employ is called atom14, this is a more dense way
of representing atoms with 14 slots. Here a given slot will correspond to a
different kind of atom depending on amino acid type, for example slot 5
corresponds to 'N delta 2' for Aspargine, but to 'C delta 1' for Isoleucine.
14 is chosen because it is the maximum number of heavy atoms for any standard
amino acid.
The order of slots can be found in 'residue_constants.residue_atoms'.
Internally the model uses the atom14 representation because it is
computationally more efficient.
The internal atom14 representation is turned into the atom37 at the output of
the network to facilitate easier conversion to existing protein datastructures.
"""
from typing import Dict, Optional
import jax
import jax.numpy as jnp
import numpy as np
from alphafold.common import residue_constants
from alphafold.model import r3
from alphafold.model import utils
def squared_difference(x, y):
return jnp.square(x - y)
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in residue_constants.restypes:
residue_name = residue_constants.restype_1to3[residue_name]
residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[residue_constants.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return jnp.asarray(chi_atom_indices)
def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...)
batch: Dict[str, jnp.ndarray]
) -> jnp.ndarray: # (N, 37, ...)
"""Convert atom14 to atom37 representation."""
assert len(atom14_data.shape) in [2, 3]
assert 'residx_atom37_to_atom14' in batch
assert 'atom37_atom_exists' in batch
atom37_data = utils.batched_gather(atom14_data,
batch['residx_atom37_to_atom14'],
batch_dims=1)
if len(atom14_data.shape) == 2:
atom37_data *= batch['atom37_atom_exists']
elif len(atom14_data.shape) == 3:
atom37_data *= batch['atom37_atom_exists'][:, :,
None].astype(atom37_data.dtype)
return atom37_data
def atom37_to_atom14(
atom37_data: jnp.ndarray, # (N, 37, ...)
batch: Dict[str, jnp.ndarray]) -> jnp.ndarray: # (N, 14, ...)
"""Convert atom14 to atom37 representation."""
assert len(atom37_data.shape) in [2, 3]
assert 'residx_atom14_to_atom37' in batch
assert 'atom14_atom_exists' in batch
atom14_data = utils.batched_gather(atom37_data,
batch['residx_atom14_to_atom37'],
batch_dims=1)
if len(atom37_data.shape) == 2:
atom14_data *= batch['atom14_atom_exists'].astype(atom14_data.dtype)
elif len(atom37_data.shape) == 3:
atom14_data *= batch['atom14_atom_exists'][:, :,
None].astype(atom14_data.dtype)
return atom14_data
def atom37_to_frames(
aatype: jnp.ndarray, # (...)
all_atom_positions: jnp.ndarray, # (..., 37, 3)
all_atom_mask: jnp.ndarray, # (..., 37)
) -> Dict[str, jnp.ndarray]:
"""Computes the frames for the up to 8 rigid groups for each residue.
The rigid groups are defined by the possible torsions in a given amino acid.
We group the atoms according to their dependence on the torsion angles into
"rigid groups". E.g., the position of atoms in the chi2-group depend on
chi1 and chi2, but do not depend on chi3 or chi4.
Jumper et al. (2021) Suppl. Table 2 and corresponding text.
Args:
aatype: Amino acid type, given as array with integers.
all_atom_positions: atom37 representation of all atom coordinates.
all_atom_mask: atom37 representation of mask on all atom coordinates.
Returns:
Dictionary containing:
* 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions'
represented as flat 12 dimensional array.
* 'rigidgroups_gt_exists': Mask denoting whether the atom positions for
the given frame are available in the ground truth, e.g. if they were
resolved in the experiment.
* 'rigidgroups_group_exists': Mask denoting whether given group is in
principle present for given amino acid type.
* 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is
affected by naming ambiguity.
* 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming
corresponding to 'all_atom_positions' represented as flat
12 dimensional array.
"""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
aatype_in_shape = aatype.shape
# If there is a batch axis, just flatten it away, and reshape everything
# back at the end of the function.
aatype = jnp.reshape(aatype, [-1])
all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3])
all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37])
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(residue_constants.restypes):
resname = residue_constants.restype_1to3[restype_letter]
for chi_idx in range(4):
if residue_constants.chi_angles_mask[restype][chi_idx]:
atom_names = residue_constants.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :] = atom_names[1:]
# Create mask for existing rigid groups.
restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32)
restype_rigidgroup_mask[:, 0] = 1
restype_rigidgroup_mask[:, 3] = 1
restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask
# Translate atom names into atom37 indices.
lookuptable = residue_constants.atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
restype_rigidgroup_base_atom_names)
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx = utils.batched_gather(
restype_rigidgroup_base_atom37_idx, aatype)
# Gather the base atom positions for each rigid group.
base_atom_pos = utils.batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
batch_dims=1)
# Compute the Rigids.
gt_frames = r3.rigids_from_3_points(
point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]),
origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]),
point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :])
)
# Compute a mask whether the group exists.
# (N, 8)
group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3)
all_atom_mask.astype(jnp.float32),
residx_rigidgroup_base_atom37_idx,
batch_dims=1)
gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
rots[0, 0, 0] = -1
rots[0, 2, 2] = -1
gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots))
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1])
for resname, _ in residue_constants.residue_atom_renaming_swaps.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]]
chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous = utils.batched_gather(
restype_rigidgroup_is_ambiguous, aatype)
residx_rigidgroup_ambiguity_rot = utils.batched_gather(
restype_rigidgroup_rots, aatype)
# Create the alternative ground truth frames.
alt_gt_frames = r3.rigids_mul_rots(
gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot))
gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames)
alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames)
# reshape back to original residue layout
gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8,))
group_exists = jnp.reshape(group_exists, aatype_in_shape + (8,))
gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
residx_rigidgroup_is_ambiguous = jnp.reshape(residx_rigidgroup_is_ambiguous,
aatype_in_shape + (8,))
alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12,
aatype_in_shape + (8, 12,))
return {
'rigidgroups_gt_frames': gt_frames_flat12, # (..., 8, 12)
'rigidgroups_gt_exists': gt_exists, # (..., 8)
'rigidgroups_group_exists': group_exists, # (..., 8)
'rigidgroups_group_is_ambiguous':
residx_rigidgroup_is_ambiguous, # (..., 8)
'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # (..., 8, 12)
}
def atom37_to_torsion_angles(
aatype: jnp.ndarray, # (B, N)
all_atom_pos: jnp.ndarray, # (B, N, 37, 3)
all_atom_mask: jnp.ndarray, # (B, N, 37)
placeholder_for_undefined=False,
) -> Dict[str, jnp.ndarray]:
"""Computes the 7 torsion angles (in sin, cos encoding) for each residue.
The 7 torsion angles are in the order
'[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]',
here pre_omega denotes the omega torsion angle between the given amino acid
and the previous amino acid.
Args:
aatype: Amino acid type, given as array with integers.
all_atom_pos: atom37 representation of all atom coordinates.
all_atom_mask: atom37 representation of mask on all atom coordinates.
placeholder_for_undefined: flag denoting whether to set masked torsion
angles to zero.
Returns:
Dict containing:
* 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final
2 dimensions denote sin and cos respectively
* 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but
with the angle shifted by pi for all chi angles affected by the naming
ambiguities.
* 'torsion_angles_mask': Mask for which chi angles are present.
"""
# Map aatype > 20 to 'Unknown' (20).
aatype = jnp.minimum(aatype, 20)
# Compute the backbone angles.
num_batch, num_res = aatype.shape
pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32)
prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1)
pad = jnp.zeros([num_batch, 1, 37], jnp.float32)
prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1)
# For each torsion angle collect the 4 atom positions that define this angle.
# shape (B, N, atoms=4, xyz=3)
pre_omega_atom_pos = jnp.concatenate(
[prev_all_atom_pos[:, :, 1:3, :], # prev CA, C
all_atom_pos[:, :, 0:2, :] # this N, CA
], axis=-2)
phi_atom_pos = jnp.concatenate(
[prev_all_atom_pos[:, :, 2:3, :], # prev C
all_atom_pos[:, :, 0:3, :] # this N, CA, C
], axis=-2)
psi_atom_pos = jnp.concatenate(
[all_atom_pos[:, :, 0:3, :], # this N, CA, C
all_atom_pos[:, :, 4:5, :] # this O
], axis=-2)
# Collect the masks from these atoms.
# Shape [batch, num_res]
pre_omega_mask = (
jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1) # prev CA, C
* jnp.prod(all_atom_mask[:, :, 0:2], axis=-1)) # this N, CA
phi_mask = (
prev_all_atom_mask[:, :, 2] # prev C
* jnp.prod(all_atom_mask[:, :, 0:3], axis=-1)) # this N, CA, C
psi_mask = (
jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) * # this N, CA, C
all_atom_mask[:, :, 4]) # this O
# Collect the atoms for the chi-angles.
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices = get_chi_atom_indices()
# Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4].
atom_indices = utils.batched_gather(
params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0)
# Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3].
chis_atom_pos = utils.batched_gather(
params=all_atom_pos, indices=atom_indices, axis=-2,
batch_dims=2)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask = list(residue_constants.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = jnp.asarray(chi_angles_mask)
# Compute the chi angle mask. I.e. which chis angles exist according to the
# aatype. Shape [batch, num_res, chis=4].
chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype,
axis=0, batch_dims=0)
# Constrain the chis_mask to those chis, where the ground truth coordinates of
# all defining four atoms are available.
# Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4].
chi_angle_atoms_mask = utils.batched_gather(
params=all_atom_mask, indices=atom_indices, axis=-1,
batch_dims=2)
# Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4].
chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1])
chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32)
# Stack all torsion angle atom positions.
# Shape (B, N, torsions=7, atoms=4, xyz=3)
torsions_atom_pos = jnp.concatenate(
[pre_omega_atom_pos[:, :, None, :, :],
phi_atom_pos[:, :, None, :, :],
psi_atom_pos[:, :, None, :, :],
chis_atom_pos
], axis=2)
# Stack up masks for all torsion angles.
# shape (B, N, torsions=7)
torsion_angles_mask = jnp.concatenate(
[pre_omega_mask[:, :, None],
phi_mask[:, :, None],
psi_mask[:, :, None],
chis_mask
], axis=2)
# Create a frame from the first three atoms:
# First atom: point on x-y-plane
# Second atom: point on negative x-axis
# Third atom: origin
# r3.Rigids (B, N, torsions=7)
torsion_frames = r3.rigids_from_3_points(
point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]),
origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]),
point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]))
# Compute the position of the forth atom in this frame (y and z coordinate
# define the chi angle)
# r3.Vecs (B, N, torsions=7)
forth_atom_rel_pos = r3.rigids_mul_vecs(
r3.invert_rigids(torsion_frames),
r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))
# Normalize to have the sin and cos of the torsion angle.
# jnp.ndarray (B, N, torsions=7, sincos=2)
torsion_angles_sin_cos = jnp.stack(
[forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1)
torsion_angles_sin_cos /= jnp.sqrt(
jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True)
+ 1e-8)
# Mirror psi, because we computed it from the Oxygen-atom.
torsion_angles_sin_cos *= jnp.asarray(
[1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]
# Create alternative angles for ambiguous atom names.
chi_is_ambiguous = utils.batched_gather(
jnp.asarray(residue_constants.chi_pi_periodic), aatype)
mirror_torsion_angles = jnp.concatenate(
[jnp.ones([num_batch, num_res, 3]),
1.0 - 2.0 * chi_is_ambiguous], axis=-1)
alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])
if placeholder_for_undefined:
# Add placeholder torsions in place of undefined torsion angles
# (e.g. N-terminus pre-omega)
placeholder_torsions = jnp.stack([
jnp.ones(torsion_angles_sin_cos.shape[:-1]),
jnp.zeros(torsion_angles_sin_cos.shape[:-1])
], axis=-1)
torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
return {
'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2)
'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2)
'torsion_angles_mask': torsion_angles_mask # (B, N, 7)
}
def torsion_angles_to_frames(
aatype: jnp.ndarray, # (N)
backb_to_global: r3.Rigids, # (N)
torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2)
) -> r3.Rigids: # (N, 8)
"""Compute rigid group frames from torsion angles.
Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" lines 2-10
Jumper et al. (2021) Suppl. Alg. 25 "makeRotX"
Args:
aatype: aatype for each residue
backb_to_global: Rigid transformations describing transformation from
backbone frame to global frame.
torsion_angles_sin_cos: sin and cosine of the 7 torsion angles
Returns:
Frames corresponding to all the Sidechain Rigid Transforms
"""
assert len(aatype.shape) == 1
assert len(backb_to_global.rot.xx.shape) == 1
assert len(torsion_angles_sin_cos.shape) == 3
assert torsion_angles_sin_cos.shape[1] == 7
assert torsion_angles_sin_cos.shape[2] == 2
# Gather the default frames for all rigid groups.
# r3.Rigids with shape (N, 8)
m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame,
aatype)
default_frames = r3.rigids_from_tensor4x4(m)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles = torsion_angles_sin_cos[..., 0]
cos_angles = torsion_angles_sin_cos[..., 1]
# insert zero rotation for backbone group.
num_residues, = aatype.shape
sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles],
axis=-1)
cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles],
axis=-1)
zeros = jnp.zeros_like(sin_angles)
ones = jnp.ones_like(sin_angles)
# all_rots are r3.Rots with shape (N, 8)
all_rots = r3.Rots(ones, zeros, zeros,
zeros, cos_angles, -sin_angles,
zeros, sin_angles, cos_angles)
# Apply rotations to the frames.
all_frames = r3.rigids_mul_rots(default_frames, all_rots)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi2_frame_to_frame = jax.tree_map(lambda x: x[:, 5], all_frames)
chi3_frame_to_frame = jax.tree_map(lambda x: x[:, 6], all_frames)
chi4_frame_to_frame = jax.tree_map(lambda x: x[:, 7], all_frames)
chi1_frame_to_backb = jax.tree_map(lambda x: x[:, 4], all_frames)
chi2_frame_to_backb = r3.rigids_mul_rigids(chi1_frame_to_backb,
chi2_frame_to_frame)
chi3_frame_to_backb = r3.rigids_mul_rigids(chi2_frame_to_backb,
chi3_frame_to_frame)
chi4_frame_to_backb = r3.rigids_mul_rigids(chi3_frame_to_backb,
chi4_frame_to_frame)
# Recombine them to a r3.Rigids with shape (N, 8).
def _concat_frames(xall, x5, x6, x7):
return jnp.concatenate(
[xall[:, 0:5], x5[:, None], x6[:, None], x7[:, None]], axis=-1)
all_frames_to_backb = jax.tree_map(
_concat_frames,
all_frames,
chi2_frame_to_backb,
chi3_frame_to_backb,
chi4_frame_to_backb)
# Create the global frames.
# shape (N, 8)
all_frames_to_global = r3.rigids_mul_rigids(
jax.tree_map(lambda x: x[:, None], backb_to_global),
all_frames_to_backb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
aatype: jnp.ndarray, # (N)
all_frames_to_global: r3.Rigids # (N, 8)
) -> r3.Vecs: # (N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group.
Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11
Args:
aatype: aatype for each residue.
all_frames_to_global: All per residue coordinate frames.
Returns:
Positions of all atom coordinates in global frame.
"""
# Pick the appropriate transform for every atom.
residx_to_group_idx = utils.batched_gather(
residue_constants.restype_atom14_to_rigid_group, aatype)
group_mask = jax.nn.one_hot(
residx_to_group_idx, num_classes=8) # shape (N, 14, 8)
# r3.Rigids with shape (N, 14)
map_atoms_to_global = jax.tree_map(
lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1),
all_frames_to_global)
# Gather the literature atom positions for each residue.
# r3.Vecs with shape (N, 14)
lit_positions = r3.vecs_from_tensor(
utils.batched_gather(
residue_constants.restype_atom14_rigid_group_positions, aatype))
# Transform each atom from its local frame to the global frame.
# r3.Vecs with shape (N, 14)
pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions)
# Mask out non-existing atoms.
mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype)
pred_positions = jax.tree_map(lambda x: x * mask, pred_positions)
return pred_positions
def extreme_ca_ca_distance_violations(
pred_atom_positions: jnp.ndarray, # (N, 37(14), 3)
pred_atom_mask: jnp.ndarray, # (N, 37(14))
residue_index: jnp.ndarray, # (N)
max_angstrom_tolerance=1.5
) -> jnp.ndarray:
"""Counts residues whose Ca is a large distance from its neighbor.
Measures the fraction of CA-CA pairs between consectutive amino acids that
are more than 'max_angstrom_tolerance' apart.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
max_angstrom_tolerance: Maximum distance allowed to not count as violation.
Returns:
Fraction of consecutive CA-CA pairs with violation.
"""
this_ca_pos = pred_atom_positions[:-1, 1, :] # (N - 1, 3)
this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1)
next_ca_pos = pred_atom_positions[1:, 1, :] # (N - 1, 3)
next_ca_mask = pred_atom_mask[1:, 1] # (N - 1)
has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(
jnp.float32)
ca_ca_distance = jnp.sqrt(
1e-6 + jnp.sum(squared_difference(this_ca_pos, next_ca_pos), axis=-1))
violations = (ca_ca_distance -
residue_constants.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
return utils.mask_mean(mask=mask, value=violations)
def between_residue_bond_loss(
pred_atom_positions: jnp.ndarray, # (N, 37(14), 3)
pred_atom_mask: jnp.ndarray, # (N, 37(14))
residue_index: jnp.ndarray, # (N)
aatype: jnp.ndarray, # (N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0
) -> Dict[str, jnp.ndarray]:
"""Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
aatype: Amino acid type of given residue
tolerance_factor_soft: soft tolerance factor measured in standard deviations
of pdb distributions
tolerance_factor_hard: hard tolerance factor measured in standard deviations
of pdb distributions
Returns:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
by CA, C, N
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
by C, N, CA
* 'per_residue_loss_sum': sum of all losses for each residue
* 'per_residue_violation_mask': mask denoting all residues with violation
present.
"""
assert len(pred_atom_positions.shape) == 3
assert len(pred_atom_mask.shape) == 2
assert len(residue_index.shape) == 1
assert len(aatype.shape) == 1
# Get the positions of the relevant backbone atoms.
this_ca_pos = pred_atom_positions[:-1, 1, :] # (N - 1, 3)
this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1)
this_c_pos = pred_atom_positions[:-1, 2, :] # (N - 1, 3)
this_c_mask = pred_atom_mask[:-1, 2] # (N - 1)
next_n_pos = pred_atom_positions[1:, 0, :] # (N - 1, 3)
next_n_mask = pred_atom_mask[1:, 0] # (N - 1)
next_ca_pos = pred_atom_positions[1:, 1, :] # (N - 1, 3)
next_ca_mask = pred_atom_mask[1:, 1] # (N - 1)
has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(
jnp.float32)
# Compute loss for the C--N bond.
c_n_bond_length = jnp.sqrt(
1e-6 + jnp.sum(squared_difference(this_c_pos, next_n_pos), axis=-1))
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
aatype[1:] == residue_constants.resname_to_idx['PRO']).astype(jnp.float32)
gt_length = (
(1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
+ next_is_proline * residue_constants.between_res_bond_length_c_n[1])
gt_stddev = (
(1. - next_is_proline) *
residue_constants.between_res_bond_length_stddev_c_n[0] +
next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1])
c_n_bond_length_error = jnp.sqrt(1e-6 +
jnp.square(c_n_bond_length - gt_length))
c_n_loss_per_residue = jax.nn.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev))
# Compute loss for the angles.
ca_c_bond_length = jnp.sqrt(1e-6 + jnp.sum(
squared_difference(this_ca_pos, this_c_pos), axis=-1))
n_ca_bond_length = jnp.sqrt(1e-6 + jnp.sum(
squared_difference(next_n_pos, next_ca_pos), axis=-1))
c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[:, None]
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[:, None]
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[:, None]
ca_c_n_cos_angle = jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1)
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle))
ca_c_n_loss_per_residue = jax.nn.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev))
c_n_ca_cos_angle = jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle))
c_n_ca_loss_per_residue = jax.nn.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev))
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (c_n_loss_per_residue +
ca_c_n_loss_per_residue +
c_n_ca_loss_per_residue)
per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) +
jnp.pad(per_residue_loss_sum, [[1, 0]]))
# Compute hard violations.
violation_mask = jnp.max(
jnp.stack([c_n_violation_mask,
ca_c_n_violation_mask,
c_n_ca_violation_mask]), axis=0)
violation_mask = jnp.maximum(
jnp.pad(violation_mask, [[0, 1]]),
jnp.pad(violation_mask, [[1, 0]]))
return {'c_n_loss_mean': c_n_loss, # shape ()
'ca_c_n_loss_mean': ca_c_n_loss, # shape ()
'c_n_ca_loss_mean': c_n_ca_loss, # shape ()
'per_residue_loss_sum': per_residue_loss_sum, # shape (N)
'per_residue_violation_mask': violation_mask # shape (N)
}
def between_residue_clash_loss(
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
atom14_atom_exists: jnp.ndarray, # (N, 14)
atom14_atom_radius: jnp.ndarray, # (N, 14)
residue_index: jnp.ndarray, # (N)
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5
) -> Dict[str, jnp.ndarray]:
"""Loss to penalize steric clashes between residues.
This is a loss penalizing any steric clashes due to non bonded atoms in
different peptides coming too close. This loss corresponds to the part with
different residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
atom14_atom_radius: Van der Waals radius for each atom.
residue_index: Residue index for given amino acid.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
Returns:
Dict containing:
* 'mean_loss': average clash loss
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
shape (N, 14)
"""
assert len(atom14_pred_positions.shape) == 3
assert len(atom14_atom_exists.shape) == 2
assert len(atom14_atom_radius.shape) == 2
assert len(residue_index.shape) == 1
# Create the distance matrix.
# (N, N, 14, 14)
dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_pred_positions[:, None, :, None, :],
atom14_pred_positions[None, :, None, :, :]),
axis=-1))
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask = (atom14_atom_exists[:, None, :, None] *
atom14_atom_exists[None, :, None, :])
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask *= (
residue_index[:, None, None, None] < residue_index[None, :, None, None])
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot = jax.nn.one_hot(2, num_classes=14)
n_one_hot = jax.nn.one_hot(0, num_classes=14)
neighbour_mask = ((residue_index[:, None, None, None] +
1) == residue_index[None, :, None, None])
c_n_bonds = neighbour_mask * c_one_hot[None, None, :,
None] * n_one_hot[None, None, None, :]
dists_mask *= (1. - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG')
cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (cys_sg_one_hot[None, None, :, None] *
cys_sg_one_hot[None, None, None, :])
dists_mask *= (1. - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound = dists_mask * (atom14_atom_radius[:, None, :, None] +
atom14_atom_radius[None, :, None, :])
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error = dists_mask * jax.nn.relu(
dists_lower_bound - overlap_tolerance_soft - dists)
# Compute the mean loss.
# shape ()
mean_loss = (jnp.sum(dists_to_low_error)
/ (1e-6 + jnp.sum(dists_mask)))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) +
jnp.sum(dists_to_low_error, axis=[1, 3]))
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = jnp.maximum(
jnp.max(clash_mask, axis=[0, 2]),
jnp.max(clash_mask, axis=[1, 3]))
return {'mean_loss': mean_loss, # shape ()
'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14)
}
def within_residue_violations(
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
atom14_atom_exists: jnp.ndarray, # (N, 14)
atom14_dists_lower_bound: jnp.ndarray, # (N, 14, 14)
atom14_dists_upper_bound: jnp.ndarray, # (N, 14, 14)
tighten_bounds_for_loss=0.0,
) -> Dict[str, jnp.ndarray]:
"""Loss to penalize steric clashes within residues.
This is a loss penalizing any steric violations or clashes of non-bonded atoms
in a given peptide. This loss corresponds to the part with
the same residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
atom14_dists_lower_bound: Lower bound on allowed distances.
atom14_dists_upper_bound: Upper bound on allowed distances
tighten_bounds_for_loss: Extra factor to tighten loss
Returns:
Dict containing:
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
shape (N, 14)
"""
assert len(atom14_pred_positions.shape) == 3
assert len(atom14_atom_exists.shape) == 2
assert len(atom14_dists_lower_bound.shape) == 3
assert len(atom14_dists_upper_bound.shape) == 3
# Compute the mask for each residue.
# shape (N, 14, 14)
dists_masks = (1. - jnp.eye(14, 14)[None])
dists_masks *= (atom14_atom_exists[:, :, None] *
atom14_atom_exists[:, None, :])
# Distance matrix
# shape (N, 14, 14)
dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_pred_positions[:, :, None, :],
atom14_pred_positions[:, None, :, :]),
axis=-1))
# Compute the loss.
# shape (N, 14, 14)
dists_to_low_error = jax.nn.relu(
atom14_dists_lower_bound + tighten_bounds_for_loss - dists)
dists_to_high_error = jax.nn.relu(
dists - (atom14_dists_upper_bound - tighten_bounds_for_loss))
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (jnp.sum(loss, axis=1) +
jnp.sum(loss, axis=2))
# Compute the violations mask.
# shape (N, 14, 14)
violations = dists_masks * ((dists < atom14_dists_lower_bound) |
(dists > atom14_dists_upper_bound))
# Compute the per atom violations.
# shape (N, 14)
per_atom_violations = jnp.maximum(
jnp.max(violations, axis=1), jnp.max(violations, axis=2))
return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_violations': per_atom_violations # shape (N, 14)
}
def find_optimal_renaming(
atom14_gt_positions: jnp.ndarray, # (N, 14, 3)
atom14_alt_gt_positions: jnp.ndarray, # (N, 14, 3)
atom14_atom_is_ambiguous: jnp.ndarray, # (N, 14)
atom14_gt_exists: jnp.ndarray, # (N, 14)
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
atom14_atom_exists: jnp.ndarray, # (N, 14)
) -> jnp.ndarray: # (N):
"""Find optimal renaming for ground truth that maximizes LDDT.
Jumper et al. (2021) Suppl. Alg. 26
"renameSymmetricGroundTruthAtoms" lines 1-5
Args:
atom14_gt_positions: Ground truth positions in global frame of ground truth.
atom14_alt_gt_positions: Alternate ground truth positions in global frame of
ground truth with coordinates of ambiguous atoms swapped relative to
'atom14_gt_positions'.
atom14_atom_is_ambiguous: Mask denoting whether atom is among ambiguous
atoms, see Jumper et al. (2021) Suppl. Table 3
atom14_gt_exists: Mask denoting whether atom at positions exists in ground
truth.
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
Returns:
Float array of shape [N] with 1. where atom14_alt_gt_positions is closer to
prediction and 0. otherwise
"""
assert len(atom14_gt_positions.shape) == 3
assert len(atom14_alt_gt_positions.shape) == 3
assert len(atom14_atom_is_ambiguous.shape) == 2
assert len(atom14_gt_exists.shape) == 2
assert len(atom14_pred_positions.shape) == 3
assert len(atom14_atom_exists.shape) == 2
# Create the pred distance matrix.
# shape (N, N, 14, 14)
pred_dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_pred_positions[:, None, :, None, :],
atom14_pred_positions[None, :, None, :, :]),
axis=-1))
# Compute distances for ground truth with original and alternative names.
# shape (N, N, 14, 14)
gt_dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_gt_positions[:, None, :, None, :],
atom14_gt_positions[None, :, None, :, :]),
axis=-1))
alt_gt_dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_alt_gt_positions[:, None, :, None, :],
atom14_alt_gt_positions[None, :, None, :, :]),
axis=-1))
# Compute LDDT's.
# shape (N, N, 14, 14)
lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists))
alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists))
# Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms
# in cols.
# shape (N ,N, 14, 14)
mask = (atom14_gt_exists[:, None, :, None] * # rows
atom14_atom_is_ambiguous[:, None, :, None] * # rows
atom14_gt_exists[None, :, None, :] * # cols
(1. - atom14_atom_is_ambiguous[None, :, None, :])) # cols
# Aggregate distances for each residue to the non-amibuguous atoms.
# shape (N)
per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3])
alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3])
# Decide for each residue, whether alternative naming is better.
# shape (N)
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32)
return alt_naming_is_better # shape (N)
def frame_aligned_point_error(
pred_frames: r3.Rigids, # shape (num_frames)
target_frames: r3.Rigids, # shape (num_frames)
frames_mask: jnp.ndarray, # shape (num_frames)
pred_positions: r3.Vecs, # shape (num_positions)
target_positions: r3.Vecs, # shape (num_positions)
positions_mask: jnp.ndarray, # shape (num_positions)
length_scale: float,
l1_clamp_distance: Optional[float] = None,
epsilon=1e-4) -> jnp.ndarray: # shape ()
"""Measure point error under different alignments.
Jumper et al. (2021) Suppl. Alg. 28 "computeFAPE"
Computes error between two structures with B points under A alignments derived
from the given pairs of frames.
Args:
pred_frames: num_frames reference frames for 'pred_positions'.
target_frames: num_frames reference frames for 'target_positions'.
frames_mask: Mask for frame pairs to use.
pred_positions: num_positions predicted positions of the structure.
target_positions: num_positions target positions of the structure.
positions_mask: Mask on which positions to score.
length_scale: length scale to divide loss by.
l1_clamp_distance: Distance cutoff on error beyond which gradients will
be zero.
epsilon: small value used to regularize denominator for masked average.
Returns:
Masked Frame Aligned Point Error.
"""
assert pred_frames.rot.xx.ndim == 1
assert target_frames.rot.xx.ndim == 1
assert frames_mask.ndim == 1, frames_mask.ndim
assert pred_positions.x.ndim == 1
assert target_positions.x.ndim == 1
assert positions_mask.ndim == 1
# Compute array of predicted positions in the predicted frames.
# r3.Vecs (num_frames, num_positions)
local_pred_pos = r3.rigids_mul_vecs(
jax.tree_map(lambda r: r[:, None], r3.invert_rigids(pred_frames)),
jax.tree_map(lambda x: x[None, :], pred_positions))
# Compute array of target positions in the target frames.
# r3.Vecs (num_frames, num_positions)
local_target_pos = r3.rigids_mul_vecs(
jax.tree_map(lambda r: r[:, None], r3.invert_rigids(target_frames)),
jax.tree_map(lambda x: x[None, :], target_positions))
# Compute errors between the structures.
# jnp.ndarray (num_frames, num_positions)
error_dist = jnp.sqrt(
r3.vecs_squared_distance(local_pred_pos, local_target_pos)
+ epsilon)
if l1_clamp_distance:
error_dist = jnp.clip(error_dist, 0, l1_clamp_distance)
normed_error = error_dist / length_scale
normed_error *= jnp.expand_dims(frames_mask, axis=-1)
normed_error *= jnp.expand_dims(positions_mask, axis=-2)
normalization_factor = (
jnp.sum(frames_mask, axis=-1) *
jnp.sum(positions_mask, axis=-1))
return (jnp.sum(normed_error, axis=(-2, -1)) /
(epsilon + normalization_factor))
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
residue_constants.restype_1to3[res] for res in residue_constants.restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
RENAMING_MATRICES = _make_renaming_matrices()
def get_alt_atom14(aatype, positions, mask):
"""Get alternative atom14 positions.
Constructs renamed atom positions for ambiguous residues.
Jumper et al. (2021) Suppl. Table 3 "Ambiguous atom names due to 180 degree-
rotation-symmetry"
Args:
aatype: Amino acid at given position
positions: Atom positions as r3.Vecs in atom14 representation, (N, 14)
mask: Atom masks in atom14 representation, (N, 14)
Returns:
renamed atom positions, renamed atom mask
"""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform = utils.batched_gather(
jnp.asarray(RENAMING_MATRICES), aatype)
positions = jax.tree_map(lambda x: x[:, :, None], positions)
alternative_positions = jax.tree_map(
lambda x: jnp.sum(x, axis=1), positions * renaming_transform)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1)
return alternative_positions, alternative_mask
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for all_atom."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from alphafold.model import all_atom
from alphafold.model import r3
L1_CLAMP_DISTANCE = 10
def get_identity_rigid(shape):
"""Returns identity rigid transform."""
ones = np.ones(shape)
zeros = np.zeros(shape)
rot = r3.Rots(ones, zeros, zeros,
zeros, ones, zeros,
zeros, zeros, ones)
trans = r3.Vecs(zeros, zeros, zeros)
return r3.Rigids(rot, trans)
def get_global_rigid_transform(rot_angle, translation, bcast_dims):
"""Returns rigid transform that globally rotates/translates by same amount."""
rot_angle = np.asarray(rot_angle)
translation = np.asarray(translation)
if bcast_dims:
for _ in range(bcast_dims):
rot_angle = np.expand_dims(rot_angle, 0)
translation = np.expand_dims(translation, 0)
sin_angle = np.sin(np.deg2rad(rot_angle))
cos_angle = np.cos(np.deg2rad(rot_angle))
ones = np.ones_like(sin_angle)
zeros = np.zeros_like(sin_angle)
rot = r3.Rots(ones, zeros, zeros,
zeros, cos_angle, -sin_angle,
zeros, sin_angle, cos_angle)
trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2])
return r3.Rigids(rot, trans)
class AllAtomTest(parameterized.TestCase, absltest.TestCase):
@parameterized.named_parameters(
('identity', 0, [0, 0, 0]),
('rot_90', 90, [0, 0, 0]),
('trans_10', 0, [0, 0, 10]),
('rot_174_trans_1', 174, [1, 1, 1]))
def test_frame_aligned_point_error_perfect_on_global_transform(
self, rot_angle, translation):
"""Tests global transform between target and preds gives perfect score."""
# pylint: disable=bad-whitespace
target_positions = np.array(
[[ 21.182, 23.095, 19.731],
[ 22.055, 20.919, 17.294],
[ 24.599, 20.005, 15.041],
[ 25.567, 18.214, 12.166],
[ 28.063, 17.082, 10.043],
[ 28.779, 15.569, 6.985],
[ 30.581, 13.815, 4.612],
[ 29.258, 12.193, 2.296]])
# pylint: enable=bad-whitespace
global_rigid_transform = get_global_rigid_transform(
rot_angle, translation, 1)
target_positions = r3.vecs_from_tensor(target_positions)
pred_positions = r3.rigids_mul_vecs(
global_rigid_transform, target_positions)
positions_mask = np.ones(target_positions.x.shape[0])
target_frames = get_identity_rigid(10)
pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames)
frames_mask = np.ones(10)
fape = all_atom.frame_aligned_point_error(
pred_frames, target_frames, frames_mask, pred_positions,
target_positions, positions_mask, L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE, epsilon=0)
self.assertAlmostEqual(fape, 0.)
@parameterized.named_parameters(
('identity',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
0.),
('shift_2.5',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]],
0.25),
('shift_5',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[5, 0, 0], [10, 0, 0], [15, 0, 0]],
0.5),
('shift_10',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[10, 0, 0], [15, 0, 0], [0, 0, 0]],
1.))
def test_frame_aligned_point_error_matches_expected(
self, target_positions, pred_positions, expected_alddt):
"""Tests score matches expected."""
target_frames = get_identity_rigid(2)
pred_frames = target_frames
frames_mask = np.ones(2)
target_positions = r3.vecs_from_tensor(np.array(target_positions))
pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
positions_mask = np.ones(target_positions.x.shape[0])
alddt = all_atom.frame_aligned_point_error(
pred_frames, target_frames, frames_mask, pred_positions,
target_positions, positions_mask, L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE, epsilon=0)
self.assertAlmostEqual(alddt, expected_alddt)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of common Haiku modules for use in protein folding."""
import haiku as hk
import jax.numpy as jnp
class Linear(hk.Module):
"""Protein folding specific Linear Module.
This differs from the standard Haiku Linear in a few ways:
* It supports inputs of arbitrary rank
* Initializers are specified by strings
"""
def __init__(self,
num_output: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
name: str = 'linear'):
"""Constructs Linear Module.
Args:
num_output: number of output channels.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}
use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias.
name: name of module, used for name scopes.
"""
super().__init__(name=name)
self.num_output = num_output
self.initializer = initializer
self.use_bias = use_bias
self.bias_init = bias_init
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Connects Module.
Args:
inputs: Tensor of shape [..., num_channel]
Returns:
output of shape [..., num_output]
"""
n_channels = int(inputs.shape[-1])
weight_shape = [n_channels, self.num_output]
if self.initializer == 'linear':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.)
elif self.initializer == 'relu':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.)
elif self.initializer == 'zeros':
weight_init = hk.initializers.Constant(0.0)
weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
weight_init)
# this is equivalent to einsum('...c,cd->...d', inputs, weights)
# but turns out to be slightly faster
inputs = jnp.swapaxes(inputs, -1, -2)
output = jnp.einsum('...cb,cd->...db', inputs, weights)
output = jnp.swapaxes(output, -1, -2)
if self.use_bias:
bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
hk.initializers.Constant(self.bias_init))
output += bias
return output
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model config."""
import copy
import ml_collections
from alphafold.model.tf import shape_placeholders
NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""
if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg
CONFIG_DIFFS = {
'model_1': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.1
'data.common.max_extra_msa': 5120,
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True
},
'model_2': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.2
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True
},
'model_3': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.1
'data.common.max_extra_msa': 5120,
},
'model_4': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.2
'data.common.max_extra_msa': 5120,
},
'model_5': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.3
},
# The following models are fine-tuned from the corresponding models above
# with an additional predicted_aligned_error head that can produce
# predicted TM-score (pTM) and predicted aligned errors.
'model_1_ptm': {
'data.common.max_extra_msa': 5120,
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_2_ptm': {
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_3_ptm': {
'data.common.max_extra_msa': 5120,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_4_ptm': {
'data.common.max_extra_msa': 5120,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1
}
}
CONFIG = ml_collections.ConfigDict({
'data': {
'common': {
'masked_msa': {
'profile_prob': 0.1,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'max_extra_msa': 1024,
'msa_cluster_features': True,
'num_recycle': 3,
'reduce_msa_clusters_by_max_templates': False,
'resample_msa_in_recycling': True,
'template_features': [
'template_all_atom_positions', 'template_sum_probs',
'template_aatype', 'template_all_atom_masks',
'template_domain_names'
],
'unsupervised_features': [
'aatype', 'residue_index', 'sequence', 'msa', 'domain_name',
'num_alignments', 'seq_length', 'between_segment_residues',
'deletion_matrix'
],
'use_templates': False,
},
'eval': {
'feat': {
'aatype': [NUM_RES],
'all_atom_mask': [NUM_RES, None],
'all_atom_positions': [NUM_RES, None, None],
'alt_chi_angles': [NUM_RES, None],
'atom14_alt_gt_exists': [NUM_RES, None],
'atom14_alt_gt_positions': [NUM_RES, None, None],
'atom14_atom_exists': [NUM_RES, None],
'atom14_atom_is_ambiguous': [NUM_RES, None],
'atom14_gt_exists': [NUM_RES, None],
'atom14_gt_positions': [NUM_RES, None, None],
'atom37_atom_exists': [NUM_RES, None],
'backbone_affine_mask': [NUM_RES],
'backbone_affine_tensor': [NUM_RES, None],
'bert_mask': [NUM_MSA_SEQ, NUM_RES],
'chi_angles': [NUM_RES, None],
'chi_mask': [NUM_RES, None],
'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_row_mask': [NUM_EXTRA_SEQ],
'is_distillation': [],
'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
'msa_mask': [NUM_MSA_SEQ, NUM_RES],
'msa_row_mask': [NUM_MSA_SEQ],
'pseudo_beta': [NUM_RES, None],
'pseudo_beta_mask': [NUM_RES],
'random_crop_to_size_seed': [None],
'residue_index': [NUM_RES],
'residx_atom14_to_atom37': [NUM_RES, None],
'residx_atom37_to_atom14': [NUM_RES, None],
'resolution': [],
'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
'rigidgroups_group_exists': [NUM_RES, None],
'rigidgroups_group_is_ambiguous': [NUM_RES, None],
'rigidgroups_gt_exists': [NUM_RES, None],
'rigidgroups_gt_frames': [NUM_RES, None, None],
'seq_length': [],
'seq_mask': [NUM_RES],
'target_feat': [NUM_RES, None],
'template_aatype': [NUM_TEMPLATES, NUM_RES],
'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
'template_all_atom_positions': [
NUM_TEMPLATES, NUM_RES, None, None],
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
'template_backbone_affine_tensor': [
NUM_TEMPLATES, NUM_RES, None],
'template_mask': [NUM_TEMPLATES],
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
'template_sum_probs': [NUM_TEMPLATES, None],
'true_msa': [NUM_MSA_SEQ, NUM_RES]
},
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
},
},
'model': {
'embeddings_and_evoformer': {
'evoformer_num_block': 48,
'evoformer': {
'msa_row_attention_with_pair_bias': {
'dropout_rate': 0.15,
'gating': True,
'num_head': 8,
'orientation': 'per_row',
'shared_dropout': True
},
'msa_column_attention': {
'dropout_rate': 0.0,
'gating': True,
'num_head': 8,
'orientation': 'per_column',
'shared_dropout': True
},
'msa_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'outer_product_mean': {
'chunk_size': 128,
'dropout_rate': 0.0,
'num_outer_channel': 32,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'max_relative_feature': 32,
'msa_channel': 256,
'pair_channel': 128,
'prev_pos': {
'min_bin': 3.25,
'max_bin': 20.75,
'num_bins': 15
},
'recycle_features': True,
'recycle_pos': True,
'seq_channel': 384,
'template': {
'attention': {
'gating': False,
'key_dim': 64,
'num_head': 4,
'value_dim': 64
},
'dgram_features': {
'min_bin': 3.25,
'max_bin': 50.75,
'num_bins': 39
},
'embed_torsion_angles': False,
'enabled': False,
'template_pair_stack': {
'num_block': 2,
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'key_dim': 64,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True,
'value_dim': 64
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'key_dim': 64,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True,
'value_dim': 64
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 2,
'orientation': 'per_row',
'shared_dropout': True
}
},
'max_templates': 4,
'subbatch_size': 128,
'use_template_unit_vector': False,
}
},
'global_config': {
'deterministic': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
},
'heads': {
'distogram': {
'first_break': 2.3125,
'last_break': 21.6875,
'num_bins': 64,
'weight': 0.3
},
'predicted_aligned_error': {
# `num_bins - 1` bins uniformly space the
# [0, max_error_bin A] range.
# The final bin covers [max_error_bin A, +infty]
# 31A gives bins with 0.5A width.
'max_error_bin': 31.,
'num_bins': 64,
'num_channels': 128,
'filter_by_resolution': True,
'min_resolution': 0.1,
'max_resolution': 3.0,
'weight': 0.0,
},
'experimentally_resolved': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'weight': 0.01
},
'structure_module': {
'num_layer': 8,
'fape': {
'clamp_distance': 10.0,
'clamp_type': 'relu',
'loss_unit_distance': 10.0
},
'angle_norm_weight': 0.01,
'chi_weight': 0.5,
'clash_overlap_tolerance': 1.5,
'compute_in_graph_metrics': True,
'dropout': 0.1,
'num_channel': 384,
'num_head': 12,
'num_layer_in_transition': 3,
'num_point_qk': 4,
'num_point_v': 8,
'num_scalar_qk': 16,
'num_scalar_v': 16,
'position_scale': 10.0,
'sidechain': {
'atom_clamp_distance': 10.0,
'num_channel': 128,
'num_residual_block': 2,
'weight_frac': 0.5,
'length_scale': 10.,
},
'structural_violation_loss_weight': 1.0,
'violation_tolerance_factor': 12.0,
'weight': 1.0
},
'predicted_lddt': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 50,
'num_channels': 128,
'weight': 0.01
},
'masked_msa': {
'num_output': 23,
'weight': 2.0
},
},
'num_recycle': 3,
'resample_msa_in_recycling': True
},
})
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convenience functions for reading data."""
import io
import os
from typing import List
import haiku as hk
import numpy as np
from alphafold.model import utils
# Internal import (7716).
def casp_model_names(data_dir: str) -> List[str]:
params = os.listdir(os.path.join(data_dir, 'params'))
return [os.path.splitext(filename)[0] for filename in params]
def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params:
"""Get the Haiku parameters from a model name."""
path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')
with open(path, 'rb') as f:
params = np.load(io.BytesIO(f.read()), allow_pickle=False)
return utils.flat_params_to_haiku(params)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to generate processed features."""
import copy
from typing import List, Mapping, Tuple
import ml_collections
import numpy as np
import tensorflow.compat.v1 as tf
from alphafold.model.tf import input_pipeline
from alphafold.model.tf import proteins_dataset
FeatureDict = Mapping[str, np.ndarray]
def make_data_config(
config: ml_collections.ConfigDict,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
"""Makes a data config for the input pipeline."""
cfg = copy.deepcopy(config.data)
feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
with cfg.unlocked():
cfg.eval.crop_size = num_res
return cfg, feature_names
def tf_example_to_features(tf_example: tf.train.Example,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
"""Converts tf_example to numpy feature dictionary."""
num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0])
cfg, feature_names = make_data_config(config, num_res=num_res)
if 'deletion_matrix_int' in set(tf_example.features.feature):
deletion_matrix_int = (
tf_example.features.feature['deletion_matrix_int'].int64_list.value)
feat = tf.train.Feature(float_list=tf.train.FloatList(
value=map(float, deletion_matrix_int)))
tf_example.features.feature['deletion_matrix'].CopyFrom(feat)
del tf_example.features.feature['deletion_matrix_int']
tf_graph = tf.Graph()
with tf_graph.as_default(), tf.device('/device:CPU:0'):
tf.compat.v1.set_random_seed(random_seed)
tensor_dict = proteins_dataset.create_tensor_dict(
raw_data=tf_example.SerializeToString(),
features=feature_names)
processed_batch = input_pipeline.process_tensors_from_config(
tensor_dict, cfg)
tf_graph.finalize()
with tf.Session(graph=tf_graph) as sess:
features = sess.run(processed_batch)
return {k: v for k, v in features.items() if v.dtype != 'O'}
def np_example_to_features(np_example: FeatureDict,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
"""Preprocesses NumPy feature dict using TF pipeline."""
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res)
if 'deletion_matrix_int' in np_example:
np_example['deletion_matrix'] = (
np_example.pop('deletion_matrix_int').astype(np.float32))
tf_graph = tf.Graph()
with tf_graph.as_default(), tf.device('/device:CPU:0'):
tf.compat.v1.set_random_seed(random_seed)
tensor_dict = proteins_dataset.np_to_tensor_dict(
np_example=np_example, features=feature_names)
processed_batch = input_pipeline.process_tensors_from_config(
tensor_dict, cfg)
tf_graph.finalize()
with tf.Session(graph=tf_graph) as sess:
features = sess.run(processed_batch)
return {k: v for k, v in features.items() if v.dtype != 'O'}
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and utilities for the structure module."""
import functools
from typing import Dict
import haiku as hk
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
from alphafold.common import residue_constants
from alphafold.model import all_atom
from alphafold.model import common_modules
from alphafold.model import prng
from alphafold.model import quat_affine
from alphafold.model import r3
from alphafold.model import utils
def squared_difference(x, y):
return jnp.square(x - y)
class InvariantPointAttention(hk.Module):
"""Invariant Point attention module.
The high-level idea is that this attention module works over a set of points
and associated orientations in 3D space (e.g. protein residues).
Each residue outputs a set of queries and keys as points in their local
reference frame. The attention is then defined as the euclidean distance
between the queries and keys in the global frame.
Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention"
"""
def __init__(self,
config,
global_config,
dist_epsilon=1e-8,
name='invariant_point_attention'):
"""Initialize.
Args:
config: Structure Module Config
global_config: Global Config of Model.
dist_epsilon: Small value to avoid NaN in distance calculation.
name: Haiku Module name.
"""
super().__init__(name=name)
self._dist_epsilon = dist_epsilon
self._zero_initialize_last = global_config.zero_init
self.config = config
self.global_config = global_config
def __call__(self, inputs_1d, inputs_2d, mask, affine):
"""Compute geometry-aware attention.
Given a set of query residues (defined by affines and associated scalar
features), this function computes geometry-aware attention between the
query residues and target residues.
The residues produce points in their local reference frame, which
are converted into the global frame in order to compute attention via
euclidean distance.
Equivalently, the target residues produce points in their local frame to be
used as attention values, which are converted into the query residues'
local frames.
Args:
inputs_1d: (N, C) 1D input embedding that is the basis for the
scalar queries.
inputs_2d: (N, M, C') 2D input embedding, used for biases and values.
mask: (N, 1) mask to indicate which elements of inputs_1d participate
in the attention.
affine: QuatAffine object describing the position and orientation of
every element in inputs_1d.
Returns:
Transformation of the input embedding.
"""
num_residues, _ = inputs_1d.shape
# Improve readability by removing a large number of 'self's.
num_head = self.config.num_head
num_scalar_qk = self.config.num_scalar_qk
num_point_qk = self.config.num_point_qk
num_scalar_v = self.config.num_scalar_v
num_point_v = self.config.num_point_v
num_output = self.config.num_channel
assert num_scalar_qk > 0
assert num_point_qk > 0
assert num_point_v > 0
# Construct scalar queries of shape:
# [num_query_residues, num_head, num_points]
q_scalar = common_modules.Linear(
num_head * num_scalar_qk, name='q_scalar')(
inputs_1d)
q_scalar = jnp.reshape(
q_scalar, [num_residues, num_head, num_scalar_qk])
# Construct scalar keys/values of shape:
# [num_target_residues, num_head, num_points]
kv_scalar = common_modules.Linear(
num_head * (num_scalar_v + num_scalar_qk), name='kv_scalar')(
inputs_1d)
kv_scalar = jnp.reshape(kv_scalar,
[num_residues, num_head,
num_scalar_v + num_scalar_qk])
k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1)
# Construct query points of shape:
# [num_residues, num_head, num_point_qk]
# First construct query points in local frame.
q_point_local = common_modules.Linear(
num_head * 3 * num_point_qk, name='q_point_local')(
inputs_1d)
q_point_local = jnp.split(q_point_local, 3, axis=-1)
# Project query points into global frame.
q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
# Reshape query point for later use.
q_point = [
jnp.reshape(x, [num_residues, num_head, num_point_qk])
for x in q_point_global]
# Construct key and value points.
# Key points have shape [num_residues, num_head, num_point_qk]
# Value points have shape [num_residues, num_head, num_point_v]
# Construct key and value points in local frame.
kv_point_local = common_modules.Linear(
num_head * 3 * (num_point_qk + num_point_v), name='kv_point_local')(
inputs_1d)
kv_point_local = jnp.split(kv_point_local, 3, axis=-1)
# Project key and value points into global frame.
kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1)
kv_point_global = [
jnp.reshape(x, [num_residues,
num_head, (num_point_qk + num_point_v)])
for x in kv_point_global]
# Split key and value points.
k_point, v_point = list(
zip(*[
jnp.split(x, [num_point_qk,], axis=-1)
for x in kv_point_global
]))
# We assume that all queries and keys come iid from N(0, 1) distribution
# and compute the variances of the attention logits.
# Each scalar pair (q, k) contributes Var q*k = 1
scalar_variance = max(num_scalar_qk, 1) * 1.
# Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
point_variance = max(num_point_qk, 1) * 9. / 2
# Allocate equal variance to scalar, point and attention 2d parts so that
# the sum is 1.
num_logit_terms = 3
scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance))
point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance))
attention_2d_weights = np.sqrt(1.0 / (num_logit_terms))
# Trainable per-head weights for points.
trainable_point_weights = jax.nn.softplus(hk.get_parameter(
'trainable_point_weights', shape=[num_head],
# softplus^{-1} (1)
init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))))
point_weights *= jnp.expand_dims(trainable_point_weights, axis=1)
v_point = [jnp.swapaxes(x, -2, -3) for x in v_point]
q_point = [jnp.swapaxes(x, -2, -3) for x in q_point]
k_point = [jnp.swapaxes(x, -2, -3) for x in k_point]
dist2 = [
squared_difference(qx[:, :, None, :], kx[:, None, :, :])
for qx, kx in zip(q_point, k_point)
]
dist2 = sum(dist2)
attn_qk_point = -0.5 * jnp.sum(
point_weights[:, None, None, :] * dist2, axis=-1)
v = jnp.swapaxes(v_scalar, -2, -3)
q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3)
k = jnp.swapaxes(k_scalar, -2, -3)
attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
attn_logits = attn_qk_scalar + attn_qk_point
attention_2d = common_modules.Linear(
num_head, name='attention_2d')(
inputs_2d)
attention_2d = jnp.transpose(attention_2d, [2, 0, 1])
attention_2d = attention_2d_weights * attention_2d
attn_logits += attention_2d
mask_2d = mask * jnp.swapaxes(mask, -1, -2)
attn_logits -= 1e5 * (1. - mask_2d)
# [num_head, num_query_residues, num_target_residues]
attn = jax.nn.softmax(attn_logits)
# [num_head, num_query_residues, num_head * num_scalar_v]
result_scalar = jnp.matmul(attn, v)
# For point result, implement matmul manually so that it will be a float32
# on TPU. This is equivalent to
# result_point_global = [jnp.einsum('bhqk,bhkc->bhqc', attn, vx)
# for vx in v_point]
# but on the TPU, doing the multiply and reduce_sum ensures the
# computation happens in float32 instead of bfloat16.
result_point_global = [jnp.sum(
attn[:, :, :, None] * vx[:, None, :, :],
axis=-2) for vx in v_point]
# [num_query_residues, num_head, num_head * num_(scalar|point)_v]
result_scalar = jnp.swapaxes(result_scalar, -2, -3)
result_point_global = [
jnp.swapaxes(x, -2, -3)
for x in result_point_global]
# Features used in the linear output projection. Should have the size
# [num_query_residues, ?]
output_features = []
result_scalar = jnp.reshape(
result_scalar, [num_residues, num_head * num_scalar_v])
output_features.append(result_scalar)
result_point_global = [
jnp.reshape(r, [num_residues, num_head * num_point_v])
for r in result_point_global]
result_point_local = affine.invert_point(result_point_global, extra_dims=1)
output_features.extend(result_point_local)
output_features.append(jnp.sqrt(self._dist_epsilon +
jnp.square(result_point_local[0]) +
jnp.square(result_point_local[1]) +
jnp.square(result_point_local[2])))
# Dimensions: h = heads, i and j = residues,
# c = inputs_2d channels
# Contraction happens over the second residue dimension, similarly to how
# the usual attention is performed.
result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d)
num_out = num_head * result_attention_over_2d.shape[-1]
output_features.append(
jnp.reshape(result_attention_over_2d,
[num_residues, num_out]))
final_init = 'zeros' if self._zero_initialize_last else 'linear'
final_act = jnp.concatenate(output_features, axis=-1)
return common_modules.Linear(
num_output,
initializer=final_init,
name='output_projection')(final_act)
class FoldIteration(hk.Module):
"""A single iteration of the main structure module loop.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21
First, each residue attends to all residues using InvariantPointAttention.
Then, we apply transition layers to update the hidden representations.
Finally, we use the hidden representations to produce an update to the
affine of each residue.
"""
def __init__(self, config, global_config,
name='fold_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
activations,
sequence_mask,
update_affine,
is_training,
initial_act,
safe_key=None,
static_feat_2d=None,
aatype=None):
c = self.config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
def safe_dropout_fn(tensor, safe_key):
return prng.safe_dropout(
tensor=tensor,
safe_key=safe_key,
rate=c.dropout,
is_deterministic=self.global_config.deterministic,
is_training=is_training)
affine = quat_affine.QuatAffine.from_tensor(activations['affine'])
act = activations['act']
attention_module = InvariantPointAttention(self.config, self.global_config)
# Attention
attn = attention_module(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=sequence_mask,
affine=affine)
act += attn
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='attention_layer_norm')(
act)
final_init = 'zeros' if self.global_config.zero_init else 'linear'
# Transition
input_act = act
for i in range(c.num_layer_in_transition):
init = 'relu' if i < c.num_layer_in_transition - 1 else final_init
act = common_modules.Linear(
c.num_channel,
initializer=init,
name='transition')(
act)
if i < c.num_layer_in_transition - 1:
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='transition_layer_norm')(act)
if update_affine:
# This block corresponds to
# Jumper et al. (2021) Alg. 23 "Backbone update"
affine_update_size = 6
# Affine update
affine_update = common_modules.Linear(
affine_update_size,
initializer=final_init,
name='affine_update')(
act)
affine = affine.pre_compose(affine_update)
sc = MultiRigidSidechain(c.sidechain, self.global_config)(
affine.scale_translation(c.position_scale), [act, initial_act], aatype)
outputs = {'affine': affine.to_tensor(), 'sc': sc}
affine = affine.apply_rotation_tensor_fn(jax.lax.stop_gradient)
new_activations = {
'act': act,
'affine': affine.to_tensor()
}
return new_activations, outputs
def generate_affines(representations, batch, config, global_config,
is_training, safe_key):
"""Generate predicted affines for a single chain.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
This is the main part of the structure module - it iteratively applies
folding to produce a set of predicted residue positions.
Args:
representations: Representations dictionary.
batch: Batch dictionary.
config: Config for the structure module.
global_config: Global config.
is_training: Whether the model is being trained.
safe_key: A prng.SafeKey object that wraps a PRNG key.
Returns:
A dictionary containing residue affines and sidechain positions.
"""
c = config
sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='single_layer_norm')(
representations['single'])
initial_act = act
act = common_modules.Linear(
c.num_channel, name='initial_projection')(
act)
affine = generate_new_affine(sequence_mask)
fold_iteration = FoldIteration(
c, global_config, name='fold_iteration')
assert len(batch['seq_mask'].shape) == 1
activations = {'act': act,
'affine': affine.to_tensor(),
}
act_2d = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='pair_layer_norm')(
representations['pair'])
outputs = []
safe_keys = safe_key.split(c.num_layer)
for sub_key in safe_keys:
activations, output = fold_iteration(
activations,
initial_act=initial_act,
static_feat_2d=act_2d,
safe_key=sub_key,
sequence_mask=sequence_mask,
update_affine=True,
is_training=is_training,
aatype=batch['aatype'])
outputs.append(output)
output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
# Include the activations in the output dict for use by the LDDT-Head.
output['act'] = activations['act']
return output
class StructureModule(hk.Module):
"""StructureModule as a network head.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
"""
def __init__(self, config, global_config, compute_loss=True,
name='structure_module'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.compute_loss = compute_loss
def __call__(self, representations, batch, is_training,
safe_key=None):
c = self.config
ret = {}
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
output = generate_affines(
representations=representations,
batch=batch,
config=self.config,
global_config=self.global_config,
is_training=is_training,
safe_key=safe_key)
representations['structure_module'] = output['act']
ret['traj'] = output['affine'] * jnp.array([1.] * 4 +
[c.position_scale] * 3)
ret['sidechains'] = output['sc']
atom14_pred_positions = r3.vecs_to_tensor(output['sc']['atom_pos'])[-1]
ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3)
ret['final_atom14_mask'] = batch['atom14_atom_exists'] # (N, 14)
atom37_pred_positions = all_atom.atom14_to_atom37(atom14_pred_positions,
batch)
atom37_pred_positions *= batch['atom37_atom_exists'][:, :, None]
ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3)
ret['final_atom_mask'] = batch['atom37_atom_exists'] # (N, 37)
ret['final_affines'] = ret['traj'][-1]
if self.compute_loss:
return ret
else:
no_loss_features = ['final_atom_positions', 'final_atom_mask']
no_loss_ret = {k: ret[k] for k in no_loss_features}
return no_loss_ret
def loss(self, value, batch):
ret = {'loss': 0.}
ret['metrics'] = {}
# If requested, compute in-graph metrics.
if self.config.compute_in_graph_metrics:
atom14_pred_positions = value['final_atom14_positions']
# Compute renaming and violations.
value.update(compute_renamed_ground_truth(batch, atom14_pred_positions))
value['violations'] = find_structural_violations(
batch, atom14_pred_positions, self.config)
# Several violation metrics:
violation_metrics = compute_violation_metrics(
batch=batch,
atom14_pred_positions=atom14_pred_positions,
violations=value['violations'])
ret['metrics'].update(violation_metrics)
backbone_loss(ret, batch, value, self.config)
if 'renamed_atom14_gt_positions' not in value:
value.update(compute_renamed_ground_truth(
batch, value['final_atom14_positions']))
sc_loss = sidechain_loss(batch, value, self.config)
ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] +
self.config.sidechain.weight_frac * sc_loss['loss'])
ret['sidechain_fape'] = sc_loss['fape']
supervised_chi_loss(ret, batch, value, self.config)
if self.config.structural_violation_loss_weight:
if 'violations' not in value:
value['violations'] = find_structural_violations(
batch, value['final_atom14_positions'], self.config)
structural_violation_loss(ret, batch, value, self.config)
return ret
def compute_renamed_ground_truth(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray,
) -> Dict[str, jnp.ndarray]:
"""Find optimal renaming of ground truth based on the predicted positions.
Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
Shape (N).
Args:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
* atom14_alt_gt_positions: Ground truth positions with renaming swaps.
* atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
renaming swaps.
* atom14_gt_exists: Mask for which atoms exist in ground truth.
* atom14_alt_gt_exists: Mask for which atoms exist in ground truth
after renaming.
* atom14_atom_exists: Mask for whether each atom is part of the given
amino acid type.
atom14_pred_positions: Array of atom positions in global frame with shape
(N, 14, 3).
Returns:
Dictionary containing:
alt_naming_is_better: Array with 1.0 where alternative swap is better.
renamed_atom14_gt_positions: Array of optimal ground truth positions
after renaming swaps are performed.
renamed_atom14_gt_exists: Mask after renaming swap is performed.
"""
alt_naming_is_better = all_atom.find_optimal_renaming(
atom14_gt_positions=batch['atom14_gt_positions'],
atom14_alt_gt_positions=batch['atom14_alt_gt_positions'],
atom14_atom_is_ambiguous=batch['atom14_atom_is_ambiguous'],
atom14_gt_exists=batch['atom14_gt_exists'],
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'])
renamed_atom14_gt_positions = (
(1. - alt_naming_is_better[:, None, None])
* batch['atom14_gt_positions']
+ alt_naming_is_better[:, None, None]
* batch['atom14_alt_gt_positions'])
renamed_atom14_gt_mask = (
(1. - alt_naming_is_better[:, None]) * batch['atom14_gt_exists']
+ alt_naming_is_better[:, None] * batch['atom14_alt_gt_exists'])
return {
'alt_naming_is_better': alt_naming_is_better, # (N)
'renamed_atom14_gt_positions': renamed_atom14_gt_positions, # (N, 14, 3)
'renamed_atom14_gt_exists': renamed_atom14_gt_mask, # (N, 14)
}
def backbone_loss(ret, batch, value, config):
"""Backbone FAPE Loss.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17
Args:
ret: Dictionary to write outputs into, needs to contain 'loss'.
batch: Batch, needs to contain 'backbone_affine_tensor',
'backbone_affine_mask'.
value: Dictionary containing structure module output, needs to contain
'traj', a trajectory of rigids.
config: Configuration of loss, should contain 'fape.clamp_distance' and
'fape.loss_unit_distance'.
"""
affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj'])
rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory)
gt_affine = quat_affine.QuatAffine.from_tensor(
batch['backbone_affine_tensor'])
gt_rigid = r3.rigids_from_quataffine(gt_affine)
backbone_mask = batch['backbone_affine_mask']
fape_loss_fn = functools.partial(
all_atom.frame_aligned_point_error,
l1_clamp_distance=config.fape.clamp_distance,
length_scale=config.fape.loss_unit_distance)
fape_loss_fn = jax.vmap(fape_loss_fn, (0, None, None, 0, None, None))
fape_loss = fape_loss_fn(rigid_trajectory, gt_rigid, backbone_mask,
rigid_trajectory.trans, gt_rigid.trans,
backbone_mask)
if 'use_clamped_fape' in batch:
# Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details"
use_clamped_fape = jnp.asarray(batch['use_clamped_fape'], jnp.float32)
unclamped_fape_loss_fn = functools.partial(
all_atom.frame_aligned_point_error,
l1_clamp_distance=None,
length_scale=config.fape.loss_unit_distance)
unclamped_fape_loss_fn = jax.vmap(unclamped_fape_loss_fn,
(0, None, None, 0, None, None))
fape_loss_unclamped = unclamped_fape_loss_fn(rigid_trajectory, gt_rigid,
backbone_mask,
rigid_trajectory.trans,
gt_rigid.trans,
backbone_mask)
fape_loss = (fape_loss * use_clamped_fape +
fape_loss_unclamped * (1 - use_clamped_fape))
ret['fape'] = fape_loss[-1]
ret['loss'] += jnp.mean(fape_loss)
def sidechain_loss(batch, value, config):
"""All Atom FAPE Loss using renamed rigids."""
# Rename Frames
# Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7
alt_naming_is_better = value['alt_naming_is_better']
renamed_gt_frames = (
(1. - alt_naming_is_better[:, None, None])
* batch['rigidgroups_gt_frames']
+ alt_naming_is_better[:, None, None]
* batch['rigidgroups_alt_gt_frames'])
flat_gt_frames = r3.rigids_from_tensor_flat12(
jnp.reshape(renamed_gt_frames, [-1, 12]))
flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists'], [-1])
flat_gt_positions = r3.vecs_from_tensor(
jnp.reshape(value['renamed_atom14_gt_positions'], [-1, 3]))
flat_positions_mask = jnp.reshape(value['renamed_atom14_gt_exists'], [-1])
# Compute frame_aligned_point_error score for the final layer.
pred_frames = value['sidechains']['frames']
pred_positions = value['sidechains']['atom_pos']
def _slice_last_layer_and_flatten(x):
return jnp.reshape(x[-1], [-1])
flat_pred_frames = jax.tree_map(
_slice_last_layer_and_flatten, pred_frames)
flat_pred_positions = jax.tree_map(
_slice_last_layer_and_flatten, pred_positions)
# FAPE Loss on sidechains
fape = all_atom.frame_aligned_point_error(
pred_frames=flat_pred_frames,
target_frames=flat_gt_frames,
frames_mask=flat_frames_mask,
pred_positions=flat_pred_positions,
target_positions=flat_gt_positions,
positions_mask=flat_positions_mask,
l1_clamp_distance=config.sidechain.atom_clamp_distance,
length_scale=config.sidechain.length_scale)
return {
'fape': fape,
'loss': fape}
def structural_violation_loss(ret, batch, value, config):
"""Computes loss for structural violations."""
assert config.sidechain.weight_frac
# Put all violation losses together to one large loss.
violations = value['violations']
num_atoms = jnp.sum(batch['atom14_atom_exists']).astype(jnp.float32)
ret['loss'] += (config.structural_violation_loss_weight * (
violations['between_residues']['bonds_c_n_loss_mean'] +
violations['between_residues']['angles_ca_c_n_loss_mean'] +
violations['between_residues']['angles_c_n_ca_loss_mean'] +
jnp.sum(
violations['between_residues']['clashes_per_atom_loss_sum'] +
violations['within_residues']['per_atom_loss_sum']) /
(1e-6 + num_atoms)))
def find_structural_violations(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
config: ml_collections.ConfigDict
):
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations = all_atom.between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32),
residue_index=batch['residue_index'].astype(jnp.float32),
aatype=batch['aatype'],
tolerance_factor_soft=config.violation_tolerance_factor,
tolerance_factor_hard=config.violation_tolerance_factor)
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
atomtype_radius = [
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
atomtype_radius, batch['residx_atom14_to_atom37'])
# Compute the between residue clash loss.
between_residue_clashes = all_atom.between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch['residue_index'],
overlap_tolerance_soft=config.clash_overlap_tolerance,
overlap_tolerance_hard=config.clash_overlap_tolerance)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=config.clash_overlap_tolerance,
bond_length_tolerance_factor=config.violation_tolerance_factor)
atom14_dists_lower_bound = utils.batched_gather(
restype_atom14_bounds['lower_bound'], batch['aatype'])
atom14_dists_upper_bound = utils.batched_gather(
restype_atom14_bounds['upper_bound'], batch['aatype'])
within_residue_violations = all_atom.within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0)
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask = jnp.max(jnp.stack([
connection_violations['per_residue_violation_mask'],
jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1),
jnp.max(within_residue_violations['per_atom_violations'],
axis=-1)]), axis=0)
return {
'between_residues': {
'bonds_c_n_loss_mean':
connection_violations['c_n_loss_mean'], # ()
'angles_ca_c_n_loss_mean':
connection_violations['ca_c_n_loss_mean'], # ()
'angles_c_n_ca_loss_mean':
connection_violations['c_n_ca_loss_mean'], # ()
'connections_per_residue_loss_sum':
connection_violations['per_residue_loss_sum'], # (N)
'connections_per_residue_violation_mask':
connection_violations['per_residue_violation_mask'], # (N)
'clashes_mean_loss':
between_residue_clashes['mean_loss'], # ()
'clashes_per_atom_loss_sum':
between_residue_clashes['per_atom_loss_sum'], # (N, 14)
'clashes_per_atom_clash_mask':
between_residue_clashes['per_atom_clash_mask'], # (N, 14)
},
'within_residues': {
'per_atom_loss_sum':
within_residue_violations['per_atom_loss_sum'], # (N, 14)
'per_atom_violations':
within_residue_violations['per_atom_violations'], # (N, 14),
},
'total_per_residue_violations_mask':
per_residue_violations_mask, # (N)
}
def compute_violation_metrics(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
violations: Dict[str, jnp.ndarray],
) -> Dict[str, jnp.ndarray]:
"""Compute several metrics to assess the structural violations."""
ret = {}
extreme_ca_ca_violations = all_atom.extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32),
residue_index=batch['residue_index'].astype(jnp.float32))
ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations
ret['violations_between_residue_bond'] = utils.mask_mean(
mask=batch['seq_mask'],
value=violations['between_residues'][
'connections_per_residue_violation_mask'])
ret['violations_between_residue_clash'] = utils.mask_mean(
mask=batch['seq_mask'],
value=jnp.max(
violations['between_residues']['clashes_per_atom_clash_mask'],
axis=-1))
ret['violations_within_residue'] = utils.mask_mean(
mask=batch['seq_mask'],
value=jnp.max(
violations['within_residues']['per_atom_violations'], axis=-1))
ret['violations_per_residue'] = utils.mask_mean(
mask=batch['seq_mask'],
value=violations['total_per_residue_violations_mask'])
return ret
def supervised_chi_loss(ret, batch, value, config):
"""Computes loss for direct chi angle supervision.
Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss"
Args:
ret: Dictionary to write outputs into, needs to contain 'loss'.
batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'.
value: Dictionary containing structure module output, needs to contain
value['sidechains']['angles_sin_cos'] for angles and
value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized
angles.
config: Configuration of loss, should contain 'chi_weight' and
'angle_norm_weight', 'angle_norm_weight' scales angle norm term,
'chi_weight' scales torsion term.
"""
eps = 1e-6
sequence_mask = batch['seq_mask']
num_res = sequence_mask.shape[0]
chi_mask = batch['chi_mask'].astype(jnp.float32)
pred_angles = jnp.reshape(
value['sidechains']['angles_sin_cos'], [-1, num_res, 7, 2])
pred_angles = pred_angles[:, :, 3:]
residue_type_one_hot = jax.nn.one_hot(
batch['aatype'], residue_constants.restype_num + 1,
dtype=jnp.float32)[None]
chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot,
jnp.asarray(residue_constants.chi_pi_periodic))
true_chi = batch['chi_angles'][None]
sin_true_chi = jnp.sin(true_chi)
cos_true_chi = jnp.cos(true_chi)
sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1)
# This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic
shifted_mask = (1 - 2 * chi_pi_periodic)[..., None]
sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi
sq_chi_error = jnp.sum(
squared_difference(sin_cos_true_chi, pred_angles), -1)
sq_chi_error_shifted = jnp.sum(
squared_difference(sin_cos_true_chi_shifted, pred_angles), -1)
sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error)
ret['chi_loss'] = sq_chi_loss
ret['loss'] += config.chi_weight * sq_chi_loss
unnormed_angles = jnp.reshape(
value['sidechains']['unnormalized_angles_sin_cos'], [-1, num_res, 7, 2])
angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps)
norm_error = jnp.abs(angle_norm - 1.)
angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None],
value=norm_error)
ret['angle_norm_loss'] = angle_norm_loss
ret['loss'] += config.angle_norm_weight * angle_norm_loss
def generate_new_affine(sequence_mask):
num_residues, _ = sequence_mask.shape
quaternion = jnp.tile(
jnp.reshape(jnp.asarray([1., 0., 0., 0.]), [1, 4]),
[num_residues, 1])
translation = jnp.zeros([num_residues, 3])
return quat_affine.QuatAffine(quaternion, translation, unstack_inputs=True)
def l2_normalize(x, axis=-1, epsilon=1e-12):
return x / jnp.sqrt(
jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon))
class MultiRigidSidechain(hk.Module):
"""Class to make side chain atoms."""
def __init__(self, config, global_config, name='rigid_sidechain'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, affine, representations_list, aatype):
"""Predict side chains using multi-rigid representations.
Args:
affine: The affines for each residue (translations in angstroms).
representations_list: A list of activations to predict side chains from.
aatype: Amino acid types.
Returns:
Dict containing atom positions and frames (in angstroms).
"""
act = [
common_modules.Linear( # pylint: disable=g-complex-comprehension
self.config.num_channel,
name='input_projection')(jax.nn.relu(x))
for x in representations_list
]
# Sum the activation list (equivalent to concat then Linear).
act = sum(act)
final_init = 'zeros' if self.global_config.zero_init else 'linear'
# Mapping with some residual blocks.
for _ in range(self.config.num_residual_block):
old_act = act
act = common_modules.Linear(
self.config.num_channel,
initializer='relu',
name='resblock1')(
jax.nn.relu(act))
act = common_modules.Linear(
self.config.num_channel,
initializer=final_init,
name='resblock2')(
jax.nn.relu(act))
act += old_act
# Map activations to torsion angles. Shape: (num_res, 14).
num_res = act.shape[0]
unnormalized_angles = common_modules.Linear(
14, name='unnormalized_angles')(
jax.nn.relu(act))
unnormalized_angles = jnp.reshape(
unnormalized_angles, [num_res, 7, 2])
angles = l2_normalize(unnormalized_angles, axis=-1)
outputs = {
'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2)
'unnormalized_angles_sin_cos':
unnormalized_angles, # jnp.ndarray (N, 7, 2)
}
# Map torsion angles to frames.
backb_to_global = r3.rigids_from_quataffine(affine)
# Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates"
# r3.Rigids with shape (N, 8).
all_frames_to_global = all_atom.torsion_angles_to_frames(
aatype,
backb_to_global,
angles)
# Use frames and literature positions to create the final atom coordinates.
# r3.Vecs with shape (N, 14).
pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos(
aatype, all_frames_to_global)
outputs.update({
'atom_pos': pred_positions, # r3.Vecs (N, 14)
'frames': all_frames_to_global, # r3.Rigids (N, 8)
})
return outputs
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Function to stack repeats of a layer function without shared parameters."""
import collections
import contextlib
import functools
import inspect
from typing import Any, Callable, Optional, Tuple, Union
import haiku as hk
import jax
import jax.numpy as jnp
LayerStackCarry = collections.namedtuple('LayerStackCarry', ['x', 'rng'])
LayerStackScanned = collections.namedtuple('LayerStackScanned',
['i', 'args_ys'])
# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the
# exact same type. We cannot express this with `typing`. So we just use it
# to inform the user. In reality, the typing below will accept anything.
NestedArray = Any
WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]]
def _check_no_varargs(f):
if list(inspect.signature(
f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL:
raise ValueError(
'The function `f` should not have any `varargs` (that is *args) '
'argument. Instead, it should only use explicit positional'
'arguments.')
@contextlib.contextmanager
def nullcontext():
yield
def maybe_with_rng(key):
if key is not None:
return hk.with_rng(key)
else:
return nullcontext()
def maybe_fold_in(key, data):
if key is not None:
return jax.random.fold_in(key, data)
else:
return None
class _LayerStack(hk.Module):
"""Module to compose parameterized functions, implemented as a scan."""
def __init__(self,
count: int,
unroll: int,
name: Optional[str] = None):
"""Iterate a function `f` `count` times, with non-shared parameters."""
super().__init__(name=name)
self._count = count
self._unroll = unroll
def __call__(self, x, *args_ys):
count = self._count
if hk.running_init():
# At initialization time, we run just one layer but add an extra first
# dimension to every initialized tensor, making sure to use different
# random keys for different slices.
def creator(next_creator, shape, dtype, init, context):
del context
def multi_init(shape, dtype):
assert shape[0] == count
key = hk.maybe_next_rng_key()
def rng_context_init(slice_idx):
slice_key = maybe_fold_in(key, slice_idx)
with maybe_with_rng(slice_key):
return init(shape[1:], dtype)
return jax.vmap(rng_context_init)(jnp.arange(count))
return next_creator((count,) + tuple(shape), dtype, multi_init)
def getter(next_getter, value, context):
trailing_dims = len(context.original_shape) + 1
sliced_value = jax.lax.index_in_dim(
value, index=0, axis=value.ndim - trailing_dims, keepdims=False)
return next_getter(sliced_value)
with hk.experimental.custom_creator(
creator), hk.experimental.custom_getter(getter):
if len(args_ys) == 1 and args_ys[0] is None:
args0 = (None,)
else:
args0 = [
jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False)
for ys in args_ys
]
x, z = self._call_wrapped(x, *args0)
if z is None:
return x, z
# Broadcast state to hold each layer state.
def broadcast_state(layer_state):
return jnp.broadcast_to(
layer_state, [count,] + list(layer_state.shape))
zs = jax.tree_util.tree_map(broadcast_state, z)
return x, zs
else:
# Use scan during apply, threading through random seed so that it's
# unique for each layer.
def layer(carry: LayerStackCarry, scanned: LayerStackScanned):
rng = carry.rng
def getter(next_getter, value, context):
# Getter slices the full param at the current loop index.
trailing_dims = len(context.original_shape) + 1
assert value.shape[value.ndim - trailing_dims] == count, (
f'Attempting to use a parameter stack of size '
f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of '
f'size {count}.')
sliced_value = jax.lax.dynamic_index_in_dim(
value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False)
return next_getter(sliced_value)
with hk.experimental.custom_getter(getter):
if rng is None:
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
else:
rng, rng_ = jax.random.split(rng)
with hk.with_rng(rng_):
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
return LayerStackCarry(x=out_x, rng=rng), z
carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key())
scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32),
args_ys=args_ys)
carry, zs = hk.scan(
layer, carry, scanned, length=count, unroll=self._unroll)
return carry.x, zs
def _call_wrapped(self,
x: jnp.ndarray,
*args,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
raise NotImplementedError()
class _LayerStackNoState(_LayerStack):
"""_LayerStack impl with no per-layer state provided to the function."""
def __init__(self,
f: WrappedFn,
count: int,
unroll: int,
name: Optional[str] = None):
super().__init__(count=count, unroll=unroll, name=name)
_check_no_varargs(f)
self._f = f
@hk.transparent
def _call_wrapped(self, args, y):
del y
ret = self._f(*args)
if len(args) == 1:
# If the function takes a single argument, the wrapped function receives
# a tuple of length 1, and therefore it must return a tuple of length 1.
ret = (ret,)
return ret, None
class _LayerStackWithState(_LayerStack):
"""_LayerStack impl with per-layer state provided to the function."""
def __init__(self,
f: WrappedFn,
count: int,
unroll: int,
name: Optional[str] = None):
super().__init__(count=count, unroll=unroll, name=name)
self._f = f
@hk.transparent
def _call_wrapped(self, x, *args):
return self._f(x, *args)
def layer_stack(num_layers: int,
with_state=False,
unroll: int = 1,
name: Optional[str] = None):
"""Utility to wrap a Haiku function and recursively apply it to an input.
A function is valid if it uses only explicit position parameters, and
its return type matches its input type. The position parameters can be
arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note
that kwargs are not supported, neither are functions with variable number
of parameters (specified by `*args`).
If `with_state=False` then the new, wrapped function can be understood as
performing the following:
```
for i in range(num_layers):
x = f(x)
return x
```
And if `with_state=True`, assuming `f` takes two arguments on top of `x`:
```
for i in range(num_layers):
x, zs[i] = f(x, ys_0[i], ys_1[i])
return x, zs
```
The code using `layer_stack` for the above function would be:
```
def f(x, y_0, y_1):
...
return new_x, z
x, zs = layer_stack.layer_stack(num_layers,
with_state=True)(f)(x, ys_0, ys_1)
```
Crucially, any parameters created inside `f` will not be shared across
iterations.
Args:
num_layers: The number of times to iterate the wrapped function.
with_state: Whether or not to pass per-layer state to the wrapped function.
unroll: the unroll used by `scan`.
name: Name of the Haiku context.
Returns:
Callable that will produce a layer stack when called with a valid function.
"""
def iterate(f):
if with_state:
@functools.wraps(f)
def wrapped(x, *args):
for ys in args:
assert ys.shape[0] == num_layers
return _LayerStackWithState(
f, num_layers, unroll=unroll, name=name)(x, *args)
else:
_check_no_varargs(f)
@functools.wraps(f)
def wrapped(*args):
ret = _LayerStackNoState(
f, num_layers, unroll=unroll, name=name)(args, None)[0]
if len(args) == 1:
# If the function takes a single argument, we must also return a
# single value, and not a tuple of length 1.
ret = ret[0]
return ret
return wrapped
return iterate
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for layer_stack."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import scipy
from alphafold.model import layer_stack
# Suffixes applied by Haiku for repeated module names.
suffixes = [''] + [f'_{i}' for i in range(1, 100)]
def _slice_layers_params(layers_params):
sliced_layers_params = {}
for k, v in layers_params.items():
for inner_k in v:
for var_slice, suffix in zip(v[inner_k], suffixes):
k_new = k.split('/')[-1] + suffix
if k_new not in sliced_layers_params:
sliced_layers_params[k_new] = {}
sliced_layers_params[k_new][inner_k] = var_slice
return sliced_layers_params
class LayerStackTest(parameterized.TestCase):
@parameterized.parameters([1, 2, 4])
def test_layer_stack(self, unroll):
"""Compare layer_stack to the equivalent unrolled stack.
Tests that the layer_stack application of a Haiku layer function is
equivalent to repeatedly applying the layer function in an unrolled loop.
Args:
unroll: Number of unrolled layers.
"""
num_layers = 20
def inner_fn(x):
x += hk.Linear(100, name='linear1')(x)
x += hk.Linear(100, name='linear2')(x)
return x
def outer_fn_unrolled(x):
for _ in range(num_layers):
x = inner_fn(x)
return x
def outer_fn_layer_stack(x):
stack = layer_stack.layer_stack(num_layers, unroll=unroll)(inner_fn)
return stack(x)
unrolled_fn = hk.transform(outer_fn_unrolled)
layer_stack_fn = hk.transform(outer_fn_layer_stack)
x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])
rng_init = jax.random.PRNGKey(42)
params = layer_stack_fn.init(rng_init, x)
sliced_params = _slice_layers_params(params)
unrolled_pred = unrolled_fn.apply(sliced_params, None, x)
layer_stack_pred = layer_stack_fn.apply(params, None, x)
np.testing.assert_allclose(unrolled_pred, layer_stack_pred)
def test_layer_stack_multi_args(self):
"""Compare layer_stack to the equivalent unrolled stack.
Similar to `test_layer_stack`, but use a function that takes more than one
argument.
"""
num_layers = 20
def inner_fn(x, y):
x_out = x + hk.Linear(100, name='linear1')(y)
y_out = y + hk.Linear(100, name='linear2')(x)
return x_out, y_out
def outer_fn_unrolled(x, y):
for _ in range(num_layers):
x, y = inner_fn(x, y)
return x, y
def outer_fn_layer_stack(x, y):
stack = layer_stack.layer_stack(num_layers)(inner_fn)
return stack(x, y)
unrolled_fn = hk.transform(outer_fn_unrolled)
layer_stack_fn = hk.transform(outer_fn_layer_stack)
x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])
y = jax.random.uniform(jax.random.PRNGKey(1), [10, 256, 100])
rng_init = jax.random.PRNGKey(42)
params = layer_stack_fn.init(rng_init, x, y)
sliced_params = _slice_layers_params(params)
unrolled_x, unrolled_y = unrolled_fn.apply(sliced_params, None, x, y)
layer_stack_x, layer_stack_y = layer_stack_fn.apply(params, None, x, y)
np.testing.assert_allclose(unrolled_x, layer_stack_x)
np.testing.assert_allclose(unrolled_y, layer_stack_y)
def test_layer_stack_no_varargs(self):
"""Test an error is raised when using a function with varargs."""
class VarArgsModule(hk.Module):
"""When used, this module should cause layer_stack to raise an Error."""
def __call__(self, *args):
return args
class NoVarArgsModule(hk.Module):
"""This module should be fine to use with layer_stack."""
def __call__(self, x):
return x
def build_and_init_stack(module_class):
def stack_fn(x):
module = module_class()
return layer_stack.layer_stack(1)(module)(x)
stack = hk.without_apply_rng(hk.transform(stack_fn))
stack.init(jax.random.PRNGKey(1729), jnp.ones([5]))
build_and_init_stack(NoVarArgsModule)
with self.assertRaisesRegex(
ValueError, 'The function `f` should not have any `varargs`'):
build_and_init_stack(VarArgsModule)
@parameterized.parameters([1, 2, 4])
def test_layer_stack_grads(self, unroll):
"""Compare layer_stack gradients to the equivalent unrolled stack.
Tests that the layer_stack application of a Haiku layer function is
equivalent to repeatedly applying the layer function in an unrolled loop.
Args:
unroll: Number of unrolled layers.
"""
num_layers = 20
def inner_fn(x):
x += hk.Linear(100, name='linear1')(x)
x += hk.Linear(100, name='linear2')(x)
return x
def outer_fn_unrolled(x):
for _ in range(num_layers):
x = inner_fn(x)
return x
def outer_fn_layer_stack(x):
stack = layer_stack.layer_stack(num_layers, unroll=unroll)(inner_fn)
return stack(x)
unrolled_fn = hk.transform(outer_fn_unrolled)
layer_stack_fn = hk.transform(outer_fn_layer_stack)
x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])
rng_init = jax.random.PRNGKey(42)
params = layer_stack_fn.init(rng_init, x)
sliced_params = _slice_layers_params(params)
unrolled_grad = jax.grad(
lambda p, x: jnp.mean(unrolled_fn.apply(p, None, x)))(sliced_params, x)
layer_stack_grad = jax.grad(
lambda p, x: jnp.mean(layer_stack_fn.apply(p, None, x)))(params, x)
assert_fn = functools.partial(
np.testing.assert_allclose, atol=1e-4, rtol=1e-4)
jax.tree_multimap(assert_fn, unrolled_grad,
_slice_layers_params(layer_stack_grad))
def test_random(self):
"""Random numbers should be handled correctly."""
n = 100
@hk.transform
@layer_stack.layer_stack(n)
def add_random(x):
x = x + jax.random.normal(hk.next_rng_key())
return x
# Evaluate a bunch of times
key, *keys = jax.random.split(jax.random.PRNGKey(7), 1024 + 1)
params = add_random.init(key, 0.)
apply_fn = jax.jit(add_random.apply)
values = [apply_fn(params, key, 0.) for key in keys]
# Should be roughly N(0, sqrt(n))
cdf = scipy.stats.norm(scale=np.sqrt(n)).cdf
_, p = scipy.stats.kstest(values, cdf)
self.assertLess(0.3, p)
def test_threading(self):
"""Test @layer_stack when the function gets per-layer state."""
n = 5
@layer_stack.layer_stack(n, with_state=True)
def f(x, y):
x = x + y * jax.nn.one_hot(y, len(x)) / 10
return x, 2 * y
@hk.without_apply_rng
@hk.transform
def g(x, ys):
x, zs = f(x, ys)
# Check here to catch issues at init time
self.assertEqual(zs.shape, (n,))
return x, zs
rng = jax.random.PRNGKey(7)
x = np.zeros(n)
ys = np.arange(n).astype(np.float32)
params = g.init(rng, x, ys)
x, zs = g.apply(params, x, ys)
self.assertTrue(np.allclose(x, [0, .1, .2, .3, .4]))
self.assertTrue(np.all(zs == 2 * ys))
def test_nested_stacks(self):
def stack_fn(x):
def layer_fn(x):
return hk.Linear(100)(x)
outer_fn = layer_stack.layer_stack(10)(layer_fn)
layer_outer = layer_stack.layer_stack(20)(outer_fn)
return layer_outer(x)
hk_mod = hk.transform(stack_fn)
apply_rng, init_rng = jax.random.split(jax.random.PRNGKey(0))
params = hk_mod.init(init_rng, jnp.zeros([10, 100]))
hk_mod.apply(params, apply_rng, jnp.zeros([10, 100]))
p, = params.values()
assert p['w'].shape == (10, 20, 100, 100)
assert p['b'].shape == (10, 20, 100)
def test_with_state_multi_args(self):
"""Test layer_stack with state with multiple arguments."""
width = 4
batch_size = 5
stack_height = 3
def f_with_multi_args(x, a, b):
return hk.Linear(
width, w_init=hk.initializers.Constant(
jnp.eye(width)))(x) * a + b, None
@hk.without_apply_rng
@hk.transform
def hk_fn(x):
return layer_stack.layer_stack(
stack_height,
with_state=True)(f_with_multi_args)(x, jnp.full([stack_height], 2.),
jnp.ones([stack_height]))
x = jnp.zeros([batch_size, width])
key_seq = hk.PRNGSequence(19)
params = hk_fn.init(next(key_seq), x)
output, z = hk_fn.apply(params, x)
self.assertIsNone(z)
self.assertEqual(output.shape, (batch_size, width))
np.testing.assert_equal(output, np.full([batch_size, width], 7.))
def test_with_container_state(self):
width = 2
batch_size = 2
stack_height = 3
def f_with_container_state(x):
hk_layer = hk.Linear(
width, w_init=hk.initializers.Constant(jnp.eye(width)))
layer_output = hk_layer(x)
layer_state = {
'raw_output': layer_output,
'output_projection': jnp.sum(layer_output)
}
return layer_output + jnp.ones_like(layer_output), layer_state
@hk.without_apply_rng
@hk.transform
def hk_fn(x):
return layer_stack.layer_stack(
stack_height,
with_state=True)(f_with_container_state)(x)
x = jnp.zeros([batch_size, width])
key_seq = hk.PRNGSequence(19)
params = hk_fn.init(next(key_seq), x)
output, z = hk_fn.apply(params, x)
self.assertEqual(z['raw_output'].shape, (stack_height, batch_size, width))
self.assertEqual(output.shape, (batch_size, width))
self.assertEqual(z['output_projection'].shape, (stack_height,))
np.testing.assert_equal(np.sum(z['output_projection']), np.array(12.))
np.testing.assert_equal(
np.all(z['raw_output'] == np.array([0., 1., 2.])[..., None, None]),
np.array(True))
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""lDDT protein distance score."""
import jax.numpy as jnp
def lddt(predicted_points,
true_points,
true_points_mask,
cutoff=15.,
per_residue=False):
"""Measure (approximate) lDDT for a batch of coordinates.
lDDT reference:
Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local
superposition-free score for comparing protein structures and models using
distance difference tests. Bioinformatics 29, 2722–2728 (2013).
lDDT is a measure of the difference between the true distance matrix and the
distance matrix of the predicted points. The difference is computed only on
points closer than cutoff *in the true structure*.
This function does not compute the exact lDDT value that the original paper
describes because it does not include terms for physical feasibility
(e.g. bond length violations). Therefore this is only an approximate
lDDT score.
Args:
predicted_points: (batch, length, 3) array of predicted 3D points
true_points: (batch, length, 3) array of true 3D points
true_points_mask: (batch, length, 1) binary-valued float array. This mask
should be 1 for points that exist in the true points.
cutoff: Maximum distance for a pair of points to be included
per_residue: If true, return score for each residue. Note that the overall
lDDT is not exactly the mean of the per_residue lDDT's because some
residues have more contacts than others.
Returns:
An (approximate, see above) lDDT score in the range 0-1.
"""
assert len(predicted_points.shape) == 3
assert predicted_points.shape[-1] == 3
assert true_points_mask.shape[-1] == 1
assert len(true_points_mask.shape) == 3
# Compute true and predicted distance matrices.
dmat_true = jnp.sqrt(1e-10 + jnp.sum(
(true_points[:, :, None] - true_points[:, None, :])**2, axis=-1))
dmat_predicted = jnp.sqrt(1e-10 + jnp.sum(
(predicted_points[:, :, None] -
predicted_points[:, None, :])**2, axis=-1))
dists_to_score = (
(dmat_true < cutoff).astype(jnp.float32) * true_points_mask *
jnp.transpose(true_points_mask, [0, 2, 1]) *
(1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction.
)
# Shift unscored distances to be far away.
dist_l1 = jnp.abs(dmat_true - dmat_predicted)
# True lDDT uses a number of fixed bins.
# We ignore the physical plausibility correction to lDDT, though.
score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) +
(dist_l1 < 1.0).astype(jnp.float32) +
(dist_l1 < 2.0).astype(jnp.float32) +
(dist_l1 < 4.0).astype(jnp.float32))
# Normalize over the appropriate axes.
reduce_axes = (-1,) if per_residue else (-2, -1)
norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes))
score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes))
return score
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lddt."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from alphafold.model import lddt
class LddtTest(parameterized.TestCase, absltest.TestCase):
@parameterized.named_parameters(
('same',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[1, 1, 1]),
('all_shifted',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[-1, 0, 0], [4, 0, 0], [9, 0, 0]],
[1, 1, 1]),
('all_rotated',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[0, 0, 0], [0, 5, 0], [0, 10, 0]],
[1, 1, 1]),
('half_a_dist',
[[0, 0, 0], [5, 0, 0]],
[[0, 0, 0], [5.5-1e-5, 0, 0]],
[1, 1]),
('one_a_dist',
[[0, 0, 0], [5, 0, 0]],
[[0, 0, 0], [6-1e-5, 0, 0]],
[0.75, 0.75]),
('two_a_dist',
[[0, 0, 0], [5, 0, 0]],
[[0, 0, 0], [7-1e-5, 0, 0]],
[0.5, 0.5]),
('four_a_dist',
[[0, 0, 0], [5, 0, 0]],
[[0, 0, 0], [9-1e-5, 0, 0]],
[0.25, 0.25],),
('five_a_dist',
[[0, 0, 0], [16-1e-5, 0, 0]],
[[0, 0, 0], [11, 0, 0]],
[0, 0]),
('no_pairs',
[[0, 0, 0], [20, 0, 0]],
[[0, 0, 0], [25-1e-5, 0, 0]],
[1, 1]),
)
def test_lddt(
self, predicted_pos, true_pos, exp_lddt):
predicted_pos = np.array([predicted_pos], dtype=np.float32)
true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32)
true_pos = np.array([true_pos], dtype=np.float32)
cutoff = 15.0
per_residue = True
result = lddt.lddt(
predicted_pos, true_pos, true_points_mask, cutoff,
per_residue)
np.testing.assert_almost_equal(result, [exp_lddt], decimal=4)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specialized mapping functions."""
import functools
from typing import Any, Callable, Optional, Sequence, Union
import haiku as hk
import jax
import jax.numpy as jnp
PYTREE = Any
PYTREE_JAX_ARRAY = Any
partial = functools.partial
PROXY = object()
def _maybe_slice(array, i, slice_size, axis):
if axis is PROXY:
return array
else:
return jax.lax.dynamic_slice_in_dim(
array, i, slice_size=slice_size, axis=axis)
def _maybe_get_size(array, axis):
if axis == PROXY:
return -1
else:
return array.shape[axis]
def _expand_axes(axes, values, name='sharded_apply'):
values_tree_def = jax.tree_flatten(values)[1]
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
# Replace None's with PROXY
flat_axes = [PROXY if x is None else x for x in flat_axes]
return jax.tree_unflatten(values_tree_def, flat_axes)
def sharded_map(
fun: Callable[..., PYTREE_JAX_ARRAY],
shard_size: Union[int, None] = 1,
in_axes: Union[int, PYTREE] = 0,
out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]:
"""Sharded vmap.
Maps `fun` over axes, in a way similar to vmap, but does so in shards of
`shard_size`. This allows a smooth trade-off between memory usage
(as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: integer or pytree denoting to what axis in the output the mapped
over axis maps.
Returns:
function with smap applied.
"""
vmapped_fun = hk.vmap(fun, in_axes, out_axes)
return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
def sharded_apply(
fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic
shard_size: Union[int, None] = 1,
in_axes: Union[int, PYTREE] = 0,
out_axes: Union[int, PYTREE] = 0,
new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]:
"""Sharded apply.
Applies `fun` over shards to axes, in a way similar to vmap,
but does so in shards of `shard_size`. Shards are stacked after.
This allows a smooth trade-off between
memory usage (as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: integer or pytree denoting to what axis in the output the mapped
over axis maps.
new_out_axes: whether to stack outputs on new axes. This assumes that the
output sizes for each shard (including the possible remainder shard) are
the same.
Returns:
function with smap applied.
"""
docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} '
'but with additional array axes over which {fun} is mapped.')
if new_out_axes:
raise NotImplementedError('New output axes not yet implemented.')
# shard size None denotes no sharding
if shard_size is None:
return fun
@jax.util.wraps(fun, docstr=docstr)
def mapped_fn(*args):
# Expand in axes and Determine Loop range
in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes)
num_extra_shards = (in_size - 1) // shard_size
# Fix Up if necessary
last_shard_size = in_size % shard_size
last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
def apply_fun_to_slice(slice_start, slice_size):
input_slice = jax.tree_multimap(
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
), args, in_axes_)
return fun(*input_slice)
remainder_shape_dtype = hk.eval_shape(
partial(apply_fun_to_slice, 0, last_shard_size))
out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)
if num_extra_shards > 0:
regular_shard_shape_dtype = hk.eval_shape(
partial(apply_fun_to_slice, 0, shard_size))
shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype)
def make_output_shape(axis, shard_shape, remainder_shape):
return shard_shape[:axis] + (
shard_shape[axis] * num_extra_shards +
remainder_shape[axis],) + shard_shape[axis + 1:]
out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes,
out_shapes)
# Calls dynamic Update slice with different argument order
# This is here since tree_multimap only works with positional arguments
def dynamic_update_slice_in_dim(full_array, update, axis, i):
return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
def compute_shard(outputs, slice_start, slice_size):
slice_out = apply_fun_to_slice(slice_start, slice_size)
update_slice = partial(
dynamic_update_slice_in_dim, i=slice_start)
return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_)
def scan_iteration(outputs, i):
new_outputs = compute_shard(outputs, i, shard_size)
return new_outputs, ()
slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)
def allocate_buffer(dtype, shape):
return jnp.zeros(shape, dtype=dtype)
outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes)
if slice_starts.shape[0] > 0:
outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
if last_shard_size != shard_size:
remainder_start = in_size - last_shard_size
outputs = compute_shard(outputs, remainder_start, last_shard_size)
return outputs
return mapped_fn
def inference_subbatch(
module: Callable[..., PYTREE_JAX_ARRAY],
subbatch_size: int,
batched_args: Sequence[PYTREE_JAX_ARRAY],
nonbatched_args: Sequence[PYTREE_JAX_ARRAY],
low_memory: bool = True,
input_subbatch_dim: int = 0,
output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY:
"""Run through subbatches (like batch apply but with split and concat)."""
assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test
if not low_memory:
args = list(batched_args) + list(nonbatched_args)
return module(*args)
if output_subbatch_dim is None:
output_subbatch_dim = input_subbatch_dim
def run_module(*batched_args):
args = list(batched_args) + list(nonbatched_args)
return module(*args)
sharded_module = sharded_apply(run_module,
shard_size=subbatch_size,
in_axes=input_subbatch_dim,
out_axes=output_subbatch_dim)
return sharded_module(*batched_args)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code for constructing the model."""
from typing import Any, Mapping, Optional, Union
from absl import logging
import haiku as hk
import jax
import ml_collections
import numpy as np
import tensorflow.compat.v1 as tf
import tree
from alphafold.common import confidence
from alphafold.model import features
from alphafold.model import modules
def get_confidence_metrics(
prediction_result: Mapping[str, Any]) -> Mapping[str, Any]:
"""Post processes prediction_result to get confidence metrics."""
confidence_metrics = {}
confidence_metrics['plddt'] = confidence.compute_plddt(
prediction_result['predicted_lddt']['logits'])
if 'predicted_aligned_error' in prediction_result:
confidence_metrics.update(confidence.compute_predicted_aligned_error(
prediction_result['predicted_aligned_error']['logits'],
prediction_result['predicted_aligned_error']['breaks']))
confidence_metrics['ptm'] = confidence.predicted_tm_score(
prediction_result['predicted_aligned_error']['logits'],
prediction_result['predicted_aligned_error']['breaks'])
return confidence_metrics
class RunModel:
"""Container for JAX model."""
def __init__(self,
config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
self.config = config
self.params = params
def _forward_fn(batch):
model = modules.AlphaFold(self.config.model)
return model(
batch,
is_training=False,
compute_loss=False,
ensemble_representations=True)
self.apply = jax.jit(hk.transform(_forward_fn).apply)
self.init = jax.jit(hk.transform(_forward_fn).init)
def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
"""Initializes the model parameters.
If none were provided when this class was instantiated then the parameters
are randomly initialized.
Args:
feat: A dictionary of NumPy feature arrays as output by
RunModel.process_features.
random_seed: A random seed to use to initialize the parameters if none
were set when this class was initialized.
"""
if not self.params:
# Init params randomly.
rng = jax.random.PRNGKey(random_seed)
self.params = hk.data_structures.to_mutable_dict(
self.init(rng, feat))
logging.warning('Initialized parameters randomly')
def process_features(
self,
raw_features: Union[tf.train.Example, features.FeatureDict],
random_seed: int) -> features.FeatureDict:
"""Processes features to prepare for feeding them into the model.
Args:
raw_features: The output of the data pipeline either as a dict of NumPy
arrays or as a tf.train.Example.
random_seed: The random seed to use when processing the features.
Returns:
A dict of NumPy feature arrays suitable for feeding into the model.
"""
if isinstance(raw_features, dict):
return features.np_example_to_features(
np_example=raw_features,
config=self.config,
random_seed=random_seed)
else:
return features.tf_example_to_features(
tf_example=raw_features,
config=self.config,
random_seed=random_seed)
def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct:
self.init_params(feat)
logging.info('Running eval_shape with shape(feat) = %s',
tree.map_structure(lambda x: x.shape, feat))
shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat)
logging.info('Output shape was %s', shape)
return shape
def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]:
"""Makes a prediction by inferencing the model on the provided features.
Args:
feat: A dictionary of NumPy feature arrays as output by
RunModel.process_features.
Returns:
A dictionary of model outputs.
"""
self.init_params(feat)
logging.info('Running predict with shape(feat) = %s',
tree.map_structure(lambda x: x.shape, feat))
result = self.apply(self.params, jax.random.PRNGKey(0), feat)
# This block is to ensure benchmark timings are accurate. Some blocking is
# already happening when computing get_confidence_metrics, and this ensures
# all outputs are blocked on.
jax.tree_map(lambda x: x.block_until_ready(), result)
result.update(get_confidence_metrics(result))
logging.info('Output shape was %s',
tree.map_structure(lambda x: x.shape, result))
return result
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and code used in the core part of AlphaFold.
The structure generation code is in 'folding.py'.
"""
import functools
import haiku as hk
import jax
import jax.numpy as jnp
from alphafold.common import residue_constants
from alphafold.model import all_atom
from alphafold.model import common_modules
from alphafold.model import folding
from alphafold.model import layer_stack
from alphafold.model import lddt
from alphafold.model import mapping
from alphafold.model import prng
from alphafold.model import quat_affine
from alphafold.model import utils
def softmax_cross_entropy(logits, labels):
"""Computes softmax cross entropy given logits and one-hot class labels."""
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
return jnp.asarray(loss)
def sigmoid_cross_entropy(logits, labels):
"""Computes sigmoid cross entropy given logits and multiple class labels."""
log_p = jax.nn.log_sigmoid(logits)
# log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable
log_not_p = jax.nn.log_sigmoid(-logits)
loss = -labels * log_p - (1. - labels) * log_not_p
return jnp.asarray(loss)
def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None):
"""Applies dropout to a tensor."""
if is_training and rate != 0.0:
shape = list(tensor.shape)
if broadcast_dim is not None:
shape[broadcast_dim] = 1
keep_rate = 1.0 - rate
keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape)
return keep * tensor / keep_rate
else:
return tensor
def dropout_wrapper(module,
input_act,
mask,
safe_key,
global_config,
output_act=None,
is_training=True,
**kwargs):
"""Applies module + dropout + residual update."""
if output_act is None:
output_act = input_act
gc = global_config
residual = module(input_act, mask, is_training=is_training, **kwargs)
dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate
if module.config.shared_dropout:
if module.config.orientation == 'per_row':
broadcast_dim = 0
else:
broadcast_dim = 1
else:
broadcast_dim = None
residual = apply_dropout(tensor=residual,
safe_key=safe_key,
rate=dropout_rate,
is_training=is_training,
broadcast_dim=broadcast_dim)
new_act = output_act + residual
return new_act
def create_extra_msa_feature(batch):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Arguments:
batch: a dictionary with the following keys:
* 'extra_msa': [N_extra_seq, N_res] MSA that wasn't selected as a cluster
centre. Note, that this is not one-hot encoded.
* 'extra_has_deletion': [N_extra_seq, N_res] Whether there is a deletion to
the left of each position in the extra MSA.
* 'extra_deletion_value': [N_extra_seq, N_res] The number of deletions to
the left of each position in the extra MSA.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
msa_1hot = jax.nn.one_hot(batch['extra_msa'], 23)
msa_feat = [msa_1hot,
jnp.expand_dims(batch['extra_has_deletion'], axis=-1),
jnp.expand_dims(batch['extra_deletion_value'], axis=-1)]
return jnp.concatenate(msa_feat, axis=-1)
class AlphaFoldIteration(hk.Module):
"""A single recycling iteration of AlphaFold architecture.
Computes ensembled (averaged) representations from the provided features.
These representations are then passed to the various heads
that have been requested by the configuration file. Each head also returns a
loss which is combined as a weighted sum to produce the total loss.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22
"""
def __init__(self, config, global_config, name='alphafold_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
ensembled_batch,
non_ensembled_batch,
is_training,
compute_loss=False,
ensemble_representations=False,
return_representations=False):
num_ensemble = jnp.asarray(ensembled_batch['seq_length'].shape[0])
if not ensemble_representations:
assert ensembled_batch['seq_length'].shape[0] == 1
def slice_batch(i):
b = {k: v[i] for k, v in ensembled_batch.items()}
b.update(non_ensembled_batch)
return b
# Compute representations for each batch element and average.
evoformer_module = EmbeddingsAndEvoformer(
self.config.embeddings_and_evoformer, self.global_config)
batch0 = slice_batch(0)
representations = evoformer_module(batch0, is_training)
# MSA representations are not ensembled so
# we don't pass tensor into the loop.
msa_representation = representations['msa']
del representations['msa']
# Average the representations (except MSA) over the batch dimension.
if ensemble_representations:
def body(x):
"""Add one element to the representations ensemble."""
i, current_representations = x
feats = slice_batch(i)
representations_update = evoformer_module(
feats, is_training)
new_representations = {}
for k in current_representations:
new_representations[k] = (
current_representations[k] + representations_update[k])
return i+1, new_representations
if hk.running_init():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
_, representations = body((1, representations))
else:
_, representations = hk.while_loop(
lambda x: x[0] < num_ensemble,
body,
(1, representations))
for k in representations:
if k != 'msa':
representations[k] /= num_ensemble.astype(representations[k].dtype)
representations['msa'] = msa_representation
batch = batch0 # We are not ensembled from here on.
heads = {}
for head_name, head_config in sorted(self.config.heads.items()):
if not head_config.weight:
continue # Do not instantiate zero-weight heads.
head_factory = {
'masked_msa': MaskedMsaHead,
'distogram': DistogramHead,
'structure_module': functools.partial(
folding.StructureModule, compute_loss=compute_loss),
'predicted_lddt': PredictedLDDTHead,
'predicted_aligned_error': PredictedAlignedErrorHead,
'experimentally_resolved': ExperimentallyResolvedHead,
}[head_name]
heads[head_name] = (head_config,
head_factory(head_config, self.global_config))
total_loss = 0.
ret = {}
ret['representations'] = representations
def loss(module, head_config, ret, name, filter_ret=True):
if filter_ret:
value = ret[name]
else:
value = ret
loss_output = module.loss(value, batch)
ret[name].update(loss_output)
loss = head_config.weight * ret[name]['loss']
return loss
for name, (head_config, module) in heads.items():
# Skip PredictedLDDTHead and PredictedAlignedErrorHead until
# StructureModule is executed.
if name in ('predicted_lddt', 'predicted_aligned_error'):
continue
else:
ret[name] = module(representations, batch, is_training)
if compute_loss:
total_loss += loss(module, head_config, ret, name)
if self.config.heads.get('predicted_lddt.weight', 0.0):
# Add PredictedLDDTHead after StructureModule executes.
name = 'predicted_lddt'
# Feed all previous results to give access to structure_module result.
head_config, module = heads[name]
ret[name] = module(representations, batch, is_training)
if compute_loss:
total_loss += loss(module, head_config, ret, name, filter_ret=False)
if ('predicted_aligned_error' in self.config.heads
and self.config.heads.get('predicted_aligned_error.weight', 0.0)):
# Add PredictedAlignedErrorHead after StructureModule executes.
name = 'predicted_aligned_error'
# Feed all previous results to give access to structure_module result.
head_config, module = heads[name]
ret[name] = module(representations, batch, is_training)
if compute_loss:
total_loss += loss(module, head_config, ret, name, filter_ret=False)
if compute_loss:
return ret, total_loss
else:
return ret
class AlphaFold(hk.Module):
"""AlphaFold model with recycling.
Jumper et al. (2021) Suppl. Alg. 2 "Inference"
"""
def __init__(self, config, name='alphafold'):
super().__init__(name=name)
self.config = config
self.global_config = config.global_config
def __call__(
self,
batch,
is_training,
compute_loss=False,
ensemble_representations=False,
return_representations=False):
"""Run the AlphaFold model.
Arguments:
batch: Dictionary with inputs to the AlphaFold model.
is_training: Whether the system is in training or inference mode.
compute_loss: Whether to compute losses (requires extra features
to be present in the batch and knowing the true structure).
ensemble_representations: Whether to use ensembling of representations.
return_representations: Whether to also return the intermediate
representations.
Returns:
When compute_loss is True:
a tuple of loss and output of AlphaFoldIteration.
When compute_loss is False:
just output of AlphaFoldIteration.
The output of AlphaFoldIteration is a nested dictionary containing
predictions from the various heads.
"""
impl = AlphaFoldIteration(self.config, self.global_config)
batch_size, num_residues = batch['aatype'].shape
def get_prev(ret):
new_prev = {
'prev_pos':
ret['structure_module']['final_atom_positions'],
'prev_msa_first_row': ret['representations']['msa_first_row'],
'prev_pair': ret['representations']['pair'],
}
return jax.tree_map(jax.lax.stop_gradient, new_prev)
def do_call(prev,
recycle_idx,
compute_loss=compute_loss):
if self.config.resample_msa_in_recycling:
num_ensemble = batch_size // (self.config.num_recycle + 1)
def slice_recycle_idx(x):
start = recycle_idx * num_ensemble
size = num_ensemble
return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0)
ensembled_batch = jax.tree_map(slice_recycle_idx, batch)
else:
num_ensemble = batch_size
ensembled_batch = batch
non_ensembled_batch = jax.tree_map(lambda x: x, prev)
return impl(
ensembled_batch=ensembled_batch,
non_ensembled_batch=non_ensembled_batch,
is_training=is_training,
compute_loss=compute_loss,
ensemble_representations=ensemble_representations)
if self.config.num_recycle:
emb_config = self.config.embeddings_and_evoformer
prev = {
'prev_pos': jnp.zeros(
[num_residues, residue_constants.atom_type_num, 3]),
'prev_msa_first_row': jnp.zeros(
[num_residues, emb_config.msa_channel]),
'prev_pair': jnp.zeros(
[num_residues, num_residues, emb_config.pair_channel]),
}
if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch.
# The value for each ensemble batch is the same, so arbitrarily taking
# 0-th.
num_iter = batch['num_iter_recycling'][0]
# Add insurance that we will not run more
# recyclings than the model is configured to run.
num_iter = jnp.minimum(num_iter, self.config.num_recycle)
else:
# Eval mode or tests: use the maximum number of iterations.
num_iter = self.config.num_recycle
body = lambda x: (x[0] + 1, # pylint: disable=g-long-lambda
get_prev(do_call(x[1], recycle_idx=x[0],
compute_loss=False)))
if hk.running_init():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
_, prev = body((0, prev))
else:
_, prev = hk.while_loop(
lambda x: x[0] < num_iter,
body,
(0, prev))
else:
prev = {}
num_iter = 0
ret = do_call(prev=prev, recycle_idx=num_iter)
if compute_loss:
ret = ret[0], [ret[1]]
if not return_representations:
del (ret[0] if compute_loss else ret)['representations'] # pytype: disable=unsupported-operands
return ret
class TemplatePairStack(hk.Module):
"""Pair stack for the templates.
Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack"
"""
def __init__(self, config, global_config, name='template_pair_stack'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, pair_act, pair_mask, is_training, safe_key=None):
"""Builds TemplatePairStack module.
Arguments:
pair_act: Pair activations for single template, shape [N_res, N_res, c_t].
pair_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
safe_key: Safe key object encapsulating the random number generation key.
Returns:
Updated pair_act, shape [N_res, N_res, c_t].
"""
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
gc = self.global_config
c = self.config
if not c.num_block:
return pair_act
def block(x):
"""One block of the template pair stack."""
pair_act, safe_key = x
dropout_wrapper_fn = functools.partial(
dropout_wrapper, is_training=is_training, global_config=gc)
safe_key, *sub_keys = safe_key.split(6)
sub_keys = iter(sub_keys)
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
name='triangle_multiplication_outgoing'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_incoming, gc,
name='triangle_multiplication_incoming'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
Transition(c.pair_transition, gc, name='pair_transition'),
pair_act,
pair_mask,
next(sub_keys))
return pair_act, safe_key
if gc.use_remat:
block = hk.remat(block)
res_stack = layer_stack.layer_stack(c.num_block)(block)
pair_act, safe_key = res_stack((pair_act, safe_key))
return pair_act
class Transition(hk.Module):
"""Transition layer.
Jumper et al. (2021) Suppl. Alg. 9 "MSATransition"
Jumper et al. (2021) Suppl. Alg. 15 "PairTransition"
"""
def __init__(self, config, global_config, name='transition_block'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
"""Builds Transition module.
Arguments:
act: A tensor of queries of size [batch_size, N_res, N_channel].
mask: A tensor denoting the mask of size [batch_size, N_res].
is_training: Whether the module is in training mode.
Returns:
A float32 tensor of size [batch_size, N_res, N_channel].
"""
_, _, nc = act.shape
num_intermediate = int(nc * self.config.num_intermediate_factor)
mask = jnp.expand_dims(mask, axis=-1)
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='input_layer_norm')(
act)
transition_module = hk.Sequential([
common_modules.Linear(
num_intermediate,
initializer='relu',
name='transition1'), jax.nn.relu,
common_modules.Linear(
nc,
initializer=utils.final_init(self.global_config),
name='transition2')
])
act = mapping.inference_subbatch(
transition_module,
self.global_config.subbatch_size,
batched_args=[act],
nonbatched_args=[],
low_memory=not is_training)
return act
def glorot_uniform():
return hk.initializers.VarianceScaling(scale=1.0,
mode='fan_avg',
distribution='uniform')
class Attention(hk.Module):
"""Multihead attention."""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, bias, nonbatched_bias=None):
"""Builds Attention module.
Arguments:
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
m_data: A tensor of memories from which the keys and values are
projected, shape [batch_size, N_keys, m_channels].
bias: A bias for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
Returns:
A float32 tensor of shape [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim),
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim),
init=glorot_uniform())
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
weights = jax.nn.softmax(logits)
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
gating_weights) + gating_bias
gate_values = jax.nn.sigmoid(gate_values)
weighted_avg *= gate_values
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
init=hk.initializers.Constant(0.0))
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
return output
class GlobalAttention(hk.Module):
"""Global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7
"""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, q_mask, bias):
"""Builds GlobalAttention module.
Arguments:
q_data: A tensor of queries with size [batch_size, N_queries,
q_channels]
m_data: A tensor of memories from which the keys and values
projected. Size [batch_size, N_keys, m_channels]
q_mask: A binary mask for q_data with zeros in the padded sequence
elements and ones otherwise. Size [batch_size, N_queries, q_channels]
(or broadcastable to this shape).
bias: A bias for the attention.
Returns:
A float32 tensor of size [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], key_dim),
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], value_dim),
init=glorot_uniform())
v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
q_avg = utils.mask_mean(q_mask, q_data, axis=1)
q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ac->bkc', m_data, k_weights)
bias = (1e9 * (q_mask[:, None, :, 0] - 1.))
logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias
weights = jax.nn.softmax(logits)
weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
init=hk.initializers.Constant(0.0))
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
gate_values = jax.nn.sigmoid(gate_values + gating_bias)
weighted_avg = weighted_avg[:, None] * gate_values
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
else:
output = jnp.einsum('bhc,hco->bo', weighted_avg, o_weights) + o_bias
output = output[:, None]
return output
class MSARowAttentionWithPairBias(hk.Module):
"""MSA per-row attention biased by the pair representation.
Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias"
"""
def __init__(self, config, global_config,
name='msa_row_attention_with_pair_bias'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
pair_act,
is_training=False):
"""Builds MSARowAttentionWithPairBias module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
pair_act: [N_res, N_res, c_z] pair representation.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_row'
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
pair_act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='feat_2d_norm')(
pair_act)
init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1]))
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
attn_mod = Attention(
c, self.global_config, msa_act.shape[-1])
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, bias],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
return msa_act
class MSAColumnAttention(hk.Module):
"""MSA per-column attention.
Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention"
"""
def __init__(self, config, global_config, name='msa_column_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
is_training=False):
"""Builds MSAColumnAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m]
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_column'
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
attn_mod = Attention(
c, self.global_config, msa_act.shape[-1])
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, bias],
nonbatched_args=[],
low_memory=not is_training)
msa_act = jnp.swapaxes(msa_act, -2, -3)
return msa_act
class MSAColumnGlobalAttention(hk.Module):
"""MSA per-column global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention"
"""
def __init__(self, config, global_config, name='msa_column_global_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
is_training=False):
"""Builds MSAColumnGlobalAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_column'
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
attn_mod = GlobalAttention(
c, self.global_config, msa_act.shape[-1],
name='attention')
# [N_seq, N_res, 1]
msa_mask = jnp.expand_dims(msa_mask, axis=-1)
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, msa_mask, bias],
nonbatched_args=[],
low_memory=not is_training)
msa_act = jnp.swapaxes(msa_act, -2, -3)
return msa_act
class TriangleAttention(hk.Module):
"""Triangle Attention.
Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode"
Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode"
"""
def __init__(self, config, global_config, name='triangle_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, pair_act, pair_mask, is_training=False):
"""Builds TriangleAttention module.
Arguments:
pair_act: [N_res, N_res, c_z] pair activations tensor
pair_mask: [N_res, N_res] mask of non-padded regions in the tensor.
is_training: Whether the module is in training mode.
Returns:
Update to pair_act, shape [N_res, N_res, c_z].
"""
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
pair_act = hk.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
pair_act)
init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1]))
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
attn_mod = Attention(
c, self.global_config, pair_act.shape[-1])
pair_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[pair_act, pair_act, bias],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
return pair_act
class MaskedMsaHead(hk.Module):
"""Head to predict MSA at the masked locations.
The MaskedMsaHead employs a BERT-style objective to reconstruct a masked
version of the full MSA, based on a linear projection of
the MSA representation.
Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction"
"""
def __init__(self, config, global_config, name='masked_msa_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds MaskedMsaHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'msa': MSA representation, shape [N_seq, N_res, c_m].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* 'logits': logits of shape [N_seq, N_res, N_aatype] with
(unnormalized) log probabilies of predicted aatype at position.
"""
del batch
logits = common_modules.Linear(
self.config.num_output,
initializer=utils.final_init(self.global_config),
name='logits')(
representations['msa'])
return dict(logits=logits)
def loss(self, value, batch):
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(batch['true_msa'], num_classes=23),
logits=value['logits'])
loss = (jnp.sum(errors * batch['bert_mask'], axis=(-2, -1)) /
(1e-8 + jnp.sum(batch['bert_mask'], axis=(-2, -1))))
return {'loss': loss}
class PredictedLDDTHead(hk.Module):
"""Head to predict the per-residue LDDT to be used as a confidence measure.
Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)"
Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca"
"""
def __init__(self, config, global_config, name='predicted_lddt_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds ExperimentallyResolvedHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'structure_module': Single representation from the structure module,
shape [N_res, c_s].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing :
* 'logits': logits of shape [N_res, N_bins] with
(unnormalized) log probabilies of binned predicted lDDT.
"""
act = representations['structure_module']
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='input_layer_norm')(
act)
act = common_modules.Linear(
self.config.num_channels,
initializer='relu',
name='act_0')(
act)
act = jax.nn.relu(act)
act = common_modules.Linear(
self.config.num_channels,
initializer='relu',
name='act_1')(
act)
act = jax.nn.relu(act)
logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='logits')(
act)
# Shape (batch_size, num_res, num_bins)
return dict(logits=logits)
def loss(self, value, batch):
# Shape (num_res, 37, 3)
pred_all_atom_pos = value['structure_module']['final_atom_positions']
# Shape (num_res, 37, 3)
true_all_atom_pos = batch['all_atom_positions']
# Shape (num_res, 37)
all_atom_mask = batch['all_atom_mask']
# Shape (num_res,)
lddt_ca = lddt.lddt(
# Shape (batch_size, num_res, 3)
predicted_points=pred_all_atom_pos[None, :, 1, :],
# Shape (batch_size, num_res, 3)
true_points=true_all_atom_pos[None, :, 1, :],
# Shape (batch_size, num_res, 1)
true_points_mask=all_atom_mask[None, :, 1:2].astype(jnp.float32),
cutoff=15.,
per_residue=True)[0]
lddt_ca = jax.lax.stop_gradient(lddt_ca)
num_bins = self.config.num_bins
bin_index = jnp.floor(lddt_ca * num_bins).astype(jnp.int32)
# protect against out of range for lddt_ca == 1
bin_index = jnp.minimum(bin_index, num_bins - 1)
lddt_ca_one_hot = jax.nn.one_hot(bin_index, num_classes=num_bins)
# Shape (num_res, num_channel)
logits = value['predicted_lddt']['logits']
errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits)
# Shape (num_res,)
mask_ca = all_atom_mask[:, residue_constants.atom_order['CA']]
mask_ca = mask_ca.astype(jnp.float32)
loss = jnp.sum(errors * mask_ca) / (jnp.sum(mask_ca) + 1e-8)
if self.config.filter_by_resolution:
# NMR & distillation have resolution = 0
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
class PredictedAlignedErrorHead(hk.Module):
"""Head to predict the distance errors in the backbone alignment frames.
Can be used to compute predicted TM-Score.
Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction"
"""
def __init__(self, config, global_config,
name='predicted_aligned_error_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds PredictedAlignedErrorHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'pair': pair representation, shape [N_res, N_res, c_z].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* logits: logits for aligned error, shape [N_res, N_res, N_bins].
* bin_breaks: array containing bin breaks, shape [N_bins - 1].
"""
act = representations['pair']
# Shape (num_res, num_res, num_bins)
logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='logits')(act)
# Shape (num_bins,)
breaks = jnp.linspace(
0., self.config.max_error_bin, self.config.num_bins - 1)
return dict(logits=logits, breaks=breaks)
def loss(self, value, batch):
# Shape (num_res, 7)
predicted_affine = quat_affine.QuatAffine.from_tensor(
value['structure_module']['final_affines'])
# Shape (num_res, 7)
true_affine = quat_affine.QuatAffine.from_tensor(
batch['backbone_affine_tensor'])
# Shape (num_res)
mask = batch['backbone_affine_mask']
# Shape (num_res, num_res)
square_mask = mask[:, None] * mask[None, :]
num_bins = self.config.num_bins
# (1, num_bins - 1)
breaks = value['predicted_aligned_error']['breaks']
# (1, num_bins)
logits = value['predicted_aligned_error']['logits']
# Compute the squared error for each alignment.
def _local_frame_points(affine):
points = [jnp.expand_dims(x, axis=-2) for x in affine.translation]
return affine.invert_point(points, extra_dims=1)
error_dist2_xyz = [
jnp.square(a - b)
for a, b in zip(_local_frame_points(predicted_affine),
_local_frame_points(true_affine))]
error_dist2 = sum(error_dist2_xyz)
# Shape (num_res, num_res)
# First num_res are alignment frames, second num_res are the residues.
error_dist2 = jax.lax.stop_gradient(error_dist2)
sq_breaks = jnp.square(breaks)
true_bins = jnp.sum((
error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1)
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits)
loss = (jnp.sum(errors * square_mask, axis=(-2, -1)) /
(1e-8 + jnp.sum(square_mask, axis=(-2, -1))))
if self.config.filter_by_resolution:
# NMR & distillation have resolution = 0
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
class ExperimentallyResolvedHead(hk.Module):
"""Predicts if an atom is experimentally resolved in a high-res structure.
Only trained on high-resolution X-ray crystals & cryo-EM.
Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction'
"""
def __init__(self, config, global_config,
name='experimentally_resolved_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds ExperimentallyResolvedHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'single': Single representation, shape [N_res, c_s].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* 'logits': logits of shape [N_res, 37],
log probability that an atom is resolved in atom37 representation,
can be converted to probability by applying sigmoid.
"""
logits = common_modules.Linear(
37, # atom_exists.shape[-1]
initializer=utils.final_init(self.global_config),
name='logits')(representations['single'])
return dict(logits=logits)
def loss(self, value, batch):
logits = value['logits']
assert len(logits.shape) == 2
# Does the atom appear in the amino acid?
atom_exists = batch['atom37_atom_exists']
# Is the atom resolved in the experiment? Subset of atom_exists,
# *except for OXT*
all_atom_mask = batch['all_atom_mask'].astype(jnp.float32)
xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits)
loss = jnp.sum(xent * atom_exists) / (1e-8 + jnp.sum(atom_exists))
if self.config.filter_by_resolution:
# NMR & distillation examples have resolution = 0.
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
class TriangleMultiplication(hk.Module):
"""Triangle multiplication layer ("outgoing" or "incoming").
Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"
Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming"
"""
def __init__(self, config, global_config, name='triangle_multiplication'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
"""Builds TriangleMultiplication module.
Arguments:
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
Returns:
Outputs, same shape/type as act.
"""
del is_training
c = self.config
gc = self.global_config
mask = mask[..., None]
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act)
input_act = act
left_projection = common_modules.Linear(
c.num_intermediate_channel,
name='left_projection')
left_proj_act = mask * left_projection(act)
right_projection = common_modules.Linear(
c.num_intermediate_channel,
name='right_projection')
right_proj_act = mask * right_projection(act)
left_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='left_gate')(act))
right_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='right_gate')(act))
left_proj_act *= left_gate_values
right_proj_act *= right_gate_values
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='center_layer_norm')(
act)
output_channel = int(input_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = jax.nn.sigmoid(common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(input_act))
act *= gate_values
return act
class DistogramHead(hk.Module):
"""Head to predict a distogram.
Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction"
"""
def __init__(self, config, global_config, name='distogram_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds DistogramHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'pair': pair representation, shape [N_res, N_res, c_z].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* logits: logits for distogram, shape [N_res, N_res, N_bins].
* bin_breaks: array containing bin breaks, shape [N_bins - 1,].
"""
half_logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='half_logits')(
representations['pair'])
logits = half_logits + jnp.swapaxes(half_logits, -2, -3)
breaks = jnp.linspace(self.config.first_break, self.config.last_break,
self.config.num_bins - 1)
return dict(logits=logits, bin_edges=breaks)
def loss(self, value, batch):
return _distogram_log_loss(value['logits'], value['bin_edges'],
batch, self.config.num_bins)
def _distogram_log_loss(logits, bin_edges, batch, num_bins):
"""Log loss of a distogram."""
assert len(logits.shape) == 3
positions = batch['pseudo_beta']
mask = batch['pseudo_beta_mask']
assert positions.shape[-1] == 3
sq_breaks = jnp.square(bin_edges)
dist2 = jnp.sum(
jnp.square(
jnp.expand_dims(positions, axis=-2) -
jnp.expand_dims(positions, axis=-3)),
axis=-1,
keepdims=True)
true_bins = jnp.sum(dist2 > sq_breaks, axis=-1)
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(true_bins, num_bins), logits=logits)
square_mask = jnp.expand_dims(mask, axis=-2) * jnp.expand_dims(mask, axis=-1)
avg_error = (
jnp.sum(errors * square_mask, axis=(-2, -1)) /
(1e-6 + jnp.sum(square_mask, axis=(-2, -1))))
dist2 = dist2[..., 0]
return dict(loss=avg_error, true_dist=jnp.sqrt(1e-6 + dist2))
class OuterProductMean(hk.Module):
"""Computes mean outer product.
Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean"
"""
def __init__(self,
config,
global_config,
num_output_channel,
name='outer_product_mean'):
super().__init__(name=name)
self.global_config = global_config
self.config = config
self.num_output_channel = num_output_channel
def __call__(self, act, mask, is_training=True):
"""Builds OuterProductMean module.
Arguments:
act: MSA representation, shape [N_seq, N_res, c_m].
mask: MSA mask, shape [N_seq, N_res].
is_training: Whether the module is in training mode.
Returns:
Update to pair representation, shape [N_res, N_res, c_z].
"""
gc = self.global_config
c = self.config
mask = mask[..., None]
act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act)
left_act = mask * common_modules.Linear(
c.num_outer_channel,
initializer='linear',
name='left_projection')(
act)
right_act = mask * common_modules.Linear(
c.num_outer_channel,
initializer='linear',
name='right_projection')(
act)
if gc.zero_init:
init_w = hk.initializers.Constant(0.0)
else:
init_w = hk.initializers.VarianceScaling(scale=2., mode='fan_in')
output_w = hk.get_parameter(
'output_w',
shape=(c.num_outer_channel, c.num_outer_channel,
self.num_output_channel),
init=init_w)
output_b = hk.get_parameter(
'output_b', shape=(self.num_output_channel,),
init=hk.initializers.Constant(0.0))
def compute_chunk(left_act):
# This is equivalent to
#
# act = jnp.einsum('abc,ade->dceb', left_act, right_act)
# act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b
#
# but faster.
left_act = jnp.transpose(left_act, [0, 2, 1])
act = jnp.einsum('acb,ade->dceb', left_act, right_act)
act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b
return jnp.transpose(act, [1, 0, 2])
act = mapping.inference_subbatch(
compute_chunk,
c.chunk_size,
batched_args=[left_act],
nonbatched_args=[],
low_memory=True,
input_subbatch_dim=1,
output_subbatch_dim=0)
epsilon = 1e-3
norm = jnp.einsum('abc,adc->bdc', mask, mask)
act /= epsilon + norm
return act
def dgram_from_positions(positions, num_bins, min_bin, max_bin):
"""Compute distogram from amino acid positions.
Arguments:
positions: [N_res, 3] Position coordinates.
num_bins: The number of bins in the distogram.
min_bin: The left edge of the first bin.
max_bin: The left edge of the final bin. The final bin catches
everything larger than `max_bin`.
Returns:
Distogram with the specified number of bins.
"""
def squared_difference(x, y):
return jnp.square(x - y)
lower_breaks = jnp.linspace(min_bin, max_bin, num_bins)
lower_breaks = jnp.square(lower_breaks)
upper_breaks = jnp.concatenate([lower_breaks[1:],
jnp.array([1e8], dtype=jnp.float32)], axis=-1)
dist2 = jnp.sum(
squared_difference(
jnp.expand_dims(positions, axis=-2),
jnp.expand_dims(positions, axis=-3)),
axis=-1, keepdims=True)
dgram = ((dist2 > lower_breaks).astype(jnp.float32) *
(dist2 < upper_breaks).astype(jnp.float32))
return dgram
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
is_gly = jnp.equal(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
pseudo_beta = jnp.where(
jnp.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = jnp.where(
is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
pseudo_beta_mask = pseudo_beta_mask.astype(jnp.float32)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
class EvoformerIteration(hk.Module):
"""Single iteration (block) of Evoformer stack.
Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10
"""
def __init__(self, config, global_config, is_extra_msa,
name='evoformer_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.is_extra_msa = is_extra_msa
def __call__(self, activations, masks, is_training=True, safe_key=None):
"""Builds EvoformerIteration module.
Arguments:
activations: Dictionary containing activations:
* 'msa': MSA activations, shape [N_seq, N_res, c_m].
* 'pair': pair activations, shape [N_res, N_res, c_z].
masks: Dictionary of masks:
* 'msa': MSA mask, shape [N_seq, N_res].
* 'pair': pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
safe_key: prng.SafeKey encapsulating rng key.
Returns:
Outputs, same shape/type as act.
"""
c = self.config
gc = self.global_config
msa_act, pair_act = activations['msa'], activations['pair']
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
msa_mask, pair_mask = masks['msa'], masks['pair']
dropout_wrapper_fn = functools.partial(
dropout_wrapper,
is_training=is_training,
global_config=gc)
safe_key, *sub_keys = safe_key.split(10)
sub_keys = iter(sub_keys)
msa_act = dropout_wrapper_fn(
MSARowAttentionWithPairBias(
c.msa_row_attention_with_pair_bias, gc,
name='msa_row_attention_with_pair_bias'),
msa_act,
msa_mask,
safe_key=next(sub_keys),
pair_act=pair_act)
if not self.is_extra_msa:
attn_mod = MSAColumnAttention(
c.msa_column_attention, gc, name='msa_column_attention')
else:
attn_mod = MSAColumnGlobalAttention(
c.msa_column_attention, gc, name='msa_column_global_attention')
msa_act = dropout_wrapper_fn(
attn_mod,
msa_act,
msa_mask,
safe_key=next(sub_keys))
msa_act = dropout_wrapper_fn(
Transition(c.msa_transition, gc, name='msa_transition'),
msa_act,
msa_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
OuterProductMean(
config=c.outer_product_mean,
global_config=self.global_config,
num_output_channel=int(pair_act.shape[-1]),
name='outer_product_mean'),
msa_act,
msa_mask,
safe_key=next(sub_keys),
output_act=pair_act)
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
name='triangle_multiplication_outgoing'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_incoming, gc,
name='triangle_multiplication_incoming'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
Transition(c.pair_transition, gc, name='pair_transition'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
return {'msa': msa_act, 'pair': pair_act}
class EmbeddingsAndEvoformer(hk.Module):
"""Embeds the input data and runs Evoformer.
Produces the MSA, single and pair representations.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18
"""
def __init__(self, config, global_config, name='evoformer'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, batch, is_training, safe_key=None):
c = self.config
gc = self.global_config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
# Embed clustered MSA.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5
# Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder"
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')(
batch['target_feat'])
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')(
batch['msa_feat'])
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear(
c.pair_channel, name='left_single')(
batch['target_feat'])
right_single = common_modules.Linear(
c.pair_channel, name='right_single')(
batch['target_feat'])
pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
# Inject previous outputs for recycling.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6
# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"
if c.recycle_pos and 'prev_pos' in batch:
prev_pseudo_beta = pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
dgram = dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos)
pair_activations += common_modules.Linear(
c.pair_channel, name='prev_pos_linear')(
dgram)
if c.recycle_features:
if 'prev_msa_first_row' in batch:
prev_msa_first_row = hk.LayerNorm([-1],
True,
True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = jax.ops.index_add(msa_activations, 0,
prev_msa_first_row)
if 'prev_pair' in batch:
pair_activations += hk.LayerNorm([-1],
True,
True,
name='prev_pair_norm')(
batch['prev_pair'])
# Relative position encoding.
# Jumper et al. (2021) Suppl. Alg. 4 "relpos"
# Jumper et al. (2021) Suppl. Alg. 5 "one_hot"
if c.max_relative_feature:
# Add one-hot-encoded clipped residue distances to the pair activations.
pos = batch['residue_index']
offset = pos[:, None] - pos[None, :]
rel_pos = jax.nn.one_hot(
jnp.clip(
offset + c.max_relative_feature,
a_min=0,
a_max=2 * c.max_relative_feature),
2 * c.max_relative_feature + 1)
pair_activations += common_modules.Linear(
c.pair_channel, name='pair_activiations')(
rel_pos)
# Embed templates into the pair activations.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13
if c.template.enabled:
template_batch = {k: batch[k] for k in batch if k.startswith('template_')}
template_pair_representation = TemplateEmbedding(c.template, gc)(
pair_activations,
template_batch,
mask_2d,
is_training=is_training)
pair_activations += template_pair_representation
# Embed extra MSA features.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16
extra_msa_feat = create_extra_msa_feature(batch)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat)
# Extra MSA Stack.
# Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack"
extra_msa_stack_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
extra_msa_stack_iteration = EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
def extra_msa_stack_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_msa_stack_iteration(
activations=act,
masks={
'msa': batch['extra_msa_mask'],
'pair': mask_2d
},
is_training=is_training,
safe_key=safe_subkey)
return (extra_evoformer_output, safe_key)
if gc.use_remat:
extra_msa_stack_fn = hk.remat(extra_msa_stack_fn)
extra_msa_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_msa_stack_fn)
extra_msa_output, safe_key = extra_msa_stack(
(extra_msa_stack_input, safe_key))
pair_activations = extra_msa_output['pair']
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {'msa': batch['msa_mask'], 'pair': mask_2d}
# Append num_templ rows to msa_activations with template embeddings.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8
if c.template.enabled and c.template.embed_torsion_angles:
num_templ, num_res = batch['template_aatype'].shape
# Embed the templates aatypes.
aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1)
# Embed the templates aatype, torsion angles and masks.
# Shape (templates, residues, msa_channels)
ret = all_atom.atom37_to_torsion_angles(
aatype=batch['template_aatype'],
all_atom_pos=batch['template_all_atom_positions'],
all_atom_mask=batch['template_all_atom_masks'],
# Ensure consistent behaviour during testing:
placeholder_for_undefined=not gc.zero_init)
template_features = jnp.concatenate([
aatype_one_hot,
jnp.reshape(
ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]),
jnp.reshape(
ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]),
ret['torsion_angles_mask']], axis=-1)
template_activations = common_modules.Linear(
c.msa_channel,
initializer='relu',
name='template_single_embedding')(
template_features)
template_activations = jax.nn.relu(template_activations)
template_activations = common_modules.Linear(
c.msa_channel,
initializer='relu',
name='template_projection')(
template_activations)
# Concatenate the templates to the msa.
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_activations], axis=0)
# Concatenate templates masks to the msa masks.
# Use mask from the psi angle, as it only depends on the backbone atoms
# from a single residue.
torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2]
torsion_angle_mask = torsion_angle_mask.astype(
evoformer_masks['msa'].dtype)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], torsion_angle_mask], axis=0)
# Main trunk of the network
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18
evoformer_iteration = EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey)
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
evoformer_fn)
evoformer_output, safe_key = evoformer_stack(
(evoformer_input, safe_key))
msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair']
single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')(
msa_activations[0])
num_sequences = batch['msa_feat'].shape[0]
output = {
'single': single_activations,
'pair': pair_activations,
# Crop away template rows such that they are not used in MaskedMsaHead.
'msa': msa_activations[:num_sequences, :, :],
'msa_first_row': msa_activations[0],
}
return output
class SingleTemplateEmbedding(hk.Module):
"""Embeds a single template.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11
"""
def __init__(self, config, global_config, name='single_template_embedding'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, query_embedding, batch, mask_2d, is_training):
"""Build the single template embedding.
Arguments:
query_embedding: Query pair representation, shape [N_res, N_res, c_z].
batch: A batch of template features (note the template dimension has been
stripped out as this module only runs over a single template).
mask_2d: Padding mask (Note: this doesn't care if a template exists,
unlike the template_pseudo_beta_mask).
is_training: Whether the module is in training mode.
Returns:
A template embedding [N_res, N_res, c_z].
"""
assert mask_2d.dtype == query_embedding.dtype
dtype = query_embedding.dtype
num_res = batch['template_aatype'].shape[0]
num_channels = (self.config.template_pair_stack
.triangle_attention_ending_node.value_dim)
template_mask = batch['template_pseudo_beta_mask']
template_mask_2d = template_mask[:, None] * template_mask[None, :]
template_mask_2d = template_mask_2d.astype(dtype)
template_dgram = dgram_from_positions(batch['template_pseudo_beta'],
**self.config.dgram_features)
template_dgram = template_dgram.astype(dtype)
to_concat = [template_dgram, template_mask_2d[:, :, None]]
aatype = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1, dtype=dtype)
to_concat.append(jnp.tile(aatype[None, :, :], [num_res, 1, 1]))
to_concat.append(jnp.tile(aatype[:, None, :], [1, num_res, 1]))
n, ca, c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')]
rot, trans = quat_affine.make_transform_from_reference(
n_xyz=batch['template_all_atom_positions'][:, n],
ca_xyz=batch['template_all_atom_positions'][:, ca],
c_xyz=batch['template_all_atom_positions'][:, c])
affines = quat_affine.QuatAffine(
quaternion=quat_affine.rot_to_quat(rot, unstack_inputs=True),
translation=trans,
rotation=rot,
unstack_inputs=True)
points = [jnp.expand_dims(x, axis=-2) for x in affines.translation]
affine_vec = affines.invert_point(points, extra_dims=1)
inv_distance_scalar = jax.lax.rsqrt(
1e-6 + sum([jnp.square(x) for x in affine_vec]))
# Backbone affine mask: whether the residue has C, CA, N
# (the template mask defined above only considers pseudo CB).
template_mask = (
batch['template_all_atom_masks'][..., n] *
batch['template_all_atom_masks'][..., ca] *
batch['template_all_atom_masks'][..., c])
template_mask_2d = template_mask[:, None] * template_mask[None, :]
inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype)
unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec]
unit_vector = [x.astype(dtype) for x in unit_vector]
template_mask_2d = template_mask_2d.astype(dtype)
if not self.config.use_template_unit_vector:
unit_vector = [jnp.zeros_like(x) for x in unit_vector]
to_concat.extend(unit_vector)
to_concat.append(template_mask_2d[..., None])
act = jnp.concatenate(to_concat, axis=-1)
# Mask out non-template regions so we don't get arbitrary values in the
# distogram for these regions.
act *= template_mask_2d[..., None]
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9
act = common_modules.Linear(
num_channels,
initializer='relu',
name='embedding2d')(
act)
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11
act = TemplatePairStack(
self.config.template_pair_stack, self.global_config)(
act, mask_2d, is_training)
act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act)
return act
class TemplateEmbedding(hk.Module):
"""Embeds a set of templates.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12
Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"
"""
def __init__(self, config, global_config, name='template_embedding'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, query_embedding, template_batch, mask_2d, is_training):
"""Build TemplateEmbedding module.
Arguments:
query_embedding: Query pair representation, shape [N_res, N_res, c_z].
template_batch: A batch of template features.
mask_2d: Padding mask (Note: this doesn't care if a template exists,
unlike the template_pseudo_beta_mask).
is_training: Whether the module is in training mode.
Returns:
A template embedding [N_res, N_res, c_z].
"""
num_templates = template_batch['template_mask'].shape[0]
num_channels = (self.config.template_pair_stack
.triangle_attention_ending_node.value_dim)
num_res = query_embedding.shape[0]
dtype = query_embedding.dtype
template_mask = template_batch['template_mask']
template_mask = template_mask.astype(dtype)
query_num_channels = query_embedding.shape[-1]
# Make sure the weights are shared across templates by constructing the
# embedder here.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12
template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
def map_fn(batch):
return template_embedder(query_embedding, batch, mask_2d, is_training)
template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(
template_batch)
# Cross attend from the query to the templates along the residue
# dimension by flattening everything else into the batch dimension.
# Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"
flat_query = jnp.reshape(query_embedding,
[num_res * num_res, 1, query_num_channels])
flat_templates = jnp.reshape(
jnp.transpose(template_pair_representation, [1, 2, 0, 3]),
[num_res * num_res, num_templates, num_channels])
bias = (1e9 * (template_mask[None, None, None, :] - 1.))
template_pointwise_attention_module = Attention(
self.config.attention, self.global_config, query_num_channels)
nonbatched_args = [bias]
batched_args = [flat_query, flat_templates]
embedding = mapping.inference_subbatch(
template_pointwise_attention_module,
self.config.subbatch_size,
batched_args=batched_args,
nonbatched_args=nonbatched_args,
low_memory=not is_training)
embedding = jnp.reshape(embedding,
[num_res, num_res, query_num_channels])
# No gradients if no templates.
embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype)
return embedding
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of utilities surrounding PRNG usage in protein folding."""
import haiku as hk
import jax
def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training):
if is_training and rate != 0.0 and not is_deterministic:
return hk.dropout(safe_key.get(), rate, tensor)
else:
return tensor
class SafeKey:
"""Safety wrapper for PRNG keys."""
def __init__(self, key):
self._key = key
self._used = False
def _assert_not_used(self):
if self._used:
raise RuntimeError('Random key has been used previously.')
def get(self):
self._assert_not_used()
self._used = True
return self._key
def split(self, num_keys=2):
self._assert_not_used()
self._used = True
new_keys = jax.random.split(self._key, num_keys)
return jax.tree_map(SafeKey, tuple(new_keys))
def duplicate(self, num_keys=2):
self._assert_not_used()
self._used = True
return tuple(SafeKey(self._key) for _ in range(num_keys))
def _safe_key_flatten(safe_key):
# Flatten transfers "ownership" to the tree
return (safe_key._key,), safe_key._used # pylint: disable=protected-access
def _safe_key_unflatten(aux_data, children):
ret = SafeKey(children[0])
ret._used = aux_data # pylint: disable=protected-access
return ret
jax.tree_util.register_pytree_node(
SafeKey, _safe_key_flatten, _safe_key_unflatten)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for prng."""
from absl.testing import absltest
import jax
from alphafold.model import prng
class PrngTest(absltest.TestCase):
def test_key_reuse(self):
init_key = jax.random.PRNGKey(42)
safe_key = prng.SafeKey(init_key)
_, safe_key = safe_key.split()
raw_key = safe_key.get()
self.assertNotEqual(raw_key[0], init_key[0])
self.assertNotEqual(raw_key[1], init_key[1])
with self.assertRaises(RuntimeError):
safe_key.get()
with self.assertRaises(RuntimeError):
safe_key.split()
with self.assertRaises(RuntimeError):
safe_key.duplicate()
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quaternion geometry modules.
This introduces a representation of coordinate frames that is based around a
‘QuatAffine’ object. This object describes an array of coordinate frames.
It consists of vectors corresponding to the
origin of the frames as well as orientations which are stored in two
ways, as unit quaternions as well as a rotation matrices.
The rotation matrices are derived from the unit quaternions and the two are kept
in sync.
For an explanation of the relation between unit quaternions and rotations see
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
This representation is used in the model for the backbone frames.
One important thing to note here, is that while we update both representations
the jit compiler is going to ensure that only the parts that are
actually used are executed.
"""
import functools
from typing import Tuple
import jax
import jax.numpy as jnp
import numpy as np
# pylint: disable=bad-whitespace
QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr
QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii
QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj
QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk
QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij
QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik
QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk
QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir
QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr
QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr
QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32)
QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
[ 0,-1, 0, 0],
[ 0, 0,-1, 0],
[ 0, 0, 0,-1]]
QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
[ 1, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0,-1, 0]]
QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
[ 0, 0, 0,-1],
[ 1, 0, 0, 0],
[ 0, 1, 0, 0]]
QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
[ 0, 0, 1, 0],
[ 0,-1, 0, 0],
[ 1, 0, 0, 0]]
QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :]
# pylint: enable=bad-whitespace
def rot_to_quat(rot, unstack_inputs=False):
"""Convert rotation matrix to quaternion.
Note that this function calls self_adjoint_eig which is extremely expensive on
the GPU. If at all possible, this function should run on the CPU.
Args:
rot: rotation matrix (see below for format).
unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)
otherwise the rotation matrix should be a list of lists of tensors.
Returns:
Quaternion as (..., 4) tensor.
"""
if unstack_inputs:
rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
# pylint: disable=bad-whitespace
k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]]
# pylint: enable=bad-whitespace
k = (1./3.) * jnp.stack([jnp.stack(x, axis=-1) for x in k],
axis=-2)
# Get eigenvalues in non-decreasing order and associated.
_, qs = jnp.linalg.eigh(k)
return qs[..., -1]
def rot_list_to_tensor(rot_list):
"""Convert list of lists to rotation tensor."""
return jnp.stack(
[jnp.stack(rot_list[0], axis=-1),
jnp.stack(rot_list[1], axis=-1),
jnp.stack(rot_list[2], axis=-1)],
axis=-2)
def vec_list_to_tensor(vec_list):
"""Convert list to vector tensor."""
return jnp.stack(vec_list, axis=-1)
def quat_to_rot(normalized_quat):
"""Convert a normalized quaternion to a rotation matrix."""
rot_tensor = jnp.sum(
np.reshape(QUAT_TO_ROT, (4, 4, 9)) *
normalized_quat[..., :, None, None] *
normalized_quat[..., None, :, None],
axis=(-3, -2))
rot = jnp.moveaxis(rot_tensor, -1, 0) # Unstack.
return [[rot[0], rot[1], rot[2]],
[rot[3], rot[4], rot[5]],
[rot[6], rot[7], rot[8]]]
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
return jnp.sum(
QUAT_MULTIPLY_BY_VEC *
quat[..., :, None, None] *
vec[..., None, :, None],
axis=(-3, -2))
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
return jnp.sum(
QUAT_MULTIPLY *
quat1[..., :, None, None] *
quat2[..., None, :, None],
axis=(-3, -2))
def apply_rot_to_vec(rot, vec, unstack=False):
"""Multiply rotation matrix by a vector."""
if unstack:
x, y, z = [vec[:, i] for i in range(3)]
else:
x, y, z = vec
return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z,
rot[1][0] * x + rot[1][1] * y + rot[1][2] * z,
rot[2][0] * x + rot[2][1] * y + rot[2][2] * z]
def apply_inverse_rot_to_vec(rot, vec):
"""Multiply the inverse of a rotation matrix by a vector."""
# Inverse rotation is just transpose
return [rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2],
rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2],
rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2]]
class QuatAffine(object):
"""Affine transformation represented by quaternion and vector."""
def __init__(self, quaternion, translation, rotation=None, normalize=True,
unstack_inputs=False):
"""Initialize from quaternion and translation.
Args:
quaternion: Rotation represented by a quaternion, to be applied
before translation. Must be a unit quaternion unless normalize==True.
translation: Translation represented as a vector.
rotation: Same rotation as the quaternion, represented as a (..., 3, 3)
tensor. If None, rotation will be calculated from the quaternion.
normalize: If True, l2 normalize the quaternion on input.
unstack_inputs: If True, translation is a vector with last component 3
"""
if quaternion is not None:
assert quaternion.shape[-1] == 4
if unstack_inputs:
if rotation is not None:
rotation = [jnp.moveaxis(x, -1, 0) # Unstack.
for x in jnp.moveaxis(rotation, -2, 0)] # Unstack.
translation = jnp.moveaxis(translation, -1, 0) # Unstack.
if normalize and quaternion is not None:
quaternion = quaternion / jnp.linalg.norm(quaternion, axis=-1,
keepdims=True)
if rotation is None:
rotation = quat_to_rot(quaternion)
self.quaternion = quaternion
self.rotation = [list(row) for row in rotation]
self.translation = list(translation)
assert all(len(row) == 3 for row in self.rotation)
assert len(self.translation) == 3
def to_tensor(self):
return jnp.concatenate(
[self.quaternion] +
[jnp.expand_dims(x, axis=-1) for x in self.translation],
axis=-1)
def apply_tensor_fn(self, tensor_fn):
"""Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient)."""
return QuatAffine(
tensor_fn(self.quaternion),
[tensor_fn(x) for x in self.translation],
rotation=[[tensor_fn(x) for x in row] for row in self.rotation],
normalize=False)
def apply_rotation_tensor_fn(self, tensor_fn):
"""Return a new QuatAffine with tensor_fn applied to the rotation part."""
return QuatAffine(
tensor_fn(self.quaternion),
[x for x in self.translation],
rotation=[[tensor_fn(x) for x in row] for row in self.rotation],
normalize=False)
def scale_translation(self, position_scale):
"""Return a new quat affine with a different scale for translation."""
return QuatAffine(
self.quaternion,
[x * position_scale for x in self.translation],
rotation=[[x for x in row] for row in self.rotation],
normalize=False)
@classmethod
def from_tensor(cls, tensor, normalize=False):
quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1)
return cls(quaternion,
[tx[..., 0], ty[..., 0], tz[..., 0]],
normalize=normalize)
def pre_compose(self, update):
"""Return a new QuatAffine which applies the transformation update first.
Args:
update: Length-6 vector. 3-vector of x, y, and z such that the quaternion
update is (1, x, y, z) and zero for the 3-vector is the identity
quaternion. 3-vector for translation concatenated.
Returns:
New QuatAffine object.
"""
vector_quaternion_update, x, y, z = jnp.split(update, [3, 4, 5], axis=-1)
trans_update = [jnp.squeeze(x, axis=-1),
jnp.squeeze(y, axis=-1),
jnp.squeeze(z, axis=-1)]
new_quaternion = (self.quaternion +
quat_multiply_by_vec(self.quaternion,
vector_quaternion_update))
trans_update = apply_rot_to_vec(self.rotation, trans_update)
new_translation = [
self.translation[0] + trans_update[0],
self.translation[1] + trans_update[1],
self.translation[2] + trans_update[2]]
return QuatAffine(new_quaternion, new_translation)
def apply_to_point(self, point, extra_dims=0):
"""Apply affine to a point.
Args:
point: List of 3 tensors to apply affine.
extra_dims: Number of dimensions at the end of the transformed_point
shape that are not present in the rotation and translation. The most
common use is rotation N points at once with extra_dims=1 for use in a
network.
Returns:
Transformed point after applying affine.
"""
rotation = self.rotation
translation = self.translation
for _ in range(extra_dims):
expand_fn = functools.partial(jnp.expand_dims, axis=-1)
rotation = jax.tree_map(expand_fn, rotation)
translation = jax.tree_map(expand_fn, translation)
rot_point = apply_rot_to_vec(rotation, point)
return [
rot_point[0] + translation[0],
rot_point[1] + translation[1],
rot_point[2] + translation[2]]
def invert_point(self, transformed_point, extra_dims=0):
"""Apply inverse of transformation to a point.
Args:
transformed_point: List of 3 tensors to apply affine
extra_dims: Number of dimensions at the end of the transformed_point
shape that are not present in the rotation and translation. The most
common use is rotation N points at once with extra_dims=1 for use in a
network.
Returns:
Transformed point after applying affine.
"""
rotation = self.rotation
translation = self.translation
for _ in range(extra_dims):
expand_fn = functools.partial(jnp.expand_dims, axis=-1)
rotation = jax.tree_map(expand_fn, rotation)
translation = jax.tree_map(expand_fn, translation)
rot_point = [
transformed_point[0] - translation[0],
transformed_point[1] - translation[1],
transformed_point[2] - translation[2]]
return apply_inverse_rot_to_vec(rotation, rot_point)
def __repr__(self):
return 'QuatAffine(%r, %r)' % (self.quaternion, self.translation)
def _multiply(a, b):
return jnp.stack([
jnp.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0],
a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1],
a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]),
jnp.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0],
a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1],
a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]),
jnp.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0],
a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1],
a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])])
def make_canonical_transform(
n_xyz: jnp.ndarray,
ca_xyz: jnp.ndarray,
c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Returns translation and rotation matrices to canonicalize residue atoms.
Note that this method does not take care of symmetries. If you provide the
atom positions in the non-standard way, the N atom will end up not at
[-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
need to take care of such cases in your code.
Args:
n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
Returns:
A tuple (translation, rotation) where:
translation is an array of shape [batch, 3] defining the translation.
rotation is an array of shape [batch, 3, 3] defining the rotation.
After applying the translation and rotation to all atoms in a residue:
* All atoms will be shifted so that CA is at the origin,
* All atoms will be rotated so that C is at the x-axis,
* All atoms will be shifted so that N is in the xy plane.
"""
assert len(n_xyz.shape) == 2, n_xyz.shape
assert n_xyz.shape[-1] == 3, n_xyz.shape
assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, (
n_xyz.shape, ca_xyz.shape, c_xyz.shape)
# Place CA at the origin.
translation = -ca_xyz
n_xyz = n_xyz + translation
c_xyz = c_xyz + translation
# Place C on the x-axis.
c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)]
# Rotate by angle c1 in the x-y plane (around the z-axis).
sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
zeros = jnp.zeros_like(sin_c1)
ones = jnp.ones_like(sin_c1)
# pylint: disable=bad-whitespace
c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]),
jnp.array([sin_c1, cos_c1, zeros]),
jnp.array([zeros, zeros, ones])])
# Rotate by angle c2 in the x-z plane (around the y-axis).
sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2)
cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt(
1e-20 + c_x**2 + c_y**2 + c_z**2)
c2_rot_matrix = jnp.stack([jnp.array([cos_c2, zeros, sin_c2]),
jnp.array([zeros, ones, zeros]),
jnp.array([-sin_c2, zeros, cos_c2])])
c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix)
n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T
# Place N in the x-y plane.
_, n_y, n_z = [n_xyz[:, i] for i in range(3)]
# Rotate by angle alpha in the y-z plane (around the x-axis).
sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
n_rot_matrix = jnp.stack([jnp.array([ones, zeros, zeros]),
jnp.array([zeros, cos_n, -sin_n]),
jnp.array([zeros, sin_n, cos_n])])
# pylint: enable=bad-whitespace
return (translation,
jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1]))
def make_transform_from_reference(
n_xyz: jnp.ndarray,
ca_xyz: jnp.ndarray,
c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
atom positions in the non-standard way, the N atom will end up not at
[-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
need to take care of such cases in your code.
Args:
n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
Returns:
A tuple (rotation, translation) where:
rotation is an array of shape [batch, 3, 3] defining the rotation.
translation is an array of shape [batch, 3] defining the translation.
After applying the translation and rotation to the reference backbone,
the coordinates will approximately equal to the input coordinates.
The order of translation and rotation differs from make_canonical_transform
because the rotation from this function should be applied before the
translation, unlike make_canonical_transform.
"""
translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz)
return np.transpose(rotation, (0, 2, 1)), -translation
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for quat_affine."""
from absl import logging
from absl.testing import absltest
import jax
import jax.numpy as jnp
import numpy as np
from alphafold.model import quat_affine
VERBOSE = False
np.set_printoptions(precision=3, suppress=True)
r2t = quat_affine.rot_list_to_tensor
v2t = quat_affine.vec_list_to_tensor
q2r = lambda q: r2t(quat_affine.quat_to_rot(q))
class QuatAffineTest(absltest.TestCase):
def _assert_check(self, to_check, tol=1e-5):
for k, (correct, generated) in to_check.items():
if VERBOSE:
logging.info(k)
logging.info('Correct %s', correct)
logging.info('Predicted %s', generated)
self.assertLess(np.max(np.abs(correct - generated)), tol)
def test_conversion(self):
quat = jnp.array([-2., 5., -1., 4.])
rotation = jnp.array([
[0.26087, 0.130435, 0.956522],
[-0.565217, -0.782609, 0.26087],
[0.782609, -0.608696, -0.130435]])
translation = jnp.array([1., -3., 4.])
point = jnp.array([0.7, 3.2, -2.9])
a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True)
true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation
self._assert_check({
'rot': (rotation, r2t(a.rotation)),
'trans': (translation, v2t(a.translation)),
'point': (true_new_point,
v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))),
# Because of the double cover, we must be careful and compare rotations
'quat': (q2r(a.quaternion),
q2r(quat_affine.rot_to_quat(a.rotation))),
})
def test_double_cover(self):
"""Test that -q is the same rotation as q."""
rng = jax.random.PRNGKey(42)
keys = jax.random.split(rng)
q = jax.random.normal(keys[0], (2, 4))
trans = jax.random.normal(keys[1], (2, 3))
a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True)
a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True)
self._assert_check({
'rot': (r2t(a1.rotation),
r2t(a2.rotation)),
'trans': (v2t(a1.translation),
v2t(a2.translation)),
})
def test_homomorphism(self):
rng = jax.random.PRNGKey(42)
keys = jax.random.split(rng, 4)
vec_q1 = jax.random.normal(keys[0], (2, 3))
q1 = jnp.concatenate([
jnp.ones_like(vec_q1)[:, :1],
vec_q1], axis=-1)
q2 = jax.random.normal(keys[1], (2, 4))
t1 = jax.random.normal(keys[2], (2, 3))
t2 = jax.random.normal(keys[3], (2, 3))
a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True)
a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True)
a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1))
rng, key = jax.random.split(rng)
x = jax.random.normal(key, (2, 3))
new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0))
new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0)))
self._assert_check({
'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)),
q2r(a21.quaternion)),
'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)),
r2t(a21.rotation)),
'point': (v2t(new_x_apply2),
v2t(new_x)),
'inverse': (x, v2t(a21.invert_point(new_x))),
})
def test_batching(self):
"""Test that affine applies batchwise."""
rng = jax.random.PRNGKey(42)
keys = jax.random.split(rng, 3)
q = jax.random.uniform(keys[0], (5, 2, 4))
t = jax.random.uniform(keys[1], (2, 3))
x = jax.random.uniform(keys[2], (5, 1, 3))
a = quat_affine.QuatAffine(q, t, unstack_inputs=True)
y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0)))
y_list = []
for i in range(5):
for j in range(2):
a_local = quat_affine.QuatAffine(q[i, j], t[j],
unstack_inputs=True)
y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0)))
y_list.append(y_local)
y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3))
self._assert_check({
'batch': (y_combine, y),
'quat': (q2r(a.quaternion),
q2r(quat_affine.rot_to_quat(a.rotation))),
})
def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06):
self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol))
def assertAllEqual(self, a, b):
self.assertTrue(np.all(np.array(a) == np.array(b)))
if __name__ == '__main__':
absltest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment