Commit 0be2b30b authored by Augustin-Zidek's avatar Augustin-Zidek
Browse files

Add code for AlphaFold-Multimer.

PiperOrigin-RevId: 407076987
parent 1d43aaff
......@@ -13,72 +13,118 @@
# limitations under the License.
"""A collection of common Haiku modules for use in protein folding."""
import numbers
from typing import Union, Sequence
import haiku as hk
import jax.numpy as jnp
import numpy as np
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
dtype=np.float32)
def get_initializer_scale(initializer_name, input_shape):
"""Get Initializer for weights and scale to multiply activations by."""
if initializer_name == 'zeros':
w_init = hk.initializers.Constant(0.0)
else:
# fan-in scaling
scale = 1.
for channel_dim in input_shape:
scale /= channel_dim
if initializer_name == 'relu':
scale *= 2
noise_scale = scale
stddev = np.sqrt(noise_scale)
# Adjust stddev for truncation.
stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)
return w_init
class Linear(hk.Module):
"""Protein folding specific Linear Module.
"""Protein folding specific Linear module.
This differs from the standard Haiku Linear in a few ways:
* It supports inputs of arbitrary rank
* It supports inputs and outputs of arbitrary rank
* Initializers are specified by strings
"""
def __init__(self,
num_output: int,
num_output: Union[int, Sequence[int]],
initializer: str = 'linear',
num_input_dims: int = 1,
use_bias: bool = True,
bias_init: float = 0.,
precision = None,
name: str = 'linear'):
"""Constructs Linear Module.
Args:
num_output: number of output channels.
num_output: Number of output channels. Can be tuple when outputting
multiple dimensions.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}
num_input_dims: Number of dimensions from the end to project.
use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias.
name: name of module, used for name scopes.
precision: What precision to use for matrix multiplication, defaults
to None.
name: Name of module, used for name scopes.
"""
super().__init__(name=name)
self.num_output = num_output
if isinstance(num_output, numbers.Integral):
self.output_shape = (num_output,)
else:
self.output_shape = tuple(num_output)
self.initializer = initializer
self.use_bias = use_bias
self.bias_init = bias_init
self.num_input_dims = num_input_dims
self.num_output_dims = len(self.output_shape)
self.precision = precision
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
def __call__(self, inputs):
"""Connects Module.
Args:
inputs: Tensor of shape [..., num_channel]
inputs: Tensor with at least num_input_dims dimensions.
Returns:
output of shape [..., num_output]
output of shape [...] + num_output.
"""
n_channels = int(inputs.shape[-1])
weight_shape = [n_channels, self.num_output]
if self.initializer == 'linear':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.)
elif self.initializer == 'relu':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.)
elif self.initializer == 'zeros':
weight_init = hk.initializers.Constant(0.0)
num_input_dims = self.num_input_dims
if self.num_input_dims > 0:
in_shape = inputs.shape[-self.num_input_dims:]
else:
in_shape = ()
weight_init = get_initializer_scale(self.initializer, in_shape)
in_letters = 'abcde'[:self.num_input_dims]
out_letters = 'hijkl'[:self.num_output_dims]
weight_shape = in_shape + self.output_shape
weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
weight_init)
# this is equivalent to einsum('...c,cd->...d', inputs, weights)
# but turns out to be slightly faster
inputs = jnp.swapaxes(inputs, -1, -2)
output = jnp.einsum('...cb,cd->...db', inputs, weights)
output = jnp.swapaxes(output, -1, -2)
equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'
output = jnp.einsum(equation, inputs, weights, precision=self.precision)
if self.use_bias:
bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
hk.initializers.Constant(self.bias_init))
output += bias
return output
......@@ -17,7 +17,6 @@ import copy
from alphafold.model.tf import shape_placeholders
import ml_collections
NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
......@@ -27,6 +26,9 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""
if 'multimer' in name:
return CONFIG_MULTIMER
if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG)
......@@ -34,6 +36,32 @@ def model_config(name: str) -> ml_collections.ConfigDict:
return cfg
MODEL_PRESETS = {
'monomer': (
'model_1',
'model_2',
'model_3',
'model_4',
'model_5',
),
'monomer_ptm': (
'model_1_ptm',
'model_2_ptm',
'model_3_ptm',
'model_4_ptm',
'model_5_ptm',
),
'multimer': (
'model_1_multimer',
'model_2_multimer',
'model_3_multimer',
'model_4_multimer',
'model_5_multimer',
),
}
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
CONFIG_DIFFS = {
'model_1': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.1
......@@ -206,6 +234,7 @@ CONFIG = ml_collections.ConfigDict({
'shared_dropout': True
},
'outer_product_mean': {
'first': False,
'chunk_size': 128,
'dropout_rate': 0.0,
'num_outer_channel': 32,
......@@ -322,6 +351,7 @@ CONFIG = ml_collections.ConfigDict({
},
'global_config': {
'deterministic': False,
'multimer_mode': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
......@@ -400,3 +430,228 @@ CONFIG = ml_collections.ConfigDict({
'resample_msa_in_recycling': True
},
})
CONFIG_MULTIMER = ml_collections.ConfigDict({
'model': {
'embeddings_and_evoformer': {
'evoformer_num_block': 48,
'evoformer': {
'msa_column_attention': {
'dropout_rate': 0.0,
'gating': True,
'num_head': 8,
'orientation': 'per_column',
'shared_dropout': True
},
'msa_row_attention_with_pair_bias': {
'dropout_rate': 0.15,
'gating': True,
'num_head': 8,
'orientation': 'per_row',
'shared_dropout': True
},
'msa_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'outer_product_mean': {
'chunk_size': 128,
'dropout_rate': 0.0,
'first': True,
'num_outer_channel': 32,
'orientation': 'per_row',
'shared_dropout': True
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'num_msa': 252,
'num_extra_msa': 1152,
'masked_msa': {
'profile_prob': 0.1,
'replace_fraction': 0.15,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'use_chain_relative': True,
'max_relative_chain': 2,
'max_relative_idx': 32,
'seq_channel': 384,
'msa_channel': 256,
'pair_channel': 128,
'prev_pos': {
'max_bin': 20.75,
'min_bin': 3.25,
'num_bins': 15
},
'recycle_features': True,
'recycle_pos': True,
'template': {
'attention': {
'gating': False,
'num_head': 4
},
'dgram_features': {
'max_bin': 50.75,
'min_bin': 3.25,
'num_bins': 39
},
'enabled': True,
'max_templates': 4,
'num_channels': 64,
'subbatch_size': 128,
'template_pair_stack': {
'num_block': 2,
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 2,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
}
}
},
},
'global_config': {
'deterministic': False,
'multimer_mode': True,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
},
'heads': {
'distogram': {
'first_break': 2.3125,
'last_break': 21.6875,
'num_bins': 64,
'weight': 0.3
},
'experimentally_resolved': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'weight': 0.01
},
'masked_msa': {
'weight': 2.0
},
'predicted_aligned_error': {
'filter_by_resolution': True,
'max_error_bin': 31.0,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 64,
'num_channels': 128,
'weight': 0.1
},
'predicted_lddt': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 50,
'num_channels': 128,
'weight': 0.01
},
'structure_module': {
'angle_norm_weight': 0.01,
'chi_weight': 0.5,
'clash_overlap_tolerance': 1.5,
'dropout': 0.1,
'interface_fape': {
'atom_clamp_distance': 1000.0,
'loss_unit_distance': 20.0
},
'intra_chain_fape': {
'atom_clamp_distance': 10.0,
'loss_unit_distance': 10.0
},
'num_channel': 384,
'num_head': 12,
'num_layer': 8,
'num_layer_in_transition': 3,
'num_point_qk': 4,
'num_point_v': 8,
'num_scalar_qk': 16,
'num_scalar_v': 16,
'position_scale': 20.0,
'sidechain': {
'atom_clamp_distance': 10.0,
'loss_unit_distance': 10.0,
'num_channel': 128,
'num_residual_block': 2,
'weight_frac': 0.5
},
'structural_violation_loss_weight': 1.0,
'violation_tolerance_factor': 12.0,
'weight': 1.0
}
},
'num_ensemble_eval': 1,
'num_recycle': 3,
'resample_msa_in_recycling': True
}
})
......@@ -15,8 +15,10 @@
"""Code to generate processed features."""
import copy
from typing import List, Mapping, Tuple
from alphafold.model.tf import input_pipeline
from alphafold.model.tf import proteins_dataset
import ml_collections
import numpy as np
import tensorflow.compat.v1 as tf
......
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from alphafold.model.geometry import rigid_matrix_vector
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
StructOfArray = struct_of_array.StructOfArray
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from __future__ import annotations
from typing import Union
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import vector
import jax
import jax.numpy as jnp
Float = Union[float, jnp.ndarray]
VERSION = '0.1'
@struct_of_array.StructOfArray(same_dtype=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
new_rotation = self.rotation @ other.rotation
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def inverse(self) -> Rigid3Array:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape),
self.translation)
return Rigid3Array(rot, trans)
@classmethod
def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, dtype=dtype),
vector.Vec3Array.zeros(shape, dtype=dtype))
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_array(self):
rot_array = self.rotation.to_array()
vec_array = self.translation.to_array()
return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
vec = vector.Vec3Array.from_array(array[..., -1])
return cls(rot, vec)
@classmethod
def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
assert array.shape[-1] == 4
assert array.shape[-2] == 4
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
return cls(rotation, translation)
def __getstate__(self):
return (VERSION, (self.rotation, self.translation))
def __setstate__(self, state):
version, (rot, trans) = state
del version
object.__setattr__(self, 'rotation', rot)
object.__setattr__(self, 'translation', trans)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from __future__ import annotations
import dataclasses
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import utils
from alphafold.model.geometry import vector
import jax
import jax.numpy as jnp
import numpy as np
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
VERSION = '0.1'
@struct_of_array.StructOfArray(same_dtype=True)
class Rot3Array:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32})
xy: jnp.ndarray
xz: jnp.ndarray
yx: jnp.ndarray
yy: jnp.ndarray
yz: jnp.ndarray
zx: jnp.ndarray
zy: jnp.ndarray
zz: jnp.ndarray
__array_ufunc__ = None
def inverse(self) -> Rot3Array:
"""Returns inverse of Rot3Array."""
return Rot3Array(self.xx, self.yx, self.zx,
self.xy, self.yy, self.zy,
self.xz, self.yz, self.zz)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies Rot3Array to point."""
return vector.Vec3Array(
self.xx * point.x + self.xy * point.y + self.xz * point.z,
self.yx * point.x + self.yy * point.y + self.yz * point.z,
self.zx * point.x + self.zy * point.y + self.zz * point.z)
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
@classmethod
def identity(cls, shape, dtype=jnp.float32) -> Rot3Array:
"""Returns identity of given shape."""
ones = jnp.ones(shape, dtype=dtype)
zeros = jnp.zeros(shape, dtype=dtype)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
@classmethod
def from_two_vectors(cls, e0: vector.Vec3Array,
e1: vector.Vec3Array) -> Rot3Array:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0 = e0.normalized()
# make e1 perpendicular to e0.
c = e1.dot(e0)
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
@classmethod
def from_array(cls, array: jnp.ndarray) -> Rot3Array:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
unstacked = utils.unstack(array, axis=-2)
unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], [])
return cls(*unstacked)
def to_array(self) -> jnp.ndarray:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return jnp.stack(
[jnp.stack([self.xx, self.xy, self.xz], axis=-1),
jnp.stack([self.yx, self.yy, self.yz], axis=-1),
jnp.stack([self.zx, self.zy, self.zz], axis=-1)],
axis=-2)
@classmethod
def from_quaternion(cls,
w: jnp.ndarray,
x: jnp.ndarray,
y: jnp.ndarray,
z: jnp.ndarray,
normalize: bool = True,
epsilon: float = 1e-6) -> Rot3Array:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2))
w *= inv_norm
x *= inv_norm
y *= inv_norm
z *= inv_norm
xx = 1 - 2 * (jnp.square(y) + jnp.square(z))
xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y)
yx = 2 * (x * y + w * z)
yy = 1 - 2 * (jnp.square(x) + jnp.square(z))
yz = 2 * (y * z - w * x)
zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x)
zz = 1 - 2 * (jnp.square(x) + jnp.square(y))
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
@classmethod
def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array:
"""Samples uniform random Rot3Array according to Haar Measure."""
quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype)
quats = utils.unstack(quat_array)
return cls.from_quaternion(*quats)
def __getstate__(self):
return (VERSION,
[np.asarray(getattr(self, field)) for field in COMPONENTS])
def __setstate__(self, state):
version, state = state
del version
for i, field in enumerate(COMPONENTS):
object.__setattr__(self, field, state[i])
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import dataclasses
from alphafold.model.geometry import rigid_matrix_vector
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import vector
import jax.numpy as jnp
import numpy as np
def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
matrix2: rotation_matrix.Rot3Array):
for field in dataclasses.fields(rotation_matrix.Rot3Array):
field = field.name
np.testing.assert_array_equal(
getattr(matrix1, field), getattr(matrix2, field))
def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
mat2: rotation_matrix.Rot3Array):
np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6)
def assert_array_equal_to_rotation_matrix(array: jnp.ndarray,
matrix: rotation_matrix.Rot3Array):
"""Check that array and Matrix match."""
np.testing.assert_array_equal(matrix.xx, array[..., 0, 0])
np.testing.assert_array_equal(matrix.xy, array[..., 0, 1])
np.testing.assert_array_equal(matrix.xz, array[..., 0, 2])
np.testing.assert_array_equal(matrix.yx, array[..., 1, 0])
np.testing.assert_array_equal(matrix.yy, array[..., 1, 1])
np.testing.assert_array_equal(matrix.yz, array[..., 1, 2])
np.testing.assert_array_equal(matrix.zx, array[..., 2, 0])
np.testing.assert_array_equal(matrix.zy, array[..., 2, 1])
np.testing.assert_array_equal(matrix.zz, array[..., 2, 2])
def assert_array_close_to_rotation_matrix(array: jnp.ndarray,
matrix: rotation_matrix.Rot3Array):
np.testing.assert_array_almost_equal(matrix.to_array(), array, 6)
def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_array_equal(vec1.x, vec2.x)
np.testing.assert_array_equal(vec1.y, vec2.y)
np.testing.assert_array_equal(vec1.z, vec2.z)
def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.)
np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.)
np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.)
def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.)
def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
np.testing.assert_array_equal(vec.to_array(), array)
def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_equal(rot, rigid.rotation)
assert_vectors_equal(trans, rigid.translation)
def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_close(rot, rigid.rotation)
assert_vectors_close(trans, rigid.translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
from typing import List
import jax.numpy as jnp
def unstack(value: jnp.ndarray, axis: int = -1) -> List[jnp.ndarray]:
return [jnp.squeeze(v, axis=axis)
for v in jnp.split(value, value.shape[axis], axis=axis)]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -15,6 +15,7 @@
"""A collection of JAX utility functions for use in protein folding."""
import collections
import functools
import numbers
from typing import Mapping
......@@ -79,3 +80,52 @@ def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params:
hk_params[scope][name] = jnp.array(array)
return hk_params
def padding_consistent_rng(f):
"""Modify any element-wise random function to be consistent with padding.
Normally if you take a function like jax.random.normal and generate an array,
say of size (10,10), you will get a different set of random numbers to if you
add padding and take the first (10,10) sub-array.
This function makes a random function that is consistent regardless of the
amount of padding added.
Note: The padding-consistent function is likely to be slower to compile and
run than the function it is wrapping, but these slowdowns are likely to be
negligible in a large network.
Args:
f: Any element-wise function that takes (PRNG key, shape) as the first 2
arguments.
Returns:
An equivalent function to f, that is now consistent for different amounts of
padding.
"""
def grid_keys(key, shape):
"""Generate a grid of rng keys that is consistent with different padding.
Generate random keys such that the keys will be identical, regardless of
how much padding is added to any dimension.
Args:
key: A PRNG key.
shape: The shape of the output array of keys that will be generated.
Returns:
An array of shape `shape` consisting of random keys.
"""
if not shape:
return key
new_keys = jax.vmap(functools.partial(jax.random.fold_in, key))(
jnp.arange(shape[0]))
return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys)
def inner(key, shape, **kwargs):
return jnp.vectorize(
lambda key: f(key, shape=(), **kwargs),
signature='(2)->()')(
grid_keys(key, shape))
return inner
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment