Commit 0be2b30b authored by Augustin-Zidek's avatar Augustin-Zidek
Browse files

Add code for AlphaFold-Multimer.

PiperOrigin-RevId: 407076987
parent 1d43aaff
...@@ -13,72 +13,118 @@ ...@@ -13,72 +13,118 @@
# limitations under the License. # limitations under the License.
"""A collection of common Haiku modules for use in protein folding.""" """A collection of common Haiku modules for use in protein folding."""
import numbers
from typing import Union, Sequence
import haiku as hk import haiku as hk
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
dtype=np.float32)
def get_initializer_scale(initializer_name, input_shape):
"""Get Initializer for weights and scale to multiply activations by."""
if initializer_name == 'zeros':
w_init = hk.initializers.Constant(0.0)
else:
# fan-in scaling
scale = 1.
for channel_dim in input_shape:
scale /= channel_dim
if initializer_name == 'relu':
scale *= 2
noise_scale = scale
stddev = np.sqrt(noise_scale)
# Adjust stddev for truncation.
stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)
return w_init
class Linear(hk.Module): class Linear(hk.Module):
"""Protein folding specific Linear Module. """Protein folding specific Linear module.
This differs from the standard Haiku Linear in a few ways: This differs from the standard Haiku Linear in a few ways:
* It supports inputs of arbitrary rank * It supports inputs and outputs of arbitrary rank
* Initializers are specified by strings * Initializers are specified by strings
""" """
def __init__(self, def __init__(self,
num_output: int, num_output: Union[int, Sequence[int]],
initializer: str = 'linear', initializer: str = 'linear',
num_input_dims: int = 1,
use_bias: bool = True, use_bias: bool = True,
bias_init: float = 0., bias_init: float = 0.,
precision = None,
name: str = 'linear'): name: str = 'linear'):
"""Constructs Linear Module. """Constructs Linear Module.
Args: Args:
num_output: number of output channels. num_output: Number of output channels. Can be tuple when outputting
multiple dimensions.
initializer: What initializer to use, should be one of {'linear', 'relu', initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'} 'zeros'}
num_input_dims: Number of dimensions from the end to project.
use_bias: Whether to include trainable bias use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias. bias_init: Value used to initialize bias.
name: name of module, used for name scopes. precision: What precision to use for matrix multiplication, defaults
to None.
name: Name of module, used for name scopes.
""" """
super().__init__(name=name) super().__init__(name=name)
self.num_output = num_output if isinstance(num_output, numbers.Integral):
self.output_shape = (num_output,)
else:
self.output_shape = tuple(num_output)
self.initializer = initializer self.initializer = initializer
self.use_bias = use_bias self.use_bias = use_bias
self.bias_init = bias_init self.bias_init = bias_init
self.num_input_dims = num_input_dims
self.num_output_dims = len(self.output_shape)
self.precision = precision
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: def __call__(self, inputs):
"""Connects Module. """Connects Module.
Args: Args:
inputs: Tensor of shape [..., num_channel] inputs: Tensor with at least num_input_dims dimensions.
Returns: Returns:
output of shape [..., num_output] output of shape [...] + num_output.
""" """
n_channels = int(inputs.shape[-1])
weight_shape = [n_channels, self.num_output] num_input_dims = self.num_input_dims
if self.initializer == 'linear':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.) if self.num_input_dims > 0:
elif self.initializer == 'relu': in_shape = inputs.shape[-self.num_input_dims:]
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.) else:
elif self.initializer == 'zeros': in_shape = ()
weight_init = hk.initializers.Constant(0.0)
weight_init = get_initializer_scale(self.initializer, in_shape)
in_letters = 'abcde'[:self.num_input_dims]
out_letters = 'hijkl'[:self.num_output_dims]
weight_shape = in_shape + self.output_shape
weights = hk.get_parameter('weights', weight_shape, inputs.dtype, weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
weight_init) weight_init)
# this is equivalent to einsum('...c,cd->...d', inputs, weights) equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'
# but turns out to be slightly faster
inputs = jnp.swapaxes(inputs, -1, -2) output = jnp.einsum(equation, inputs, weights, precision=self.precision)
output = jnp.einsum('...cb,cd->...db', inputs, weights)
output = jnp.swapaxes(output, -1, -2)
if self.use_bias: if self.use_bias:
bias = hk.get_parameter('bias', [self.num_output], inputs.dtype, bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
hk.initializers.Constant(self.bias_init)) hk.initializers.Constant(self.bias_init))
output += bias output += bias
return output return output
...@@ -17,7 +17,6 @@ import copy ...@@ -17,7 +17,6 @@ import copy
from alphafold.model.tf import shape_placeholders from alphafold.model.tf import shape_placeholders
import ml_collections import ml_collections
NUM_RES = shape_placeholders.NUM_RES NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
...@@ -27,6 +26,9 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES ...@@ -27,6 +26,9 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def model_config(name: str) -> ml_collections.ConfigDict: def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model.""" """Get the ConfigDict of a CASP14 model."""
if 'multimer' in name:
return CONFIG_MULTIMER
if name not in CONFIG_DIFFS: if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.') raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG) cfg = copy.deepcopy(CONFIG)
...@@ -34,6 +36,32 @@ def model_config(name: str) -> ml_collections.ConfigDict: ...@@ -34,6 +36,32 @@ def model_config(name: str) -> ml_collections.ConfigDict:
return cfg return cfg
MODEL_PRESETS = {
'monomer': (
'model_1',
'model_2',
'model_3',
'model_4',
'model_5',
),
'monomer_ptm': (
'model_1_ptm',
'model_2_ptm',
'model_3_ptm',
'model_4_ptm',
'model_5_ptm',
),
'multimer': (
'model_1_multimer',
'model_2_multimer',
'model_3_multimer',
'model_4_multimer',
'model_5_multimer',
),
}
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
CONFIG_DIFFS = { CONFIG_DIFFS = {
'model_1': { 'model_1': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.1 # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1
...@@ -206,6 +234,7 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -206,6 +234,7 @@ CONFIG = ml_collections.ConfigDict({
'shared_dropout': True 'shared_dropout': True
}, },
'outer_product_mean': { 'outer_product_mean': {
'first': False,
'chunk_size': 128, 'chunk_size': 128,
'dropout_rate': 0.0, 'dropout_rate': 0.0,
'num_outer_channel': 32, 'num_outer_channel': 32,
...@@ -322,6 +351,7 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -322,6 +351,7 @@ CONFIG = ml_collections.ConfigDict({
}, },
'global_config': { 'global_config': {
'deterministic': False, 'deterministic': False,
'multimer_mode': False,
'subbatch_size': 4, 'subbatch_size': 4,
'use_remat': False, 'use_remat': False,
'zero_init': True 'zero_init': True
...@@ -400,3 +430,228 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -400,3 +430,228 @@ CONFIG = ml_collections.ConfigDict({
'resample_msa_in_recycling': True 'resample_msa_in_recycling': True
}, },
}) })
CONFIG_MULTIMER = ml_collections.ConfigDict({
'model': {
'embeddings_and_evoformer': {
'evoformer_num_block': 48,
'evoformer': {
'msa_column_attention': {
'dropout_rate': 0.0,
'gating': True,
'num_head': 8,
'orientation': 'per_column',
'shared_dropout': True
},
'msa_row_attention_with_pair_bias': {
'dropout_rate': 0.15,
'gating': True,
'num_head': 8,
'orientation': 'per_row',
'shared_dropout': True
},
'msa_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'outer_product_mean': {
'chunk_size': 128,
'dropout_rate': 0.0,
'first': True,
'num_outer_channel': 32,
'orientation': 'per_row',
'shared_dropout': True
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'num_msa': 252,
'num_extra_msa': 1152,
'masked_msa': {
'profile_prob': 0.1,
'replace_fraction': 0.15,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'use_chain_relative': True,
'max_relative_chain': 2,
'max_relative_idx': 32,
'seq_channel': 384,
'msa_channel': 256,
'pair_channel': 128,
'prev_pos': {
'max_bin': 20.75,
'min_bin': 3.25,
'num_bins': 15
},
'recycle_features': True,
'recycle_pos': True,
'template': {
'attention': {
'gating': False,
'num_head': 4
},
'dgram_features': {
'max_bin': 50.75,
'min_bin': 3.25,
'num_bins': 39
},
'enabled': True,
'max_templates': 4,
'num_channels': 64,
'subbatch_size': 128,
'template_pair_stack': {
'num_block': 2,
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 2,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
}
}
},
},
'global_config': {
'deterministic': False,
'multimer_mode': True,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
},
'heads': {
'distogram': {
'first_break': 2.3125,
'last_break': 21.6875,
'num_bins': 64,
'weight': 0.3
},
'experimentally_resolved': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'weight': 0.01
},
'masked_msa': {
'weight': 2.0
},
'predicted_aligned_error': {
'filter_by_resolution': True,
'max_error_bin': 31.0,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 64,
'num_channels': 128,
'weight': 0.1
},
'predicted_lddt': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 50,
'num_channels': 128,
'weight': 0.01
},
'structure_module': {
'angle_norm_weight': 0.01,
'chi_weight': 0.5,
'clash_overlap_tolerance': 1.5,
'dropout': 0.1,
'interface_fape': {
'atom_clamp_distance': 1000.0,
'loss_unit_distance': 20.0
},
'intra_chain_fape': {
'atom_clamp_distance': 10.0,
'loss_unit_distance': 10.0
},
'num_channel': 384,
'num_head': 12,
'num_layer': 8,
'num_layer_in_transition': 3,
'num_point_qk': 4,
'num_point_v': 8,
'num_scalar_qk': 16,
'num_scalar_v': 16,
'position_scale': 20.0,
'sidechain': {
'atom_clamp_distance': 10.0,
'loss_unit_distance': 10.0,
'num_channel': 128,
'num_residual_block': 2,
'weight_frac': 0.5
},
'structural_violation_loss_weight': 1.0,
'violation_tolerance_factor': 12.0,
'weight': 1.0
}
},
'num_ensemble_eval': 1,
'num_recycle': 3,
'resample_msa_in_recycling': True
}
})
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
"""Code to generate processed features.""" """Code to generate processed features."""
import copy import copy
from typing import List, Mapping, Tuple from typing import List, Mapping, Tuple
from alphafold.model.tf import input_pipeline from alphafold.model.tf import input_pipeline
from alphafold.model.tf import proteins_dataset from alphafold.model.tf import proteins_dataset
import ml_collections import ml_collections
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
......
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from alphafold.model.geometry import rigid_matrix_vector
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
StructOfArray = struct_of_array.StructOfArray
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from __future__ import annotations
from typing import Union
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import vector
import jax
import jax.numpy as jnp
Float = Union[float, jnp.ndarray]
VERSION = '0.1'
@struct_of_array.StructOfArray(same_dtype=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
new_rotation = self.rotation @ other.rotation
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def inverse(self) -> Rigid3Array:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape),
self.translation)
return Rigid3Array(rot, trans)
@classmethod
def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, dtype=dtype),
vector.Vec3Array.zeros(shape, dtype=dtype))
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_array(self):
rot_array = self.rotation.to_array()
vec_array = self.translation.to_array()
return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
vec = vector.Vec3Array.from_array(array[..., -1])
return cls(rot, vec)
@classmethod
def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
assert array.shape[-1] == 4
assert array.shape[-2] == 4
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
return cls(rotation, translation)
def __getstate__(self):
return (VERSION, (self.rotation, self.translation))
def __setstate__(self, state):
version, (rot, trans) = state
del version
object.__setattr__(self, 'rotation', rot)
object.__setattr__(self, 'translation', trans)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from __future__ import annotations
import dataclasses
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import utils
from alphafold.model.geometry import vector
import jax
import jax.numpy as jnp
import numpy as np
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
VERSION = '0.1'
@struct_of_array.StructOfArray(same_dtype=True)
class Rot3Array:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32})
xy: jnp.ndarray
xz: jnp.ndarray
yx: jnp.ndarray
yy: jnp.ndarray
yz: jnp.ndarray
zx: jnp.ndarray
zy: jnp.ndarray
zz: jnp.ndarray
__array_ufunc__ = None
def inverse(self) -> Rot3Array:
"""Returns inverse of Rot3Array."""
return Rot3Array(self.xx, self.yx, self.zx,
self.xy, self.yy, self.zy,
self.xz, self.yz, self.zz)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies Rot3Array to point."""
return vector.Vec3Array(
self.xx * point.x + self.xy * point.y + self.xz * point.z,
self.yx * point.x + self.yy * point.y + self.yz * point.z,
self.zx * point.x + self.zy * point.y + self.zz * point.z)
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
@classmethod
def identity(cls, shape, dtype=jnp.float32) -> Rot3Array:
"""Returns identity of given shape."""
ones = jnp.ones(shape, dtype=dtype)
zeros = jnp.zeros(shape, dtype=dtype)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
@classmethod
def from_two_vectors(cls, e0: vector.Vec3Array,
e1: vector.Vec3Array) -> Rot3Array:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0 = e0.normalized()
# make e1 perpendicular to e0.
c = e1.dot(e0)
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
@classmethod
def from_array(cls, array: jnp.ndarray) -> Rot3Array:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
unstacked = utils.unstack(array, axis=-2)
unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], [])
return cls(*unstacked)
def to_array(self) -> jnp.ndarray:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return jnp.stack(
[jnp.stack([self.xx, self.xy, self.xz], axis=-1),
jnp.stack([self.yx, self.yy, self.yz], axis=-1),
jnp.stack([self.zx, self.zy, self.zz], axis=-1)],
axis=-2)
@classmethod
def from_quaternion(cls,
w: jnp.ndarray,
x: jnp.ndarray,
y: jnp.ndarray,
z: jnp.ndarray,
normalize: bool = True,
epsilon: float = 1e-6) -> Rot3Array:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2))
w *= inv_norm
x *= inv_norm
y *= inv_norm
z *= inv_norm
xx = 1 - 2 * (jnp.square(y) + jnp.square(z))
xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y)
yx = 2 * (x * y + w * z)
yy = 1 - 2 * (jnp.square(x) + jnp.square(z))
yz = 2 * (y * z - w * x)
zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x)
zz = 1 - 2 * (jnp.square(x) + jnp.square(y))
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
@classmethod
def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array:
"""Samples uniform random Rot3Array according to Haar Measure."""
quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype)
quats = utils.unstack(quat_array)
return cls.from_quaternion(*quats)
def __getstate__(self):
return (VERSION,
[np.asarray(getattr(self, field)) for field in COMPONENTS])
def __setstate__(self, state):
version, state = state
del version
for i, field in enumerate(COMPONENTS):
object.__setattr__(self, field, state[i])
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class decorator to represent (nested) struct of arrays."""
import dataclasses
import jax
def get_item(instance, key):
sliced = {}
for field in get_array_fields(instance):
num_trailing_dims = field.metadata.get('num_trailing_dims', 0)
this_key = key
if isinstance(key, tuple) and Ellipsis in this_key:
this_key += (slice(None),) * num_trailing_dims
sliced[field.name] = getattr(instance, field.name)[this_key]
return dataclasses.replace(instance, **sliced)
@property
def get_shape(instance):
"""Returns Shape for given instance of dataclass."""
first_field = dataclasses.fields(instance)[0]
num_trailing_dims = first_field.metadata.get('num_trailing_dims', None)
value = getattr(instance, first_field.name)
if num_trailing_dims:
return value.shape[:-num_trailing_dims]
else:
return value.shape
def get_len(instance):
"""Returns length for given instance of dataclass."""
shape = instance.shape
if shape:
return shape[0]
else:
raise TypeError('len() of unsized object') # Match jax.numpy behavior.
@property
def get_dtype(instance):
"""Returns Dtype for given instance of dataclass."""
fields = dataclasses.fields(instance)
sets_dtype = [
field.name for field in fields if field.metadata.get('sets_dtype', False)
]
if sets_dtype:
assert len(sets_dtype) == 1, 'at most field can set dtype'
field_value = getattr(instance, sets_dtype[0])
elif instance.same_dtype:
field_value = getattr(instance, fields[0].name)
else:
# Should this be Value Error?
raise AttributeError('Trying to access Dtype on Struct of Array without'
'either "same_dtype" or field setting dtype')
if hasattr(field_value, 'dtype'):
return field_value.dtype
else:
# Should this be Value Error?
raise AttributeError(f'field_value {field_value} does not have dtype')
def replace(instance, **kwargs):
return dataclasses.replace(instance, **kwargs)
def post_init(instance):
"""Validate instance has same shapes & dtypes."""
array_fields = get_array_fields(instance)
arrays = list(get_array_fields(instance, return_values=True).values())
first_field = array_fields[0]
# These slightly weird constructions about checking whether the leaves are
# actual arrays is since e.g. vmap internally relies on being able to
# construct pytree's with object() as leaves, this would break the checking
# as such we are only validating the object when the entries in the dataclass
# Are arrays or other dataclasses of arrays.
try:
dtype = instance.dtype
except AttributeError:
dtype = None
if dtype is not None:
first_shape = instance.shape
for array, field in zip(arrays, array_fields):
field_shape = array.shape
num_trailing_dims = field.metadata.get('num_trailing_dims', None)
if num_trailing_dims:
array_shape = array.shape
field_shape = array_shape[:-num_trailing_dims]
msg = (f'field {field} should have number of trailing dims'
' {num_trailing_dims}')
assert len(array_shape) == len(first_shape) + num_trailing_dims, msg
else:
field_shape = array.shape
shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't "
f"match shape {first_shape} of field {first_field}")
assert field_shape == first_shape, shape_msg
field_dtype = array.dtype
allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', [])
if allowed_metadata_dtypes:
msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}'
assert field_dtype in allowed_metadata_dtypes, msg
if 'dtype' in field.metadata:
target_dtype = field.metadata['dtype']
else:
target_dtype = dtype
msg = f'Dtype is {field_dtype} but must be {target_dtype}'
assert field_dtype == target_dtype, msg
def flatten(instance):
"""Flatten Struct of Array instance."""
array_likes = list(get_array_fields(instance, return_values=True).values())
flat_array_likes = []
inner_treedefs = []
num_arrays = []
for array_like in array_likes:
flat_array_like, inner_treedef = jax.tree_flatten(array_like)
inner_treedefs.append(inner_treedef)
flat_array_likes += flat_array_like
num_arrays.append(len(flat_array_like))
metadata = get_metadata_fields(instance, return_values=True)
metadata = type(instance).metadata_cls(**metadata)
return flat_array_likes, (inner_treedefs, metadata, num_arrays)
def make_metadata_class(cls):
metadata_fields = get_fields(cls,
lambda x: x.metadata.get('is_metadata', False))
metadata_cls = dataclasses.make_dataclass(
cls_name='Meta' + cls.__name__,
fields=[(field.name, field.type, field) for field in metadata_fields],
frozen=True,
eq=True)
return metadata_cls
def get_fields(cls_or_instance, filterfn, return_values=False):
fields = dataclasses.fields(cls_or_instance)
fields = [field for field in fields if filterfn(field)]
if return_values:
return {
field.name: getattr(cls_or_instance, field.name) for field in fields
}
else:
return fields
def get_array_fields(cls, return_values=False):
return get_fields(
cls,
lambda x: not x.metadata.get('is_metadata', False),
return_values=return_values)
def get_metadata_fields(cls, return_values=False):
return get_fields(
cls,
lambda x: x.metadata.get('is_metadata', False),
return_values=return_values)
class StructOfArray:
"""Class Decorator for Struct Of Arrays."""
def __init__(self, same_dtype=True):
self.same_dtype = same_dtype
def __call__(self, cls):
cls.__array_ufunc__ = None
cls.replace = replace
cls.same_dtype = self.same_dtype
cls.dtype = get_dtype
cls.shape = get_shape
cls.__len__ = get_len
cls.__getitem__ = get_item
cls.__post_init__ = post_init
new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args
# pytree claims to require metadata to be hashable, not sure why,
# But making derived dataclass that can just hold metadata
new_cls.metadata_cls = make_metadata_class(new_cls)
def unflatten(aux, data):
inner_treedefs, metadata, num_arrays = aux
array_fields = [field.name for field in get_array_fields(new_cls)]
value_dict = {}
array_start = 0
for num_array, inner_treedef, array_field in zip(num_arrays,
inner_treedefs,
array_fields):
value_dict[array_field] = jax.tree_unflatten(
inner_treedef, data[array_start:array_start + num_array])
array_start += num_array
metadata_fields = get_metadata_fields(new_cls)
for field in metadata_fields:
value_dict[field.name] = getattr(metadata, field.name)
return new_cls(**value_dict)
jax.tree_util.register_pytree_node(
nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten)
return new_cls
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AlphaFold Colab notebook."""
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment