"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "43b8c6f991aefab6b05c600e3365449d3b3387b0"
Commit 6e68d6b0 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add geometry functions to multimer

parent dba44612
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from openfold.utils.geometry import rigid_matrix_vector
from openfold.utils.geometry import rotation_matrix
from openfold.utils.geometry import struct_of_array
from openfold.utils.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
StructOfArray = struct_of_array.StructOfArray
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
class QuatRigid(nn.Module):
def __init__(self, c_hidden, full_quat):
super().__init__()
self.full_quat = full_quat
if self.full_quat:
rigid_dim = 7
else:
rigid_dim = 6
self.linear = Linear(c_hidden, rigid_dim)
def forward(self, activations: torch.Tensor) -> Rigid3Array:
# NOTE: During training, this needs to be run in higher precision
rigid_flat = self.linear(activations.to(torch.float32))
print(rigid_flat.shape)
rigid_flat = torch.unbind(rigid_flat, dim=-1)
if(self.full_quat):
qw, qx, qy, qz = rigid_flat[:4]
translation = rigid_flat[4:]
else:
qx, qy, qz = rigid_flat[:3]
qw = torch.ones_like(qx)
translation = rigid_flat[3:]
rotation = Rot3Array.from_quaternion(
qw, qx, qy, qz, normalize=True,
)
translation = Vec3Array(*translation)
return Rigid3Array(rotation, translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from __future__ import annotations
import dataclasses
from typing import Union, List
import torch
from openfold.utils.geometry import rotation_matrix
from openfold.utils.geometry import struct_of_array
from openfold.utils.geometry import vector
Float = Union[float, torch.Tensor]
@dataclasses.dataclass(frozen=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
new_rotation = self.rotation @ other.rotation # __matmul__
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def __getitem__(self, index) -> Rigid3Array:
return Rigid3Array(
self.rotation[index],
self.translation[index],
)
def __mul__(self, other: torch.Tensor) -> Rigid3Array:
return Rigid3Array(
self.rotation * other,
self.translation * other,
)
def map_tensor_fn(self, fn) -> Rigid3Array:
return Rigid3Array(
self.rotation.map_tensor_fn(fn),
self.translation.map_tensor_fn(fn),
)
def inverse(self) -> Rigid3Array:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
return Rigid3Array(rot, trans.clone())
@classmethod
def identity(cls, shape, device) -> Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, device),
vector.Vec3Array.zeros(shape, device)
)
@classmethod
def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
return cls(
Rot3Array.cat([r.rotation for r in rigids], dim=dim),
Vec3Array.cat([r.translation for r in rigids], dim=dim),
)
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_array(self):
rot_array = self.rotation.to_array()
vec_array = self.translation.to_array()
return torch.cat([rot_array, vec_array[..., None]], dim=-1)
def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape)
return Rigid3Aray(rots, trans)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
vec = vector.Vec3Array.from_array(array[..., -1])
return cls(rot, vec)
@classmethod
def from_tensor_4x4(cls, array):
return cls.from_array(array)
@classmethod
def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
return cls(rotation, translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from __future__ import annotations
import dataclasses
import torch
import numpy as np
from openfold.utils.geometry import struct_of_array
from openfold.utils.geometry import utils
from openfold.utils.geometry import vector
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
@dataclasses.dataclass(frozen=True)
class Rot3Array:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
xy: torch.Tensor
xz: torch.Tensor
yx: torch.Tensor
yy: torch.Tensor
yz: torch.Tensor
zx: torch.Tensor
zy: torch.Tensor
zz: torch.Tensor
__array_ufunc__ = None
def __getitem__(self, index):
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: getattr(self, name)[index]
for name in field_names
}
)
def __mul__(self, other: torch.Tensor):
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: getattr(self, name) * other
for name in field_names
}
)
def map_tensor_fn(self, fn) -> Rot3Array:
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: fn(getattr(self, name))
for name in field_names
}
)
def inverse(self) -> Rot3Array:
"""Returns inverse of Rot3Array."""
return Rot3Array(
self.xx, self.yx, self.zx,
self.xy, self.yy, self.zy,
self.xz, self.yz, self.zz
)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies Rot3Array to point."""
return vector.Vec3Array(
self.xx * point.x + self.xy * point.y + self.xz * point.z,
self.yx * point.x + self.yy * point.y + self.yz * point.z,
self.zx * point.x + self.zy * point.y + self.zz * point.z
)
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
@classmethod
def identity(cls, shape, device) -> Rot3Array:
"""Returns identity of given shape."""
ones = torch.ones(shape, dtype=torch.float32, device=device)
zeros = torch.zeros(shape, dtype=torch.float32, device=device)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
@classmethod
def from_two_vectors(
cls, e0: vector.Vec3Array,
e1: vector.Vec3Array
) -> Rot3Array:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0 = e0.normalized()
# make e1 perpendicular to e0.
c = e1.dot(e0)
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
@classmethod
def from_array(cls, array: torch.Tensor) -> Rot3Array:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
return cls(torch.unbind(array, dim=-2))
def to_array(self) -> torch.Tensor:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return torch.stack(
[
torch.stack([self.xx, self.xy, self.xz], dim=-1),
torch.stack([self.yx, self.yy, self.yz], dim=-1),
torch.stack([self.zx, self.zy, self.zz], dim=-1)
],
dim=-2)
@classmethod
def from_quaternion(cls,
w: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
normalize: bool = True,
eps: float = 1e-6
) -> Rot3Array:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2)
w *= inv_norm
x *= inv_norm
y *= inv_norm
z *= inv_norm
xx = 1 - 2 * (y ** 2 + z ** 2)
xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y)
yx = 2 * (x * y + w * z)
yy = 1 - 2 * (x ** 2 + z ** 2)
yz = 2 * (y * z - w * x)
zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x)
zz = 1 - 2 * (x ** 2 + y ** 2)
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
def reshape(self, new_shape):
field_names = utils.get_field_names(Rot3Array)
reshape_fn = lambda t: t.reshape(new_shape)
return Rot3Array(
**{
name: reshape_fn(getattr(self, name))
for name in field_names
}
)
@classmethod
def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array:
field_names = utils.get_field_names(Rot3Array)
cat_fn = lambda l: torch.cat(l, dim=dim)
return cls(
**{
name: cat_fn([getattr(r, name) for r in rots])
for name in field_names
}
)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class decorator to represent (nested) struct of arrays."""
import dataclasses
import jax
def get_item(instance, key):
sliced = {}
for field in get_array_fields(instance):
num_trailing_dims = field.metadata.get('num_trailing_dims', 0)
this_key = key
if isinstance(key, tuple) and Ellipsis in this_key:
this_key += (slice(None),) * num_trailing_dims
sliced[field.name] = getattr(instance, field.name)[this_key]
return dataclasses.replace(instance, **sliced)
@property
def get_shape(instance):
"""Returns Shape for given instance of dataclass."""
first_field = dataclasses.fields(instance)[0]
num_trailing_dims = first_field.metadata.get('num_trailing_dims', None)
value = getattr(instance, first_field.name)
if num_trailing_dims:
return value.shape[:-num_trailing_dims]
else:
return value.shape
def get_len(instance):
"""Returns length for given instance of dataclass."""
shape = instance.shape
if shape:
return shape[0]
else:
raise TypeError('len() of unsized object') # Match jax.numpy behavior.
@property
def get_dtype(instance):
"""Returns Dtype for given instance of dataclass."""
fields = dataclasses.fields(instance)
sets_dtype = [
field.name for field in fields if field.metadata.get('sets_dtype', False)
]
if sets_dtype:
assert len(sets_dtype) == 1, 'at most field can set dtype'
field_value = getattr(instance, sets_dtype[0])
elif instance.same_dtype:
field_value = getattr(instance, fields[0].name)
else:
# Should this be Value Error?
raise AttributeError('Trying to access Dtype on Struct of Array without'
'either "same_dtype" or field setting dtype')
if hasattr(field_value, 'dtype'):
return field_value.dtype
else:
# Should this be Value Error?
raise AttributeError(f'field_value {field_value} does not have dtype')
def replace(instance, **kwargs):
return dataclasses.replace(instance, **kwargs)
def post_init(instance):
"""Validate instance has same shapes & dtypes."""
array_fields = get_array_fields(instance)
arrays = list(get_array_fields(instance, return_values=True).values())
first_field = array_fields[0]
# These slightly weird constructions about checking whether the leaves are
# actual arrays is since e.g. vmap internally relies on being able to
# construct pytree's with object() as leaves, this would break the checking
# as such we are only validating the object when the entries in the dataclass
# Are arrays or other dataclasses of arrays.
try:
dtype = instance.dtype
except AttributeError:
dtype = None
if dtype is not None:
first_shape = instance.shape
for array, field in zip(arrays, array_fields):
field_shape = array.shape
num_trailing_dims = field.metadata.get('num_trailing_dims', None)
if num_trailing_dims:
array_shape = array.shape
field_shape = array_shape[:-num_trailing_dims]
msg = (f'field {field} should have number of trailing dims'
' {num_trailing_dims}')
assert len(array_shape) == len(first_shape) + num_trailing_dims, msg
else:
field_shape = array.shape
shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't "
f"match shape {first_shape} of field {first_field}")
assert field_shape == first_shape, shape_msg
field_dtype = array.dtype
allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', [])
if allowed_metadata_dtypes:
msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}'
assert field_dtype in allowed_metadata_dtypes, msg
if 'dtype' in field.metadata:
target_dtype = field.metadata['dtype']
else:
target_dtype = dtype
msg = f'Dtype is {field_dtype} but must be {target_dtype}'
assert field_dtype == target_dtype, msg
def flatten(instance):
"""Flatten Struct of Array instance."""
array_likes = list(get_array_fields(instance, return_values=True).values())
flat_array_likes = []
inner_treedefs = []
num_arrays = []
for array_like in array_likes:
flat_array_like, inner_treedef = jax.tree_flatten(array_like)
inner_treedefs.append(inner_treedef)
flat_array_likes += flat_array_like
num_arrays.append(len(flat_array_like))
metadata = get_metadata_fields(instance, return_values=True)
metadata = type(instance).metadata_cls(**metadata)
return flat_array_likes, (inner_treedefs, metadata, num_arrays)
def make_metadata_class(cls):
metadata_fields = get_fields(cls,
lambda x: x.metadata.get('is_metadata', False))
metadata_cls = dataclasses.make_dataclass(
cls_name='Meta' + cls.__name__,
fields=[(field.name, field.type, field) for field in metadata_fields],
frozen=True,
eq=True)
return metadata_cls
def get_fields(cls_or_instance, filterfn, return_values=False):
fields = dataclasses.fields(cls_or_instance)
fields = [field for field in fields if filterfn(field)]
if return_values:
return {
field.name: getattr(cls_or_instance, field.name) for field in fields
}
else:
return fields
def get_array_fields(cls, return_values=False):
return get_fields(
cls,
lambda x: not x.metadata.get('is_metadata', False),
return_values=return_values)
def get_metadata_fields(cls, return_values=False):
return get_fields(
cls,
lambda x: x.metadata.get('is_metadata', False),
return_values=return_values)
class StructOfArray:
"""Class Decorator for Struct Of Arrays."""
def __init__(self, same_dtype=True):
self.same_dtype = same_dtype
def __call__(self, cls):
cls.__array_ufunc__ = None
cls.replace = replace
cls.same_dtype = self.same_dtype
cls.dtype = get_dtype
cls.shape = get_shape
cls.__len__ = get_len
cls.__getitem__ = get_item
cls.__post_init__ = post_init
new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args
# pytree claims to require metadata to be hashable, not sure why,
# But making derived dataclass that can just hold metadata
new_cls.metadata_cls = make_metadata_class(new_cls)
def unflatten(aux, data):
inner_treedefs, metadata, num_arrays = aux
array_fields = [field.name for field in get_array_fields(new_cls)]
value_dict = {}
array_start = 0
for num_array, inner_treedef, array_field in zip(num_arrays,
inner_treedefs,
array_fields):
value_dict[array_field] = jax.tree_unflatten(
inner_treedef, data[array_start:array_start + num_array])
array_start += num_array
metadata_fields = get_metadata_fields(new_cls)
for field in metadata_fields:
value_dict[field.name] = getattr(metadata, field.name)
return new_cls(**value_dict)
jax.tree_util.register_pytree_node(
nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten)
return new_cls
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import dataclasses
from alphafold.model.geometry import rigid_matrix_vector
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import vector
import jax.numpy as jnp
import numpy as np
def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
matrix2: rotation_matrix.Rot3Array):
for field in dataclasses.fields(rotation_matrix.Rot3Array):
field = field.name
np.testing.assert_array_equal(
getattr(matrix1, field), getattr(matrix2, field))
def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
mat2: rotation_matrix.Rot3Array):
np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6)
def assert_array_equal_to_rotation_matrix(array: jnp.ndarray,
matrix: rotation_matrix.Rot3Array):
"""Check that array and Matrix match."""
np.testing.assert_array_equal(matrix.xx, array[..., 0, 0])
np.testing.assert_array_equal(matrix.xy, array[..., 0, 1])
np.testing.assert_array_equal(matrix.xz, array[..., 0, 2])
np.testing.assert_array_equal(matrix.yx, array[..., 1, 0])
np.testing.assert_array_equal(matrix.yy, array[..., 1, 1])
np.testing.assert_array_equal(matrix.yz, array[..., 1, 2])
np.testing.assert_array_equal(matrix.zx, array[..., 2, 0])
np.testing.assert_array_equal(matrix.zy, array[..., 2, 1])
np.testing.assert_array_equal(matrix.zz, array[..., 2, 2])
def assert_array_close_to_rotation_matrix(array: jnp.ndarray,
matrix: rotation_matrix.Rot3Array):
np.testing.assert_array_almost_equal(matrix.to_array(), array, 6)
def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_array_equal(vec1.x, vec2.x)
np.testing.assert_array_equal(vec1.y, vec2.y)
np.testing.assert_array_equal(vec1.z, vec2.z)
def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.)
np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.)
np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.)
def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.)
def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
np.testing.assert_array_equal(vec.to_array(), array)
def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_equal(rot, rigid.rotation)
assert_vectors_equal(trans, rigid.translation)
def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_close(rot, rigid.rotation)
assert_vectors_close(trans, rigid.translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
import dataclasses
def get_field_names(cls):
fields = dataclasses.fields(cls)
field_names = [f.name for f in fields]
return field_names
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from __future__ import annotations
import dataclasses
from typing import Union, List
import torch
from openfold.utils.geometry import utils
Float = Union[float, torch.Tensor]
@dataclasses.dataclass(frozen=True)
class Vec3Array:
x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
y: torch.Tensor
z: torch.Tensor
def __post_init__(self):
if hasattr(self.x, 'dtype'):
assert self.x.dtype == self.y.dtype
assert self.x.dtype == self.z.dtype
assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
def __add__(self, other: Vec3Array) -> Vec3Array:
return Vec3Array(
self.x + other.x,
self.y + other.y,
self.z + other.z,
)
def __sub__(self, other: Vec3Array) -> Vec3Array:
return Vec3Array(
self.x - other.x,
self.y - other.y,
self.z - other.z,
)
def __mul__(self, other: Float) -> Vec3Array:
return Vec3Array(
self.x * other,
self.y * other,
self.z * other,
)
def __rmul__(self, other: Float) -> Vec3Array:
return self * other
def __truediv__(self, other: Float) -> Vec3Array:
return Vec3Array(
self.x / other,
self.y / other,
self.z / other,
)
def __neg__(self) -> Vec3Array:
return self * -1
def __pos__(self) -> Vec3Array:
return self * 1
def __getitem__(self, index) -> Vec3Array:
return Vec3Array(
self.x[index],
self.y[index],
self.z[index],
)
def __iter__(self):
return iter((self.x, self.y, self.z))
@property
def shape(self):
return self.x.shape
def map_tensor_fn(self, fn) -> Vec3Array:
return Vec3Array(
fn(self.x),
fn(self.y),
fn(self.z),
)
def cross(self, other: Vec3Array) -> Vec3Array:
"""Compute cross product between 'self' and 'other'."""
new_x = self.y * other.z - self.z * other.y
new_y = self.z * other.x - self.x * other.z
new_z = self.x * other.y - self.y * other.x
return Vec3Array(new_x, new_y, new_z)
def dot(self, other: Vec3Array) -> Float:
"""Compute dot product between 'self' and 'other'."""
return self.x * other.x + self.y * other.y + self.z * other.z
def norm(self, epsilon: float = 1e-6) -> Float:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self)
if epsilon:
norm2 = torch.clamp(norm2, max=epsilon**2)
return torch.sqrt(norm2)
def norm2(self):
return self.dot(self)
def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
"""Return unit vector with optional clipping."""
return self / self.norm(epsilon)
def clone(self) -> Vec3Array:
return Vec3Array(
self.x.clone(),
self.y.clone(),
self.z.clone(),
)
def reshape(self, new_shape) -> Vec3Array:
x = self.x.reshape(new_shape)
y = self.y.reshape(new_shape)
z = self.z.reshape(new_shape)
return Vec3Array(x, y, z)
def sum(self, dim) -> Vec3Array:
return Vec3Array(
torch.sum(self.x, dim=dim),
torch.sum(self.y, dim=dim),
torch.sum(self.z, dim=dim),
)
@classmethod
def zeros(cls, shape, device="cpu"):
"""Return Vec3Array corresponding to zeros of given shape."""
return cls(
torch.zeros(shape, dtype=torch.float32, device=device),
torch.zeros(shape, dtype=torch.float32, device=device),
torch.zeros(shape, dtype=torch.float32, device=device)
)
def to_array(self) -> torch.Tensor:
return torch.stack([self.x, self.y, self.z], dim=-1)
@classmethod
def from_tensor(cls, tensor):
return cls(*torch.unbind(tensor, dim=-1))
@classmethod
def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array:
return cls(
torch.cat([v.x for v in vecs], dim=dim),
torch.cat([v.y for v in vecs], dim=dim),
torch.cat([v.z for v in vecs], dim=dim),
)
def square_euclidean_distance(
vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6
) -> Float:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference = vec1 - vec2
distance = difference.dot(difference)
if epsilon:
distance = torch.maximum(distance, epsilon)
return distance
def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.dot(vector2)
def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.cross(vector2)
def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
return vector.norm(epsilon)
def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
return vector.normalized(epsilon)
def euclidean_distance(
vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6
) -> Float:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
distance = torch.sqrt(distance_sq)
return distance
def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
d: Vec3Array) -> Float:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1 = a - b
v2 = b - c
v3 = d - c
c1 = v1.cross(v2)
c2 = v3.cross(v2)
c3 = c2.cross(c1)
v2_mag = v2.norm()
return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment