Commit 2f0d89e7 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove duplicated code

parent a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformations for 3D coordinates.
This Module contains objects for representing Vectors (Vecs), Rotation Matrices
(Rots) and proper Rigid transformation (Rigids). These are represented as
named tuples with arrays for each entry, for example a set of
[N, M] points would be represented as a Vecs object with arrays of shape [N, M]
for x, y and z.
This is being done to improve readability by making it very clear what objects
are geometric objects rather than relying on comments and array shapes.
Another reason for this is to avoid using matrix
multiplication primitives like matmul or einsum, on modern accelerator hardware
these can end up on specialized cores such as tensor cores on GPU or the MXU on
cloud TPUs, this often involves lower computational precision which can be
problematic for coordinate geometry. Also these cores are typically optimized
for larger matrices than 3 dimensional, this code is written to avoid any
unintended use of these cores on both GPUs and TPUs.
"""
import collections
from typing import List
from alphafold.model import quat_affine
import jax.numpy as jnp
import tree
# Array of 3-component vectors, stored as individual array for
# each component.
Vecs = collections.namedtuple('Vecs', ['x', 'y', 'z'])
# Array of 3x3 rotation matrices, stored as individual array for
# each component.
Rots = collections.namedtuple('Rots', ['xx', 'xy', 'xz',
'yx', 'yy', 'yz',
'zx', 'zy', 'zz'])
# Array of rigid 3D transformations, stored as array of rotations and
# array of translations.
Rigids = collections.namedtuple('Rigids', ['rot', 'trans'])
def squared_difference(x, y):
return jnp.square(x - y)
def invert_rigids(r: Rigids) -> Rigids:
"""Computes group inverse of rigid transformations 'r'."""
inv_rots = invert_rots(r.rot)
t = rots_mul_vecs(inv_rots, r.trans)
inv_trans = Vecs(-t.x, -t.y, -t.z)
return Rigids(inv_rots, inv_trans)
def invert_rots(m: Rots) -> Rots:
"""Computes inverse of rotations 'm'."""
return Rots(m.xx, m.yx, m.zx,
m.xy, m.yy, m.zy,
m.xz, m.yz, m.zz)
def rigids_from_3_points(
point_on_neg_x_axis: Vecs, # shape (...)
origin: Vecs, # shape (...)
point_on_xy_plane: Vecs, # shape (...)
) -> Rigids: # shape (...)
"""Create Rigids from 3 points.
Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points"
This creates a set of rigid transformations from 3 points by Gram Schmidt
orthogonalization.
Args:
point_on_neg_x_axis: Vecs corresponding to points on the negative x axis
origin: Origin of resulting rigid transformations
point_on_xy_plane: Vecs corresponding to points in the xy plane
Returns:
Rigid transformations from global frame to local frames derived from
the input points.
"""
m = rots_from_two_vecs(
e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis),
e1_unnormalized=vecs_sub(point_on_xy_plane, origin))
return Rigids(rot=m, trans=origin)
def rigids_from_list(l: List[jnp.ndarray]) -> Rigids:
"""Converts flat list of arrays to rigid transformations."""
assert len(l) == 12
return Rigids(Rots(*(l[:9])), Vecs(*(l[9:])))
def rigids_from_quataffine(a: quat_affine.QuatAffine) -> Rigids:
"""Converts QuatAffine object to the corresponding Rigids object."""
return Rigids(Rots(*tree.flatten(a.rotation)),
Vecs(*a.translation))
def rigids_from_tensor4x4(
m: jnp.ndarray # shape (..., 4, 4)
) -> Rigids: # shape (...)
"""Construct Rigids object from an 4x4 array.
Here the 4x4 is representing the transformation in homogeneous coordinates.
Args:
m: Array representing transformations in homogeneous coordinates.
Returns:
Rigids object corresponding to transformations m
"""
assert m.shape[-1] == 4
assert m.shape[-2] == 4
return Rigids(
Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2],
m[..., 1, 0], m[..., 1, 1], m[..., 1, 2],
m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]),
Vecs(m[..., 0, 3], m[..., 1, 3], m[..., 2, 3]))
def rigids_from_tensor_flat9(
m: jnp.ndarray # shape (..., 9)
) -> Rigids: # shape (...)
"""Flat9 encoding: first two columns of rotation matrix + translation."""
assert m.shape[-1] == 9
e0 = Vecs(m[..., 0], m[..., 1], m[..., 2])
e1 = Vecs(m[..., 3], m[..., 4], m[..., 5])
trans = Vecs(m[..., 6], m[..., 7], m[..., 8])
return Rigids(rot=rots_from_two_vecs(e0, e1),
trans=trans)
def rigids_from_tensor_flat12(
m: jnp.ndarray # shape (..., 12)
) -> Rigids: # shape (...)
"""Flat12 encoding: rotation matrix (9 floats) + translation (3 floats)."""
assert m.shape[-1] == 12
x = jnp.moveaxis(m, -1, 0) # Unstack
return Rigids(Rots(*x[:9]), Vecs(*x[9:]))
def rigids_mul_rigids(a: Rigids, b: Rigids) -> Rigids:
"""Group composition of Rigids 'a' and 'b'."""
return Rigids(
rots_mul_rots(a.rot, b.rot),
vecs_add(a.trans, rots_mul_vecs(a.rot, b.trans)))
def rigids_mul_rots(r: Rigids, m: Rots) -> Rigids:
"""Compose rigid transformations 'r' with rotations 'm'."""
return Rigids(rots_mul_rots(r.rot, m), r.trans)
def rigids_mul_vecs(r: Rigids, v: Vecs) -> Vecs:
"""Apply rigid transforms 'r' to points 'v'."""
return vecs_add(rots_mul_vecs(r.rot, v), r.trans)
def rigids_to_list(r: Rigids) -> List[jnp.ndarray]:
"""Turn Rigids into flat list, inverse of 'rigids_from_list'."""
return list(r.rot) + list(r.trans)
def rigids_to_quataffine(r: Rigids) -> quat_affine.QuatAffine:
"""Convert Rigids r into QuatAffine, inverse of 'rigids_from_quataffine'."""
return quat_affine.QuatAffine(
quaternion=None,
rotation=[[r.rot.xx, r.rot.xy, r.rot.xz],
[r.rot.yx, r.rot.yy, r.rot.yz],
[r.rot.zx, r.rot.zy, r.rot.zz]],
translation=[r.trans.x, r.trans.y, r.trans.z])
def rigids_to_tensor_flat9(
r: Rigids # shape (...)
) -> jnp.ndarray: # shape (..., 9)
"""Flat9 encoding: first two columns of rotation matrix + translation."""
return jnp.stack(
[r.rot.xx, r.rot.yx, r.rot.zx, r.rot.xy, r.rot.yy, r.rot.zy]
+ list(r.trans), axis=-1)
def rigids_to_tensor_flat12(
r: Rigids # shape (...)
) -> jnp.ndarray: # shape (..., 12)
"""Flat12 encoding: rotation matrix (9 floats) + translation (3 floats)."""
return jnp.stack(list(r.rot) + list(r.trans), axis=-1)
def rots_from_tensor3x3(
m: jnp.ndarray, # shape (..., 3, 3)
) -> Rots: # shape (...)
"""Convert rotations represented as (3, 3) array to Rots."""
assert m.shape[-1] == 3
assert m.shape[-2] == 3
return Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2],
m[..., 1, 0], m[..., 1, 1], m[..., 1, 2],
m[..., 2, 0], m[..., 2, 1], m[..., 2, 2])
def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots:
"""Create rotation matrices from unnormalized vectors for the x and y-axes.
This creates a rotation matrix from two vectors using Gram-Schmidt
orthogonalization.
Args:
e0_unnormalized: vectors lying along x-axis of resulting rotation
e1_unnormalized: vectors lying in xy-plane of resulting rotation
Returns:
Rotations resulting from Gram-Schmidt procedure.
"""
# Normalize the unit vector for the x-axis, e0.
e0 = vecs_robust_normalize(e0_unnormalized)
# make e1 perpendicular to e0.
c = vecs_dot_vecs(e1_unnormalized, e0)
e1 = Vecs(e1_unnormalized.x - c * e0.x,
e1_unnormalized.y - c * e0.y,
e1_unnormalized.z - c * e0.z)
e1 = vecs_robust_normalize(e1)
# Compute e2 as cross product of e0 and e1.
e2 = vecs_cross_vecs(e0, e1)
return Rots(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
def rots_mul_rots(a: Rots, b: Rots) -> Rots:
"""Composition of rotations 'a' and 'b'."""
c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx))
c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy))
c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz))
return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
def rots_mul_vecs(m: Rots, v: Vecs) -> Vecs:
"""Apply rotations 'm' to vectors 'v'."""
return Vecs(m.xx * v.x + m.xy * v.y + m.xz * v.z,
m.yx * v.x + m.yy * v.y + m.yz * v.z,
m.zx * v.x + m.zy * v.y + m.zz * v.z)
def vecs_add(v1: Vecs, v2: Vecs) -> Vecs:
"""Add two vectors 'v1' and 'v2'."""
return Vecs(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z)
def vecs_dot_vecs(v1: Vecs, v2: Vecs) -> jnp.ndarray:
"""Dot product of vectors 'v1' and 'v2'."""
return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z
def vecs_cross_vecs(v1: Vecs, v2: Vecs) -> Vecs:
"""Cross product of vectors 'v1' and 'v2'."""
return Vecs(v1.y * v2.z - v1.z * v2.y,
v1.z * v2.x - v1.x * v2.z,
v1.x * v2.y - v1.y * v2.x)
def vecs_from_tensor(x: jnp.ndarray # shape (..., 3)
) -> Vecs: # shape (...)
"""Converts from tensor of shape (3,) to Vecs."""
num_components = x.shape[-1]
assert num_components == 3
return Vecs(x[..., 0], x[..., 1], x[..., 2])
def vecs_robust_normalize(v: Vecs, epsilon: float = 1e-8) -> Vecs:
"""Normalizes vectors 'v'.
Args:
v: vectors to be normalized.
epsilon: small regularizer added to squared norm before taking square root.
Returns:
normalized vectors
"""
norms = vecs_robust_norm(v, epsilon)
return Vecs(v.x / norms, v.y / norms, v.z / norms)
def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> jnp.ndarray:
"""Computes norm of vectors 'v'.
Args:
v: vectors to be normalized.
epsilon: small regularizer added to squared norm before taking square root.
Returns:
norm of 'v'
"""
return jnp.sqrt(jnp.square(v.x) + jnp.square(v.y) + jnp.square(v.z) + epsilon)
def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs:
"""Computes v1 - v2."""
return Vecs(v1.x - v2.x, v1.y - v2.y, v1.z - v2.z)
def vecs_squared_distance(v1: Vecs, v2: Vecs) -> jnp.ndarray:
"""Computes squared euclidean difference between 'v1' and 'v2'."""
return (squared_difference(v1.x, v2.x) +
squared_difference(v1.y, v2.y) +
squared_difference(v1.z, v2.z))
def vecs_to_tensor(v: Vecs # shape (...)
) -> jnp.ndarray: # shape(..., 3)
"""Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'."""
return jnp.stack([v.x, v.y, v.z], axis=-1)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alphafold model TensorFlow code."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data for AlphaFold."""
from alphafold.common import residue_constants
from alphafold.model.tf import shape_helpers
from alphafold.model.tf import shape_placeholders
from alphafold.model.tf import utils
import numpy as np
import tensorflow.compat.v1 as tf
# Pylint gets confused by the curry1 decorator because it changes the number
# of arguments to the function.
# pylint:disable=no-value-for-parameter
NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def cast_64bit_ints(protein):
for k, v in protein.items():
if v.dtype == tf.int64:
protein[k] = tf.cast(v, tf.int32)
return protein
_MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask',
'true_msa'
]
def make_seq_mask(protein):
protein['seq_mask'] = tf.ones(
shape_helpers.shape_list(protein['aatype']), dtype=tf.float32)
return protein
def make_template_mask(protein):
protein['template_mask'] = tf.ones(
shape_helpers.shape_list(protein['template_domain_names']),
dtype=tf.float32)
return protein
def curry1(f):
"""Supply all arguments but the first."""
def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)
return fc
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = tf.constant(float(distillation),
shape=[],
dtype=tf.float32)
return protein
def make_all_atom_aatype(protein):
protein['all_atom_aatype'] = protein['aatype']
return protein
def fix_templates_aatype(protein):
"""Fixes aatype encoding of templates."""
# Map one-hot to indices.
protein['template_aatype'] = tf.argmax(
protein['template_aatype'], output_type=tf.int32, axis=-1)
# Map hhsearch-aatype to our aatype.
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = tf.constant(new_order_list, dtype=tf.int32)
protein['template_aatype'] = tf.gather(params=new_order,
indices=protein['template_aatype'])
return protein
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = tf.constant(new_order_list, dtype=protein['msa'].dtype)
protein['msa'] = tf.gather(new_order, protein['msa'], axis=0)
perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1.
for k in protein:
if 'profile' in k: # Include both hhblits and psiblast profiles
num_dim = protein[k].shape.as_list()[-1]
assert num_dim in [20, 21, 22], (
'num_dim for %s out of expected range: %s' % (k, num_dim))
protein[k] = tf.tensordot(protein[k], perm_matrix[:num_dim, :num_dim], 1)
return protein
def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein['aatype'] = tf.argmax(
protein['aatype'], axis=-1, output_type=tf.int32)
for k in [
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
'superfamily', 'deletion_matrix', 'resolution',
'between_segment_residues', 'residue_index', 'template_all_atom_masks']:
if k in protein:
final_dim = shape_helpers.shape_list(protein[k])[-1]
if isinstance(final_dim, int) and final_dim == 1:
protein[k] = tf.squeeze(protein[k], axis=-1)
for k in ['seq_length', 'num_alignments']:
if k in protein:
protein[k] = protein[k][0] # Remove fake sequence dimension
return protein
def make_random_crop_to_size_seed(protein):
"""Random seed for cropping residues and templates."""
protein['random_crop_to_size_seed'] = utils.make_random_seed()
return protein
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a proportion of the MSA with 'X'."""
msa_mask = (tf.random.uniform(shape_helpers.shape_list(protein['msa'])) <
replace_proportion)
x_idx = 20
gap_idx = 21
msa_mask = tf.logical_and(msa_mask, protein['msa'] != gap_idx)
protein['msa'] = tf.where(msa_mask,
tf.ones_like(protein['msa']) * x_idx,
protein['msa'])
aatype_mask = (
tf.random.uniform(shape_helpers.shape_list(protein['aatype'])) <
replace_proportion)
protein['aatype'] = tf.where(aatype_mask,
tf.ones_like(protein['aatype']) * x_idx,
protein['aatype'])
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
protein: batch to sample msa from.
max_seq: number of sequences to sample.
keep_extra: When True sequences not sampled are put into fields starting
with 'extra_*'.
Returns:
Protein with sampled msa.
"""
num_seq = tf.shape(protein['msa'])[0]
shuffled = tf.random_shuffle(tf.range(1, num_seq))
index_order = tf.concat([[0], shuffled], axis=0)
num_sel = tf.minimum(max_seq, num_seq)
sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel])
for k in _MSA_FEATURE_NAMES:
if k in protein:
if keep_extra:
protein['extra_' + k] = tf.gather(protein[k], not_sel_seq)
protein[k] = tf.gather(protein[k], sel_seq)
return protein
@curry1
def crop_extra_msa(protein, max_extra_msa):
"""MSA features are cropped so only `max_extra_msa` sequences are kept."""
num_seq = tf.shape(protein['extra_msa'])[0]
num_sel = tf.minimum(max_extra_msa, num_seq)
select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel]
for k in _MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices)
return protein
def delete_extra_msa(protein):
for k in _MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
del protein['extra_' + k]
return protein
@curry1
def block_delete_msa(protein, config):
"""Sample MSA by deleting contiguous blocks.
Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"
Arguments:
protein: batch dict containing the msa
config: ConfigDict with parameters
Returns:
updated protein
"""
num_seq = shape_helpers.shape_list(protein['msa'])[0]
block_num_seq = tf.cast(
tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block),
tf.int32)
if config.randomize_num_blocks:
nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32)
else:
nb = config.num_blocks
del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32)
del_blocks = del_block_starts[:, None] + tf.range(block_num_seq)
del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1)
del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0]
# Make sure we keep the original sequence
sparse_diff = tf.sets.difference(tf.range(1, num_seq)[None],
del_indices[None])
keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0)
keep_indices = tf.concat([[0], keep_indices], axis=0)
for k in _MSA_FEATURE_NAMES:
if k in protein:
protein[k] = tf.gather(protein[k], keep_indices)
return protein
@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask
weights = tf.concat([
tf.ones(21),
gap_agreement_weight * tf.ones(1),
np.zeros(1)], 0)
# Make agreement score as weighted Hamming distance
sample_one_hot = (protein['msa_mask'][:, :, None] *
tf.one_hot(protein['msa'], 23))
extra_one_hot = (protein['extra_msa_mask'][:, :, None] *
tf.one_hot(protein['extra_msa'], 23))
num_seq, num_res, _ = shape_helpers.shape_list(sample_one_hot)
extra_num_seq, _, _ = shape_helpers.shape_list(extra_one_hot)
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement = tf.matmul(
tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]),
transpose_b=True)
# Assign each sequence in the extra sequences to the closest MSA sample
protein['extra_cluster_assignment'] = tf.argmax(
agreement, axis=1, output_type=tf.int32)
return protein
@curry1
def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = shape_helpers.shape_list(protein['msa'])[0]
def csum(x):
return tf.math.unsorted_segment_sum(
x, protein['extra_cluster_assignment'], num_seq)
mask = protein['extra_msa_mask']
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center
msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23))
msa_sum += tf.one_hot(protein['msa'], 23) # Original sequence
protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]
del msa_sum
del_sum = csum(mask * protein['extra_deletion_matrix'])
del_sum += protein['deletion_matrix'] # Original sequence
protein['cluster_deletion_mean'] = del_sum / mask_counts
del del_sum
return protein
def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein['msa_mask'] = tf.ones(
shape_helpers.shape_list(protein['msa']), dtype=tf.float32)
protein['msa_row_mask'] = tf.ones(
shape_helpers.shape_list(protein['msa'])[0], dtype=tf.float32)
return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
is_gly = tf.equal(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
pseudo_beta = tf.where(
tf.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = tf.where(
is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
pseudo_beta_mask = tf.cast(pseudo_beta_mask, tf.float32)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
@curry1
def make_pseudo_beta(protein, prefix=''):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ['', 'template_']
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
pseudo_beta_fn(
protein['template_aatype' if prefix else 'all_atom_aatype'],
protein[prefix + 'all_atom_positions'],
protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
return protein
@curry1
def add_constant_field(protein, key, value):
protein[key] = tf.convert_to_tensor(value)
return protein
def shaped_categorical(probs, epsilon=1e-10):
ds = shape_helpers.shape_list(probs)
num_classes = ds[-1]
counts = tf.random.categorical(
tf.reshape(tf.log(probs + epsilon), [-1, num_classes]),
1,
dtype=tf.int32)
return tf.reshape(counts, ds[:-1])
def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present."""
if 'hhblits_profile' in protein:
return protein
# Compute the profile for every residue (over all MSA sequences).
protein['hhblits_profile'] = tf.reduce_mean(
tf.one_hot(protein['msa'], 22), axis=0)
return protein
@curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly
random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * protein['hhblits_profile'] +
config.same_prob * tf.one_hot(protein['msa'], 22))
# Put all remaining probability on [MASK] which is a new column
pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
pad_shapes[-1][1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
categorical_probs = tf.pad(
categorical_probs, pad_shapes, constant_values=mask_prob)
sh = shape_helpers.shape_list(protein['msa'])
mask_position = tf.random.uniform(sh) < replace_fraction
bert_msa = shaped_categorical(categorical_probs)
bert_msa = tf.where(mask_position, bert_msa, protein['msa'])
# Mix real and masked MSA
protein['bert_mask'] = tf.cast(mask_position, tf.float32)
protein['true_msa'] = protein['msa']
protein['msa'] = bert_msa
return protein
@curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
num_res, num_templates=0):
"""Guess at the MSA and sequence dimensions to make fixed size."""
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
NUM_EXTRA_SEQ: extra_msa_size,
NUM_TEMPLATES: num_templates,
}
for k, v in protein.items():
# Don't transfer this to the accelerator.
if k == 'extra_cluster_assignment':
continue
shape = v.shape.as_list()
schema = shape_schema[k]
assert len(shape) == len(schema), (
f'Rank mismatch between shape and shape schema for {k}: '
f'{shape} vs {schema}')
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)]
if padding:
protein[k] = tf.pad(
v, padding, name=f'pad_to_fixed_{k}')
protein[k].set_shape(pad_size)
return protein
@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
has_break = tf.clip_by_value(
tf.cast(protein['between_segment_residues'], tf.float32),
0, 1)
aatype_1hot = tf.one_hot(protein['aatype'], 21, axis=-1)
target_feat = [
tf.expand_dims(has_break, axis=-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = tf.one_hot(protein['msa'], 23, axis=-1)
has_deletion = tf.clip_by_value(protein['deletion_matrix'], 0., 1.)
deletion_value = tf.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi)
msa_feat = [
msa_1hot,
tf.expand_dims(has_deletion, axis=-1),
tf.expand_dims(deletion_value, axis=-1),
]
if 'cluster_profile' in protein:
deletion_mean_value = (
tf.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
msa_feat.extend([
protein['cluster_profile'],
tf.expand_dims(deletion_mean_value, axis=-1),
])
if 'extra_deletion_matrix' in protein:
protein['extra_has_deletion'] = tf.clip_by_value(
protein['extra_deletion_matrix'], 0., 1.)
protein['extra_deletion_value'] = tf.atan(
protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)
protein['msa_feat'] = tf.concat(msa_feat, axis=-1)
protein['target_feat'] = tf.concat(target_feat, axis=-1)
return protein
@curry1
def select_feat(protein, feature_list):
return {k: v for k, v in protein.items() if k in feature_list}
@curry1
def crop_templates(protein, max_templates):
for k, v in protein.items():
if k.startswith('template_'):
protein[k] = v[:max_templates]
return protein
@curry1
def random_crop_to_size(protein, crop_size, max_templates, shape_schema,
subsample_templates=False):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein['seq_length']
if 'template_mask' in protein:
num_templates = tf.cast(
shape_helpers.shape_list(protein['template_mask'])[0], tf.int32)
else:
num_templates = tf.constant(0, dtype=tf.int32)
num_res_crop_size = tf.math.minimum(seq_length, crop_size)
# Ensures that the cropping of residues and templates happens in the same way
# across ensembling iterations.
# Do not use for randomness that should vary in ensembling.
seed_maker = utils.SeedMaker(initial_seed=protein['random_crop_to_size_seed'])
if subsample_templates:
templates_crop_start = tf.random.stateless_uniform(
shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32,
seed=seed_maker())
else:
templates_crop_start = 0
num_templates_crop_size = tf.math.minimum(
num_templates - templates_crop_start, max_templates)
num_res_crop_start = tf.random.stateless_uniform(
shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1,
dtype=tf.int32, seed=seed_maker())
templates_select_indices = tf.argsort(tf.random.stateless_uniform(
[num_templates], seed=seed_maker()))
for k, v in protein.items():
if k not in shape_schema or (
'template' not in k and NUM_RES not in shape_schema[k]):
continue
# randomly permute the templates before cropping them.
if k.startswith('template') and subsample_templates:
v = tf.gather(v, templates_select_indices)
crop_sizes = []
crop_starts = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k],
shape_helpers.shape_list(v))):
is_num_res = (dim_size == NUM_RES)
if i == 0 and k.startswith('template'):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
else:
crop_start = num_res_crop_start if is_num_res else 0
crop_size = (num_res_crop_size if is_num_res else
(-1 if dim is None else dim))
crop_sizes.append(crop_size)
crop_starts.append(crop_start)
protein[k] = tf.slice(v, crop_starts, crop_sizes)
protein['seq_length'] = num_res_crop_size
return protein
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = tf.gather(restype_atom14_to_atom37,
protein['aatype'])
residx_atom14_mask = tf.gather(restype_atom14_mask,
protein['aatype'])
protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14 = tf.gather(restype_atom37_to_atom14,
protein['aatype'])
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = tf.gather(restype_atom37_mask,
protein['aatype'])
protein['atom37_atom_exists'] = residx_atom37_mask
return protein
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature pre-processing input pipeline for AlphaFold."""
from alphafold.model.tf import data_transforms
from alphafold.model.tf import shape_placeholders
import tensorflow.compat.v1 as tf
import tree
# Pylint gets confused by the curry1 decorator because it changes the number
# of arguments to the function.
# pylint:disable=no-value-for-parameter
NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def nonensembled_map_fns(data_config):
"""Input pipeline functions which are not ensembled."""
common_cfg = data_config.common
map_fns = [
data_transforms.correct_msa_restypes,
data_transforms.add_distillation_flag(False),
data_transforms.cast_64bit_ints,
data_transforms.squeeze_features,
# Keep to not disrupt RNG.
data_transforms.randomly_replace_msa_with_unknown(0.0),
data_transforms.make_seq_mask,
data_transforms.make_msa_mask,
# Compute the HHblits profile if it's not set. This has to be run before
# sampling the MSA.
data_transforms.make_hhblits_profile,
data_transforms.make_random_crop_to_size_seed,
]
if common_cfg.use_templates:
map_fns.extend([
data_transforms.fix_templates_aatype,
data_transforms.make_template_mask,
data_transforms.make_pseudo_beta('template_')
])
map_fns.extend([
data_transforms.make_atom14_masks,
])
return map_fns
def ensembled_map_fns(data_config):
"""Input pipeline functions that can be ensembled and averaged."""
common_cfg = data_config.common
eval_cfg = data_config.eval
map_fns = []
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
else:
pad_msa_clusters = eval_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
map_fns.append(
data_transforms.sample_msa(
max_msa_clusters,
keep_extra=True))
if 'masked_msa' in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
map_fns.append(
data_transforms.make_masked_msa(common_cfg.masked_msa,
eval_cfg.masked_msa_replace_fraction))
if common_cfg.msa_cluster_features:
map_fns.append(data_transforms.nearest_neighbor_clusters())
map_fns.append(data_transforms.summarize_clusters())
# Crop after creating the cluster profiles.
if max_extra_msa:
map_fns.append(data_transforms.crop_extra_msa(max_extra_msa))
else:
map_fns.append(data_transforms.delete_extra_msa)
map_fns.append(data_transforms.make_msa_feat())
crop_feats = dict(eval_cfg.feat)
if eval_cfg.fixed_size:
map_fns.append(data_transforms.select_feat(list(crop_feats)))
map_fns.append(data_transforms.random_crop_to_size(
eval_cfg.crop_size,
eval_cfg.max_templates,
crop_feats,
eval_cfg.subsample_templates))
map_fns.append(data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
eval_cfg.crop_size,
eval_cfg.max_templates))
else:
map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates))
return map_fns
def process_tensors_from_config(tensors, data_config):
"""Apply filters and maps to an existing dataset, based on the config."""
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_map_fns(data_config)
fn = compose(fns)
d['ensemble_index'] = i
return fn(d)
eval_cfg = data_config.eval
tensors = compose(
nonensembled_map_fns(
data_config))(
tensors)
tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0))
num_ensemble = eval_cfg.num_ensemble
if data_config.common.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
num_ensemble *= data_config.common.num_recycle + 1
if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1:
fn_output_signature = tree.map_structure(
tf.TensorSpec.from_tensor, tensors_0)
tensors = tf.map_fn(
lambda x: wrap_ensemble_fn(tensors, x),
tf.range(num_ensemble),
parallel_iterations=1,
fn_output_signature=fn_output_signature)
else:
tensors = tree.map_structure(lambda x: x[None],
tensors_0)
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains descriptions of various protein features."""
import enum
from typing import Dict, Optional, Sequence, Tuple, Union
from alphafold.common import residue_constants
import tensorflow.compat.v1 as tf
# Type aliases.
FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]
class FeatureType(enum.Enum):
ZERO_DIM = 0 # Shape [x]
ONE_DIM = 1 # Shape [num_res, x]
TWO_DIM = 2 # Shape [num_res, num_res, x]
MSA = 3 # Shape [msa_length, num_res, x]
# Placeholder values that will be replaced with their true value at runtime.
NUM_RES = "num residues placeholder"
NUM_SEQ = "length msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
# Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders
# to be replaced with the number of residues and the number of sequences in the
# multiple sequence alignment, respectively.
FEATURES = {
#### Static features of a protein sequence ####
"aatype": (tf.float32, [NUM_RES, 21]),
"between_segment_residues": (tf.int64, [NUM_RES, 1]),
"deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]),
"domain_name": (tf.string, [1]),
"msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]),
"num_alignments": (tf.int64, [NUM_RES, 1]),
"residue_index": (tf.int64, [NUM_RES, 1]),
"seq_length": (tf.int64, [NUM_RES, 1]),
"sequence": (tf.string, [1]),
"all_atom_positions": (tf.float32,
[NUM_RES, residue_constants.atom_type_num, 3]),
"all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]),
"resolution": (tf.float32, [1]),
"template_domain_names": (tf.string, [NUM_TEMPLATES]),
"template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]),
"template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]),
"template_all_atom_positions": (tf.float32, [
NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3
]),
"template_all_atom_masks": (tf.float32, [
NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1
]),
}
FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()}
FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()}
def register_feature(name: str,
type_: tf.dtypes.DType,
shape_: Tuple[Union[str, int]]):
"""Register extra features used in custom datasets."""
FEATURES[name] = (type_, shape_)
FEATURE_TYPES[name] = type_
FEATURE_SIZES[name] = shape_
def shape(feature_name: str,
num_residues: int,
msa_length: int,
num_templates: Optional[int] = None,
features: Optional[FeaturesMetadata] = None):
"""Get the shape for the given feature name.
This is near identical to _get_tf_shape_no_placeholders() but with 2
differences:
* This method does not calculate a single placeholder from the total number of
elements (eg given <NUM_RES, 3> and size := 12, this won't deduce NUM_RES
must be 4)
* This method will work with tensors
Args:
feature_name: String identifier for the feature. If the feature name ends
with "_unnormalized", this suffix is stripped off.
num_residues: The number of residues in the current domain - some elements
of the shape can be dynamic and will be replaced by this value.
msa_length: The number of sequences in the multiple sequence alignment, some
elements of the shape can be dynamic and will be replaced by this value.
If the number of alignments is unknown / not read, please pass None for
msa_length.
num_templates (optional): The number of templates in this tfexample.
features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.
Returns:
List of ints representation the tensor size.
Raises:
ValueError: If a feature is requested but no concrete placeholder value is
given.
"""
features = features or FEATURES
if feature_name.endswith("_unnormalized"):
feature_name = feature_name[:-13]
unused_dtype, raw_sizes = features[feature_name]
replacements = {NUM_RES: num_residues,
NUM_SEQ: msa_length}
if num_templates is not None:
replacements[NUM_TEMPLATES] = num_templates
sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
for dimension in sizes:
if isinstance(dimension, str):
raise ValueError("Could not parse %s (shape: %s) with values: %s" % (
feature_name, raw_sizes, replacements))
return sizes
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for protein_features."""
import uuid
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.model.tf import protein_features
import tensorflow.compat.v1 as tf
def _random_bytes():
return str(uuid.uuid4()).encode('utf-8')
class FeaturesTest(parameterized.TestCase, tf.test.TestCase):
def testFeatureNames(self):
self.assertEqual(len(protein_features.FEATURE_SIZES),
len(protein_features.FEATURE_TYPES))
sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys())
sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys())
for i, size_name in enumerate(sorted_size_names):
self.assertEqual(size_name, sorted_type_names[i])
def testReplacement(self):
for name in protein_features.FEATURE_SIZES.keys():
sizes = protein_features.shape(name,
num_residues=12,
msa_length=24,
num_templates=3)
for x in sizes:
self.assertEqual(type(x), int)
self.assertGreater(x, 0)
if __name__ == '__main__':
tf.disable_v2_behavior()
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datasets consisting of proteins."""
from typing import Dict, Mapping, Optional, Sequence
from alphafold.model.tf import protein_features
import numpy as np
import tensorflow.compat.v1 as tf
TensorDict = Dict[str, tf.Tensor]
def parse_tfexample(
raw_data: bytes,
features: protein_features.FeaturesMetadata,
key: Optional[str] = None) -> Dict[str, tf.train.Feature]:
"""Read a single TF Example proto and return a subset of its features.
Args:
raw_data: A serialized tf.Example proto.
features: A dictionary of features, mapping string feature names to a tuple
(dtype, shape). This dictionary should be a subset of
protein_features.FEATURES (or the dictionary itself for all features).
key: Optional string with the SSTable key of that tf.Example. This will be
added into features as a 'key' but only if requested in features.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
feature_map = {
k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
for k, v in features.items()
}
parsed_features = tf.io.parse_single_example(raw_data, feature_map)
reshaped_features = parse_reshape_logic(parsed_features, features, key=key)
return reshaped_features
def _first(tensor: tf.Tensor) -> tf.Tensor:
"""Returns the 1st element - the input can be a tensor or a scalar."""
return tf.reshape(tensor, shape=(-1,))[0]
def parse_reshape_logic(
parsed_features: TensorDict,
features: protein_features.FeaturesMetadata,
key: Optional[str] = None) -> TensorDict:
"""Transforms parsed serial features to the correct shape."""
# Find out what is the number of sequences and the number of alignments.
num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)
if "num_alignments" in parsed_features:
num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)
else:
num_msa = 0
if "template_domain_names" in parsed_features:
num_templates = tf.cast(
tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)
else:
num_templates = 0
if key is not None and "key" in features:
parsed_features["key"] = [key] # Expand dims from () to (1,).
# Reshape the tensors according to the sequence length and num alignments.
for k, v in parsed_features.items():
new_shape = protein_features.shape(
feature_name=k,
num_residues=num_residues,
msa_length=num_msa,
num_templates=num_templates,
features=features)
new_shape_size = tf.constant(1, dtype=tf.int32)
for dim in new_shape:
new_shape_size *= tf.cast(dim, tf.int32)
assert_equal = tf.assert_equal(
tf.size(v), new_shape_size,
name="assert_%s_shape_correct" % k,
message="The size of feature %s (%s) could not be reshaped "
"into %s" % (k, tf.size(v), new_shape))
if "template" not in k:
# Make sure the feature we are reshaping is not empty.
assert_non_empty = tf.assert_greater(
tf.size(v), 0, name="assert_%s_non_empty" % k,
message="The feature %s is not set in the tf.Example. Either do not "
"request the feature or use a tf.Example that has the "
"feature set." % k)
with tf.control_dependencies([assert_non_empty, assert_equal]):
parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
else:
with tf.control_dependencies([assert_equal]):
parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
return parsed_features
def _make_features_metadata(
feature_names: Sequence[str]) -> protein_features.FeaturesMetadata:
"""Makes a feature name to type and shape mapping from a list of names."""
# Make sure these features are always read.
required_features = ["aatype", "sequence", "seq_length"]
feature_names = list(set(feature_names) | set(required_features))
features_metadata = {name: protein_features.FEATURES[name]
for name in feature_names}
return features_metadata
def create_tensor_dict(
raw_data: bytes,
features: Sequence[str],
key: Optional[str] = None,
) -> TensorDict:
"""Creates a dictionary of tensor features.
Args:
raw_data: A serialized tf.Example proto.
features: A list of strings of feature names to be returned in the dataset.
key: Optional string with the SSTable key of that tf.Example. This will be
added into features as a 'key' but only if requested in features.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
features_metadata = _make_features_metadata(features)
return parse_tfexample(raw_data, features_metadata, key)
def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray],
features: Sequence[str],
) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
np_example: A dict of NumPy feature arrays.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
features_metadata = _make_features_metadata(features)
tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
if k in features_metadata}
# Ensures shapes are as expected. Needed for setting size of empty features
# e.g. when no template hits were found.
tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
return tensor_dict
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with shapes of TensorFlow tensors."""
import tensorflow.compat.v1 as tf
def shape_list(x):
"""Return list of dimensions of a tensor, statically where possible.
Like `x.shape.as_list()` but with tensors instead of `None`s.
Args:
x: A tensor.
Returns:
A list with length equal to the rank of the tensor. The n-th element of the
list is an integer when that dimension is statically known otherwise it is
the n-th element of `tf.shape(x)`.
"""
x = tf.convert_to_tensor(x)
# If unknown rank, return dynamic shape
if x.get_shape().dims is None:
return tf.shape(x)
static = x.get_shape().as_list()
shape = tf.shape(x)
ret = []
for i in range(len(static)):
dim = static[i]
if dim is None:
dim = shape[i]
ret.append(dim)
return ret
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for shape_helpers."""
from alphafold.model.tf import shape_helpers
import numpy as np
import tensorflow.compat.v1 as tf
class ShapeTest(tf.test.TestCase):
def test_shape_list(self):
"""Test that shape_list can allow for reshaping to dynamic shapes."""
a = tf.zeros([10, 4, 4, 2])
p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4])
shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4]
b = tf.reshape(a, shape_dyn)
with self.session() as sess:
out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))})
self.assertAllEqual(out.shape, (20, 1, 4, 4))
if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Placeholder values for run-time varying dimension sizes."""
NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for various components."""
import tensorflow.compat.v1 as tf
def tf_combine_mask(*masks):
"""Take the intersection of float-valued masks."""
ret = 1
for m in masks:
ret *= m
return ret
class SeedMaker(object):
"""Return unique seeds."""
def __init__(self, initial_seed=0):
self.next_seed = initial_seed
def __call__(self):
i = self.next_seed
self.next_seed += 1
return i
seed_maker = SeedMaker()
def make_random_seed():
return tf.random.uniform([2],
tf.int32.min,
tf.int32.max,
tf.int32,
seed=seed_maker())
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of JAX utility functions for use in protein folding."""
import collections
import functools
import numbers
from typing import Mapping
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
def final_init(config):
if config.zero_init:
return 'zeros'
else:
return 'linear'
def batched_gather(params, indices, axis=0, batch_dims=0):
"""Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode="clip")
for _ in range(batch_dims):
take_fn = jax.vmap(take_fn)
return take_fn(params, indices)
def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
"""Masked mean."""
if drop_mask_channel:
mask = mask[..., 0]
mask_shape = mask.shape
value_shape = value.shape
assert len(mask_shape) == len(value_shape)
if isinstance(axis, numbers.Integral):
axis = [axis]
elif axis is None:
axis = list(range(len(mask_shape)))
assert isinstance(axis, collections.Iterable), (
'axis needs to be either an iterable, integer or "None"')
broadcast_factor = 1.
for axis_ in axis:
value_size = value_shape[axis_]
mask_size = mask_shape[axis_]
if mask_size == 1:
broadcast_factor *= value_size
else:
assert mask_size == value_size
return (jnp.sum(mask * value, axis=axis) /
(jnp.sum(mask, axis=axis) * broadcast_factor + eps))
def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params:
"""Convert a dictionary of NumPy arrays to Haiku parameters."""
hk_params = {}
for path, array in params.items():
scope, name = path.split('//')
if scope not in hk_params:
hk_params[scope] = {}
hk_params[scope][name] = jnp.array(array)
return hk_params
def padding_consistent_rng(f):
"""Modify any element-wise random function to be consistent with padding.
Normally if you take a function like jax.random.normal and generate an array,
say of size (10,10), you will get a different set of random numbers to if you
add padding and take the first (10,10) sub-array.
This function makes a random function that is consistent regardless of the
amount of padding added.
Note: The padding-consistent function is likely to be slower to compile and
run than the function it is wrapping, but these slowdowns are likely to be
negligible in a large network.
Args:
f: Any element-wise function that takes (PRNG key, shape) as the first 2
arguments.
Returns:
An equivalent function to f, that is now consistent for different amounts of
padding.
"""
def grid_keys(key, shape):
"""Generate a grid of rng keys that is consistent with different padding.
Generate random keys such that the keys will be identical, regardless of
how much padding is added to any dimension.
Args:
key: A PRNG key.
shape: The shape of the output array of keys that will be generated.
Returns:
An array of shape `shape` consisting of random keys.
"""
if not shape:
return key
new_keys = jax.vmap(functools.partial(jax.random.fold_in, key))(
jnp.arange(shape[0]))
return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys)
def inner(key, shape, **kwargs):
return jnp.vectorize(
lambda key: f(key, shape=(), **kwargs),
signature='(2)->()')(
grid_keys(key, shape))
return inner
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AlphaFold Colab notebook."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper methods for the AlphaFold Colab notebook."""
import enum
import json
from typing import Any, Mapping, Optional, Sequence, Tuple
from alphafold.common import residue_constants
from alphafold.data import parsers
from matplotlib import pyplot as plt
import numpy as np
@enum.unique
class ModelType(enum.Enum):
MONOMER = 0
MULTIMER = 1
def clean_and_validate_sequence(
input_sequence: str, min_length: int, max_length: int) -> str:
"""Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case.
clean_sequence = input_sequence.translate(
str.maketrans('', '', ' \n\t')).upper()
aatypes = set(residue_constants.restypes) # 20 standard aatypes.
if not set(clean_sequence).issubset(aatypes):
raise ValueError(
f'Input sequence contains non-amino acid letters: '
f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard '
'amino acids as inputs.')
if len(clean_sequence) < min_length:
raise ValueError(
f'Input sequence is too short: {len(clean_sequence)} amino acids, '
f'while the minimum is {min_length}')
if len(clean_sequence) > max_length:
raise ValueError(
f'Input sequence is too long: {len(clean_sequence)} amino acids, while '
f'the maximum is {max_length}. You may be able to run it with the full '
f'AlphaFold system depending on your resources (system memory, '
f'GPU memory).')
return clean_sequence
def validate_input(
input_sequences: Sequence[str],
min_length: int,
max_length: int,
max_multimer_length: int) -> Tuple[Sequence[str], ModelType]:
"""Validates and cleans input sequences and determines which model to use."""
sequences = []
for input_sequence in input_sequences:
if input_sequence.strip():
input_sequence = clean_and_validate_sequence(
input_sequence=input_sequence,
min_length=min_length,
max_length=max_length)
sequences.append(input_sequence)
if len(sequences) == 1:
print('Using the single-chain model.')
return sequences, ModelType.MONOMER
elif len(sequences) > 1:
total_multimer_length = sum([len(seq) for seq in sequences])
if total_multimer_length > max_multimer_length:
raise ValueError(f'The total length of multimer sequences is too long: '
f'{total_multimer_length}, while the maximum is '
f'{max_multimer_length}. Please use the full AlphaFold '
f'system for long multimers.')
elif total_multimer_length > 1536:
print('WARNING: The accuracy of the system has not been fully validated '
'above 1536 residues, and you may experience long running times or '
f'run out of memory for your complex with {total_multimer_length} '
'residues.')
print(f'Using the multimer model with {len(sequences)} sequences.')
return sequences, ModelType.MULTIMER
else:
raise ValueError('No input amino acid sequence provided, please provide at '
'least one sequence.')
def merge_chunked_msa(
results: Sequence[Mapping[str, Any]],
max_hits: Optional[int] = None
) -> parsers.Msa:
"""Merges chunked database hits together into hits for the full database."""
unsorted_results = []
for chunk_index, chunk in enumerate(results):
msa = parsers.parse_stockholm(chunk['sto'])
e_values_dict = parsers.parse_e_values_from_tblout(chunk['tbl'])
# Jackhmmer lists sequences as <sequence name>/<residue from>-<residue to>.
e_values = [e_values_dict[t.partition('/')[0]] for t in msa.descriptions]
chunk_results = zip(
msa.sequences, msa.deletion_matrix, msa.descriptions, e_values)
if chunk_index != 0:
next(chunk_results) # Only take query (first hit) from the first chunk.
unsorted_results.extend(chunk_results)
sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[-1])
merged_sequences, merged_deletion_matrix, merged_descriptions, _ = zip(
*sorted_by_evalue)
merged_msa = parsers.Msa(sequences=merged_sequences,
deletion_matrix=merged_deletion_matrix,
descriptions=merged_descriptions)
if max_hits is not None:
merged_msa = merged_msa.truncate(max_seqs=max_hits)
return merged_msa
def show_msa_info(
single_chain_msas: Sequence[parsers.Msa],
sequence_index: int):
"""Prints info and shows a plot of the deduplicated single chain MSA."""
full_single_chain_msa = []
for single_chain_msa in single_chain_msas:
full_single_chain_msa.extend(single_chain_msa.sequences)
# Deduplicate but preserve order (hence can't use set).
deduped_full_single_chain_msa = list(dict.fromkeys(full_single_chain_msa))
total_msa_size = len(deduped_full_single_chain_msa)
print(f'\n{total_msa_size} unique sequences found in total for sequence '
f'{sequence_index}\n')
aa_map = {res: i for i, res in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}
msa_arr = np.array(
[[aa_map[aa] for aa in seq] for seq in deduped_full_single_chain_msa])
plt.figure(figsize=(12, 3))
plt.title(f'Per-Residue Count of Non-Gap Amino Acids in the MSA for Sequence '
f'{sequence_index}')
plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')
plt.ylabel('Non-Gap Count')
plt.yticks(range(0, total_msa_size + 1, max(1, int(total_msa_size / 3))))
plt.show()
def empty_placeholder_template_features(
num_templates: int, num_res: int) -> Mapping[str, np.ndarray]:
return {
'template_aatype': np.zeros(
(num_templates, num_res,
len(residue_constants.restypes_with_x_and_gap)), dtype=np.float32),
'template_all_atom_masks': np.zeros(
(num_templates, num_res, residue_constants.atom_type_num),
dtype=np.float32),
'template_all_atom_positions': np.zeros(
(num_templates, num_res, residue_constants.atom_type_num, 3),
dtype=np.float32),
'template_domain_names': np.zeros([num_templates], dtype=np.object),
'template_sequence': np.zeros([num_templates], dtype=np.object),
'template_sum_probs': np.zeros([num_templates], dtype=np.float32),
}
def get_pae_json(pae: np.ndarray, max_pae: float) -> str:
"""Returns the PAE in the same format as is used in the AFDB."""
rounded_errors = np.round(pae.astype(np.float64), decimals=1)
indices = np.indices((len(rounded_errors), len(rounded_errors))) + 1
indices_1 = indices[0].flatten().tolist()
indices_2 = indices[1].flatten().tolist()
return json.dumps(
[{'residue1': indices_1,
'residue2': indices_2,
'distance': rounded_errors.flatten().tolist(),
'max_predicted_aligned_error': max_pae}],
indent=None, separators=(',', ':'))
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for notebook_utils."""
import io
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.data import parsers
from alphafold.data import templates
from alphafold.notebooks import notebook_utils
import mock
import numpy as np
ONLY_QUERY_HIT = {
'sto': (
'# STOCKHOLM 1.0\n'
'#=GF ID query-i1\n'
'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEH\n'
'//\n'),
'tbl': '',
'stderr': b'',
'n_iter': 1,
'e_value': 0.0001}
# pylint: disable=line-too-long
MULTI_SEQUENCE_HIT_1 = {
'sto': (
'# STOCKHOLM 1.0\n'
'#=GF ID query-i1\n'
'#=GS ERR1700680_4602609/41-109 DE [subseq from] ERR1700680_4602609\n'
'#=GS ERR1019366_5760491/40-105 DE [subseq from] ERR1019366_5760491\n'
'#=GS SRR5580704_12853319/61-125 DE [subseq from] SRR5580704_12853319\n'
'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH\n'
'ERR1700680_4602609/41-109 --INKGAEYHKKAAEHHELAAKHHREAAKHHEAGSHEKAAHHSEIAAGHGLTAVHHTEEATK-HHPEEHTEK--\n'
'ERR1019366_5760491/40-105 ---RSGAQHHDAAAQHYEEAARHHRMAAKQYQASHHEKAAHYAQLAYAHHMYAEQHAAEAAK-AHAKNHG----\n'
'SRR5580704_12853319/61-125 ----PAADHHMKAAEHHEEAAKHHRAAAEHHTAGDHQKAGHHAHVANGHHVNAVHHAEEASK-HHATDHS----\n'
'//\n'),
'tbl': (
'ERR1700680_4602609 - query - 7.7e-09 47.7 33.8 1.1e-08 47.2 33.8 1.2 1 0 0 1 1 1 1 -\n'
'ERR1019366_5760491 - query - 1.7e-08 46.6 33.1 2.5e-08 46.1 33.1 1.3 1 0 0 1 1 1 1 -\n'
'SRR5580704_12853319 - query - 1.1e-07 44.0 41.6 2e-07 43.1 41.6 1.4 1 0 0 1 1 1 1 -\n'),
'stderr': b'',
'n_iter': 1,
'e_value': 0.0001}
MULTI_SEQUENCE_HIT_2 = {
'sto': (
'# STOCKHOLM 1.0\n'
'#=GF ID query-i1\n'
'#=GS ERR1700719_3476944/70-137 DE [subseq from] ERR1700719_3476944\n'
'#=GS ERR1700761_4254522/72-138 DE [subseq from] ERR1700761_4254522\n'
'#=GS SRR5438477_9761204/64-132 DE [subseq from] SRR5438477_9761204\n'
'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH\n'
'ERR1700719_3476944/70-137 ---KQAAEHHHQAAEHHEHAARHHREAAKHHEAGDHESAAHHAHTAQGHLHQATHHASEAAKLHVEHHGQK--\n'
'ERR1700761_4254522/72-138 ----QASEHHNLAAEHHEHAARHHRDAAKHHKAGDHEKAAHHAHVAHGHHLHATHHATEAAKHHVEAHGEK--\n'
'SRR5438477_9761204/64-132 MPKHEGAEHHKKAAEHNEHAARHHKEAARHHEEGSHEKVGHHAHIAHGHHLHATHHAEEAAKTHSNQHE----\n'
'//\n'),
'tbl': (
'ERR1700719_3476944 - query - 2e-07 43.2 47.5 3.5e-07 42.4 47.5 1.4 1 0 0 1 1 1 1 -\n'
'ERR1700761_4254522 - query - 6.1e-07 41.6 48.1 8.1e-07 41.3 48.1 1.2 1 0 0 1 1 1 1 -\n'
'SRR5438477_9761204 - query - 1.8e-06 40.2 46.9 2.3e-06 39.8 46.9 1.2 1 0 0 1 1 1 1 -\n'),
'stderr': b'',
'n_iter': 1,
'e_value': 0.0001}
# pylint: enable=line-too-long
class NotebookUtilsTest(parameterized.TestCase):
@parameterized.parameters(
('DeepMind', 'DEEPMIND'), ('A ', 'A'), ('\tA', 'A'), (' A\t\n', 'A'),
('ACDEFGHIKLMNPQRSTVWY', 'ACDEFGHIKLMNPQRSTVWY'))
def test_clean_and_validate_sequence_ok(self, sequence, exp_clean):
clean = notebook_utils.clean_and_validate_sequence(
sequence, min_length=1, max_length=100)
self.assertEqual(clean, exp_clean)
@parameterized.named_parameters(
('too_short', 'AA', 'too short'),
('too_long', 'AAAAAAAAAA', 'too long'),
('bad_amino_acids_B', 'BBBB', 'non-amino acid'),
('bad_amino_acids_J', 'JJJJ', 'non-amino acid'),
('bad_amino_acids_O', 'OOOO', 'non-amino acid'),
('bad_amino_acids_U', 'UUUU', 'non-amino acid'),
('bad_amino_acids_X', 'XXXX', 'non-amino acid'),
('bad_amino_acids_Z', 'ZZZZ', 'non-amino acid'))
def test_clean_and_validate_sequence_bad(self, sequence, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.clean_and_validate_sequence(
sequence, min_length=4, max_length=8)
@parameterized.parameters(
(['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A'],
notebook_utils.ModelType.MONOMER),
(['', 'A'], ['A'],
notebook_utils.ModelType.MONOMER),
(['A', 'C ', ''], ['A', 'C'],
notebook_utils.ModelType.MULTIMER),
(['', 'A', '', 'C '], ['A', 'C'],
notebook_utils.ModelType.MULTIMER))
def test_validate_input_ok(
self, input_sequences, exp_sequences, exp_model_type):
sequences, model_type = notebook_utils.validate_input(
input_sequences=input_sequences,
min_length=1, max_length=100, max_multimer_length=100)
self.assertSequenceEqual(sequences, exp_sequences)
self.assertEqual(model_type, exp_model_type)
@parameterized.named_parameters(
('no_input_sequence', ['', '\t', '\n'], 'No input amino acid sequence'),
('too_long_single', ['AAAAAAAAA', 'AAAA'], 'Input sequence is too long'),
('too_long_multimer', ['AAAA', 'AAAAA'], 'The total length of multimer'))
def test_validate_input_bad(self, input_sequences, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.validate_input(
input_sequences=input_sequences,
min_length=4, max_length=8, max_multimer_length=6)
def test_merge_chunked_msa_no_hits(self):
results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT]
merged_msa = notebook_utils.merge_chunked_msa(
results=results)
self.assertSequenceEqual(
merged_msa.sequences,
('MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEH',))
self.assertSequenceEqual(merged_msa.deletion_matrix, ([0] * 56,))
def test_merge_chunked_msa(self):
results = [MULTI_SEQUENCE_HIT_1, MULTI_SEQUENCE_HIT_2]
merged_msa = notebook_utils.merge_chunked_msa(
results=results)
self.assertLen(merged_msa.sequences, 7)
# The 1st one is the query.
self.assertEqual(
merged_msa.sequences[0],
'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAP'
'KPH')
# The 2nd one is the one with the lowest e-value: ERR1700680_4602609.
self.assertEqual(
merged_msa.sequences[1],
'--INKGAEYHKKAAEHHELAAKHHREAAKHHEAGSHEKAAHHSEIAAGHGLTAVHHTEEATK-HHPEEHT'
'EK-')
# The last one is the one with the largest e-value: SRR5438477_9761204.
self.assertEqual(
merged_msa.sequences[-1],
'MPKHEGAEHHKKAAEHNEHAARHHKEAARHHEEGSHEKVGHHAHIAHGHHLHATHHAEEAAKTHSNQHE-'
'---')
self.assertLen(merged_msa.deletion_matrix, 7)
@mock.patch('sys.stdout', new_callable=io.StringIO)
def test_show_msa_info(self, mocked_stdout):
single_chain_msas = [
parsers.Msa(sequences=['A', 'B', 'C', 'C'],
deletion_matrix=[None] * 4,
descriptions=[''] * 4),
parsers.Msa(sequences=['A', 'A', 'A', 'D'],
deletion_matrix=[None] * 4,
descriptions=[''] * 4)
]
notebook_utils.show_msa_info(
single_chain_msas=single_chain_msas, sequence_index=1)
self.assertEqual(mocked_stdout.getvalue(),
'\n4 unique sequences found in total for sequence 1\n\n')
@parameterized.named_parameters(
('some_templates', 4), ('no_templates', 0))
def test_empty_placeholder_template_features(self, num_templates):
template_features = notebook_utils.empty_placeholder_template_features(
num_templates=num_templates, num_res=16)
self.assertCountEqual(template_features.keys(),
templates.TEMPLATE_FEATURES.keys())
self.assertSameElements(
[v.shape[0] for v in template_features.values()], [num_templates])
self.assertSequenceEqual(
[t.dtype for t in template_features.values()],
[np.array([], dtype=templates.TEMPLATE_FEATURES[feat_name]).dtype
for feat_name in template_features])
def test_get_pae_json(self):
pae = np.array([[0.01, 13.12345], [20.0987, 0.0]])
pae_json = notebook_utils.get_pae_json(pae=pae, max_pae=31.75)
self.assertEqual(
pae_json,
'[{"residue1":[1,1,2,2],"residue2":[1,2,1,2],"distance":'
'[0.0,13.1,20.1,0.0],"max_predicted_aligned_error":31.75}]')
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Amber relaxation."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Restrained Amber Minimization of a structure."""
import io
import time
from typing import Collection, Optional, Sequence
from absl import logging
from alphafold.common import protein
from alphafold.common import residue_constants
from alphafold.model import folding
from alphafold.relax import cleanup
from alphafold.relax import utils
import ml_collections
import numpy as np
try:
# openmm >= 7.6
import openmm
from openmm import unit
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms
def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
"""Returns True if the atom will be restrained by the given restraint set."""
if rset == "non_hydrogen":
return atom.element.name != "hydrogen"
elif rset == "c_alpha":
return atom.name == "CA"
def _add_restraints(
system: openmm.System,
reference_pdb: openmm_app.PDBFile,
stiffness: unit.Unit,
rset: str,
exclude_residues: Sequence[int]):
"""Adds a harmonic potential that restrains the system to a structure."""
assert rset in ["non_hydrogen", "c_alpha"]
force = openmm.CustomExternalForce(
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
force.addGlobalParameter("k", stiffness)
for p in ["x0", "y0", "z0"]:
force.addPerParticleParameter(p)
for i, atom in enumerate(reference_pdb.topology.atoms()):
if atom.residue.index in exclude_residues:
continue
if will_restrain(atom, rset):
force.addParticle(i, reference_pdb.positions[i])
logging.info("Restraining %d / %d particles.",
force.getNumParticles(), system.getNumParticles())
system.addForce(force)
def _openmm_minimize(
pdb_str: str,
max_iterations: int,
tolerance: unit.Unit,
stiffness: unit.Unit,
restraint_set: str,
exclude_residues: Sequence[int],
use_gpu: bool):
"""Minimize energy via openmm."""
pdb_file = io.StringIO(pdb_str)
pdb = openmm_app.PDBFile(pdb_file)
force_field = openmm_app.ForceField("amber99sb.xml")
constraints = openmm_app.HBonds
system = force_field.createSystem(
pdb.topology, constraints=constraints)
if stiffness > 0 * ENERGY / (LENGTH**2):
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
platform = openmm.Platform.getPlatformByName("HIP" if use_gpu else "CPU")
simulation = openmm_app.Simulation(
pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions)
ret = {}
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
simulation.minimizeEnergy(maxIterations=max_iterations,
tolerance=tolerance)
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
return ret
def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
"""Returns a pdb string provided OpenMM topology and positions."""
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, positions, f)
return f.getvalue()
def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
"""Checks that no atom positions have been altered by cleaning."""
cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
for ref_res, cl_res in zip(reference.topology.residues(),
cleaned.topology.residues()):
assert ref_res.name == cl_res.name
for rat in ref_res.atoms():
for cat in cl_res.atoms():
if cat.name == rat.name:
if not np.array_equal(cl_xyz[cat.index], ref_xyz[rat.index]):
raise ValueError(f"Coordinates of cleaned atom {cat} do not match "
f"coordinates of reference atom {rat}.")
def _check_residues_are_well_defined(prot: protein.Protein):
"""Checks that all residues contain non-empty atom sets."""
if (prot.atom_mask.sum(axis=-1) == 0).any():
raise ValueError("Amber minimization can only be performed on proteins with"
" well-defined residues. This protein contains at least"
" one residue with no atoms.")
def _check_atom_mask_is_ideal(prot):
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
atom_mask = prot.atom_mask
ideal_atom_mask = protein.ideal_atom_mask(prot)
utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
def clean_protein(
prot: protein.Protein,
checks: bool = True):
"""Adds missing atoms to Protein instance.
Args:
prot: A `protein.Protein` instance.
checks: A `bool` specifying whether to add additional checks to the cleaning
process.
Returns:
pdb_string: A string of the cleaned protein.
"""
_check_atom_mask_is_ideal(prot)
# Clean pdb.
prot_pdb_string = protein.to_pdb(prot)
pdb_file = io.StringIO(prot_pdb_string)
alterations_info = {}
fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
fixed_pdb_file = io.StringIO(fixed_pdb)
pdb_structure = PdbStructure(fixed_pdb_file)
cleanup.clean_structure(pdb_structure, alterations_info)
logging.info("alterations info: %s", alterations_info)
# Write pdb file of cleaned structure.
as_file = openmm_app.PDBFile(pdb_structure)
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string)
return pdb_string
def make_atom14_positions(prot):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK'.
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein.
residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
prot["all_atom_mask"], residx_atom14_to_atom37, axis=1).astype(np.float32)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
np.take_along_axis(prot["all_atom_positions"],
residx_atom14_to_atom37[..., None],
axis=1))
prot["atom14_atom_exists"] = residx_atom14_mask
prot["atom14_gt_exists"] = residx_atom14_gt_mask
prot["atom14_gt_positions"] = residx_atom14_gt_positions
prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37
# Create the gather indices for mapping back.
residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14
# Create the corresponding mask.
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
prot["atom37_atom_exists"] = residx_atom37_mask
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [
residue_constants.restype_1to3[res] for res in residue_constants.restypes
]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[prot["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = np.einsum("rac,rab->rbc",
residx_atom14_gt_positions,
renaming_transform)
prot["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = np.einsum("ra,rab->rb",
residx_atom14_gt_mask,
renaming_transform)
prot["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]]
atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
prot["atom14_atom_is_ambiguous"] = (
restype_atom14_is_ambiguous[prot["aatype"]])
return prot
def find_violations(prot_np: protein.Protein):
"""Analyzes a protein and returns structural violation information.
Args:
prot_np: A protein.
Returns:
violations: A `dict` of structure components with structural violations.
violation_metrics: A `dict` of violation metrics.
"""
batch = {
"aatype": prot_np.aatype,
"all_atom_positions": prot_np.atom_positions.astype(np.float32),
"all_atom_mask": prot_np.atom_mask.astype(np.float32),
"residue_index": prot_np.residue_index,
}
batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
batch = make_atom14_positions(batch)
violations = folding.find_structural_violations(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
config=ml_collections.ConfigDict(
{"violation_tolerance_factor": 12, # Taken from model config.
"clash_overlap_tolerance": 1.5, # Taken from model config.
}))
violation_metrics = folding.compute_violation_metrics(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
violations=violations,
)
return violations, violation_metrics
def get_violation_metrics(prot: protein.Protein):
"""Computes violation and alignment metrics."""
structural_violations, struct_metrics = find_violations(prot)
violation_idx = np.flatnonzero(
structural_violations["total_per_residue_violations_mask"])
struct_metrics["residue_violations"] = violation_idx
struct_metrics["num_residue_violations"] = len(violation_idx)
struct_metrics["structural_violations"] = structural_violations
return struct_metrics
def _run_one_iteration(
*,
pdb_string: str,
max_iterations: int,
tolerance: float,
stiffness: float,
restraint_set: str,
max_attempts: int,
use_gpu: bool,
exclude_residues: Optional[Collection[int]] = None):
"""Runs the minimization pipeline.
Args:
pdb_string: A pdb string.
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts.
use_gpu: Whether to run on GPU.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A `dict` of minimization info.
"""
exclude_residues = exclude_residues or []
# Assign physical dimensions.
tolerance = tolerance * ENERGY
stiffness = stiffness * ENERGY / (LENGTH**2)
start = time.time()
minimized = False
attempts = 0
while not minimized and attempts < max_attempts:
attempts += 1
try:
logging.info("Minimizing protein, attempt %d of %d.",
attempts, max_attempts)
ret = _openmm_minimize(
pdb_string, max_iterations=max_iterations,
tolerance=tolerance, stiffness=stiffness,
restraint_set=restraint_set,
exclude_residues=exclude_residues,
use_gpu=use_gpu)
minimized = True
except Exception as e: # pylint: disable=broad-except
logging.info(e)
if not minimized:
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
ret["opt_time"] = time.time() - start
ret["min_attempts"] = attempts
return ret
def run_pipeline(
prot: protein.Protein,
stiffness: float,
use_gpu: bool,
max_outer_iterations: int = 1,
place_hydrogens_every_iteration: bool = True,
max_iterations: int = 0,
tolerance: float = 2.39,
restraint_set: str = "non_hydrogen",
max_attempts: int = 100,
checks: bool = True,
exclude_residues: Optional[Sequence[int]] = None):
"""Run iterative amber relax.
Successive relax iterations are performed until all violations have been
resolved. Each iteration involves a restrained Amber minimization, with
restraint exclusions determined by violation-participating residues.
Args:
prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness.
use_gpu: Whether to run on GPU.
max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization.
max_iterations: An `int` specifying the maximum number of L-BFGS steps
per relax iteration. A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
The default value is the OpenMM default.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts per iteration.
checks: Whether to perform cleaning checks.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
out: A dictionary of output values.
"""
# `protein.to_pdb` will strip any poorly-defined residues so we need to
# perform this check before `clean_protein`.
_check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks)
exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues)
violations = np.inf
iteration = 0
while violations > 0 and iteration < max_outer_iterations:
ret = _run_one_iteration(
pdb_string=pdb_string,
exclude_residues=exclude_residues,
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
max_attempts=max_attempts,
use_gpu=use_gpu)
prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True)
else:
pdb_string = ret["min_pdb"]
ret.update(get_violation_metrics(prot))
ret.update({
"num_exclusions": len(exclude_residues),
"iteration": iteration,
})
violations = ret["violations_per_residue"]
exclude_residues = exclude_residues.union(ret["residue_violations"])
logging.info("Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
"num residue violations %d num residue exclusions %d ",
ret["einit"], ret["efinal"], ret["opt_time"],
ret["num_residue_violations"], ret["num_exclusions"])
iteration += 1
return ret
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for amber_minimize."""
import os
from absl.testing import absltest
from alphafold.common import protein
from alphafold.relax import amber_minimize
import numpy as np
# Internal import (7716).
_USE_GPU = False
def _load_test_protein(data_path):
pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
with open(pdb_path, 'r') as f:
return protein.from_pdb_string(f.read())
class AmberMinimizeTest(absltest.TestCase):
def test_multiple_disulfides_target(self):
prot = _load_test_protein(
'alphafold/relax/testdata/multiple_disulfides_target.pdb'
)
ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1,
stiffness=10., use_gpu=_USE_GPU)
self.assertIn('opt_time', ret)
self.assertIn('min_attempts', ret)
def test_raises_invalid_protein_assertion(self):
prot = _load_test_protein(
'alphafold/relax/testdata/multiple_disulfides_target.pdb'
)
prot.atom_mask[4, :] = 0
with self.assertRaisesRegex(
ValueError,
'Amber minimization can only be performed on proteins with well-defined'
' residues. This protein contains at least one residue with no atoms.'):
amber_minimize.run_pipeline(prot, max_iterations=10,
stiffness=1.,
max_attempts=1,
use_gpu=_USE_GPU)
def test_iterative_relax(self):
prot = _load_test_protein(
'alphafold/relax/testdata/with_violations.pdb'
)
violations = amber_minimize.get_violation_metrics(prot)
self.assertGreater(violations['num_residue_violations'], 0)
out = amber_minimize.run_pipeline(
prot=prot, max_outer_iterations=10, stiffness=10., use_gpu=_USE_GPU)
self.assertLess(out['efinal'], out['einit'])
self.assertEqual(0, out['num_residue_violations'])
def test_find_violations(self):
prot = _load_test_protein(
'alphafold/relax/testdata/multiple_disulfides_target.pdb'
)
viols, _ = amber_minimize.find_violations(prot)
expected_between_residues_connection_mask = np.zeros((191,), np.float32)
for residue in (42, 43, 59, 60, 135, 136):
expected_between_residues_connection_mask[residue] = 1.0
expected_clash_indices = np.array([
[8, 4],
[8, 5],
[13, 3],
[14, 1],
[14, 4],
[26, 4],
[26, 5],
[31, 8],
[31, 10],
[39, 0],
[39, 1],
[39, 2],
[39, 3],
[39, 4],
[42, 5],
[42, 6],
[42, 7],
[42, 8],
[47, 7],
[47, 8],
[47, 9],
[47, 10],
[64, 4],
[85, 5],
[102, 4],
[102, 5],
[109, 13],
[111, 5],
[118, 6],
[118, 7],
[118, 8],
[124, 4],
[124, 5],
[131, 5],
[139, 7],
[147, 4],
[152, 7]], dtype=np.int32)
expected_between_residues_clash_mask = np.zeros([191, 14])
expected_between_residues_clash_mask[expected_clash_indices[:, 0],
expected_clash_indices[:, 1]] += 1
expected_per_atom_violations = np.zeros([191, 14])
np.testing.assert_array_equal(
viols['between_residues']['connections_per_residue_violation_mask'],
expected_between_residues_connection_mask)
np.testing.assert_array_equal(
viols['between_residues']['clashes_per_atom_clash_mask'],
expected_between_residues_clash_mask)
np.testing.assert_array_equal(
viols['within_residues']['per_atom_violations'],
expected_per_atom_violations)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
fix_pdb uses a third-party tool. We also support fixing some additional edge
cases like removing chains of length one (see clean_structure).
"""
import io
import pdbfixer
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
Args:
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
Returns:
A PDB string representing the fixed structure.
"""
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
fixer.findNonstandardResidues()
alterations_info['nonstandard_residues'] = fixer.nonstandardResidues
fixer.replaceNonstandardResidues()
_remove_heterogens(fixer, alterations_info, keep_water=False)
fixer.findMissingResidues()
alterations_info['missing_residues'] = fixer.missingResidues
fixer.findMissingAtoms()
alterations_info['missing_heavy_atoms'] = fixer.missingAtoms
alterations_info['missing_terminals'] = fixer.missingTerminals
fixer.addMissingAtoms(seed=0)
fixer.addMissingHydrogens()
out_handle = io.StringIO()
app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle,
keepIds=True)
return out_handle.getvalue()
def clean_structure(pdb_structure, alterations_info):
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
Args:
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
_replace_met_se(pdb_structure, alterations_info)
_remove_chains_of_length_one(pdb_structure, alterations_info)
def _remove_heterogens(fixer, alterations_info, keep_water):
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
initial_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
initial_resnames.add(residue.name)
fixer.removeHeterogens(keepWater=keep_water)
final_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
final_resnames.add(residue.name)
alterations_info['removed_heterogens'] = (
initial_resnames.difference(final_resnames))
def _replace_met_se(pdb_structure, alterations_info):
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues = []
for res in pdb_structure.iter_residues():
name = res.get_name_with_spaces().strip()
if name == 'MET':
s_atom = res.get_atom('SD')
if s_atom.element_symbol == 'Se':
s_atom.element_symbol = 'S'
s_atom.element = element.get_by_symbol('S')
modified_met_residues.append(s_atom.residue_number)
alterations_info['Se_in_MET'] = modified_met_residues
def _remove_chains_of_length_one(pdb_structure, alterations_info):
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
removed_chains = {}
for model in pdb_structure.iter_models():
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1]
model.chains = valid_chains
for chain_id in invalid_chain_ids:
model.chains_by_id.pop(chain_id)
removed_chains[model.number] = invalid_chain_ids
alterations_info['removed_chains'] = removed_chains
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for relax.cleanup."""
import io
from absl.testing import absltest
from alphafold.relax import cleanup
from simtk.openmm.app.internal import pdbstructure
def _pdb_to_structure(pdb_str):
handle = io.StringIO(pdb_str)
return pdbstructure.PdbStructure(handle)
def _lines_to_structure(pdb_lines):
return _pdb_to_structure('\n'.join(pdb_lines))
class CleanupTest(absltest.TestCase):
def test_missing_residues(self):
pdb_lines = ['SEQRES 1 C 3 CYS GLY LEU',
'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 '
'19.08 N',
'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 '
'17.23 C',
'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 '
'15.38 C',
'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 '
'16.04 O',
'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 '
'14.75 N',
'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 '
'16.81 C',
'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 '
'16.95 C',
'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 '
'16.97 O']
input_handle = io.StringIO('\n'.join(pdb_lines))
alterations = {}
result = cleanup.fix_pdb(input_handle, alterations)
structure = _pdb_to_structure(result)
residue_names = [r.get_name() for r in structure.iter_residues()]
self.assertCountEqual(residue_names, ['CYS', 'GLY', 'LEU'])
self.assertCountEqual(alterations['missing_residues'].values(), [['GLY']])
def test_missing_atoms(self):
pdb_lines = ['SEQRES 1 A 1 PRO',
'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 '
' 0.00 C']
input_handle = io.StringIO('\n'.join(pdb_lines))
alterations = {}
result = cleanup.fix_pdb(input_handle, alterations)
structure = _pdb_to_structure(result)
atom_names = [a.get_name() for a in structure.iter_atoms()]
self.assertCountEqual(atom_names, ['N', 'CD', 'HD2', 'HD3', 'CG', 'HG2',
'HG3', 'CB', 'HB2', 'HB3', 'CA', 'HA',
'C', 'O', 'H2', 'H3', 'OXT'])
missing_atoms_by_residue = list(alterations['missing_heavy_atoms'].values())
self.assertLen(missing_atoms_by_residue, 1)
atoms_added = [a.name for a in missing_atoms_by_residue[0]]
self.assertCountEqual(atoms_added, ['N', 'CD', 'CG', 'CB', 'C', 'O'])
missing_terminals_by_residue = alterations['missing_terminals']
self.assertLen(missing_terminals_by_residue, 1)
has_missing_terminal = [r.name for r in missing_terminals_by_residue.keys()]
self.assertCountEqual(has_missing_terminal, ['PRO'])
self.assertCountEqual([t for t in missing_terminals_by_residue.values()],
[['OXT']])
def test_remove_heterogens(self):
pdb_lines = ['SEQRES 1 A 1 GLY',
'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
' 0.00 C',
'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 '
' 0.00 O']
input_handle = io.StringIO('\n'.join(pdb_lines))
alterations = {}
result = cleanup.fix_pdb(input_handle, alterations)
structure = _pdb_to_structure(result)
self.assertCountEqual([res.get_name() for res in structure.iter_residues()],
['GLY'])
self.assertEqual(alterations['removed_heterogens'], set(['HOH']))
def test_fix_nonstandard_residues(self):
pdb_lines = ['SEQRES 1 A 1 DAL',
'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 '
' 0.00 C']
input_handle = io.StringIO('\n'.join(pdb_lines))
alterations = {}
result = cleanup.fix_pdb(input_handle, alterations)
structure = _pdb_to_structure(result)
residue_names = [res.get_name() for res in structure.iter_residues()]
self.assertCountEqual(residue_names, ['ALA'])
self.assertLen(alterations['nonstandard_residues'], 1)
original_res, new_name = alterations['nonstandard_residues'][0]
self.assertEqual(original_res.id, '1')
self.assertEqual(new_name, 'ALA')
def test_replace_met_se(self):
pdb_lines = ['SEQRES 1 A 1 MET',
'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 '
' 0.00 Se']
structure = _lines_to_structure(pdb_lines)
alterations = {}
cleanup._replace_met_se(structure, alterations)
sd = [a for a in structure.iter_atoms() if a.get_name() == 'SD']
self.assertLen(sd, 1)
self.assertEqual(sd[0].element_symbol, 'S')
self.assertCountEqual(alterations['Se_in_MET'], [sd[0].residue_number])
def test_remove_chains_of_length_one(self):
pdb_lines = ['SEQRES 1 A 1 GLY',
'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
' 0.00 C']
structure = _lines_to_structure(pdb_lines)
alterations = {}
cleanup._remove_chains_of_length_one(structure, alterations)
chains = list(structure.iter_chains())
self.assertEmpty(chains)
self.assertCountEqual(alterations['removed_chains'].values(), [['A']])
if __name__ == '__main__':
absltest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment