Commit eb93322b authored by mashun1's avatar mashun1
Browse files

dtk24.04.1

parents
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations."""
from typing import Dict, Optional
from alphafold.common import residue_constants
from alphafold.model import geometry
from alphafold.model import utils
import jax
import jax.numpy as jnp
import numpy as np
def squared_difference(x, y):
return jnp.square(x - y)
def _make_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in residue_constants.restypes:
residue_name = residue_constants.restype_1to3[residue_name]
residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[residue_constants.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return np.array(chi_atom_indices)
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
residue_constants.restype_1to3[res] for res in residue_constants.restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
def _make_restype_atom37_mask():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
return restype_atom37_mask
def _make_restype_atom14_mask():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
restype_atom14_mask.append([0.] * 14)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
return restype_atom14_mask
def _make_restype_atom37_to_atom14():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
restype_atom37_to_atom14.append([0] * 37)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
return restype_atom37_to_atom14
def _make_restype_atom14_to_atom37():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
return restype_atom14_to_atom37
def _make_restype_atom14_is_ambiguous():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]]
atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
return restype_atom14_is_ambiguous
def _make_restype_rigidgroup_base_atom37_idx():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(residue_constants.restypes):
resname = residue_constants.restype_1to3[restype_letter]
for chi_idx in range(4):
if residue_constants.chi_angles_mask[restype][chi_idx]:
atom_names = residue_constants.chi_angles_atoms[resname][chi_idx]
base_atom_names[restype, chi_idx + 4, :] = atom_names[1:]
# Translate atom names into atom37 indices.
lookuptable = residue_constants.atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
base_atom_names)
return restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES = _make_chi_atom_indices()
RENAMING_MATRICES = _make_renaming_matrices()
RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
RESTYPE_ATOM37_MASK = _make_restype_atom37_mask()
RESTYPE_ATOM14_MASK = _make_restype_atom14_mask()
RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32)
RESTYPE_RIGIDGROUP_MASK[:, 0] = 1
RESTYPE_RIGIDGROUP_MASK[:, 3] = 1
RESTYPE_RIGIDGROUP_MASK[:20, 4:] = residue_constants.chi_angles_mask
def get_atom37_mask(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_MASK), aatype)
def get_atom14_mask(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype)
def get_atom14_is_ambiguous(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_IS_AMBIGUOUS), aatype)
def get_atom14_to_atom37_map(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype)
def get_atom37_to_atom14_map(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_TO_ATOM14), aatype)
def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...)
aatype: jnp.ndarray
) -> jnp.ndarray: # (N, 37, ...)
"""Convert atom14 to atom37 representation."""
assert len(atom14_data.shape) in [2, 3]
idx_atom37_to_atom14 = get_atom37_to_atom14_map(aatype)
atom37_data = utils.batched_gather(
atom14_data, idx_atom37_to_atom14, batch_dims=1)
atom37_mask = get_atom37_mask(aatype)
if len(atom14_data.shape) == 2:
atom37_data *= atom37_mask
elif len(atom14_data.shape) == 3:
atom37_data *= atom37_mask[:, :, None].astype(atom37_data.dtype)
return atom37_data
def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
"""Convert Atom37 positions to Atom14 positions."""
residx_atom14_to_atom37 = utils.batched_gather(
jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype)
atom14_mask = utils.batched_gather(
all_atom_mask, residx_atom14_to_atom37, batch_dims=1).astype(jnp.float32)
# create a mask for known groundtruth positions
atom14_mask *= utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype)
# gather the groundtruth positions
atom14_positions = jax.tree_map(
lambda x: utils.batched_gather(x, residx_atom14_to_atom37, batch_dims=1),
all_atom_pos)
atom14_positions = atom14_mask * atom14_positions
return atom14_positions, atom14_mask
def get_alt_atom14(aatype, positions: geometry.Vec3Array, mask):
"""Get alternative atom14 positions."""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform = utils.batched_gather(
jnp.asarray(RENAMING_MATRICES), aatype)
alternative_positions = jax.tree_map(
lambda x: jnp.sum(x, axis=1), positions[:, :, None] * renaming_transform)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1)
return alternative_positions, alternative_mask
def atom37_to_frames(
aatype: jnp.ndarray, # (...)
all_atom_positions: geometry.Vec3Array, # (..., 37)
all_atom_mask: jnp.ndarray, # (..., 37)
) -> Dict[str, jnp.ndarray]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
aatype_in_shape = aatype.shape
# If there is a batch axis, just flatten it away, and reshape everything
# back at the end of the function.
aatype = jnp.reshape(aatype, [-1])
all_atom_positions = jax.tree_map(lambda x: jnp.reshape(x, [-1, 37]),
all_atom_positions)
all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37])
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx = utils.batched_gather(
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype)
# Gather the base atom positions for each rigid group.
base_atom_pos = jax.tree_map(
lambda x: utils.batched_gather( # pylint: disable=g-long-lambda
x, residx_rigidgroup_base_atom37_idx, batch_dims=1),
all_atom_positions)
# Compute the Rigids.
point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin = base_atom_pos[:, :, 1]
point_on_xy_plane = base_atom_pos[:, :, 2]
gt_rotation = geometry.Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = geometry.Rigid3Array(gt_rotation, origin)
# Compute a mask whether the group exists.
# (N, 8)
group_exists = utils.batched_gather(RESTYPE_RIGIDGROUP_MASK, aatype)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3)
all_atom_mask.astype(jnp.float32),
residx_rigidgroup_base_atom37_idx,
batch_dims=1)
gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
rots[0, 0, 0] = -1
rots[0, 2, 2] = -1
gt_frames = gt_frames.compose_rotation(
geometry.Rot3Array.from_array(rots))
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1])
for resname, _ in residue_constants.residue_atom_renaming_swaps.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]]
chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous = utils.batched_gather(
restype_rigidgroup_is_ambiguous, aatype)
ambiguity_rot = utils.batched_gather(restype_rigidgroup_rots, aatype)
ambiguity_rot = geometry.Rot3Array.from_array(ambiguity_rot)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
fix_shape = lambda x: jnp.reshape(x, aatype_in_shape + (8,))
# reshape back to original residue layout
gt_frames = jax.tree_map(fix_shape, gt_frames)
gt_exists = fix_shape(gt_exists)
group_exists = fix_shape(group_exists)
residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous)
alt_gt_frames = jax.tree_map(fix_shape, alt_gt_frames)
return {
'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8)
'rigidgroups_gt_exists': gt_exists, # (..., 8)
'rigidgroups_group_exists': group_exists, # (..., 8)
'rigidgroups_group_is_ambiguous':
residx_rigidgroup_is_ambiguous, # (..., 8)
'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8)
}
def torsion_angles_to_frames(
aatype: jnp.ndarray, # (N)
backb_to_global: geometry.Rigid3Array, # (N)
torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2)
) -> geometry.Rigid3Array: # (N, 8)
"""Compute rigid group frames from torsion angles."""
assert len(aatype.shape) == 1, (
f'Expected array of rank 1, got array with shape: {aatype.shape}.')
assert len(backb_to_global.rotation.shape) == 1, (
f'Expected array of rank 1, got array with shape: '
f'{backb_to_global.rotation.shape}')
assert len(torsion_angles_sin_cos.shape) == 3, (
f'Expected array of rank 3, got array with shape: '
f'{torsion_angles_sin_cos.shape}')
assert torsion_angles_sin_cos.shape[1] == 7, (
f'wrong shape {torsion_angles_sin_cos.shape}')
assert torsion_angles_sin_cos.shape[2] == 2, (
f'wrong shape {torsion_angles_sin_cos.shape}')
# Gather the default frames for all rigid groups.
# geometry.Rigid3Array with shape (N, 8)
m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame,
aatype)
default_frames = geometry.Rigid3Array.from_array4x4(m)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles = torsion_angles_sin_cos[..., 0]
cos_angles = torsion_angles_sin_cos[..., 1]
# insert zero rotation for backbone group.
num_residues, = aatype.shape
sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles],
axis=-1)
cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles],
axis=-1)
zeros = jnp.zeros_like(sin_angles)
ones = jnp.ones_like(sin_angles)
# all_rots are geometry.Rot3Array with shape (N, 8)
all_rots = geometry.Rot3Array(ones, zeros, zeros,
zeros, cos_angles, -sin_angles,
zeros, sin_angles, cos_angles)
# Apply rotations to the frames.
all_frames = default_frames.compose_rotation(all_rots)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi1_frame_to_backb = all_frames[:, 4]
chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[:, 5]
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6]
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7]
all_frames_to_backb = jax.tree_map(
lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5],
chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None],
chi4_frame_to_backb[:, None])
# Create the global frames.
# shape (N, 8)
all_frames_to_global = backb_to_global[:, None] @ all_frames_to_backb
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
aatype: jnp.ndarray, # (N)
all_frames_to_global: geometry.Rigid3Array # (N, 8)
) -> geometry.Vec3Array: # (N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group."""
# Pick the appropriate transform for every atom.
residx_to_group_idx = utils.batched_gather(
residue_constants.restype_atom14_to_rigid_group, aatype)
group_mask = jax.nn.one_hot(
residx_to_group_idx, num_classes=8) # shape (N, 14, 8)
# geometry.Rigid3Array with shape (N, 14)
map_atoms_to_global = jax.tree_map(
lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1),
all_frames_to_global)
# Gather the literature atom positions for each residue.
# geometry.Vec3Array with shape (N, 14)
lit_positions = geometry.Vec3Array.from_array(
utils.batched_gather(
residue_constants.restype_atom14_rigid_group_positions, aatype))
# Transform each atom from its local frame to the global frame.
# geometry.Vec3Array with shape (N, 14)
pred_positions = map_atoms_to_global.apply_to_point(lit_positions)
# Mask out non-existing atoms.
mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype)
pred_positions = pred_positions * mask
return pred_positions
def extreme_ca_ca_distance_violations(
positions: geometry.Vec3Array, # (N, 37(14))
mask: jnp.ndarray, # (N, 37(14))
residue_index: jnp.ndarray, # (N)
max_angstrom_tolerance=1.5
) -> jnp.ndarray:
"""Counts residues whose Ca is a large distance from its neighbor."""
this_ca_pos = positions[:-1, 1] # (N - 1,)
this_ca_mask = mask[:-1, 1] # (N - 1)
next_ca_pos = positions[1:, 1] # (N - 1,)
next_ca_mask = mask[1:, 1] # (N - 1)
has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(
jnp.float32)
ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, 1e-6)
violations = (ca_ca_distance -
residue_constants.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
return utils.mask_mean(mask=mask, value=violations)
def between_residue_bond_loss(
pred_atom_positions: geometry.Vec3Array, # (N, 37(14))
pred_atom_mask: jnp.ndarray, # (N, 37(14))
residue_index: jnp.ndarray, # (N)
aatype: jnp.ndarray, # (N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0) -> Dict[str, jnp.ndarray]:
"""Flat-bottom loss to penalize structural violations between residues."""
assert len(pred_atom_positions.shape) == 2
assert len(pred_atom_mask.shape) == 2
assert len(residue_index.shape) == 1
assert len(aatype.shape) == 1
# Get the positions of the relevant backbone atoms.
this_ca_pos = pred_atom_positions[:-1, 1] # (N - 1)
this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1)
this_c_pos = pred_atom_positions[:-1, 2] # (N - 1)
this_c_mask = pred_atom_mask[:-1, 2] # (N - 1)
next_n_pos = pred_atom_positions[1:, 0] # (N - 1)
next_n_mask = pred_atom_mask[1:, 0] # (N - 1)
next_ca_pos = pred_atom_positions[1:, 1] # (N - 1)
next_ca_mask = pred_atom_mask[1:, 1] # (N - 1)
has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(
jnp.float32)
# Compute loss for the C--N bond.
c_n_bond_length = geometry.euclidean_distance(this_c_pos, next_n_pos, 1e-6)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
aatype[1:] == residue_constants.restype_order['P']).astype(jnp.float32)
gt_length = (
(1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
+ next_is_proline * residue_constants.between_res_bond_length_c_n[1])
gt_stddev = (
(1. - next_is_proline) *
residue_constants.between_res_bond_length_stddev_c_n[0] +
next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1])
c_n_bond_length_error = jnp.sqrt(1e-6 +
jnp.square(c_n_bond_length - gt_length))
c_n_loss_per_residue = jax.nn.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev))
# Compute loss for the angles.
c_ca_unit_vec = (this_ca_pos - this_c_pos).normalized(1e-6)
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length
n_ca_unit_vec = (next_ca_pos - next_n_pos).normalized(1e-6)
ca_c_n_cos_angle = c_ca_unit_vec.dot(c_n_unit_vec)
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle))
ca_c_n_loss_per_residue = jax.nn.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev))
c_n_ca_cos_angle = (-c_n_unit_vec).dot(n_ca_unit_vec)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle))
c_n_ca_loss_per_residue = jax.nn.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev))
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (c_n_loss_per_residue +
ca_c_n_loss_per_residue +
c_n_ca_loss_per_residue)
per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) +
jnp.pad(per_residue_loss_sum, [[1, 0]]))
# Compute hard violations.
violation_mask = jnp.max(
jnp.stack([c_n_violation_mask,
ca_c_n_violation_mask,
c_n_ca_violation_mask]), axis=0)
violation_mask = jnp.maximum(
jnp.pad(violation_mask, [[0, 1]]),
jnp.pad(violation_mask, [[1, 0]]))
return {'c_n_loss_mean': c_n_loss, # shape ()
'ca_c_n_loss_mean': ca_c_n_loss, # shape ()
'c_n_ca_loss_mean': c_n_ca_loss, # shape ()
'per_residue_loss_sum': per_residue_loss_sum, # shape (N)
'per_residue_violation_mask': violation_mask # shape (N)
}
def between_residue_clash_loss(
pred_positions: geometry.Vec3Array, # (N, 14)
atom_exists: jnp.ndarray, # (N, 14)
atom_radius: jnp.ndarray, # (N, 14)
residue_index: jnp.ndarray, # (N)
asym_id: jnp.ndarray, # (N)
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5) -> Dict[str, jnp.ndarray]:
"""Loss to penalize steric clashes between residues."""
assert len(pred_positions.shape) == 2
assert len(atom_exists.shape) == 2
assert len(atom_radius.shape) == 2
assert len(residue_index.shape) == 1
# Create the distance matrix.
# (N, N, 14, 14)
dists = geometry.euclidean_distance(pred_positions[:, None, :, None],
pred_positions[None, :, None, :], 1e-10)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask = (atom_exists[:, None, :, None] * atom_exists[None, :, None, :])
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask *= (
residue_index[:, None, None, None] < residue_index[None, :, None, None])
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot = jax.nn.one_hot(2, num_classes=14)
n_one_hot = jax.nn.one_hot(0, num_classes=14)
neighbour_mask = ((residue_index[:, None] + 1) == residue_index[None, :])
neighbour_mask &= (asym_id[:, None] == asym_id[None, :])
neighbour_mask = neighbour_mask[..., None, None]
c_n_bonds = neighbour_mask * c_one_hot[None, None, :,
None] * n_one_hot[None, None, None, :]
dists_mask *= (1. - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG')
cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (cys_sg_one_hot[None, None, :, None] *
cys_sg_one_hot[None, None, None, :])
dists_mask *= (1. - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound = dists_mask * (
atom_radius[:, None, :, None] + atom_radius[None, :, None, :])
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error = dists_mask * jax.nn.relu(
dists_lower_bound - overlap_tolerance_soft - dists)
# Compute the mean loss.
# shape ()
mean_loss = (jnp.sum(dists_to_low_error)
/ (1e-6 + jnp.sum(dists_mask)))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) +
jnp.sum(dists_to_low_error, axis=[1, 3]))
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = jnp.maximum(
jnp.max(clash_mask, axis=[0, 2]),
jnp.max(clash_mask, axis=[1, 3]))
return {'mean_loss': mean_loss, # shape ()
'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14)
}
def within_residue_violations(
pred_positions: geometry.Vec3Array, # (N, 14)
atom_exists: jnp.ndarray, # (N, 14)
dists_lower_bound: jnp.ndarray, # (N, 14, 14)
dists_upper_bound: jnp.ndarray, # (N, 14, 14)
tighten_bounds_for_loss=0.0,
) -> Dict[str, jnp.ndarray]:
"""Find within-residue violations."""
assert len(pred_positions.shape) == 2
assert len(atom_exists.shape) == 2
assert len(dists_lower_bound.shape) == 3
assert len(dists_upper_bound.shape) == 3
# Compute the mask for each residue.
# shape (N, 14, 14)
dists_masks = (1. - jnp.eye(14, 14)[None])
dists_masks *= (atom_exists[:, :, None] * atom_exists[:, None, :])
# Distance matrix
# shape (N, 14, 14)
dists = geometry.euclidean_distance(pred_positions[:, :, None],
pred_positions[:, None, :], 1e-10)
# Compute the loss.
# shape (N, 14, 14)
dists_to_low_error = jax.nn.relu(
dists_lower_bound + tighten_bounds_for_loss - dists)
dists_to_high_error = jax.nn.relu(
dists + tighten_bounds_for_loss - dists_upper_bound)
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (jnp.sum(loss, axis=1) +
jnp.sum(loss, axis=2))
# Compute the violations mask.
# shape (N, 14, 14)
violations = dists_masks * ((dists < dists_lower_bound) |
(dists > dists_upper_bound))
# Compute the per atom violations.
# shape (N, 14)
per_atom_violations = jnp.maximum(
jnp.max(violations, axis=1), jnp.max(violations, axis=2))
return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_violations': per_atom_violations # shape (N, 14)
}
def find_optimal_renaming(
gt_positions: geometry.Vec3Array, # (N, 14)
alt_gt_positions: geometry.Vec3Array, # (N, 14)
atom_is_ambiguous: jnp.ndarray, # (N, 14)
gt_exists: jnp.ndarray, # (N, 14)
pred_positions: geometry.Vec3Array, # (N, 14)
) -> jnp.ndarray: # (N):
"""Find optimal renaming for ground truth that maximizes LDDT."""
assert len(gt_positions.shape) == 2
assert len(alt_gt_positions.shape) == 2
assert len(atom_is_ambiguous.shape) == 2
assert len(gt_exists.shape) == 2
assert len(pred_positions.shape) == 2
# Create the pred distance matrix.
# shape (N, N, 14, 14)
pred_dists = geometry.euclidean_distance(pred_positions[:, None, :, None],
pred_positions[None, :, None, :],
1e-10)
# Compute distances for ground truth with original and alternative names.
# shape (N, N, 14, 14)
gt_dists = geometry.euclidean_distance(gt_positions[:, None, :, None],
gt_positions[None, :, None, :], 1e-10)
alt_gt_dists = geometry.euclidean_distance(alt_gt_positions[:, None, :, None],
alt_gt_positions[None, :, None, :],
1e-10)
# Compute LDDT's.
# shape (N, N, 14, 14)
lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists))
alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists))
# Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms
# in cols.
# shape (N ,N, 14, 14)
mask = (
gt_exists[:, None, :, None] * # rows
atom_is_ambiguous[:, None, :, None] * # rows
gt_exists[None, :, None, :] * # cols
(1. - atom_is_ambiguous[None, :, None, :])) # cols
# Aggregate distances for each residue to the non-amibuguous atoms.
# shape (N)
per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3])
alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3])
# Decide for each residue, whether alternative naming is better.
# shape (N)
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32)
return alt_naming_is_better # shape (N)
def frame_aligned_point_error(
pred_frames: geometry.Rigid3Array, # shape (num_frames)
target_frames: geometry.Rigid3Array, # shape (num_frames)
frames_mask: jnp.ndarray, # shape (num_frames)
pred_positions: geometry.Vec3Array, # shape (num_positions)
target_positions: geometry.Vec3Array, # shape (num_positions)
positions_mask: jnp.ndarray, # shape (num_positions)
pair_mask: Optional[jnp.ndarray], # shape (num_frames, num_posiitons)
l1_clamp_distance: float,
length_scale=20.,
epsilon=1e-4) -> jnp.ndarray: # shape ()
"""Measure point error under different alignements.
Computes error between two structures with B points
under A alignments derived form the given pairs of frames.
Args:
pred_frames: num_frames reference frames for 'pred_positions'.
target_frames: num_frames reference frames for 'target_positions'.
frames_mask: Mask for frame pairs to use.
pred_positions: num_positions predicted positions of the structure.
target_positions: num_positions target positions of the structure.
positions_mask: Mask on which positions to score.
pair_mask: A (num_frames, num_positions) mask to use in the loss, useful
for separating intra from inter chain losses.
l1_clamp_distance: Distance cutoff on error beyond which gradients will
be zero.
length_scale: length scale to divide loss by.
epsilon: small value used to regularize denominator for masked average.
Returns:
Masked Frame aligned point error.
"""
# For now we do not allow any batch dimensions.
assert len(pred_frames.rotation.shape) == 1
assert len(target_frames.rotation.shape) == 1
assert frames_mask.ndim == 1
assert pred_positions.x.ndim == 1
assert target_positions.x.ndim == 1
assert positions_mask.ndim == 1
# Compute array of predicted positions in the predicted frames.
# geometry.Vec3Array (num_frames, num_positions)
local_pred_pos = pred_frames[:, None].inverse().apply_to_point(
pred_positions[None, :])
# Compute array of target positions in the target frames.
# geometry.Vec3Array (num_frames, num_positions)
local_target_pos = target_frames[:, None].inverse().apply_to_point(
target_positions[None, :])
# Compute errors between the structures.
# jnp.ndarray (num_frames, num_positions)
error_dist = geometry.euclidean_distance(local_pred_pos, local_target_pos,
epsilon)
clipped_error_dist = jnp.clip(error_dist, 0, l1_clamp_distance)
normed_error = clipped_error_dist / length_scale
normed_error *= jnp.expand_dims(frames_mask, axis=-1)
normed_error *= jnp.expand_dims(positions_mask, axis=-2)
if pair_mask is not None:
normed_error *= pair_mask
mask = (jnp.expand_dims(frames_mask, axis=-1) *
jnp.expand_dims(positions_mask, axis=-2))
if pair_mask is not None:
mask *= pair_mask
normalization_factor = jnp.sum(mask, axis=(-1, -2))
return (jnp.sum(normed_error, axis=(-2, -1)) /
(epsilon + normalization_factor))
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in residue_constants.restypes:
residue_name = residue_constants.restype_1to3[residue_name]
residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[residue_constants.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return jnp.asarray(chi_atom_indices)
def compute_chi_angles(positions: geometry.Vec3Array,
mask: geometry.Vec3Array,
aatype: geometry.Vec3Array):
"""Computes the chi angles given all atom positions and the amino acid type.
Args:
positions: A Vec3Array of shape
[num_res, residue_constants.atom_type_num], with positions of
atoms needed to calculate chi angles. Supports up to 1 batch dimension.
mask: An optional tensor of shape
[num_res, residue_constants.atom_type_num] that masks which atom
positions are set for each residue. If given, then the chi mask will be
set to 1 for a chi angle only if the amino acid has that chi angle and all
the chi atoms needed to calculate that chi angle are set. If not given
(set to None), the chi mask will be set to 1 for a chi angle if the amino
acid has that chi angle and whether the actual atoms needed to calculate
it were set will be ignored.
aatype: A tensor of shape [num_res] with amino acid type integer
code (0 to 21). Supports up to 1 batch dimension.
Returns:
A tuple of tensors (chi_angles, mask), where both have shape
[num_res, 4]. The mask masks out unused chi angles for amino acid
types that have less than 4 chi angles. If atom_positions_mask is set, the
chi mask will also mask out uncomputable chi angles.
"""
# Don't assert on the num_res and batch dimensions as they might be unknown.
assert positions.shape[-1] == residue_constants.atom_type_num
assert mask.shape[-1] == residue_constants.atom_type_num
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices = get_chi_atom_indices()
# Select atoms to compute chis. Shape: [num_res, chis=4, atoms=4].
atom_indices = utils.batched_gather(
params=chi_atom_indices, indices=aatype, axis=0)
# Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].
chi_angle_atoms = jax.tree_map(
lambda x: utils.batched_gather( # pylint: disable=g-long-lambda
params=x, indices=atom_indices, axis=-1, batch_dims=1), positions)
a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)]
chi_angles = geometry.dihedral_angle(a, b, c, d)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask = list(residue_constants.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = jnp.asarray(chi_angles_mask)
# Compute the chi angle mask. Shape [num_res, chis=4].
chi_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype,
axis=0)
# The chi_mask is set to 1 only when all necessary chi angle atoms were set.
# Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].
chi_angle_atoms_mask = utils.batched_gather(
params=mask, indices=atom_indices, axis=-1, batch_dims=1)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1])
chi_mask = chi_mask * chi_angle_atoms_mask.astype(jnp.float32)
return chi_angles, chi_mask
def make_transform_from_reference(
a_xyz: geometry.Vec3Array,
b_xyz: geometry.Vec3Array,
c_xyz: geometry.Vec3Array) -> geometry.Rigid3Array:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz,
a_xyz - b_xyz)
return geometry.Rigid3Array(rotation, b_xyz)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for all_atom."""
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.model import all_atom
from alphafold.model import r3
import numpy as np
L1_CLAMP_DISTANCE = 10
def get_identity_rigid(shape):
"""Returns identity rigid transform."""
ones = np.ones(shape)
zeros = np.zeros(shape)
rot = r3.Rots(ones, zeros, zeros,
zeros, ones, zeros,
zeros, zeros, ones)
trans = r3.Vecs(zeros, zeros, zeros)
return r3.Rigids(rot, trans)
def get_global_rigid_transform(rot_angle, translation, bcast_dims):
"""Returns rigid transform that globally rotates/translates by same amount."""
rot_angle = np.asarray(rot_angle)
translation = np.asarray(translation)
if bcast_dims:
for _ in range(bcast_dims):
rot_angle = np.expand_dims(rot_angle, 0)
translation = np.expand_dims(translation, 0)
sin_angle = np.sin(np.deg2rad(rot_angle))
cos_angle = np.cos(np.deg2rad(rot_angle))
ones = np.ones_like(sin_angle)
zeros = np.zeros_like(sin_angle)
rot = r3.Rots(ones, zeros, zeros,
zeros, cos_angle, -sin_angle,
zeros, sin_angle, cos_angle)
trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2])
return r3.Rigids(rot, trans)
class AllAtomTest(parameterized.TestCase, absltest.TestCase):
@parameterized.named_parameters(
('identity', 0, [0, 0, 0]),
('rot_90', 90, [0, 0, 0]),
('trans_10', 0, [0, 0, 10]),
('rot_174_trans_1', 174, [1, 1, 1]))
def test_frame_aligned_point_error_perfect_on_global_transform(
self, rot_angle, translation):
"""Tests global transform between target and preds gives perfect score."""
# pylint: disable=bad-whitespace
target_positions = np.array(
[[ 21.182, 23.095, 19.731],
[ 22.055, 20.919, 17.294],
[ 24.599, 20.005, 15.041],
[ 25.567, 18.214, 12.166],
[ 28.063, 17.082, 10.043],
[ 28.779, 15.569, 6.985],
[ 30.581, 13.815, 4.612],
[ 29.258, 12.193, 2.296]])
# pylint: enable=bad-whitespace
global_rigid_transform = get_global_rigid_transform(
rot_angle, translation, 1)
target_positions = r3.vecs_from_tensor(target_positions)
pred_positions = r3.rigids_mul_vecs(
global_rigid_transform, target_positions)
positions_mask = np.ones(target_positions.x.shape[0])
target_frames = get_identity_rigid(10)
pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames)
frames_mask = np.ones(10)
fape = all_atom.frame_aligned_point_error(
pred_frames, target_frames, frames_mask, pred_positions,
target_positions, positions_mask, L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE, epsilon=0)
self.assertAlmostEqual(fape, 0.)
@parameterized.named_parameters(
('identity',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
0.),
('shift_2.5',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]],
0.25),
('shift_5',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[5, 0, 0], [10, 0, 0], [15, 0, 0]],
0.5),
('shift_10',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[10, 0, 0], [15, 0, 0], [0, 0, 0]],
1.))
def test_frame_aligned_point_error_matches_expected(
self, target_positions, pred_positions, expected_alddt):
"""Tests score matches expected."""
target_frames = get_identity_rigid(2)
pred_frames = target_frames
frames_mask = np.ones(2)
target_positions = r3.vecs_from_tensor(np.array(target_positions))
pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
positions_mask = np.ones(target_positions.x.shape[0])
alddt = all_atom.frame_aligned_point_error(
pred_frames, target_frames, frames_mask, pred_positions,
target_positions, positions_mask, L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE, epsilon=0)
self.assertAlmostEqual(alddt, expected_alddt)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of common Haiku modules for use in protein folding."""
import numbers
from typing import Union, Sequence
import haiku as hk
import jax.numpy as jnp
import numpy as np
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
dtype=np.float32)
def get_initializer_scale(initializer_name, input_shape):
"""Get Initializer for weights and scale to multiply activations by."""
if initializer_name == 'zeros':
w_init = hk.initializers.Constant(0.0)
else:
# fan-in scaling
scale = 1.
for channel_dim in input_shape:
scale /= channel_dim
if initializer_name == 'relu':
scale *= 2
noise_scale = scale
stddev = np.sqrt(noise_scale)
# Adjust stddev for truncation.
stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)
return w_init
class Linear(hk.Module):
"""Protein folding specific Linear module.
This differs from the standard Haiku Linear in a few ways:
* It supports inputs and outputs of arbitrary rank
* Initializers are specified by strings
"""
def __init__(self,
num_output: Union[int, Sequence[int]],
initializer: str = 'linear',
num_input_dims: int = 1,
use_bias: bool = True,
bias_init: float = 0.,
precision = None,
name: str = 'linear'):
"""Constructs Linear Module.
Args:
num_output: Number of output channels. Can be tuple when outputting
multiple dimensions.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}
num_input_dims: Number of dimensions from the end to project.
use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias.
precision: What precision to use for matrix multiplication, defaults
to None.
name: Name of module, used for name scopes.
"""
super().__init__(name=name)
if isinstance(num_output, numbers.Integral):
self.output_shape = (num_output,)
else:
self.output_shape = tuple(num_output)
self.initializer = initializer
self.use_bias = use_bias
self.bias_init = bias_init
self.num_input_dims = num_input_dims
self.num_output_dims = len(self.output_shape)
self.precision = precision
def __call__(self, inputs):
"""Connects Module.
Args:
inputs: Tensor with at least num_input_dims dimensions.
Returns:
output of shape [...] + num_output.
"""
num_input_dims = self.num_input_dims
if self.num_input_dims > 0:
in_shape = inputs.shape[-self.num_input_dims:]
else:
in_shape = ()
weight_init = get_initializer_scale(self.initializer, in_shape)
in_letters = 'abcde'[:self.num_input_dims]
out_letters = 'hijkl'[:self.num_output_dims]
weight_shape = in_shape + self.output_shape
weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
weight_init)
equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'
output = jnp.einsum(equation, inputs, weights, precision=self.precision)
if self.use_bias:
bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
hk.initializers.Constant(self.bias_init))
output += bias
return output
class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""
def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)
param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)
param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)
if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)
out = super().__call__(x, scale=scale, offset=offset)
if is_bf16:
out = out.astype(jnp.bfloat16)
return out
\ No newline at end of file
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model config."""
import copy
from alphafold.model.tf import shape_placeholders
import ml_collections
NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""
if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.')
if 'multimer' in name:
cfg = copy.deepcopy(CONFIG_MULTIMER)
else:
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg
MODEL_PRESETS = {
'monomer': (
'model_1',
'model_2',
'model_3',
'model_4',
'model_5',
),
'monomer_ptm': (
'model_1_ptm',
'model_2_ptm',
'model_3_ptm',
'model_4_ptm',
'model_5_ptm',
),
'multimer': (
'model_1_multimer_v3',
'model_2_multimer_v3',
'model_3_multimer_v3',
'model_4_multimer_v3',
'model_5_multimer_v3',
),
}
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
CONFIG_DIFFS = {
'model_1': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.1
'data.common.max_extra_msa': 5120,
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True
},
'model_2': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.2
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True
},
'model_3': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.1
'data.common.max_extra_msa': 5120,
},
'model_4': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.2
'data.common.max_extra_msa': 5120,
},
'model_5': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.3
},
# The following models are fine-tuned from the corresponding models above
# with an additional predicted_aligned_error head that can produce
# predicted TM-score (pTM) and predicted aligned errors.
'model_1_ptm': {
'data.common.max_extra_msa': 5120,
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_2_ptm': {
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_3_ptm': {
'data.common.max_extra_msa': 5120,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_4_ptm': {
'data.common.max_extra_msa': 5120,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_1_multimer_v3': {},
'model_2_multimer_v3': {},
'model_3_multimer_v3': {},
'model_4_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
'model_5_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
}
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates = {
'model.embeddings_and_evoformer.num_msa': 252,
'model.embeddings_and_evoformer.num_extra_msa': 1152,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False,
}
CONFIG_DIFFS.update(
{f'model_{i}_multimer': common_updates for i in range(1, 6)})
CONFIG_DIFFS.update(
{f'model_{i}_multimer_v2': common_updates for i in range(1, 6)})
CONFIG = ml_collections.ConfigDict({
'data': {
'common': {
'masked_msa': {
'profile_prob': 0.1,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'max_extra_msa': 1024,
'msa_cluster_features': True,
'num_recycle': 3,
'reduce_msa_clusters_by_max_templates': False,
'resample_msa_in_recycling': True,
'template_features': [
'template_all_atom_positions', 'template_sum_probs',
'template_aatype', 'template_all_atom_masks',
'template_domain_names'
],
'unsupervised_features': [
'aatype', 'residue_index', 'sequence', 'msa', 'domain_name',
'num_alignments', 'seq_length', 'between_segment_residues',
'deletion_matrix'
],
'use_templates': False,
},
'eval': {
'feat': {
'aatype': [NUM_RES],
'all_atom_mask': [NUM_RES, None],
'all_atom_positions': [NUM_RES, None, None],
'alt_chi_angles': [NUM_RES, None],
'atom14_alt_gt_exists': [NUM_RES, None],
'atom14_alt_gt_positions': [NUM_RES, None, None],
'atom14_atom_exists': [NUM_RES, None],
'atom14_atom_is_ambiguous': [NUM_RES, None],
'atom14_gt_exists': [NUM_RES, None],
'atom14_gt_positions': [NUM_RES, None, None],
'atom37_atom_exists': [NUM_RES, None],
'backbone_affine_mask': [NUM_RES],
'backbone_affine_tensor': [NUM_RES, None],
'bert_mask': [NUM_MSA_SEQ, NUM_RES],
'chi_angles': [NUM_RES, None],
'chi_mask': [NUM_RES, None],
'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_row_mask': [NUM_EXTRA_SEQ],
'is_distillation': [],
'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
'msa_mask': [NUM_MSA_SEQ, NUM_RES],
'msa_row_mask': [NUM_MSA_SEQ],
'pseudo_beta': [NUM_RES, None],
'pseudo_beta_mask': [NUM_RES],
'random_crop_to_size_seed': [None],
'residue_index': [NUM_RES],
'residx_atom14_to_atom37': [NUM_RES, None],
'residx_atom37_to_atom14': [NUM_RES, None],
'resolution': [],
'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
'rigidgroups_group_exists': [NUM_RES, None],
'rigidgroups_group_is_ambiguous': [NUM_RES, None],
'rigidgroups_gt_exists': [NUM_RES, None],
'rigidgroups_gt_frames': [NUM_RES, None, None],
'seq_length': [],
'seq_mask': [NUM_RES],
'target_feat': [NUM_RES, None],
'template_aatype': [NUM_TEMPLATES, NUM_RES],
'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
'template_all_atom_positions': [
NUM_TEMPLATES, NUM_RES, None, None],
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
'template_backbone_affine_tensor': [
NUM_TEMPLATES, NUM_RES, None],
'template_mask': [NUM_TEMPLATES],
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
'template_sum_probs': [NUM_TEMPLATES, None],
'true_msa': [NUM_MSA_SEQ, NUM_RES]
},
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
},
},
'model': {
'embeddings_and_evoformer': {
'evoformer_num_block': 48,
'evoformer': {
'msa_row_attention_with_pair_bias': {
'dropout_rate': 0.15,
'gating': True,
'num_head': 8,
'orientation': 'per_row',
'shared_dropout': True
},
'msa_column_attention': {
'dropout_rate': 0.0,
'gating': True,
'num_head': 8,
'orientation': 'per_column',
'shared_dropout': True
},
'msa_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'outer_product_mean': {
'first': False,
'chunk_size': 128,
'dropout_rate': 0.0,
'num_outer_channel': 32,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'max_relative_feature': 32,
'msa_channel': 256,
'pair_channel': 128,
'prev_pos': {
'min_bin': 3.25,
'max_bin': 20.75,
'num_bins': 15
},
'recycle_features': True,
'recycle_pos': True,
'seq_channel': 384,
'template': {
'attention': {
'gating': False,
'key_dim': 64,
'num_head': 4,
'value_dim': 64
},
'dgram_features': {
'min_bin': 3.25,
'max_bin': 50.75,
'num_bins': 39
},
'embed_torsion_angles': False,
'enabled': False,
'template_pair_stack': {
'num_block': 2,
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'key_dim': 64,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True,
'value_dim': 64
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'key_dim': 64,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True,
'value_dim': 64
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 2,
'orientation': 'per_row',
'shared_dropout': True
}
},
'max_templates': 4,
'subbatch_size': 128,
'use_template_unit_vector': False,
}
},
'global_config': {
'deterministic': False,
'multimer_mode': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True,
'eval_dropout': False,
},
'heads': {
'distogram': {
'first_break': 2.3125,
'last_break': 21.6875,
'num_bins': 64,
'weight': 0.3
},
'predicted_aligned_error': {
# `num_bins - 1` bins uniformly space the
# [0, max_error_bin A] range.
# The final bin covers [max_error_bin A, +infty]
# 31A gives bins with 0.5A width.
'max_error_bin': 31.,
'num_bins': 64,
'num_channels': 128,
'filter_by_resolution': True,
'min_resolution': 0.1,
'max_resolution': 3.0,
'weight': 0.0,
},
'experimentally_resolved': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'weight': 0.01
},
'structure_module': {
'num_layer': 8,
'fape': {
'clamp_distance': 10.0,
'clamp_type': 'relu',
'loss_unit_distance': 10.0
},
'angle_norm_weight': 0.01,
'chi_weight': 0.5,
'clash_overlap_tolerance': 1.5,
'compute_in_graph_metrics': True,
'dropout': 0.1,
'num_channel': 384,
'num_head': 12,
'num_layer_in_transition': 3,
'num_point_qk': 4,
'num_point_v': 8,
'num_scalar_qk': 16,
'num_scalar_v': 16,
'position_scale': 10.0,
'sidechain': {
'atom_clamp_distance': 10.0,
'num_channel': 128,
'num_residual_block': 2,
'weight_frac': 0.5,
'length_scale': 10.,
},
'structural_violation_loss_weight': 1.0,
'violation_tolerance_factor': 12.0,
'weight': 1.0
},
'predicted_lddt': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 50,
'num_channels': 128,
'weight': 0.01
},
'masked_msa': {
'num_output': 23,
'weight': 2.0
},
},
'num_recycle': 3,
'resample_msa_in_recycling': True
},
})
CONFIG_MULTIMER = ml_collections.ConfigDict({
'model': {
'embeddings_and_evoformer': {
'evoformer_num_block': 48,
'evoformer': {
'msa_column_attention': {
'dropout_rate': 0.0,
'gating': True,
'num_head': 8,
'orientation': 'per_column',
'shared_dropout': True
},
'msa_row_attention_with_pair_bias': {
'dropout_rate': 0.15,
'gating': True,
'num_head': 8,
'orientation': 'per_row',
'shared_dropout': True
},
'msa_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'outer_product_mean': {
'chunk_size': 128,
'dropout_rate': 0.0,
'first': True,
'num_outer_channel': 32,
'orientation': 'per_row',
'shared_dropout': True
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': True,
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'num_msa': 508,
'num_extra_msa': 2048,
'masked_msa': {
'profile_prob': 0.1,
'replace_fraction': 0.15,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'use_chain_relative': True,
'max_relative_chain': 2,
'max_relative_idx': 32,
'seq_channel': 384,
'msa_channel': 256,
'pair_channel': 128,
'prev_pos': {
'max_bin': 20.75,
'min_bin': 3.25,
'num_bins': 15
},
'recycle_features': True,
'recycle_pos': True,
'template': {
'attention': {
'gating': False,
'num_head': 4
},
'dgram_features': {
'max_bin': 50.75,
'min_bin': 3.25,
'num_bins': 39
},
'enabled': True,
'max_templates': 4,
'num_channels': 64,
'subbatch_size': 128,
'template_pair_stack': {
'num_block': 2,
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 2,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': True,
}
}
},
},
'global_config': {
'bfloat16': True,
'bfloat16_output': False,
'deterministic': False,
'multimer_mode': True,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True,
'eval_dropout': False,
},
'heads': {
'distogram': {
'first_break': 2.3125,
'last_break': 21.6875,
'num_bins': 64,
'weight': 0.3
},
'experimentally_resolved': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'weight': 0.01
},
'masked_msa': {
'weight': 2.0
},
'predicted_aligned_error': {
'filter_by_resolution': True,
'max_error_bin': 31.0,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 64,
'num_channels': 128,
'weight': 0.1
},
'predicted_lddt': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 50,
'num_channels': 128,
'weight': 0.01
},
'structure_module': {
'angle_norm_weight': 0.01,
'chi_weight': 0.5,
'clash_overlap_tolerance': 1.5,
'dropout': 0.1,
'interface_fape': {
'atom_clamp_distance': 1000.0,
'loss_unit_distance': 20.0
},
'intra_chain_fape': {
'atom_clamp_distance': 10.0,
'loss_unit_distance': 10.0
},
'num_channel': 384,
'num_head': 12,
'num_layer': 8,
'num_layer_in_transition': 3,
'num_point_qk': 4,
'num_point_v': 8,
'num_scalar_qk': 16,
'num_scalar_v': 16,
'position_scale': 20.0,
'sidechain': {
'atom_clamp_distance': 10.0,
'loss_unit_distance': 10.0,
'num_channel': 128,
'num_residual_block': 2,
'weight_frac': 0.5
},
'structural_violation_loss_weight': 1.0,
'violation_tolerance_factor': 12.0,
'weight': 1.0
}
},
'num_ensemble_eval': 1,
'num_recycle': 20,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance': 0.5,
'resample_msa_in_recycling': True
}
})
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convenience functions for reading data."""
import io
import os
from alphafold.model import utils
import haiku as hk
import numpy as np
# Internal import (7716).
def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params:
"""Get the Haiku parameters from a model name."""
path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')
with open(path, 'rb') as f:
params = np.load(io.BytesIO(f.read()), allow_pickle=False)
return utils.flat_params_to_haiku(params)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to generate processed features."""
import copy
from typing import List, Mapping, Tuple
from alphafold.model.tf import input_pipeline
from alphafold.model.tf import proteins_dataset
import ml_collections
import numpy as np
import tensorflow.compat.v1 as tf
FeatureDict = Mapping[str, np.ndarray]
def make_data_config(
config: ml_collections.ConfigDict,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
"""Makes a data config for the input pipeline."""
cfg = copy.deepcopy(config.data)
feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
with cfg.unlocked():
cfg.eval.crop_size = num_res
return cfg, feature_names
def tf_example_to_features(tf_example: tf.train.Example,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
"""Converts tf_example to numpy feature dictionary."""
num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0])
cfg, feature_names = make_data_config(config, num_res=num_res)
if 'deletion_matrix_int' in set(tf_example.features.feature):
deletion_matrix_int = (
tf_example.features.feature['deletion_matrix_int'].int64_list.value)
feat = tf.train.Feature(float_list=tf.train.FloatList(
value=map(float, deletion_matrix_int)))
tf_example.features.feature['deletion_matrix'].CopyFrom(feat)
del tf_example.features.feature['deletion_matrix_int']
tf_graph = tf.Graph()
with tf_graph.as_default(), tf.device('/device:CPU:0'):
tf.compat.v1.set_random_seed(random_seed)
tensor_dict = proteins_dataset.create_tensor_dict(
raw_data=tf_example.SerializeToString(),
features=feature_names)
processed_batch = input_pipeline.process_tensors_from_config(
tensor_dict, cfg)
tf_graph.finalize()
with tf.Session(graph=tf_graph) as sess:
features = sess.run(processed_batch)
return {k: v for k, v in features.items() if v.dtype != 'O'}
def np_example_to_features(np_example: FeatureDict,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
"""Preprocesses NumPy feature dict using TF pipeline."""
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res)
if 'deletion_matrix_int' in np_example:
np_example['deletion_matrix'] = (
np_example.pop('deletion_matrix_int').astype(np.float32))
tf_graph = tf.Graph()
with tf_graph.as_default(), tf.device('/device:CPU:0'):
tf.compat.v1.set_random_seed(random_seed)
tensor_dict = proteins_dataset.np_to_tensor_dict(
np_example=np_example, features=feature_names)
processed_batch = input_pipeline.process_tensors_from_config(
tensor_dict, cfg)
tf_graph.finalize()
with tf.Session(graph=tf_graph) as sess:
features = sess.run(processed_batch)
return {k: v for k, v in features.items() if v.dtype != 'O'}
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and utilities for the structure module."""
import functools
from typing import Dict
from alphafold.common import residue_constants
from alphafold.model import all_atom
from alphafold.model import common_modules
from alphafold.model import prng
from alphafold.model import quat_affine
from alphafold.model import r3
from alphafold.model import utils
import haiku as hk
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
def squared_difference(x, y):
return jnp.square(x - y)
class InvariantPointAttention(hk.Module):
"""Invariant Point attention module.
The high-level idea is that this attention module works over a set of points
and associated orientations in 3D space (e.g. protein residues).
Each residue outputs a set of queries and keys as points in their local
reference frame. The attention is then defined as the euclidean distance
between the queries and keys in the global frame.
Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention"
"""
def __init__(self,
config,
global_config,
dist_epsilon=1e-8,
name='invariant_point_attention'):
"""Initialize.
Args:
config: Structure Module Config
global_config: Global Config of Model.
dist_epsilon: Small value to avoid NaN in distance calculation.
name: Haiku Module name.
"""
super().__init__(name=name)
self._dist_epsilon = dist_epsilon
self._zero_initialize_last = global_config.zero_init
self.config = config
self.global_config = global_config
def __call__(self, inputs_1d, inputs_2d, mask, affine):
"""Compute geometry-aware attention.
Given a set of query residues (defined by affines and associated scalar
features), this function computes geometry-aware attention between the
query residues and target residues.
The residues produce points in their local reference frame, which
are converted into the global frame in order to compute attention via
euclidean distance.
Equivalently, the target residues produce points in their local frame to be
used as attention values, which are converted into the query residues'
local frames.
Args:
inputs_1d: (N, C) 1D input embedding that is the basis for the
scalar queries.
inputs_2d: (N, M, C') 2D input embedding, used for biases and values.
mask: (N, 1) mask to indicate which elements of inputs_1d participate
in the attention.
affine: QuatAffine object describing the position and orientation of
every element in inputs_1d.
Returns:
Transformation of the input embedding.
"""
num_residues, _ = inputs_1d.shape
# Improve readability by removing a large number of 'self's.
num_head = self.config.num_head
num_scalar_qk = self.config.num_scalar_qk
num_point_qk = self.config.num_point_qk
num_scalar_v = self.config.num_scalar_v
num_point_v = self.config.num_point_v
num_output = self.config.num_channel
assert num_scalar_qk > 0
assert num_point_qk > 0
assert num_point_v > 0
# Construct scalar queries of shape:
# [num_query_residues, num_head, num_points]
q_scalar = common_modules.Linear(
num_head * num_scalar_qk, name='q_scalar')(
inputs_1d)
q_scalar = jnp.reshape(
q_scalar, [num_residues, num_head, num_scalar_qk])
# Construct scalar keys/values of shape:
# [num_target_residues, num_head, num_points]
kv_scalar = common_modules.Linear(
num_head * (num_scalar_v + num_scalar_qk), name='kv_scalar')(
inputs_1d)
kv_scalar = jnp.reshape(kv_scalar,
[num_residues, num_head,
num_scalar_v + num_scalar_qk])
k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1)
# Construct query points of shape:
# [num_residues, num_head, num_point_qk]
# First construct query points in local frame.
q_point_local = common_modules.Linear(
num_head * 3 * num_point_qk, name='q_point_local')(
inputs_1d)
q_point_local = jnp.split(q_point_local, 3, axis=-1)
# Project query points into global frame.
q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
# Reshape query point for later use.
q_point = [
jnp.reshape(x, [num_residues, num_head, num_point_qk])
for x in q_point_global]
# Construct key and value points.
# Key points have shape [num_residues, num_head, num_point_qk]
# Value points have shape [num_residues, num_head, num_point_v]
# Construct key and value points in local frame.
kv_point_local = common_modules.Linear(
num_head * 3 * (num_point_qk + num_point_v), name='kv_point_local')(
inputs_1d)
kv_point_local = jnp.split(kv_point_local, 3, axis=-1)
# Project key and value points into global frame.
kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1)
kv_point_global = [
jnp.reshape(x, [num_residues,
num_head, (num_point_qk + num_point_v)])
for x in kv_point_global]
# Split key and value points.
k_point, v_point = list(
zip(*[
jnp.split(x, [num_point_qk,], axis=-1)
for x in kv_point_global
]))
# We assume that all queries and keys come iid from N(0, 1) distribution
# and compute the variances of the attention logits.
# Each scalar pair (q, k) contributes Var q*k = 1
scalar_variance = max(num_scalar_qk, 1) * 1.
# Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
point_variance = max(num_point_qk, 1) * 9. / 2
# Allocate equal variance to scalar, point and attention 2d parts so that
# the sum is 1.
num_logit_terms = 3
scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance))
point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance))
attention_2d_weights = np.sqrt(1.0 / (num_logit_terms))
# Trainable per-head weights for points.
trainable_point_weights = jax.nn.softplus(hk.get_parameter(
'trainable_point_weights', shape=[num_head],
# softplus^{-1} (1)
init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))))
point_weights *= jnp.expand_dims(trainable_point_weights, axis=1)
v_point = [jnp.swapaxes(x, -2, -3) for x in v_point]
q_point = [jnp.swapaxes(x, -2, -3) for x in q_point]
k_point = [jnp.swapaxes(x, -2, -3) for x in k_point]
dist2 = [
squared_difference(qx[:, :, None, :], kx[:, None, :, :])
for qx, kx in zip(q_point, k_point)
]
dist2 = sum(dist2)
attn_qk_point = -0.5 * jnp.sum(
point_weights[:, None, None, :] * dist2, axis=-1)
v = jnp.swapaxes(v_scalar, -2, -3)
q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3)
k = jnp.swapaxes(k_scalar, -2, -3)
attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
attn_logits = attn_qk_scalar + attn_qk_point
attention_2d = common_modules.Linear(
num_head, name='attention_2d')(
inputs_2d)
attention_2d = jnp.transpose(attention_2d, [2, 0, 1])
attention_2d = attention_2d_weights * attention_2d
attn_logits += attention_2d
mask_2d = mask * jnp.swapaxes(mask, -1, -2)
attn_logits -= 1e5 * (1. - mask_2d)
# [num_head, num_query_residues, num_target_residues]
attn = jax.nn.softmax(attn_logits)
# [num_head, num_query_residues, num_head * num_scalar_v]
result_scalar = jnp.matmul(attn, v)
# For point result, implement matmul manually so that it will be a float32
# on TPU. This is equivalent to
# result_point_global = [jnp.einsum('bhqk,bhkc->bhqc', attn, vx)
# for vx in v_point]
# but on the TPU, doing the multiply and reduce_sum ensures the
# computation happens in float32 instead of bfloat16.
result_point_global = [jnp.sum(
attn[:, :, :, None] * vx[:, None, :, :],
axis=-2) for vx in v_point]
# [num_query_residues, num_head, num_head * num_(scalar|point)_v]
result_scalar = jnp.swapaxes(result_scalar, -2, -3)
result_point_global = [
jnp.swapaxes(x, -2, -3)
for x in result_point_global]
# Features used in the linear output projection. Should have the size
# [num_query_residues, ?]
output_features = []
result_scalar = jnp.reshape(
result_scalar, [num_residues, num_head * num_scalar_v])
output_features.append(result_scalar)
result_point_global = [
jnp.reshape(r, [num_residues, num_head * num_point_v])
for r in result_point_global]
result_point_local = affine.invert_point(result_point_global, extra_dims=1)
output_features.extend(result_point_local)
output_features.append(jnp.sqrt(self._dist_epsilon +
jnp.square(result_point_local[0]) +
jnp.square(result_point_local[1]) +
jnp.square(result_point_local[2])))
# Dimensions: h = heads, i and j = residues,
# c = inputs_2d channels
# Contraction happens over the second residue dimension, similarly to how
# the usual attention is performed.
result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d)
num_out = num_head * result_attention_over_2d.shape[-1]
output_features.append(
jnp.reshape(result_attention_over_2d,
[num_residues, num_out]))
final_init = 'zeros' if self._zero_initialize_last else 'linear'
final_act = jnp.concatenate(output_features, axis=-1)
return common_modules.Linear(
num_output,
initializer=final_init,
name='output_projection')(final_act)
class FoldIteration(hk.Module):
"""A single iteration of the main structure module loop.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21
First, each residue attends to all residues using InvariantPointAttention.
Then, we apply transition layers to update the hidden representations.
Finally, we use the hidden representations to produce an update to the
affine of each residue.
"""
def __init__(self, config, global_config,
name='fold_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
activations,
sequence_mask,
update_affine,
is_training,
initial_act,
safe_key=None,
static_feat_2d=None,
aatype=None):
c = self.config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
def safe_dropout_fn(tensor, safe_key):
return prng.safe_dropout(
tensor=tensor,
safe_key=safe_key,
rate=c.dropout,
is_deterministic=self.global_config.deterministic,
is_training=is_training)
affine = quat_affine.QuatAffine.from_tensor(activations['affine'])
act = activations['act']
attention_module = InvariantPointAttention(self.config, self.global_config)
# Attention
attn = attention_module(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=sequence_mask,
affine=affine)
act += attn
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='attention_layer_norm')(
act)
final_init = 'zeros' if self.global_config.zero_init else 'linear'
# Transition
input_act = act
for i in range(c.num_layer_in_transition):
init = 'relu' if i < c.num_layer_in_transition - 1 else final_init
act = common_modules.Linear(
c.num_channel,
initializer=init,
name='transition')(
act)
if i < c.num_layer_in_transition - 1:
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='transition_layer_norm')(act)
if update_affine:
# This block corresponds to
# Jumper et al. (2021) Alg. 23 "Backbone update"
affine_update_size = 6
# Affine update
affine_update = common_modules.Linear(
affine_update_size,
initializer=final_init,
name='affine_update')(
act)
affine = affine.pre_compose(affine_update)
sc = MultiRigidSidechain(c.sidechain, self.global_config)(
affine.scale_translation(c.position_scale), [act, initial_act], aatype)
outputs = {'affine': affine.to_tensor(), 'sc': sc}
affine = affine.apply_rotation_tensor_fn(jax.lax.stop_gradient)
new_activations = {
'act': act,
'affine': affine.to_tensor()
}
return new_activations, outputs
def generate_affines(representations, batch, config, global_config,
is_training, safe_key):
"""Generate predicted affines for a single chain.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
This is the main part of the structure module - it iteratively applies
folding to produce a set of predicted residue positions.
Args:
representations: Representations dictionary.
batch: Batch dictionary.
config: Config for the structure module.
global_config: Global config.
is_training: Whether the model is being trained.
safe_key: A prng.SafeKey object that wraps a PRNG key.
Returns:
A dictionary containing residue affines and sidechain positions.
"""
c = config
sequence_mask = batch['seq_mask'][:, None]
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='single_layer_norm')(
representations['single'])
initial_act = act
act = common_modules.Linear(
c.num_channel, name='initial_projection')(
act)
affine = generate_new_affine(sequence_mask)
fold_iteration = FoldIteration(
c, global_config, name='fold_iteration')
assert len(batch['seq_mask'].shape) == 1
activations = {'act': act,
'affine': affine.to_tensor(),
}
act_2d = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='pair_layer_norm')(
representations['pair'])
outputs = []
safe_keys = safe_key.split(c.num_layer)
for sub_key in safe_keys:
activations, output = fold_iteration(
activations,
initial_act=initial_act,
static_feat_2d=act_2d,
safe_key=sub_key,
sequence_mask=sequence_mask,
update_affine=True,
is_training=is_training,
aatype=batch['aatype'])
outputs.append(output)
output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
# Include the activations in the output dict for use by the LDDT-Head.
output['act'] = activations['act']
return output
class StructureModule(hk.Module):
"""StructureModule as a network head.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
"""
def __init__(self, config, global_config, compute_loss=True,
name='structure_module'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.compute_loss = compute_loss
def __call__(self, representations, batch, is_training,
safe_key=None):
c = self.config
ret = {}
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
output = generate_affines(
representations=representations,
batch=batch,
config=self.config,
global_config=self.global_config,
is_training=is_training,
safe_key=safe_key)
ret['representations'] = {'structure_module': output['act']}
ret['traj'] = output['affine'] * jnp.array([1.] * 4 +
[c.position_scale] * 3)
ret['sidechains'] = output['sc']
atom14_pred_positions = r3.vecs_to_tensor(output['sc']['atom_pos'])[-1]
ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3)
ret['final_atom14_mask'] = batch['atom14_atom_exists'] # (N, 14)
atom37_pred_positions = all_atom.atom14_to_atom37(atom14_pred_positions,
batch)
atom37_pred_positions *= batch['atom37_atom_exists'][:, :, None]
ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3)
ret['final_atom_mask'] = batch['atom37_atom_exists'] # (N, 37)
ret['final_affines'] = ret['traj'][-1]
if self.compute_loss:
return ret
else:
no_loss_features = ['final_atom_positions', 'final_atom_mask',
'representations']
no_loss_ret = {k: ret[k] for k in no_loss_features}
return no_loss_ret
def loss(self, value, batch):
ret = {'loss': 0.}
ret['metrics'] = {}
# If requested, compute in-graph metrics.
if self.config.compute_in_graph_metrics:
atom14_pred_positions = value['final_atom14_positions']
# Compute renaming and violations.
value.update(compute_renamed_ground_truth(batch, atom14_pred_positions))
value['violations'] = find_structural_violations(
batch, atom14_pred_positions, self.config)
# Several violation metrics:
violation_metrics = compute_violation_metrics(
batch=batch,
atom14_pred_positions=atom14_pred_positions,
violations=value['violations'])
ret['metrics'].update(violation_metrics)
backbone_loss(ret, batch, value, self.config)
if 'renamed_atom14_gt_positions' not in value:
value.update(compute_renamed_ground_truth(
batch, value['final_atom14_positions']))
sc_loss = sidechain_loss(batch, value, self.config)
ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] +
self.config.sidechain.weight_frac * sc_loss['loss'])
ret['sidechain_fape'] = sc_loss['fape']
supervised_chi_loss(ret, batch, value, self.config)
if self.config.structural_violation_loss_weight:
if 'violations' not in value:
value['violations'] = find_structural_violations(
batch, value['final_atom14_positions'], self.config)
structural_violation_loss(ret, batch, value, self.config)
return ret
def compute_renamed_ground_truth(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray,
) -> Dict[str, jnp.ndarray]:
"""Find optimal renaming of ground truth based on the predicted positions.
Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
Shape (N).
Args:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
* atom14_alt_gt_positions: Ground truth positions with renaming swaps.
* atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
renaming swaps.
* atom14_gt_exists: Mask for which atoms exist in ground truth.
* atom14_alt_gt_exists: Mask for which atoms exist in ground truth
after renaming.
* atom14_atom_exists: Mask for whether each atom is part of the given
amino acid type.
atom14_pred_positions: Array of atom positions in global frame with shape
(N, 14, 3).
Returns:
Dictionary containing:
alt_naming_is_better: Array with 1.0 where alternative swap is better.
renamed_atom14_gt_positions: Array of optimal ground truth positions
after renaming swaps are performed.
renamed_atom14_gt_exists: Mask after renaming swap is performed.
"""
alt_naming_is_better = all_atom.find_optimal_renaming(
atom14_gt_positions=batch['atom14_gt_positions'],
atom14_alt_gt_positions=batch['atom14_alt_gt_positions'],
atom14_atom_is_ambiguous=batch['atom14_atom_is_ambiguous'],
atom14_gt_exists=batch['atom14_gt_exists'],
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'])
renamed_atom14_gt_positions = (
(1. - alt_naming_is_better[:, None, None])
* batch['atom14_gt_positions']
+ alt_naming_is_better[:, None, None]
* batch['atom14_alt_gt_positions'])
renamed_atom14_gt_mask = (
(1. - alt_naming_is_better[:, None]) * batch['atom14_gt_exists']
+ alt_naming_is_better[:, None] * batch['atom14_alt_gt_exists'])
return {
'alt_naming_is_better': alt_naming_is_better, # (N)
'renamed_atom14_gt_positions': renamed_atom14_gt_positions, # (N, 14, 3)
'renamed_atom14_gt_exists': renamed_atom14_gt_mask, # (N, 14)
}
def backbone_loss(ret, batch, value, config):
"""Backbone FAPE Loss.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17
Args:
ret: Dictionary to write outputs into, needs to contain 'loss'.
batch: Batch, needs to contain 'backbone_affine_tensor',
'backbone_affine_mask'.
value: Dictionary containing structure module output, needs to contain
'traj', a trajectory of rigids.
config: Configuration of loss, should contain 'fape.clamp_distance' and
'fape.loss_unit_distance'.
"""
affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj'])
rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory)
gt_affine = quat_affine.QuatAffine.from_tensor(
batch['backbone_affine_tensor'])
gt_rigid = r3.rigids_from_quataffine(gt_affine)
backbone_mask = batch['backbone_affine_mask']
fape_loss_fn = functools.partial(
all_atom.frame_aligned_point_error,
l1_clamp_distance=config.fape.clamp_distance,
length_scale=config.fape.loss_unit_distance)
fape_loss_fn = jax.vmap(fape_loss_fn, (0, None, None, 0, None, None))
fape_loss = fape_loss_fn(rigid_trajectory, gt_rigid, backbone_mask,
rigid_trajectory.trans, gt_rigid.trans,
backbone_mask)
if 'use_clamped_fape' in batch:
# Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details"
use_clamped_fape = jnp.asarray(batch['use_clamped_fape'], jnp.float32)
unclamped_fape_loss_fn = functools.partial(
all_atom.frame_aligned_point_error,
l1_clamp_distance=None,
length_scale=config.fape.loss_unit_distance)
unclamped_fape_loss_fn = jax.vmap(unclamped_fape_loss_fn,
(0, None, None, 0, None, None))
fape_loss_unclamped = unclamped_fape_loss_fn(rigid_trajectory, gt_rigid,
backbone_mask,
rigid_trajectory.trans,
gt_rigid.trans,
backbone_mask)
fape_loss = (fape_loss * use_clamped_fape +
fape_loss_unclamped * (1 - use_clamped_fape))
ret['fape'] = fape_loss[-1]
ret['loss'] += jnp.mean(fape_loss)
def sidechain_loss(batch, value, config):
"""All Atom FAPE Loss using renamed rigids."""
# Rename Frames
# Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7
alt_naming_is_better = value['alt_naming_is_better']
renamed_gt_frames = (
(1. - alt_naming_is_better[:, None, None])
* batch['rigidgroups_gt_frames']
+ alt_naming_is_better[:, None, None]
* batch['rigidgroups_alt_gt_frames'])
flat_gt_frames = r3.rigids_from_tensor_flat12(
jnp.reshape(renamed_gt_frames, [-1, 12]))
flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists'], [-1])
flat_gt_positions = r3.vecs_from_tensor(
jnp.reshape(value['renamed_atom14_gt_positions'], [-1, 3]))
flat_positions_mask = jnp.reshape(value['renamed_atom14_gt_exists'], [-1])
# Compute frame_aligned_point_error score for the final layer.
pred_frames = value['sidechains']['frames']
pred_positions = value['sidechains']['atom_pos']
def _slice_last_layer_and_flatten(x):
return jnp.reshape(x[-1], [-1])
flat_pred_frames = jax.tree_map(
_slice_last_layer_and_flatten, pred_frames)
flat_pred_positions = jax.tree_map(
_slice_last_layer_and_flatten, pred_positions)
# FAPE Loss on sidechains
fape = all_atom.frame_aligned_point_error(
pred_frames=flat_pred_frames,
target_frames=flat_gt_frames,
frames_mask=flat_frames_mask,
pred_positions=flat_pred_positions,
target_positions=flat_gt_positions,
positions_mask=flat_positions_mask,
l1_clamp_distance=config.sidechain.atom_clamp_distance,
length_scale=config.sidechain.length_scale)
return {
'fape': fape,
'loss': fape}
def structural_violation_loss(ret, batch, value, config):
"""Computes loss for structural violations."""
assert config.sidechain.weight_frac
# Put all violation losses together to one large loss.
violations = value['violations']
num_atoms = jnp.sum(batch['atom14_atom_exists']).astype(jnp.float32)
ret['loss'] += (config.structural_violation_loss_weight * (
violations['between_residues']['bonds_c_n_loss_mean'] +
violations['between_residues']['angles_ca_c_n_loss_mean'] +
violations['between_residues']['angles_c_n_ca_loss_mean'] +
jnp.sum(
violations['between_residues']['clashes_per_atom_loss_sum'] +
violations['within_residues']['per_atom_loss_sum']) /
(1e-6 + num_atoms)))
def find_structural_violations(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
config: ml_collections.ConfigDict
):
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations = all_atom.between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32),
residue_index=batch['residue_index'].astype(jnp.float32),
aatype=batch['aatype'],
tolerance_factor_soft=config.violation_tolerance_factor,
tolerance_factor_hard=config.violation_tolerance_factor)
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
atomtype_radius = jnp.array([
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
])
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
atomtype_radius, batch['residx_atom14_to_atom37'])
# Compute the between residue clash loss.
between_residue_clashes = all_atom.between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch['residue_index'],
overlap_tolerance_soft=config.clash_overlap_tolerance,
overlap_tolerance_hard=config.clash_overlap_tolerance)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=config.clash_overlap_tolerance,
bond_length_tolerance_factor=config.violation_tolerance_factor)
atom14_dists_lower_bound = utils.batched_gather(
restype_atom14_bounds['lower_bound'], batch['aatype'])
atom14_dists_upper_bound = utils.batched_gather(
restype_atom14_bounds['upper_bound'], batch['aatype'])
within_residue_violations = all_atom.within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0)
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask = jnp.max(jnp.stack([
connection_violations['per_residue_violation_mask'],
jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1),
jnp.max(within_residue_violations['per_atom_violations'],
axis=-1)]), axis=0)
return {
'between_residues': {
'bonds_c_n_loss_mean':
connection_violations['c_n_loss_mean'], # ()
'angles_ca_c_n_loss_mean':
connection_violations['ca_c_n_loss_mean'], # ()
'angles_c_n_ca_loss_mean':
connection_violations['c_n_ca_loss_mean'], # ()
'connections_per_residue_loss_sum':
connection_violations['per_residue_loss_sum'], # (N)
'connections_per_residue_violation_mask':
connection_violations['per_residue_violation_mask'], # (N)
'clashes_mean_loss':
between_residue_clashes['mean_loss'], # ()
'clashes_per_atom_loss_sum':
between_residue_clashes['per_atom_loss_sum'], # (N, 14)
'clashes_per_atom_clash_mask':
between_residue_clashes['per_atom_clash_mask'], # (N, 14)
},
'within_residues': {
'per_atom_loss_sum':
within_residue_violations['per_atom_loss_sum'], # (N, 14)
'per_atom_violations':
within_residue_violations['per_atom_violations'], # (N, 14),
},
'total_per_residue_violations_mask':
per_residue_violations_mask, # (N)
}
def compute_violation_metrics(
batch: Dict[str, jnp.ndarray],
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
violations: Dict[str, jnp.ndarray],
) -> Dict[str, jnp.ndarray]:
"""Compute several metrics to assess the structural violations."""
ret = {}
extreme_ca_ca_violations = all_atom.extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32),
residue_index=batch['residue_index'].astype(jnp.float32))
ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations
ret['violations_between_residue_bond'] = utils.mask_mean(
mask=batch['seq_mask'],
value=violations['between_residues'][
'connections_per_residue_violation_mask'])
ret['violations_between_residue_clash'] = utils.mask_mean(
mask=batch['seq_mask'],
value=jnp.max(
violations['between_residues']['clashes_per_atom_clash_mask'],
axis=-1))
ret['violations_within_residue'] = utils.mask_mean(
mask=batch['seq_mask'],
value=jnp.max(
violations['within_residues']['per_atom_violations'], axis=-1))
ret['violations_per_residue'] = utils.mask_mean(
mask=batch['seq_mask'],
value=violations['total_per_residue_violations_mask'])
return ret
def supervised_chi_loss(ret, batch, value, config):
"""Computes loss for direct chi angle supervision.
Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss"
Args:
ret: Dictionary to write outputs into, needs to contain 'loss'.
batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'.
value: Dictionary containing structure module output, needs to contain
value['sidechains']['angles_sin_cos'] for angles and
value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized
angles.
config: Configuration of loss, should contain 'chi_weight' and
'angle_norm_weight', 'angle_norm_weight' scales angle norm term,
'chi_weight' scales torsion term.
"""
eps = 1e-6
sequence_mask = batch['seq_mask']
num_res = sequence_mask.shape[0]
chi_mask = batch['chi_mask'].astype(jnp.float32)
pred_angles = jnp.reshape(
value['sidechains']['angles_sin_cos'], [-1, num_res, 7, 2])
pred_angles = pred_angles[:, :, 3:]
residue_type_one_hot = jax.nn.one_hot(
batch['aatype'], residue_constants.restype_num + 1,
dtype=jnp.float32)[None]
chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot,
jnp.asarray(residue_constants.chi_pi_periodic))
true_chi = batch['chi_angles'][None]
sin_true_chi = jnp.sin(true_chi)
cos_true_chi = jnp.cos(true_chi)
sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1)
# This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic
shifted_mask = (1 - 2 * chi_pi_periodic)[..., None]
sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi
sq_chi_error = jnp.sum(
squared_difference(sin_cos_true_chi, pred_angles), -1)
sq_chi_error_shifted = jnp.sum(
squared_difference(sin_cos_true_chi_shifted, pred_angles), -1)
sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error)
ret['chi_loss'] = sq_chi_loss
ret['loss'] += config.chi_weight * sq_chi_loss
unnormed_angles = jnp.reshape(
value['sidechains']['unnormalized_angles_sin_cos'], [-1, num_res, 7, 2])
angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps)
norm_error = jnp.abs(angle_norm - 1.)
angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None],
value=norm_error)
ret['angle_norm_loss'] = angle_norm_loss
ret['loss'] += config.angle_norm_weight * angle_norm_loss
def generate_new_affine(sequence_mask):
num_residues, _ = sequence_mask.shape
quaternion = jnp.tile(
jnp.reshape(jnp.asarray([1., 0., 0., 0.]), [1, 4]),
[num_residues, 1])
translation = jnp.zeros([num_residues, 3])
return quat_affine.QuatAffine(quaternion, translation, unstack_inputs=True)
def l2_normalize(x, axis=-1, epsilon=1e-12):
return x / jnp.sqrt(
jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon))
class MultiRigidSidechain(hk.Module):
"""Class to make side chain atoms."""
def __init__(self, config, global_config, name='rigid_sidechain'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, affine, representations_list, aatype):
"""Predict side chains using multi-rigid representations.
Args:
affine: The affines for each residue (translations in angstroms).
representations_list: A list of activations to predict side chains from.
aatype: Amino acid types.
Returns:
Dict containing atom positions and frames (in angstroms).
"""
act = [
common_modules.Linear( # pylint: disable=g-complex-comprehension
self.config.num_channel,
name='input_projection')(jax.nn.relu(x))
for x in representations_list
]
# Sum the activation list (equivalent to concat then Linear).
act = sum(act)
final_init = 'zeros' if self.global_config.zero_init else 'linear'
# Mapping with some residual blocks.
for _ in range(self.config.num_residual_block):
old_act = act
act = common_modules.Linear(
self.config.num_channel,
initializer='relu',
name='resblock1')(
jax.nn.relu(act))
act = common_modules.Linear(
self.config.num_channel,
initializer=final_init,
name='resblock2')(
jax.nn.relu(act))
act += old_act
# Map activations to torsion angles. Shape: (num_res, 14).
num_res = act.shape[0]
unnormalized_angles = common_modules.Linear(
14, name='unnormalized_angles')(
jax.nn.relu(act))
unnormalized_angles = jnp.reshape(
unnormalized_angles, [num_res, 7, 2])
angles = l2_normalize(unnormalized_angles, axis=-1)
outputs = {
'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2)
'unnormalized_angles_sin_cos':
unnormalized_angles, # jnp.ndarray (N, 7, 2)
}
# Map torsion angles to frames.
backb_to_global = r3.rigids_from_quataffine(affine)
# Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates"
# r3.Rigids with shape (N, 8).
all_frames_to_global = all_atom.torsion_angles_to_frames(
aatype,
backb_to_global,
angles)
# Use frames and literature positions to create the final atom coordinates.
# r3.Vecs with shape (N, 14).
pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos(
aatype, all_frames_to_global)
outputs.update({
'atom_pos': pred_positions, # r3.Vecs (N, 14)
'frames': all_frames_to_global, # r3.Rigids (N, 8)
})
return outputs
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and utilities for the structure module in the multimer system."""
import functools
import numbers
from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union
from alphafold.common import residue_constants
from alphafold.model import all_atom_multimer
from alphafold.model import common_modules
from alphafold.model import geometry
from alphafold.model import modules
from alphafold.model import prng
from alphafold.model import utils
from alphafold.model.geometry import utils as geometry_utils
import haiku as hk
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
EPSILON = 1e-8
Float = Union[float, jnp.ndarray]
def squared_difference(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Computes Squared difference between two arrays."""
return jnp.square(x - y)
def make_backbone_affine(
positions: geometry.Vec3Array,
mask: jnp.ndarray,
aatype: jnp.ndarray,
) -> Tuple[geometry.Rigid3Array, jnp.ndarray]:
"""Make backbone Rigid3Array and mask."""
del aatype
a = residue_constants.atom_order['N']
b = residue_constants.atom_order['CA']
c = residue_constants.atom_order['C']
rigid_mask = (mask[:, a] * mask[:, b] * mask[:, c]).astype(
jnp.float32)
rigid = all_atom_multimer.make_transform_from_reference(
a_xyz=positions[:, a], b_xyz=positions[:, b], c_xyz=positions[:, c])
return rigid, rigid_mask
class QuatRigid(hk.Module):
"""Module for projecting Rigids via a quaternion."""
def __init__(self,
global_config: ml_collections.ConfigDict,
rigid_shape: Union[int, Iterable[int]] = tuple(),
full_quat: bool = False,
init: str = 'zeros',
name: str = 'quat_rigid'):
"""Module projecting a Rigid Object.
For this Module the Rotation is parametrized as a quaternion,
If 'full_quat' is True a 4 vector is produced for the rotation which is
normalized and treated as a quaternion.
When 'full_quat' is False a 3 vector is produced and the 1st component of
the quaternion is set to 1.
Args:
global_config: Global Config, used to set certain properties of underlying
Linear module, see common_modules.Linear for details.
rigid_shape: Shape of Rigids relative to shape of activations, e.g. when
activations have shape (n,) and this is (m,) output will be (n, m)
full_quat: Whether to parametrize rotation using full quaternion.
init: initializer to use, see common_modules.Linear for details
name: Name to use for module.
"""
self.init = init
self.global_config = global_config
if isinstance(rigid_shape, int):
self.rigid_shape = (rigid_shape,)
else:
self.rigid_shape = tuple(rigid_shape)
self.full_quat = full_quat
super(QuatRigid, self).__init__(name=name)
def __call__(self, activations: jnp.ndarray) -> geometry.Rigid3Array:
"""Executes Module.
This returns a set of rigid with the same shape as activations, projecting
the channel dimension, rigid_shape controls the trailing dimensions.
For example when activations is shape (12, 5) and rigid_shape is (3, 2)
then the shape of the output rigids will be (12, 3, 2).
This also supports passing in an empty tuple for rigid shape, in that case
the example would produce a rigid of shape (12,).
Args:
activations: Activations to use for projection, shape [..., num_channel]
Returns:
Rigid transformations with shape [...] + rigid_shape
"""
if self.full_quat:
rigid_dim = 7
else:
rigid_dim = 6
linear_dims = self.rigid_shape + (rigid_dim,)
rigid_flat = common_modules.Linear(
linear_dims,
initializer=self.init,
precision=jax.lax.Precision.HIGHEST,
name='rigid')(
activations)
rigid_flat = geometry_utils.unstack(rigid_flat)
if self.full_quat:
qw, qx, qy, qz = rigid_flat[:4]
translation = rigid_flat[4:]
else:
qx, qy, qz = rigid_flat[:3]
qw = jnp.ones_like(qx)
translation = rigid_flat[3:]
rotation = geometry.Rot3Array.from_quaternion(
qw, qx, qy, qz, normalize=True)
translation = geometry.Vec3Array(*translation)
return geometry.Rigid3Array(rotation, translation)
class PointProjection(hk.Module):
"""Given input reprensentation and frame produces points in global frame."""
def __init__(self,
num_points: Union[Iterable[int], int],
global_config: ml_collections.ConfigDict,
return_local_points: bool = False,
name: str = 'point_projection'):
"""Constructs Linear Module.
Args:
num_points: number of points to project. Can be tuple when outputting
multiple dimensions
global_config: Global Config, passed through to underlying Linear
return_local_points: Whether to return points in local frame as well.
name: name of module, used for name scopes.
"""
if isinstance(num_points, numbers.Integral):
self.num_points = (num_points,)
else:
self.num_points = tuple(num_points)
self.return_local_points = return_local_points
self.global_config = global_config
super().__init__(name=name)
def __call__(
self, activations: jnp.ndarray, rigids: geometry.Rigid3Array
) -> Union[geometry.Vec3Array, Tuple[geometry.Vec3Array, geometry.Vec3Array]]:
output_shape = self.num_points
output_shape = output_shape[:-1] + (3 * output_shape[-1],)
points_local = common_modules.Linear(
output_shape,
precision=jax.lax.Precision.HIGHEST,
name='point_projection')(
activations)
points_local = jnp.split(points_local, 3, axis=-1)
points_local = geometry.Vec3Array(*points_local)
rigids = rigids[(...,) + (None,) * len(output_shape)]
points_global = rigids.apply_to_point(points_local)
if self.return_local_points:
return points_global, points_local
else:
return points_global
class InvariantPointAttention(hk.Module):
"""Invariant point attention module.
The high-level idea is that this attention module works over a set of points
and associated orientations in 3D space (e.g. protein residues).
Each residue outputs a set of queries and keys as points in their local
reference frame. The attention is then defined as the euclidean distance
between the queries and keys in the global frame.
"""
def __init__(self,
config: ml_collections.ConfigDict,
global_config: ml_collections.ConfigDict,
dist_epsilon: float = 1e-8,
name: str = 'invariant_point_attention'):
"""Initialize.
Args:
config: iterative Fold Head Config
global_config: Global Config of Model.
dist_epsilon: Small value to avoid NaN in distance calculation.
name: Sonnet name.
"""
super().__init__(name=name)
self._dist_epsilon = dist_epsilon
self._zero_initialize_last = global_config.zero_init
self.config = config
self.global_config = global_config
def __call__(
self,
inputs_1d: jnp.ndarray,
inputs_2d: jnp.ndarray,
mask: jnp.ndarray,
rigid: geometry.Rigid3Array,
) -> jnp.ndarray:
"""Compute geometric aware attention.
Given a set of query residues (defined by affines and associated scalar
features), this function computes geometric aware attention between the
query residues and target residues.
The residues produce points in their local reference frame, which
are converted into the global frame to get attention via euclidean distance.
Equivalently the target residues produce points in their local frame to be
used as attention values, which are converted into the query residues local
frames.
Args:
inputs_1d: (N, C) 1D input embedding that is the basis for the
scalar queries.
inputs_2d: (N, M, C') 2D input embedding, used for biases values in the
attention between query_inputs_1d and target_inputs_1d.
mask: (N, 1) mask to indicate query_inputs_1d that participate in
the attention.
rigid: Rigid object describing the position and orientation of
every element in query_inputs_1d.
Returns:
Transformation of the input embedding.
"""
num_head = self.config.num_head
attn_logits = 0.
num_point_qk = self.config.num_point_qk
# Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
point_variance = max(num_point_qk, 1) * 9. / 2
point_weights = np.sqrt(1.0 / point_variance)
# This is equivalent to jax.nn.softplus, but avoids a bug in the test...
softplus = lambda x: jnp.logaddexp(x, jnp.zeros_like(x))
raw_point_weights = hk.get_parameter(
'trainable_point_weights',
shape=[num_head],
# softplus^{-1} (1)
init=hk.initializers.Constant(np.log(np.exp(1.) - 1.)))
# Trainable per-head weights for points.
trainable_point_weights = softplus(raw_point_weights)
point_weights *= trainable_point_weights
q_point = PointProjection([num_head, num_point_qk],
self.global_config,
name='q_point_projection')(inputs_1d,
rigid)
k_point = PointProjection([num_head, num_point_qk],
self.global_config,
name='k_point_projection')(inputs_1d,
rigid)
dist2 = geometry.square_euclidean_distance(
q_point[:, None, :, :], k_point[None, :, :, :], epsilon=0.)
attn_qk_point = -0.5 * jnp.sum(point_weights[:, None] * dist2, axis=-1)
attn_logits += attn_qk_point
num_scalar_qk = self.config.num_scalar_qk
# We assume that all queries and keys come iid from N(0, 1) distribution
# and compute the variances of the attention logits.
# Each scalar pair (q, k) contributes Var q*k = 1
scalar_variance = max(num_scalar_qk, 1) * 1.
scalar_weights = np.sqrt(1.0 / scalar_variance)
q_scalar = common_modules.Linear([num_head, num_scalar_qk],
use_bias=False,
name='q_scalar_projection')(
inputs_1d)
k_scalar = common_modules.Linear([num_head, num_scalar_qk],
use_bias=False,
name='k_scalar_projection')(
inputs_1d)
q_scalar *= scalar_weights
attn_logits += jnp.einsum('qhc,khc->qkh', q_scalar, k_scalar)
attention_2d = common_modules.Linear(
num_head, name='attention_2d')(inputs_2d)
attn_logits += attention_2d
mask_2d = mask * jnp.swapaxes(mask, -1, -2)
attn_logits -= 1e5 * (1. - mask_2d[..., None])
attn_logits *= np.sqrt(1. / 3) # Normalize by number of logit terms (3)
attn = jax.nn.softmax(attn_logits, axis=-2)
num_scalar_v = self.config.num_scalar_v
v_scalar = common_modules.Linear([num_head, num_scalar_v],
use_bias=False,
name='v_scalar_projection')(
inputs_1d)
# [num_query_residues, num_head, num_scalar_v]
result_scalar = jnp.einsum('qkh, khc->qhc', attn, v_scalar)
num_point_v = self.config.num_point_v
v_point = PointProjection([num_head, num_point_v],
self.global_config,
name='v_point_projection')(inputs_1d,
rigid)
result_point_global = jax.tree_map(
lambda x: jnp.sum(attn[..., None] * x, axis=-3), v_point[None])
# Features used in the linear output projection. Should have the size
# [num_query_residues, ?]
output_features = []
num_query_residues, _ = inputs_1d.shape
flat_shape = [num_query_residues, -1]
result_scalar = jnp.reshape(result_scalar, flat_shape)
output_features.append(result_scalar)
result_point_global = jax.tree_map(lambda r: jnp.reshape(r, flat_shape),
result_point_global)
result_point_local = rigid[..., None].apply_inverse_to_point(
result_point_global)
output_features.extend(
[result_point_local.x, result_point_local.y, result_point_local.z])
point_norms = result_point_local.norm(self._dist_epsilon)
output_features.append(point_norms)
# Dimensions: h = heads, i and j = residues,
# c = inputs_2d channels
# Contraction happens over the second residue dimension, similarly to how
# the usual attention is performed.
result_attention_over_2d = jnp.einsum('ijh, ijc->ihc', attn, inputs_2d)
output_features.append(jnp.reshape(result_attention_over_2d, flat_shape))
final_init = 'zeros' if self._zero_initialize_last else 'linear'
final_act = jnp.concatenate(output_features, axis=-1)
return common_modules.Linear(
self.config.num_channel,
initializer=final_init,
name='output_projection')(final_act)
class FoldIteration(hk.Module):
"""A single iteration of iterative folding.
First, each residue attends to all residues using InvariantPointAttention.
Then, we apply transition layers to update the hidden representations.
Finally, we use the hidden representations to produce an update to the
affine of each residue.
"""
def __init__(self,
config: ml_collections.ConfigDict,
global_config: ml_collections.ConfigDict,
name: str = 'fold_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(
self,
activations: Mapping[str, Any],
aatype: jnp.ndarray,
sequence_mask: jnp.ndarray,
update_rigid: bool,
is_training: bool,
initial_act: jnp.ndarray,
safe_key: Optional[prng.SafeKey] = None,
static_feat_2d: Optional[jnp.ndarray] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
c = self.config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
def safe_dropout_fn(tensor, safe_key):
return modules.apply_dropout(
tensor=tensor,
safe_key=safe_key,
rate=0.0 if self.global_config.deterministic else c.dropout,
is_training=is_training)
rigid = activations['rigid']
act = activations['act']
attention_module = InvariantPointAttention(
self.config, self.global_config)
# Attention
act += attention_module(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=sequence_mask,
rigid=rigid)
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
name='attention_layer_norm')(
act)
final_init = 'zeros' if self.global_config.zero_init else 'linear'
# Transition
input_act = act
for i in range(c.num_layer_in_transition):
init = 'relu' if i < c.num_layer_in_transition - 1 else final_init
act = common_modules.Linear(
c.num_channel,
initializer=init,
name='transition')(
act)
if i < c.num_layer_in_transition - 1:
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
name='transition_layer_norm')(act)
if update_rigid:
# Rigid update
rigid_update = QuatRigid(
self.global_config, init=final_init)(
act)
rigid = rigid @ rigid_update
sc = MultiRigidSidechain(c.sidechain, self.global_config)(
rigid.scale_translation(c.position_scale), [act, initial_act], aatype)
outputs = {'rigid': rigid, 'sc': sc}
rotation = jax.tree_map(jax.lax.stop_gradient, rigid.rotation)
rigid = geometry.Rigid3Array(rotation, rigid.translation)
new_activations = {
'act': act,
'rigid': rigid
}
return new_activations, outputs
def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
batch: Mapping[str, jnp.ndarray],
config: ml_collections.ConfigDict,
global_config: ml_collections.ConfigDict,
is_training: bool,
safe_key: prng.SafeKey
) -> Dict[str, Any]:
"""Generate predicted Rigid's for a single chain.
This is the main part of the iterative fold head - it iteratively applies
folding to produce a set of predicted residue positions.
Args:
representations: Embeddings dictionary.
batch: Batch dictionary.
config: config for the iterative fold head.
global_config: global config.
is_training: is training.
safe_key: A prng.SafeKey object that wraps a PRNG key.
Returns:
A dictionary containing residue Rigid's and sidechain positions.
"""
c = config
sequence_mask = batch['seq_mask'][:, None]
act = common_modules.LayerNorm(
axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')(
representations['single'])
initial_act = act
act = common_modules.Linear(
c.num_channel, name='initial_projection')(act)
# Sequence Mask has extra 1 at the end.
rigid = geometry.Rigid3Array.identity(sequence_mask.shape[:-1])
fold_iteration = FoldIteration(
c, global_config, name='fold_iteration')
assert len(batch['seq_mask'].shape) == 1
activations = {
'act':
act,
'rigid':
rigid
}
act_2d = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
name='pair_layer_norm')(
representations['pair'])
safe_keys = safe_key.split(c.num_layer)
outputs = []
for key in safe_keys:
activations, output = fold_iteration(
activations,
initial_act=initial_act,
static_feat_2d=act_2d,
aatype=batch['aatype'],
safe_key=key,
sequence_mask=sequence_mask,
update_rigid=True,
is_training=is_training,
)
outputs.append(output)
output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
# Pass along for LDDT-Head.
output['act'] = activations['act']
return output
class StructureModule(hk.Module):
"""StructureModule as a network head.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
"""
def __init__(self,
config: ml_collections.ConfigDict,
global_config: ml_collections.ConfigDict,
name: str = 'structure_module'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
representations: Mapping[str, jnp.ndarray],
batch: Mapping[str, Any],
is_training: bool,
safe_key: Optional[prng.SafeKey] = None,
compute_loss: bool = False
) -> Dict[str, Any]:
c = self.config
ret = {}
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
output = generate_monomer_rigids(
representations=representations,
batch=batch,
config=self.config,
global_config=self.global_config,
is_training=is_training,
safe_key=safe_key)
ret['traj'] = output['rigid'].scale_translation(c.position_scale).to_array()
ret['sidechains'] = output['sc']
ret['sidechains']['atom_pos'] = ret['sidechains']['atom_pos'].to_array()
ret['sidechains']['frames'] = ret['sidechains']['frames'].to_array()
if 'local_atom_pos' in ret['sidechains']:
ret['sidechains']['local_atom_pos'] = ret['sidechains'][
'local_atom_pos'].to_array()
ret['sidechains']['local_frames'] = ret['sidechains'][
'local_frames'].to_array()
aatype = batch['aatype']
seq_mask = batch['seq_mask']
atom14_pred_mask = all_atom_multimer.get_atom14_mask(
aatype) * seq_mask[:, None]
atom14_pred_positions = output['sc']['atom_pos'][-1]
ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3)
ret['final_atom14_mask'] = atom14_pred_mask # (N, 14)
atom37_mask = all_atom_multimer.get_atom37_mask(aatype) * seq_mask[:, None]
atom37_pred_positions = all_atom_multimer.atom14_to_atom37(
atom14_pred_positions, aatype)
atom37_pred_positions *= atom37_mask[:, :, None]
ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3)
ret['final_atom_mask'] = atom37_mask # (N, 37)
ret['final_rigids'] = ret['traj'][-1]
ret['act'] = output['act']
if compute_loss:
return ret
else:
no_loss_features = ['final_atom_positions', 'final_atom_mask', 'act']
no_loss_ret = {k: ret[k] for k in no_loss_features}
return no_loss_ret
def loss(self,
value: Mapping[str, Any],
batch: Mapping[str, Any]
) -> Dict[str, Any]:
raise NotImplementedError(
'This function should be called on a batch with reordered chains (see '
'Evans et al (2021) Section 7.3. Multi-Chain Permutation Alignment.')
ret = {'loss': 0.}
ret['metrics'] = {}
aatype = batch['aatype']
all_atom_positions = batch['all_atom_positions']
all_atom_positions = geometry.Vec3Array.from_array(all_atom_positions)
all_atom_mask = batch['all_atom_mask']
seq_mask = batch['seq_mask']
residue_index = batch['residue_index']
gt_rigid, gt_affine_mask = make_backbone_affine(all_atom_positions,
all_atom_mask,
aatype)
chi_angles, chi_mask = all_atom_multimer.compute_chi_angles(
all_atom_positions, all_atom_mask, aatype)
pred_mask = all_atom_multimer.get_atom14_mask(aatype)
pred_mask *= seq_mask[:, None]
pred_positions = value['final_atom14_positions']
pred_positions = geometry.Vec3Array.from_array(pred_positions)
gt_positions, gt_mask, alt_naming_is_better = compute_atom14_gt(
aatype, all_atom_positions, all_atom_mask, pred_positions)
violations = find_structural_violations(
aatype=aatype,
residue_index=residue_index,
mask=pred_mask,
pred_positions=pred_positions,
config=self.config,
asym_id=batch['asym_id'])
sidechains = value['sidechains']
gt_chi_angles = get_renamed_chi_angles(aatype, chi_angles,
alt_naming_is_better)
# Several violation metrics:
violation_metrics = compute_violation_metrics(
residue_index=residue_index,
mask=pred_mask,
seq_mask=seq_mask,
pred_positions=pred_positions,
violations=violations)
ret['metrics'].update(violation_metrics)
target_rigid = geometry.Rigid3Array.from_array(value['traj'])
gt_frames_mask = gt_affine_mask
# Split the loss into within-chain and between-chain components.
intra_chain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
intra_chain_bb_loss, intra_chain_fape = backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=gt_frames_mask,
gt_positions_mask=gt_affine_mask,
target_rigid=target_rigid,
config=self.config.intra_chain_fape,
pair_mask=intra_chain_mask)
interface_bb_loss, interface_fape = backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=gt_frames_mask,
gt_positions_mask=gt_affine_mask,
target_rigid=target_rigid,
config=self.config.interface_fape,
pair_mask=1. - intra_chain_mask)
bb_loss = intra_chain_bb_loss + interface_bb_loss
ret['fape'] = intra_chain_fape + interface_fape
ret['bb_loss'] = bb_loss
ret['loss'] += bb_loss
pred_frames = geometry.Rigid3Array.from_array(sidechains['frames'])
pred_positions = geometry.Vec3Array.from_array(sidechains['atom_pos'])
gt_sc_frames, gt_sc_frames_mask = compute_frames(
aatype=aatype,
all_atom_positions=all_atom_positions,
all_atom_mask=all_atom_mask,
use_alt=alt_naming_is_better)
sc_loss = sidechain_loss(
gt_frames=gt_sc_frames,
gt_frames_mask=gt_sc_frames_mask,
gt_positions=gt_positions,
gt_mask=gt_mask,
pred_frames=pred_frames,
pred_positions=pred_positions,
config=self.config)
ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] +
self.config.sidechain.weight_frac * sc_loss['loss'])
ret['sidechain_fape'] = sc_loss['fape']
unnormed_angles = sidechains['unnormalized_angles_sin_cos']
pred_angles = sidechains['angles_sin_cos']
sup_chi_loss, ret['chi_loss'], ret[
'angle_norm_loss'] = supervised_chi_loss(
sequence_mask=seq_mask,
target_chi_mask=chi_mask,
target_chi_angles=gt_chi_angles,
aatype=aatype,
pred_angles=pred_angles,
unnormed_angles=unnormed_angles,
config=self.config)
ret['loss'] += sup_chi_loss
if self.config.structural_violation_loss_weight:
ret['loss'] += structural_violation_loss(
mask=pred_mask, violations=violations, config=self.config)
return ret
def compute_atom14_gt(
aatype: jnp.ndarray,
all_atom_positions: geometry.Vec3Array,
all_atom_mask: jnp.ndarray,
pred_pos: geometry.Vec3Array
) -> Tuple[geometry.Vec3Array, jnp.ndarray, jnp.ndarray]:
"""Find atom14 positions, this includes finding the correct renaming."""
gt_positions, gt_mask = all_atom_multimer.atom37_to_atom14(
aatype, all_atom_positions,
all_atom_mask)
alt_gt_positions, alt_gt_mask = all_atom_multimer.get_alt_atom14(
aatype, gt_positions, gt_mask)
atom_is_ambiguous = all_atom_multimer.get_atom14_is_ambiguous(aatype)
alt_naming_is_better = all_atom_multimer.find_optimal_renaming(
gt_positions=gt_positions,
alt_gt_positions=alt_gt_positions,
atom_is_ambiguous=atom_is_ambiguous,
gt_exists=gt_mask,
pred_positions=pred_pos)
use_alt = alt_naming_is_better[:, None]
gt_mask = (1. - use_alt) * gt_mask + use_alt * alt_gt_mask
gt_positions = (1. - use_alt) * gt_positions + use_alt * alt_gt_positions
return gt_positions, gt_mask, alt_naming_is_better
def backbone_loss(gt_rigid: geometry.Rigid3Array,
gt_frames_mask: jnp.ndarray,
gt_positions_mask: jnp.ndarray,
target_rigid: geometry.Rigid3Array,
config: ml_collections.ConfigDict,
pair_mask: jnp.ndarray
) -> Tuple[Float, jnp.ndarray]:
"""Backbone FAPE Loss."""
loss_fn = functools.partial(
all_atom_multimer.frame_aligned_point_error,
l1_clamp_distance=config.atom_clamp_distance,
length_scale=config.loss_unit_distance)
loss_fn = jax.vmap(loss_fn, (0, None, None, 0, None, None, None))
fape = loss_fn(target_rigid, gt_rigid, gt_frames_mask,
target_rigid.translation, gt_rigid.translation,
gt_positions_mask, pair_mask)
return jnp.mean(fape), fape[-1]
def compute_frames(
aatype: jnp.ndarray,
all_atom_positions: geometry.Vec3Array,
all_atom_mask: jnp.ndarray,
use_alt: jnp.ndarray
) -> Tuple[geometry.Rigid3Array, jnp.ndarray]:
"""Compute Frames from all atom positions.
Args:
aatype: array of aatypes, int of [N]
all_atom_positions: Vector of all atom positions, shape [N, 37]
all_atom_mask: mask, shape [N]
use_alt: whether to use alternative orientation for ambiguous aatypes
shape [N]
Returns:
Rigid corresponding to Frames w shape [N, 8],
mask which Rigids are present w shape [N, 8]
"""
frames_batch = all_atom_multimer.atom37_to_frames(aatype, all_atom_positions,
all_atom_mask)
gt_frames = frames_batch['rigidgroups_gt_frames']
alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames']
use_alt = use_alt[:, None]
renamed_gt_frames = jax.tree_map(
lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames)
return renamed_gt_frames, frames_batch['rigidgroups_gt_exists']
def sidechain_loss(gt_frames: geometry.Rigid3Array,
gt_frames_mask: jnp.ndarray,
gt_positions: geometry.Vec3Array,
gt_mask: jnp.ndarray,
pred_frames: geometry.Rigid3Array,
pred_positions: geometry.Vec3Array,
config: ml_collections.ConfigDict
) -> Dict[str, jnp.ndarray]:
"""Sidechain Loss using cleaned up rigids."""
flat_gt_frames = jax.tree_map(jnp.ravel, gt_frames)
flat_frames_mask = jnp.ravel(gt_frames_mask)
flat_gt_positions = jax.tree_map(jnp.ravel, gt_positions)
flat_positions_mask = jnp.ravel(gt_mask)
# Compute frame_aligned_point_error score for the final layer.
def _slice_last_layer_and_flatten(x):
return jnp.ravel(x[-1])
flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames)
flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten,
pred_positions)
fape = all_atom_multimer.frame_aligned_point_error(
pred_frames=flat_pred_frames,
target_frames=flat_gt_frames,
frames_mask=flat_frames_mask,
pred_positions=flat_pred_positions,
target_positions=flat_gt_positions,
positions_mask=flat_positions_mask,
pair_mask=None,
length_scale=config.sidechain.loss_unit_distance,
l1_clamp_distance=config.sidechain.atom_clamp_distance)
return {
'fape': fape,
'loss': fape}
def structural_violation_loss(mask: jnp.ndarray,
violations: Mapping[str, Float],
config: ml_collections.ConfigDict
) -> Float:
"""Computes Loss for structural Violations."""
# Put all violation losses together to one large loss.
num_atoms = jnp.sum(mask).astype(jnp.float32) + 1e-6
between_residues = violations['between_residues']
within_residues = violations['within_residues']
return (config.structural_violation_loss_weight *
(between_residues['bonds_c_n_loss_mean'] +
between_residues['angles_ca_c_n_loss_mean'] +
between_residues['angles_c_n_ca_loss_mean'] +
jnp.sum(between_residues['clashes_per_atom_loss_sum'] +
within_residues['per_atom_loss_sum']) / num_atoms
))
def find_structural_violations(
aatype: jnp.ndarray,
residue_index: jnp.ndarray,
mask: jnp.ndarray,
pred_positions: geometry.Vec3Array, # (N, 14)
config: ml_collections.ConfigDict,
asym_id: jnp.ndarray,
) -> Dict[str, Any]:
"""Computes several checks for structural Violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations = all_atom_multimer.between_residue_bond_loss(
pred_atom_positions=pred_positions,
pred_atom_mask=mask.astype(jnp.float32),
residue_index=residue_index.astype(jnp.float32),
aatype=aatype,
tolerance_factor_soft=config.violation_tolerance_factor,
tolerance_factor_hard=config.violation_tolerance_factor)
# Compute the van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# shape (N, 14)
atomtype_radius = jnp.array([
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
])
residx_atom14_to_atom37 = all_atom_multimer.get_atom14_to_atom37_map(aatype)
atom_radius = mask * utils.batched_gather(atomtype_radius,
residx_atom14_to_atom37)
# Compute the between residue clash loss.
between_residue_clashes = all_atom_multimer.between_residue_clash_loss(
pred_positions=pred_positions,
atom_exists=mask,
atom_radius=atom_radius,
residue_index=residue_index,
overlap_tolerance_soft=config.clash_overlap_tolerance,
overlap_tolerance_hard=config.clash_overlap_tolerance,
asym_id=asym_id)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=config.clash_overlap_tolerance,
bond_length_tolerance_factor=config.violation_tolerance_factor)
dists_lower_bound = utils.batched_gather(restype_atom14_bounds['lower_bound'],
aatype)
dists_upper_bound = utils.batched_gather(restype_atom14_bounds['upper_bound'],
aatype)
within_residue_violations = all_atom_multimer.within_residue_violations(
pred_positions=pred_positions,
atom_exists=mask,
dists_lower_bound=dists_lower_bound,
dists_upper_bound=dists_upper_bound,
tighten_bounds_for_loss=0.0)
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask = jnp.max(jnp.stack([
connection_violations['per_residue_violation_mask'],
jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1),
jnp.max(within_residue_violations['per_atom_violations'],
axis=-1)]), axis=0)
return {
'between_residues': {
'bonds_c_n_loss_mean':
connection_violations['c_n_loss_mean'], # ()
'angles_ca_c_n_loss_mean':
connection_violations['ca_c_n_loss_mean'], # ()
'angles_c_n_ca_loss_mean':
connection_violations['c_n_ca_loss_mean'], # ()
'connections_per_residue_loss_sum':
connection_violations['per_residue_loss_sum'], # (N)
'connections_per_residue_violation_mask':
connection_violations['per_residue_violation_mask'], # (N)
'clashes_mean_loss':
between_residue_clashes['mean_loss'], # ()
'clashes_per_atom_loss_sum':
between_residue_clashes['per_atom_loss_sum'], # (N, 14)
'clashes_per_atom_clash_mask':
between_residue_clashes['per_atom_clash_mask'], # (N, 14)
},
'within_residues': {
'per_atom_loss_sum':
within_residue_violations['per_atom_loss_sum'], # (N, 14)
'per_atom_violations':
within_residue_violations['per_atom_violations'], # (N, 14),
},
'total_per_residue_violations_mask':
per_residue_violations_mask, # (N)
}
def compute_violation_metrics(
residue_index: jnp.ndarray,
mask: jnp.ndarray,
seq_mask: jnp.ndarray,
pred_positions: geometry.Vec3Array, # (N, 14)
violations: Mapping[str, jnp.ndarray],
) -> Dict[str, jnp.ndarray]:
"""Compute several metrics to assess the structural violations."""
ret = {}
between_residues = violations['between_residues']
within_residues = violations['within_residues']
extreme_ca_ca_violations = all_atom_multimer.extreme_ca_ca_distance_violations(
positions=pred_positions,
mask=mask.astype(jnp.float32),
residue_index=residue_index.astype(jnp.float32))
ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations
ret['violations_between_residue_bond'] = utils.mask_mean(
mask=seq_mask,
value=between_residues['connections_per_residue_violation_mask'])
ret['violations_between_residue_clash'] = utils.mask_mean(
mask=seq_mask,
value=jnp.max(between_residues['clashes_per_atom_clash_mask'], axis=-1))
ret['violations_within_residue'] = utils.mask_mean(
mask=seq_mask,
value=jnp.max(within_residues['per_atom_violations'], axis=-1))
ret['violations_per_residue'] = utils.mask_mean(
mask=seq_mask, value=violations['total_per_residue_violations_mask'])
return ret
def supervised_chi_loss(
sequence_mask: jnp.ndarray,
target_chi_mask: jnp.ndarray,
aatype: jnp.ndarray,
target_chi_angles: jnp.ndarray,
pred_angles: jnp.ndarray,
unnormed_angles: jnp.ndarray,
config: ml_collections.ConfigDict) -> Tuple[Float, Float, Float]:
"""Computes loss for direct chi angle supervision."""
eps = 1e-6
chi_mask = target_chi_mask.astype(jnp.float32)
pred_angles = pred_angles[:, :, 3:]
residue_type_one_hot = jax.nn.one_hot(
aatype, residue_constants.restype_num + 1, dtype=jnp.float32)[None]
chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot,
jnp.asarray(residue_constants.chi_pi_periodic))
true_chi = target_chi_angles[None]
sin_true_chi = jnp.sin(true_chi)
cos_true_chi = jnp.cos(true_chi)
sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1)
# This is -1 if chi is pi periodic and +1 if it's 2 pi periodic
shifted_mask = (1 - 2 * chi_pi_periodic)[..., None]
sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi
sq_chi_error = jnp.sum(
squared_difference(sin_cos_true_chi, pred_angles), -1)
sq_chi_error_shifted = jnp.sum(
squared_difference(sin_cos_true_chi_shifted, pred_angles), -1)
sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error)
angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps)
norm_error = jnp.abs(angle_norm - 1.)
angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None],
value=norm_error)
loss = (config.chi_weight * sq_chi_loss
+ config.angle_norm_weight * angle_norm_loss)
return loss, sq_chi_loss, angle_norm_loss
def l2_normalize(x: jnp.ndarray,
axis: int = -1,
epsilon: float = 1e-12
) -> jnp.ndarray:
return x / jnp.sqrt(
jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon))
def get_renamed_chi_angles(aatype: jnp.ndarray,
chi_angles: jnp.ndarray,
alt_is_better: jnp.ndarray
) -> jnp.ndarray:
"""Return renamed chi angles."""
chi_angle_is_ambiguous = utils.batched_gather(
jnp.array(residue_constants.chi_pi_periodic, dtype=jnp.float32), aatype)
alt_chi_angles = chi_angles + np.pi * chi_angle_is_ambiguous
# Map back to [-pi, pi].
alt_chi_angles = alt_chi_angles - 2 * np.pi * (alt_chi_angles > np.pi).astype(
jnp.float32)
alt_is_better = alt_is_better[:, None]
return (1. - alt_is_better) * chi_angles + alt_is_better * alt_chi_angles
class MultiRigidSidechain(hk.Module):
"""Class to make side chain atoms."""
def __init__(self,
config: ml_collections.ConfigDict,
global_config: ml_collections.ConfigDict,
name: str = 'rigid_sidechain'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
rigid: geometry.Rigid3Array,
representations_list: Iterable[jnp.ndarray],
aatype: jnp.ndarray
) -> Dict[str, Any]:
"""Predict sidechains using multi-rigid representations.
Args:
rigid: The Rigid's for each residue (translations in angstoms)
representations_list: A list of activations to predict sidechains from.
aatype: amino acid types.
Returns:
dict containing atom positions and frames (in angstrom)
"""
act = [
common_modules.Linear( # pylint: disable=g-complex-comprehension
self.config.num_channel,
name='input_projection')(jax.nn.relu(x))
for x in representations_list]
# Sum the activation list (equivalent to concat then Conv1D)
act = sum(act)
final_init = 'zeros' if self.global_config.zero_init else 'linear'
# Mapping with some residual blocks.
for _ in range(self.config.num_residual_block):
old_act = act
act = common_modules.Linear(
self.config.num_channel,
initializer='relu',
name='resblock1')(
jax.nn.relu(act))
act = common_modules.Linear(
self.config.num_channel,
initializer=final_init,
name='resblock2')(
jax.nn.relu(act))
act += old_act
# Map activations to torsion angles.
# [batch_size, num_res, 14]
num_res = act.shape[0]
unnormalized_angles = common_modules.Linear(
14, name='unnormalized_angles')(
jax.nn.relu(act))
unnormalized_angles = jnp.reshape(
unnormalized_angles, [num_res, 7, 2])
angles = l2_normalize(unnormalized_angles, axis=-1)
outputs = {
'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2)
'unnormalized_angles_sin_cos':
unnormalized_angles, # jnp.ndarray (N, 7, 2)
}
# Map torsion angles to frames.
# geometry.Rigid3Array with shape (N, 8)
all_frames_to_global = all_atom_multimer.torsion_angles_to_frames(
aatype,
rigid,
angles)
# Use frames and literature positions to create the final atom coordinates.
# geometry.Vec3Array with shape (N, 14)
pred_positions = all_atom_multimer.frames_and_literature_positions_to_atom14_pos(
aatype, all_frames_to_global)
outputs.update({
'atom_pos': pred_positions, # geometry.Vec3Array (N, 14)
'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8)
})
return outputs
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from alphafold.model.geometry import rigid_matrix_vector
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
StructOfArray = struct_of_array.StructOfArray
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from __future__ import annotations
from typing import Union
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import vector
import jax
import jax.numpy as jnp
Float = Union[float, jnp.ndarray]
VERSION = '0.1'
@struct_of_array.StructOfArray(same_dtype=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
new_rotation = self.rotation @ other.rotation
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def inverse(self) -> Rigid3Array:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape),
self.translation)
return Rigid3Array(rot, trans)
@classmethod
def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, dtype=dtype),
vector.Vec3Array.zeros(shape, dtype=dtype)) # pytype: disable=wrong-arg-count # trace-all-classes
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_array(self):
rot_array = self.rotation.to_array()
vec_array = self.translation.to_array()
return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
vec = vector.Vec3Array.from_array(array[..., -1])
return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
assert array.shape[-1] == 4
assert array.shape[-2] == 4
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes
def __getstate__(self):
return (VERSION, (self.rotation, self.translation))
def __setstate__(self, state):
version, (rot, trans) = state
del version
object.__setattr__(self, 'rotation', rot)
object.__setattr__(self, 'translation', trans)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from __future__ import annotations
import dataclasses
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import utils
from alphafold.model.geometry import vector
import jax
import jax.numpy as jnp
import numpy as np
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
VERSION = '0.1'
@struct_of_array.StructOfArray(same_dtype=True)
class Rot3Array:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32})
xy: jnp.ndarray
xz: jnp.ndarray
yx: jnp.ndarray
yy: jnp.ndarray
yz: jnp.ndarray
zx: jnp.ndarray
zy: jnp.ndarray
zz: jnp.ndarray
__array_ufunc__ = None
def inverse(self) -> Rot3Array:
"""Returns inverse of Rot3Array."""
return Rot3Array(self.xx, self.yx, self.zx,
self.xy, self.yy, self.zy,
self.xz, self.yz, self.zz)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies Rot3Array to point."""
return vector.Vec3Array(
self.xx * point.x + self.xy * point.y + self.xz * point.z,
self.yx * point.x + self.yy * point.y + self.yz * point.z,
self.zx * point.x + self.zy * point.y + self.zz * point.z)
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
@classmethod
def identity(cls, shape, dtype=jnp.float32) -> Rot3Array:
"""Returns identity of given shape."""
ones = jnp.ones(shape, dtype=dtype)
zeros = jnp.zeros(shape, dtype=dtype)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_two_vectors(cls, e0: vector.Vec3Array,
e1: vector.Vec3Array) -> Rot3Array:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0 = e0.normalized()
# make e1 perpendicular to e0.
c = e1.dot(e0)
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_array(cls, array: jnp.ndarray) -> Rot3Array:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
unstacked = utils.unstack(array, axis=-2)
unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], [])
return cls(*unstacked)
def to_array(self) -> jnp.ndarray:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return jnp.stack(
[jnp.stack([self.xx, self.xy, self.xz], axis=-1),
jnp.stack([self.yx, self.yy, self.yz], axis=-1),
jnp.stack([self.zx, self.zy, self.zz], axis=-1)],
axis=-2)
@classmethod
def from_quaternion(cls,
w: jnp.ndarray,
x: jnp.ndarray,
y: jnp.ndarray,
z: jnp.ndarray,
normalize: bool = True,
epsilon: float = 1e-6) -> Rot3Array:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2))
w *= inv_norm
x *= inv_norm
y *= inv_norm
z *= inv_norm
xx = 1 - 2 * (jnp.square(y) + jnp.square(z))
xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y)
yx = 2 * (x * y + w * z)
yy = 1 - 2 * (jnp.square(x) + jnp.square(z))
yz = 2 * (y * z - w * x)
zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x)
zz = 1 - 2 * (jnp.square(x) + jnp.square(y))
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array:
"""Samples uniform random Rot3Array according to Haar Measure."""
quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype)
quats = utils.unstack(quat_array)
return cls.from_quaternion(*quats)
def __getstate__(self):
return (VERSION,
[np.asarray(getattr(self, field)) for field in COMPONENTS])
def __setstate__(self, state):
version, state = state
del version
for i, field in enumerate(COMPONENTS):
object.__setattr__(self, field, state[i])
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class decorator to represent (nested) struct of arrays."""
import dataclasses
import jax
def get_item(instance, key):
sliced = {}
for field in get_array_fields(instance):
num_trailing_dims = field.metadata.get('num_trailing_dims', 0)
this_key = key
if isinstance(key, tuple) and Ellipsis in this_key:
this_key += (slice(None),) * num_trailing_dims
sliced[field.name] = getattr(instance, field.name)[this_key]
return dataclasses.replace(instance, **sliced)
@property
def get_shape(instance):
"""Returns Shape for given instance of dataclass."""
first_field = dataclasses.fields(instance)[0]
num_trailing_dims = first_field.metadata.get('num_trailing_dims', None)
value = getattr(instance, first_field.name)
if num_trailing_dims:
return value.shape[:-num_trailing_dims]
else:
return value.shape
def get_len(instance):
"""Returns length for given instance of dataclass."""
shape = instance.shape
if shape:
return shape[0]
else:
raise TypeError('len() of unsized object') # Match jax.numpy behavior.
@property
def get_dtype(instance):
"""Returns Dtype for given instance of dataclass."""
fields = dataclasses.fields(instance)
sets_dtype = [
field.name for field in fields if field.metadata.get('sets_dtype', False)
]
if sets_dtype:
assert len(sets_dtype) == 1, 'at most field can set dtype'
field_value = getattr(instance, sets_dtype[0])
elif instance.same_dtype:
field_value = getattr(instance, fields[0].name)
else:
# Should this be Value Error?
raise AttributeError('Trying to access Dtype on Struct of Array without'
'either "same_dtype" or field setting dtype')
if hasattr(field_value, 'dtype'):
return field_value.dtype
else:
# Should this be Value Error?
raise AttributeError(f'field_value {field_value} does not have dtype')
def replace(instance, **kwargs):
return dataclasses.replace(instance, **kwargs)
def post_init(instance):
"""Validate instance has same shapes & dtypes."""
array_fields = get_array_fields(instance)
arrays = list(get_array_fields(instance, return_values=True).values())
first_field = array_fields[0]
# These slightly weird constructions about checking whether the leaves are
# actual arrays is since e.g. vmap internally relies on being able to
# construct pytree's with object() as leaves, this would break the checking
# as such we are only validating the object when the entries in the dataclass
# Are arrays or other dataclasses of arrays.
try:
dtype = instance.dtype
except AttributeError:
dtype = None
if dtype is not None:
first_shape = instance.shape
for array, field in zip(arrays, array_fields):
field_shape = array.shape
num_trailing_dims = field.metadata.get('num_trailing_dims', None)
if num_trailing_dims:
array_shape = array.shape
field_shape = array_shape[:-num_trailing_dims]
msg = (f'field {field} should have number of trailing dims'
' {num_trailing_dims}')
assert len(array_shape) == len(first_shape) + num_trailing_dims, msg
else:
field_shape = array.shape
shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't "
f"match shape {first_shape} of field {first_field}")
assert field_shape == first_shape, shape_msg
field_dtype = array.dtype
allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', [])
if allowed_metadata_dtypes:
msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}'
assert field_dtype in allowed_metadata_dtypes, msg
if 'dtype' in field.metadata:
target_dtype = field.metadata['dtype']
else:
target_dtype = dtype
msg = f'Dtype is {field_dtype} but must be {target_dtype}'
assert field_dtype == target_dtype, msg
def flatten(instance):
"""Flatten Struct of Array instance."""
array_likes = list(get_array_fields(instance, return_values=True).values())
flat_array_likes = []
inner_treedefs = []
num_arrays = []
for array_like in array_likes:
flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like)
inner_treedefs.append(inner_treedef)
flat_array_likes += flat_array_like
num_arrays.append(len(flat_array_like))
metadata = get_metadata_fields(instance, return_values=True)
metadata = type(instance).metadata_cls(**metadata)
return flat_array_likes, (inner_treedefs, metadata, num_arrays)
def make_metadata_class(cls):
metadata_fields = get_fields(cls,
lambda x: x.metadata.get('is_metadata', False))
metadata_cls = dataclasses.make_dataclass(
cls_name='Meta' + cls.__name__,
fields=[(field.name, field.type, field) for field in metadata_fields],
frozen=True,
eq=True)
return metadata_cls
def get_fields(cls_or_instance, filterfn, return_values=False):
fields = dataclasses.fields(cls_or_instance)
fields = [field for field in fields if filterfn(field)]
if return_values:
return {
field.name: getattr(cls_or_instance, field.name) for field in fields
}
else:
return fields
def get_array_fields(cls, return_values=False):
return get_fields(
cls,
lambda x: not x.metadata.get('is_metadata', False),
return_values=return_values)
def get_metadata_fields(cls, return_values=False):
return get_fields(
cls,
lambda x: x.metadata.get('is_metadata', False),
return_values=return_values)
class StructOfArray:
"""Class Decorator for Struct Of Arrays."""
def __init__(self, same_dtype=True):
self.same_dtype = same_dtype
def __call__(self, cls):
cls.__array_ufunc__ = None
cls.replace = replace
cls.same_dtype = self.same_dtype
cls.dtype = get_dtype
cls.shape = get_shape
cls.__len__ = get_len
cls.__getitem__ = get_item
cls.__post_init__ = post_init
new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args
# pytree claims to require metadata to be hashable, not sure why,
# But making derived dataclass that can just hold metadata
new_cls.metadata_cls = make_metadata_class(new_cls)
def unflatten(aux, data):
inner_treedefs, metadata, num_arrays = aux
array_fields = [field.name for field in get_array_fields(new_cls)]
value_dict = {}
array_start = 0
for num_array, inner_treedef, array_field in zip(num_arrays,
inner_treedefs,
array_fields):
value_dict[array_field] = jax.tree_util.tree_unflatten(
inner_treedef, data[array_start:array_start + num_array])
array_start += num_array
metadata_fields = get_metadata_fields(new_cls)
for field in metadata_fields:
value_dict[field.name] = getattr(metadata, field.name)
return new_cls(**value_dict)
jax.tree_util.register_pytree_node(
nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten)
return new_cls
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import dataclasses
from alphafold.model.geometry import rigid_matrix_vector
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import vector
import jax.numpy as jnp
import numpy as np
def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
matrix2: rotation_matrix.Rot3Array):
for field in dataclasses.fields(rotation_matrix.Rot3Array):
field = field.name
np.testing.assert_array_equal(
getattr(matrix1, field), getattr(matrix2, field))
def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
mat2: rotation_matrix.Rot3Array):
np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6)
def assert_array_equal_to_rotation_matrix(array: jnp.ndarray,
matrix: rotation_matrix.Rot3Array):
"""Check that array and Matrix match."""
np.testing.assert_array_equal(matrix.xx, array[..., 0, 0])
np.testing.assert_array_equal(matrix.xy, array[..., 0, 1])
np.testing.assert_array_equal(matrix.xz, array[..., 0, 2])
np.testing.assert_array_equal(matrix.yx, array[..., 1, 0])
np.testing.assert_array_equal(matrix.yy, array[..., 1, 1])
np.testing.assert_array_equal(matrix.yz, array[..., 1, 2])
np.testing.assert_array_equal(matrix.zx, array[..., 2, 0])
np.testing.assert_array_equal(matrix.zy, array[..., 2, 1])
np.testing.assert_array_equal(matrix.zz, array[..., 2, 2])
def assert_array_close_to_rotation_matrix(array: jnp.ndarray,
matrix: rotation_matrix.Rot3Array):
np.testing.assert_array_almost_equal(matrix.to_array(), array, 6)
def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_array_equal(vec1.x, vec2.x)
np.testing.assert_array_equal(vec1.y, vec2.y)
np.testing.assert_array_equal(vec1.z, vec2.z)
def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-5, rtol=0.)
np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-5, rtol=0.)
np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-5, rtol=0.)
def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.)
def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
np.testing.assert_array_equal(vec.to_array(), array)
def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_equal(rot, rigid.rotation)
assert_vectors_equal(trans, rigid.translation)
def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_close(rot, rigid.rotation)
assert_vectors_close(trans, rigid.translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
from typing import List
import jax.numpy as jnp
def unstack(value: jnp.ndarray, axis: int = -1) -> List[jnp.ndarray]:
return [jnp.squeeze(v, axis=axis)
for v in jnp.split(value, value.shape[axis], axis=axis)]
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from __future__ import annotations
import dataclasses
from typing import Union
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import utils
import jax
import jax.numpy as jnp
import numpy as np
Float = Union[float, jnp.ndarray]
VERSION = '0.1'
@struct_of_array.StructOfArray(same_dtype=True)
class Vec3Array:
"""Vec3Array in 3 dimensional Space implemented as struct of arrays.
This is done in order to improve performance and precision.
On TPU small matrix multiplications are very suboptimal and will waste large
compute ressources, furthermore any matrix multiplication on tpu happen in
mixed bfloat16/float32 precision, which is often undesirable when handling
physical coordinates.
In most cases this will also be faster on cpu's/gpu's since it allows for
easier use of vector instructions.
"""
x: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32})
y: jnp.ndarray
z: jnp.ndarray
def __post_init__(self):
if hasattr(self.x, 'dtype'):
assert self.x.dtype == self.y.dtype
assert self.x.dtype == self.z.dtype
assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
def __add__(self, other: Vec3Array) -> Vec3Array:
return jax.tree_map(lambda x, y: x + y, self, other)
def __sub__(self, other: Vec3Array) -> Vec3Array:
return jax.tree_map(lambda x, y: x - y, self, other)
def __mul__(self, other: Float) -> Vec3Array:
return jax.tree_map(lambda x: x * other, self)
def __rmul__(self, other: Float) -> Vec3Array:
return self * other
def __truediv__(self, other: Float) -> Vec3Array:
return jax.tree_map(lambda x: x / other, self)
def __neg__(self) -> Vec3Array:
return jax.tree_map(lambda x: -x, self)
def __pos__(self) -> Vec3Array:
return jax.tree_map(lambda x: x, self)
def cross(self, other: Vec3Array) -> Vec3Array:
"""Compute cross product between 'self' and 'other'."""
new_x = self.y * other.z - self.z * other.y
new_y = self.z * other.x - self.x * other.z
new_z = self.x * other.y - self.y * other.x
return Vec3Array(new_x, new_y, new_z)
def dot(self, other: Vec3Array) -> Float:
"""Compute dot product between 'self' and 'other'."""
return self.x * other.x + self.y * other.y + self.z * other.z
def norm(self, epsilon: float = 1e-6) -> Float:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self)
if epsilon:
norm2 = jnp.maximum(norm2, epsilon**2)
return jnp.sqrt(norm2)
def norm2(self):
return self.dot(self)
def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
"""Return unit vector with optional clipping."""
return self / self.norm(epsilon)
@classmethod
def zeros(cls, shape, dtype=jnp.float32):
"""Return Vec3Array corresponding to zeros of given shape."""
return cls(
jnp.zeros(shape, dtype), jnp.zeros(shape, dtype),
jnp.zeros(shape, dtype)) # pytype: disable=wrong-arg-count # trace-all-classes
def to_array(self) -> jnp.ndarray:
return jnp.stack([self.x, self.y, self.z], axis=-1)
@classmethod
def from_array(cls, array):
return cls(*utils.unstack(array))
def __getstate__(self):
return (VERSION,
[np.asarray(self.x),
np.asarray(self.y),
np.asarray(self.z)])
def __setstate__(self, state):
version, state = state
del version
for i, letter in enumerate('xyz'):
object.__setattr__(self, letter, state[i])
def square_euclidean_distance(vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6) -> Float:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference = vec1 - vec2
distance = difference.dot(difference)
if epsilon:
distance = jnp.maximum(distance, epsilon)
return distance
def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.dot(vector2)
def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.cross(vector2)
def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
return vector.norm(epsilon)
def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
return vector.normalized(epsilon)
def euclidean_distance(vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6) -> Float:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
distance = jnp.sqrt(distance_sq)
return distance
def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
d: Vec3Array) -> Float:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1 = a - b
v2 = b - c
v3 = d - c
c1 = v1.cross(v2)
c2 = v3.cross(v2)
c3 = c2.cross(c1)
v2_mag = v2.norm()
return jnp.arctan2(c3.dot(v2), v2_mag * c1.dot(c2))
def random_gaussian_vector(shape, key, dtype=jnp.float32):
vec_array = jax.random.normal(key, shape + (3,), dtype)
return Vec3Array.from_array(vec_array)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Function to stack repeats of a layer function without shared parameters."""
import collections
import contextlib
import functools
import inspect
from typing import Any, Callable, Optional, Tuple, Union
import haiku as hk
import jax
import jax.numpy as jnp
LayerStackCarry = collections.namedtuple('LayerStackCarry', ['x', 'rng'])
LayerStackScanned = collections.namedtuple('LayerStackScanned',
['i', 'args_ys'])
# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the
# exact same type. We cannot express this with `typing`. So we just use it
# to inform the user. In reality, the typing below will accept anything.
NestedArray = Any
WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]]
def _check_no_varargs(f):
if list(inspect.signature(
f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL:
raise ValueError(
'The function `f` should not have any `varargs` (that is *args) '
'argument. Instead, it should only use explicit positional'
'arguments.')
@contextlib.contextmanager
def nullcontext():
yield
def maybe_with_rng(key):
if key is not None:
return hk.with_rng(key)
else:
return nullcontext()
def maybe_fold_in(key, data):
if key is not None:
return jax.random.fold_in(key, data)
else:
return None
class _LayerStack(hk.Module):
"""Module to compose parameterized functions, implemented as a scan."""
def __init__(self,
count: int,
unroll: int,
name: Optional[str] = None):
"""Iterate a function `f` `count` times, with non-shared parameters."""
super().__init__(name=name)
self._count = count
self._unroll = unroll
def __call__(self, x, *args_ys):
count = self._count
if hk.running_init():
# At initialization time, we run just one layer but add an extra first
# dimension to every initialized tensor, making sure to use different
# random keys for different slices.
def creator(next_creator, shape, dtype, init, context):
del context
def multi_init(shape, dtype):
assert shape[0] == count
key = hk.maybe_next_rng_key()
def rng_context_init(slice_idx):
slice_key = maybe_fold_in(key, slice_idx)
with maybe_with_rng(slice_key):
return init(shape[1:], dtype)
return jax.vmap(rng_context_init)(jnp.arange(count))
return next_creator((count,) + tuple(shape), dtype, multi_init)
def getter(next_getter, value, context):
trailing_dims = len(context.original_shape) + 1
sliced_value = jax.lax.index_in_dim(
value, index=0, axis=value.ndim - trailing_dims, keepdims=False)
return next_getter(sliced_value)
with hk.experimental.custom_creator(
creator), hk.experimental.custom_getter(getter):
if len(args_ys) == 1 and args_ys[0] is None:
args0 = (None,)
else:
args0 = [
jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False)
for ys in args_ys
]
x, z = self._call_wrapped(x, *args0)
if z is None:
return x, z
# Broadcast state to hold each layer state.
def broadcast_state(layer_state):
return jnp.broadcast_to(
layer_state, [count,] + list(layer_state.shape))
zs = jax.tree_util.tree_map(broadcast_state, z)
return x, zs
else:
# Use scan during apply, threading through random seed so that it's
# unique for each layer.
def layer(carry: LayerStackCarry, scanned: LayerStackScanned):
rng = carry.rng
def getter(next_getter, value, context):
# Getter slices the full param at the current loop index.
trailing_dims = len(context.original_shape) + 1
assert value.shape[value.ndim - trailing_dims] == count, (
f'Attempting to use a parameter stack of size '
f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of '
f'size {count}.')
sliced_value = jax.lax.dynamic_index_in_dim(
value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False)
return next_getter(sliced_value)
with hk.experimental.custom_getter(getter):
if rng is None:
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
else:
rng, rng_ = jax.random.split(rng)
with hk.with_rng(rng_):
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
return LayerStackCarry(x=out_x, rng=rng), z
carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key())
scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32),
args_ys=args_ys)
carry, zs = hk.scan(
layer, carry, scanned, length=count, unroll=self._unroll)
return carry.x, zs
def _call_wrapped(self,
x: jnp.ndarray,
*args,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
raise NotImplementedError()
class _LayerStackNoState(_LayerStack):
"""_LayerStack impl with no per-layer state provided to the function."""
def __init__(self,
f: WrappedFn,
count: int,
unroll: int,
name: Optional[str] = None):
super().__init__(count=count, unroll=unroll, name=name)
_check_no_varargs(f)
self._f = f
@hk.transparent
def _call_wrapped(self, args, y):
del y
ret = self._f(*args)
if len(args) == 1:
# If the function takes a single argument, the wrapped function receives
# a tuple of length 1, and therefore it must return a tuple of length 1.
ret = (ret,)
return ret, None
class _LayerStackWithState(_LayerStack):
"""_LayerStack impl with per-layer state provided to the function."""
def __init__(self,
f: WrappedFn,
count: int,
unroll: int,
name: Optional[str] = None):
super().__init__(count=count, unroll=unroll, name=name)
self._f = f
@hk.transparent
def _call_wrapped(self, x, *args):
return self._f(x, *args)
def layer_stack(num_layers: int,
with_state=False,
unroll: int = 1,
name: Optional[str] = None):
"""Utility to wrap a Haiku function and recursively apply it to an input.
A function is valid if it uses only explicit position parameters, and
its return type matches its input type. The position parameters can be
arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note
that kwargs are not supported, neither are functions with variable number
of parameters (specified by `*args`).
If `with_state=False` then the new, wrapped function can be understood as
performing the following:
```
for i in range(num_layers):
x = f(x)
return x
```
And if `with_state=True`, assuming `f` takes two arguments on top of `x`:
```
for i in range(num_layers):
x, zs[i] = f(x, ys_0[i], ys_1[i])
return x, zs
```
The code using `layer_stack` for the above function would be:
```
def f(x, y_0, y_1):
...
return new_x, z
x, zs = layer_stack.layer_stack(num_layers,
with_state=True)(f)(x, ys_0, ys_1)
```
Crucially, any parameters created inside `f` will not be shared across
iterations.
Args:
num_layers: The number of times to iterate the wrapped function.
with_state: Whether or not to pass per-layer state to the wrapped function.
unroll: the unroll used by `scan`.
name: Name of the Haiku context.
Returns:
Callable that will produce a layer stack when called with a valid function.
"""
def iterate(f):
if with_state:
@functools.wraps(f)
def wrapped(x, *args):
for ys in args:
assert ys.shape[0] == num_layers
return _LayerStackWithState(
f, num_layers, unroll=unroll, name=name)(x, *args)
else:
_check_no_varargs(f)
@functools.wraps(f)
def wrapped(*args):
ret = _LayerStackNoState(
f, num_layers, unroll=unroll, name=name)(args, None)[0]
if len(args) == 1:
# If the function takes a single argument, we must also return a
# single value, and not a tuple of length 1.
ret = ret[0]
return ret
return wrapped
return iterate
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for layer_stack."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.model import layer_stack
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import scipy.stats
# Suffixes applied by Haiku for repeated module names.
suffixes = [''] + [f'_{i}' for i in range(1, 100)]
def _slice_layers_params(layers_params):
sliced_layers_params = {}
for k, v in layers_params.items():
for inner_k in v:
for var_slice, suffix in zip(v[inner_k], suffixes):
k_new = k.split('/')[-1] + suffix
if k_new not in sliced_layers_params:
sliced_layers_params[k_new] = {}
sliced_layers_params[k_new][inner_k] = var_slice
return sliced_layers_params
class LayerStackTest(parameterized.TestCase):
@parameterized.parameters([1, 2, 4])
def test_layer_stack(self, unroll):
"""Compare layer_stack to the equivalent unrolled stack.
Tests that the layer_stack application of a Haiku layer function is
equivalent to repeatedly applying the layer function in an unrolled loop.
Args:
unroll: Number of unrolled layers.
"""
num_layers = 20
def inner_fn(x):
x += hk.Linear(100, name='linear1')(x)
x += hk.Linear(100, name='linear2')(x)
return x
def outer_fn_unrolled(x):
for _ in range(num_layers):
x = inner_fn(x)
return x
def outer_fn_layer_stack(x):
stack = layer_stack.layer_stack(num_layers, unroll=unroll)(inner_fn)
return stack(x)
unrolled_fn = hk.transform(outer_fn_unrolled)
layer_stack_fn = hk.transform(outer_fn_layer_stack)
x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])
rng_init = jax.random.PRNGKey(42)
params = layer_stack_fn.init(rng_init, x)
sliced_params = _slice_layers_params(params)
unrolled_pred = unrolled_fn.apply(sliced_params, None, x)
layer_stack_pred = layer_stack_fn.apply(params, None, x)
np.testing.assert_allclose(unrolled_pred, layer_stack_pred)
def test_layer_stack_multi_args(self):
"""Compare layer_stack to the equivalent unrolled stack.
Similar to `test_layer_stack`, but use a function that takes more than one
argument.
"""
num_layers = 20
def inner_fn(x, y):
x_out = x + hk.Linear(100, name='linear1')(y)
y_out = y + hk.Linear(100, name='linear2')(x)
return x_out, y_out
def outer_fn_unrolled(x, y):
for _ in range(num_layers):
x, y = inner_fn(x, y)
return x, y
def outer_fn_layer_stack(x, y):
stack = layer_stack.layer_stack(num_layers)(inner_fn)
return stack(x, y)
unrolled_fn = hk.transform(outer_fn_unrolled)
layer_stack_fn = hk.transform(outer_fn_layer_stack)
x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])
y = jax.random.uniform(jax.random.PRNGKey(1), [10, 256, 100])
rng_init = jax.random.PRNGKey(42)
params = layer_stack_fn.init(rng_init, x, y)
sliced_params = _slice_layers_params(params)
unrolled_x, unrolled_y = unrolled_fn.apply(sliced_params, None, x, y)
layer_stack_x, layer_stack_y = layer_stack_fn.apply(params, None, x, y)
np.testing.assert_allclose(unrolled_x, layer_stack_x)
np.testing.assert_allclose(unrolled_y, layer_stack_y)
def test_layer_stack_no_varargs(self):
"""Test an error is raised when using a function with varargs."""
class VarArgsModule(hk.Module):
"""When used, this module should cause layer_stack to raise an Error."""
def __call__(self, *args):
return args
class NoVarArgsModule(hk.Module):
"""This module should be fine to use with layer_stack."""
def __call__(self, x):
return x
def build_and_init_stack(module_class):
def stack_fn(x):
module = module_class()
return layer_stack.layer_stack(1)(module)(x)
stack = hk.without_apply_rng(hk.transform(stack_fn))
stack.init(jax.random.PRNGKey(1729), jnp.ones([5]))
build_and_init_stack(NoVarArgsModule)
with self.assertRaisesRegex(
ValueError, 'The function `f` should not have any `varargs`'):
build_and_init_stack(VarArgsModule)
@parameterized.parameters([1, 2, 4])
def test_layer_stack_grads(self, unroll):
"""Compare layer_stack gradients to the equivalent unrolled stack.
Tests that the layer_stack application of a Haiku layer function is
equivalent to repeatedly applying the layer function in an unrolled loop.
Args:
unroll: Number of unrolled layers.
"""
num_layers = 20
def inner_fn(x):
x += hk.Linear(100, name='linear1')(x)
x += hk.Linear(100, name='linear2')(x)
return x
def outer_fn_unrolled(x):
for _ in range(num_layers):
x = inner_fn(x)
return x
def outer_fn_layer_stack(x):
stack = layer_stack.layer_stack(num_layers, unroll=unroll)(inner_fn)
return stack(x)
unrolled_fn = hk.transform(outer_fn_unrolled)
layer_stack_fn = hk.transform(outer_fn_layer_stack)
x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])
rng_init = jax.random.PRNGKey(42)
params = layer_stack_fn.init(rng_init, x)
sliced_params = _slice_layers_params(params)
unrolled_grad = jax.grad(
lambda p, x: jnp.mean(unrolled_fn.apply(p, None, x)))(sliced_params, x)
layer_stack_grad = jax.grad(
lambda p, x: jnp.mean(layer_stack_fn.apply(p, None, x)))(params, x)
assert_fn = functools.partial(
np.testing.assert_allclose, atol=1e-4, rtol=1e-4)
jax.tree_map(assert_fn, unrolled_grad,
_slice_layers_params(layer_stack_grad))
def test_random(self):
"""Random numbers should be handled correctly."""
n = 100
@hk.transform
@layer_stack.layer_stack(n)
def add_random(x):
x = x + jax.random.normal(hk.next_rng_key())
return x
# Evaluate a bunch of times
key, *keys = jax.random.split(jax.random.PRNGKey(7), 1024 + 1)
params = add_random.init(key, 0.)
apply_fn = jax.jit(add_random.apply)
values = [apply_fn(params, key, 0.) for key in keys]
# Should be roughly N(0, sqrt(n))
cdf = scipy.stats.norm(scale=np.sqrt(n)).cdf
_, p = scipy.stats.kstest(values, cdf)
self.assertLess(0.3, p)
def test_threading(self):
"""Test @layer_stack when the function gets per-layer state."""
n = 5
@layer_stack.layer_stack(n, with_state=True)
def f(x, y):
x = x + y * jax.nn.one_hot(y, len(x)) / 10
return x, 2 * y
@hk.without_apply_rng
@hk.transform
def g(x, ys):
x, zs = f(x, ys)
# Check here to catch issues at init time
self.assertEqual(zs.shape, (n,))
return x, zs
rng = jax.random.PRNGKey(7)
x = np.zeros(n)
ys = np.arange(n).astype(np.float32)
params = g.init(rng, x, ys)
x, zs = g.apply(params, x, ys)
self.assertTrue(np.allclose(x, [0, .1, .2, .3, .4]))
self.assertTrue(np.all(zs == 2 * ys))
def test_nested_stacks(self):
def stack_fn(x):
def layer_fn(x):
return hk.Linear(100)(x)
outer_fn = layer_stack.layer_stack(10)(layer_fn)
layer_outer = layer_stack.layer_stack(20)(outer_fn)
return layer_outer(x)
hk_mod = hk.transform(stack_fn)
apply_rng, init_rng = jax.random.split(jax.random.PRNGKey(0))
params = hk_mod.init(init_rng, jnp.zeros([10, 100]))
hk_mod.apply(params, apply_rng, jnp.zeros([10, 100]))
p, = params.values()
assert p['w'].shape == (10, 20, 100, 100)
assert p['b'].shape == (10, 20, 100)
def test_with_state_multi_args(self):
"""Test layer_stack with state with multiple arguments."""
width = 4
batch_size = 5
stack_height = 3
def f_with_multi_args(x, a, b):
return hk.Linear(
width, w_init=hk.initializers.Constant(
jnp.eye(width)))(x) * a + b, None
@hk.without_apply_rng
@hk.transform
def hk_fn(x):
return layer_stack.layer_stack(
stack_height,
with_state=True)(f_with_multi_args)(x, jnp.full([stack_height], 2.),
jnp.ones([stack_height]))
x = jnp.zeros([batch_size, width])
key_seq = hk.PRNGSequence(19)
params = hk_fn.init(next(key_seq), x)
output, z = hk_fn.apply(params, x)
self.assertIsNone(z)
self.assertEqual(output.shape, (batch_size, width))
np.testing.assert_equal(output, np.full([batch_size, width], 7.))
def test_with_container_state(self):
width = 2
batch_size = 2
stack_height = 3
def f_with_container_state(x):
hk_layer = hk.Linear(
width, w_init=hk.initializers.Constant(jnp.eye(width)))
layer_output = hk_layer(x)
layer_state = {
'raw_output': layer_output,
'output_projection': jnp.sum(layer_output)
}
return layer_output + jnp.ones_like(layer_output), layer_state
@hk.without_apply_rng
@hk.transform
def hk_fn(x):
return layer_stack.layer_stack(
stack_height,
with_state=True)(f_with_container_state)(x)
x = jnp.zeros([batch_size, width])
key_seq = hk.PRNGSequence(19)
params = hk_fn.init(next(key_seq), x)
output, z = hk_fn.apply(params, x)
self.assertEqual(z['raw_output'].shape, (stack_height, batch_size, width))
self.assertEqual(output.shape, (batch_size, width))
self.assertEqual(z['output_projection'].shape, (stack_height,))
np.testing.assert_equal(np.sum(z['output_projection']), np.array(12.))
np.testing.assert_equal(
np.all(z['raw_output'] == np.array([0., 1., 2.])[..., None, None]),
np.array(True))
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""lDDT protein distance score."""
import jax.numpy as jnp
def lddt(predicted_points,
true_points,
true_points_mask,
cutoff=15.,
per_residue=False):
"""Measure (approximate) lDDT for a batch of coordinates.
lDDT reference:
Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local
superposition-free score for comparing protein structures and models using
distance difference tests. Bioinformatics 29, 2722–2728 (2013).
lDDT is a measure of the difference between the true distance matrix and the
distance matrix of the predicted points. The difference is computed only on
points closer than cutoff *in the true structure*.
This function does not compute the exact lDDT value that the original paper
describes because it does not include terms for physical feasibility
(e.g. bond length violations). Therefore this is only an approximate
lDDT score.
Args:
predicted_points: (batch, length, 3) array of predicted 3D points
true_points: (batch, length, 3) array of true 3D points
true_points_mask: (batch, length, 1) binary-valued float array. This mask
should be 1 for points that exist in the true points.
cutoff: Maximum distance for a pair of points to be included
per_residue: If true, return score for each residue. Note that the overall
lDDT is not exactly the mean of the per_residue lDDT's because some
residues have more contacts than others.
Returns:
An (approximate, see above) lDDT score in the range 0-1.
"""
assert len(predicted_points.shape) == 3
assert predicted_points.shape[-1] == 3
assert true_points_mask.shape[-1] == 1
assert len(true_points_mask.shape) == 3
# Compute true and predicted distance matrices.
dmat_true = jnp.sqrt(1e-10 + jnp.sum(
(true_points[:, :, None] - true_points[:, None, :])**2, axis=-1))
dmat_predicted = jnp.sqrt(1e-10 + jnp.sum(
(predicted_points[:, :, None] -
predicted_points[:, None, :])**2, axis=-1))
dists_to_score = (
(dmat_true < cutoff).astype(jnp.float32) * true_points_mask *
jnp.transpose(true_points_mask, [0, 2, 1]) *
(1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction.
)
# Shift unscored distances to be far away.
dist_l1 = jnp.abs(dmat_true - dmat_predicted)
# True lDDT uses a number of fixed bins.
# We ignore the physical plausibility correction to lDDT, though.
score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) +
(dist_l1 < 1.0).astype(jnp.float32) +
(dist_l1 < 2.0).astype(jnp.float32) +
(dist_l1 < 4.0).astype(jnp.float32))
# Normalize over the appropriate axes.
reduce_axes = (-1,) if per_residue else (-2, -1)
norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes))
score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes))
return score
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lddt."""
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.model import lddt
import numpy as np
class LddtTest(parameterized.TestCase, absltest.TestCase):
@parameterized.named_parameters(
('same',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[1, 1, 1]),
('all_shifted',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[-1, 0, 0], [4, 0, 0], [9, 0, 0]],
[1, 1, 1]),
('all_rotated',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[0, 0, 0], [0, 5, 0], [0, 10, 0]],
[1, 1, 1]),
('half_a_dist',
[[0, 0, 0], [5, 0, 0]],
[[0, 0, 0], [5.5-1e-5, 0, 0]],
[1, 1]),
('one_a_dist',
[[0, 0, 0], [5, 0, 0]],
[[0, 0, 0], [6-1e-5, 0, 0]],
[0.75, 0.75]),
('two_a_dist',
[[0, 0, 0], [5, 0, 0]],
[[0, 0, 0], [7-1e-5, 0, 0]],
[0.5, 0.5]),
('four_a_dist',
[[0, 0, 0], [5, 0, 0]],
[[0, 0, 0], [9-1e-5, 0, 0]],
[0.25, 0.25],),
('five_a_dist',
[[0, 0, 0], [16-1e-5, 0, 0]],
[[0, 0, 0], [11, 0, 0]],
[0, 0]),
('no_pairs',
[[0, 0, 0], [20, 0, 0]],
[[0, 0, 0], [25-1e-5, 0, 0]],
[1, 1]),
)
def test_lddt(
self, predicted_pos, true_pos, exp_lddt):
predicted_pos = np.array([predicted_pos], dtype=np.float32)
true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32)
true_pos = np.array([true_pos], dtype=np.float32)
cutoff = 15.0
per_residue = True
result = lddt.lddt(
predicted_pos, true_pos, true_points_mask, cutoff,
per_residue)
np.testing.assert_almost_equal(result, [exp_lddt], decimal=4)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specialized mapping functions."""
import functools
import inspect
from typing import Any, Callable, Optional, Sequence, Union
import haiku as hk
import jax
import jax.numpy as jnp
PYTREE = Any
PYTREE_JAX_ARRAY = Any
partial = functools.partial
PROXY = object()
def _maybe_slice(array, i, slice_size, axis):
if axis is PROXY:
return array
else:
return jax.lax.dynamic_slice_in_dim(
array, i, slice_size=slice_size, axis=axis)
def _maybe_get_size(array, axis):
if axis == PROXY:
return -1
else:
return array.shape[axis]
def _expand_axes(axes, values, name='sharded_apply'):
values_tree_def = jax.tree_util.tree_flatten(values)[1]
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
# Replace None's with PROXY
flat_axes = [PROXY if x is None else x for x in flat_axes]
return jax.tree_util.tree_unflatten(values_tree_def, flat_axes)
def sharded_map(
fun: Callable[..., PYTREE_JAX_ARRAY],
shard_size: Union[int, None] = 1,
in_axes: Union[int, PYTREE] = 0,
out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]:
"""Sharded vmap.
Maps `fun` over axes, in a way similar to vmap, but does so in shards of
`shard_size`. This allows a smooth trade-off between memory usage
(as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: integer or pytree denoting to what axis in the output the mapped
over axis maps.
Returns:
function with smap applied.
"""
if 'split_rng' in inspect.signature(hk.vmap).parameters:
vmapped_fun = hk.vmap(fun, in_axes, out_axes, split_rng=False)
else:
# TODO(tomhennigan): Remove this when older versions of Haiku aren't used.
vmapped_fun = hk.vmap(fun, in_axes, out_axes)
return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
def sharded_apply(
fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic
shard_size: Union[int, None] = 1,
in_axes: Union[int, PYTREE] = 0,
out_axes: Union[int, PYTREE] = 0,
new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]:
"""Sharded apply.
Applies `fun` over shards to axes, in a way similar to vmap,
but does so in shards of `shard_size`. Shards are stacked after.
This allows a smooth trade-off between
memory usage (as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: integer or pytree denoting to what axis in the output the mapped
over axis maps.
new_out_axes: whether to stack outputs on new axes. This assumes that the
output sizes for each shard (including the possible remainder shard) are
the same.
Returns:
function with smap applied.
"""
docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} '
'but with additional array axes over which {fun} is mapped.')
if new_out_axes:
raise NotImplementedError('New output axes not yet implemented.')
# shard size None denotes no sharding
if shard_size is None:
return fun
@jax.util.wraps(fun, docstr=docstr)
def mapped_fn(*args):
# Expand in axes and Determine Loop range
in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes)
num_extra_shards = (in_size - 1) // shard_size
# Fix Up if necessary
last_shard_size = in_size % shard_size
last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
def apply_fun_to_slice(slice_start, slice_size):
input_slice = jax.tree_map(
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
), args, in_axes_)
return fun(*input_slice)
remainder_shape_dtype = hk.eval_shape(
partial(apply_fun_to_slice, 0, last_shard_size))
out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)
if num_extra_shards > 0:
regular_shard_shape_dtype = hk.eval_shape(
partial(apply_fun_to_slice, 0, shard_size))
shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype)
def make_output_shape(axis, shard_shape, remainder_shape):
return shard_shape[:axis] + (
shard_shape[axis] * num_extra_shards +
remainder_shape[axis],) + shard_shape[axis + 1:]
out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes,
out_shapes)
# Calls dynamic Update slice with different argument order
# This is here since tree_map only works with positional arguments
def dynamic_update_slice_in_dim(full_array, update, axis, i):
return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
def compute_shard(outputs, slice_start, slice_size):
slice_out = apply_fun_to_slice(slice_start, slice_size)
update_slice = partial(
dynamic_update_slice_in_dim, i=slice_start)
return jax.tree_map(update_slice, outputs, slice_out, out_axes_)
def scan_iteration(outputs, i):
new_outputs = compute_shard(outputs, i, shard_size)
return new_outputs, ()
slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)
def allocate_buffer(dtype, shape):
return jnp.zeros(shape, dtype=dtype)
outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes)
if slice_starts.shape[0] > 0:
outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
if last_shard_size != shard_size:
remainder_start = in_size - last_shard_size
outputs = compute_shard(outputs, remainder_start, last_shard_size)
return outputs
return mapped_fn
def inference_subbatch(
module: Callable[..., PYTREE_JAX_ARRAY],
subbatch_size: int,
batched_args: Sequence[PYTREE_JAX_ARRAY],
nonbatched_args: Sequence[PYTREE_JAX_ARRAY],
low_memory: bool = True,
input_subbatch_dim: int = 0,
output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY:
"""Run through subbatches (like batch apply but with split and concat)."""
assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test
if not low_memory:
args = list(batched_args) + list(nonbatched_args)
return module(*args)
if output_subbatch_dim is None:
output_subbatch_dim = input_subbatch_dim
def run_module(*batched_args):
args = list(batched_args) + list(nonbatched_args)
return module(*args)
sharded_module = sharded_apply(run_module,
shard_size=subbatch_size,
in_axes=input_subbatch_dim,
out_axes=output_subbatch_dim)
return sharded_module(*batched_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment