"ml/vscode:/vscode.git/clone" did not exist on "ed567ef43b5822423bd165f5f57fb6bad5fce1b3"
Commit b14e47f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/hpcaitech/FastFold

parents 490cb6f5 05681304
Pipeline #234 failed with stages
in 0 seconds
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from fastfold.utils.geometry import rigid_matrix_vector
from fastfold.utils.geometry import rotation_matrix
from fastfold.utils.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
import torch
import torch.nn as nn
from fastfold.model.nn.primitives import Linear
from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from fastfold.utils.geometry.rotation_matrix import Rot3Array
from fastfold.utils.geometry.vector import Vec3Array
class QuatRigid(nn.Module):
def __init__(self, c_hidden, full_quat):
super().__init__()
self.full_quat = full_quat
if self.full_quat:
rigid_dim = 7
else:
rigid_dim = 6
self.linear = Linear(c_hidden, rigid_dim)
def forward(self, activations: torch.Tensor) -> Rigid3Array:
# NOTE: During training, this needs to be run in higher precision
rigid_flat = self.linear(activations.to(torch.float32))
rigid_flat = torch.unbind(rigid_flat, dim=-1)
if(self.full_quat):
qw, qx, qy, qz = rigid_flat[:4]
translation = rigid_flat[4:]
else:
qx, qy, qz = rigid_flat[:3]
qw = torch.ones_like(qx)
translation = rigid_flat[3:]
rotation = Rot3Array.from_quaternion(
qw, qx, qy, qz, normalize=True,
)
translation = Vec3Array(*translation)
return Rigid3Array(rotation, translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from __future__ import annotations
import dataclasses
from typing import Union, List
import torch
from fastfold.utils.geometry import rotation_matrix
from fastfold.utils.geometry import vector
Float = Union[float, torch.Tensor]
@dataclasses.dataclass(frozen=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
new_rotation = self.rotation @ other.rotation # __matmul__
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def __getitem__(self, index) -> Rigid3Array:
return Rigid3Array(
self.rotation[index],
self.translation[index],
)
def __mul__(self, other: torch.Tensor) -> Rigid3Array:
return Rigid3Array(
self.rotation * other,
self.translation * other,
)
def map_tensor_fn(self, fn) -> Rigid3Array:
return Rigid3Array(
self.rotation.map_tensor_fn(fn),
self.translation.map_tensor_fn(fn),
)
def inverse(self) -> Rigid3Array:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply(self, point: torch.Tensor) -> vector.Vec3Array:
return self.apply_to_point(vector.Vec3Array.from_array(point))
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
return Rigid3Array(rot, self.translation.clone())
def compose(self, other_rigid):
return self @ other_rigid
def unsqueeze(self, dim: int):
return Rigid3Array(
self.rotation.unsqueeze(dim),
self.translation.unsqueeze(dim),
)
@property
def shape(self) -> torch.Size:
return self.rotation.xx.shape
@property
def dtype(self) -> torch.dtype:
return self.rotation.xx.dtype
@property
def device(self) -> torch.device:
return self.rotation.xx.device
@classmethod
def identity(cls, shape, device) -> Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, device),
vector.Vec3Array.zeros(shape, device)
)
@classmethod
def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
return cls(
rotation_matrix.Rot3Array.cat(
[r.rotation for r in rigids], dim=dim
),
vector.Vec3Array.cat(
[r.translation for r in rigids], dim=dim
),
)
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_tensor(self) -> torch.Tensor:
rot_array = self.rotation.to_tensor()
vec_array = self.translation.to_tensor()
array = torch.zeros(
rot_array.shape[:-2] + (4, 4),
device=rot_array.device,
dtype=rot_array.dtype
)
array[..., :3, :3] = rot_array
array[..., :3, 3] = vec_array
array[..., 3, 3] = 1.
return array
def to_tensor_4x4(self) -> torch.Tensor:
return self.to_tensor()
def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape)
return Rigid3Aray(rots, trans)
def stop_rot_gradient(self) -> Rigid3Array:
return Rigid3Array(
self.rotation.stop_gradient(),
self.translation,
)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(
array[..., :3, :3],
)
vec = vector.Vec3Array.from_array(array[..., :3, 3])
return cls(rot, vec)
@classmethod
def from_tensor_4x4(cls, array):
return cls.from_array(array)
@classmethod
def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
)
return cls(rotation, translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from __future__ import annotations
import dataclasses
import torch
import numpy as np
from fastfold.utils.geometry import utils
from fastfold.utils.geometry import vector
from fastfold.utils.tensor_utils import tensor_tree_map
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
@dataclasses.dataclass(frozen=True)
class Rot3Array:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
xy: torch.Tensor
xz: torch.Tensor
yx: torch.Tensor
yy: torch.Tensor
yz: torch.Tensor
zx: torch.Tensor
zy: torch.Tensor
zz: torch.Tensor
__array_ufunc__ = None
def __getitem__(self, index):
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: getattr(self, name)[index]
for name in field_names
}
)
def __mul__(self, other: torch.Tensor):
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: getattr(self, name) * other
for name in field_names
}
)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
def map_tensor_fn(self, fn) -> Rot3Array:
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: fn(getattr(self, name))
for name in field_names
}
)
def inverse(self) -> Rot3Array:
"""Returns inverse of Rot3Array."""
return Rot3Array(
self.xx, self.yx, self.zx,
self.xy, self.yy, self.zy,
self.xz, self.yz, self.zz
)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies Rot3Array to point."""
return vector.Vec3Array(
self.xx * point.x + self.xy * point.y + self.xz * point.z,
self.yx * point.x + self.yy * point.y + self.yz * point.z,
self.zx * point.x + self.zy * point.y + self.zz * point.z
)
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point)
def unsqueeze(self, dim: int):
return Rot3Array(
*tensor_tree_map(
lambda t: t.unsqueeze(dim),
[getattr(self, c) for c in COMPONENTS]
)
)
def stop_gradient(self) -> Rot3Array:
return Rot3Array(
*[getattr(self, c).detach() for c in COMPONENTS]
)
@classmethod
def identity(cls, shape, device) -> Rot3Array:
"""Returns identity of given shape."""
ones = torch.ones(shape, dtype=torch.float32, device=device)
zeros = torch.zeros(shape, dtype=torch.float32, device=device)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
@classmethod
def from_two_vectors(
cls, e0: vector.Vec3Array,
e1: vector.Vec3Array
) -> Rot3Array:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0 = e0.normalized()
# make e1 perpendicular to e0.
c = e1.dot(e0)
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
@classmethod
def from_array(cls, array: torch.Tensor) -> Rot3Array:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
rows = torch.unbind(array, dim=-2)
rc = [torch.unbind(e, dim=-1) for e in rows]
return cls(*[e for row in rc for e in row])
def to_tensor(self) -> torch.Tensor:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return torch.stack(
[
torch.stack([self.xx, self.xy, self.xz], dim=-1),
torch.stack([self.yx, self.yy, self.yz], dim=-1),
torch.stack([self.zx, self.zy, self.zz], dim=-1)
],
dim=-2
)
@classmethod
def from_quaternion(cls,
w: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
normalize: bool = True,
eps: float = 1e-6
) -> Rot3Array:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2)
w *= inv_norm
x *= inv_norm
y *= inv_norm
z *= inv_norm
xx = 1 - 2 * (y ** 2 + z ** 2)
xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y)
yx = 2 * (x * y + w * z)
yy = 1 - 2 * (x ** 2 + z ** 2)
yz = 2 * (y * z - w * x)
zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x)
zz = 1 - 2 * (x ** 2 + y ** 2)
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
def reshape(self, new_shape):
field_names = utils.get_field_names(Rot3Array)
reshape_fn = lambda t: t.reshape(new_shape)
return Rot3Array(
**{
name: reshape_fn(getattr(self, name))
for name in field_names
}
)
@classmethod
def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array:
field_names = utils.get_field_names(Rot3Array)
cat_fn = lambda l: torch.cat(l, dim=dim)
return cls(
**{
name: cat_fn([getattr(r, name) for r in rots])
for name in field_names
}
)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import dataclasses
from fastfold.utils.geometry import rigid_matrix_vector
from fastfold.utils.geometry import rotation_matrix
from fastfold.utils.geometry import vector
import numpy as np
def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
matrix2: rotation_matrix.Rot3Array):
for field in dataclasses.fields(rotation_matrix.Rot3Array):
field = field.name
np.testing.assert_array_equal(
getattr(matrix1, field), getattr(matrix2, field))
def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
mat2: rotation_matrix.Rot3Array):
np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6)
def assert_array_equal_to_rotation_matrix(array: np.ndarray,
matrix: rotation_matrix.Rot3Array):
"""Check that array and Matrix match."""
np.testing.assert_array_equal(matrix.xx, array[..., 0, 0])
np.testing.assert_array_equal(matrix.xy, array[..., 0, 1])
np.testing.assert_array_equal(matrix.xz, array[..., 0, 2])
np.testing.assert_array_equal(matrix.yx, array[..., 1, 0])
np.testing.assert_array_equal(matrix.yy, array[..., 1, 1])
np.testing.assert_array_equal(matrix.yz, array[..., 1, 2])
np.testing.assert_array_equal(matrix.zx, array[..., 2, 0])
np.testing.assert_array_equal(matrix.zy, array[..., 2, 1])
np.testing.assert_array_equal(matrix.zz, array[..., 2, 2])
def assert_array_close_to_rotation_matrix(array: np.ndarray,
matrix: rotation_matrix.Rot3Array):
np.testing.assert_array_almost_equal(matrix.to_array(), array, 6)
def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_array_equal(vec1.x, vec2.x)
np.testing.assert_array_equal(vec1.y, vec2.y)
np.testing.assert_array_equal(vec1.z, vec2.z)
def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.)
np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.)
np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.)
def assert_array_close_to_vector(array: np.ndarray, vec: vector.Vec3Array):
np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.)
def assert_array_equal_to_vector(array: np.ndarray, vec: vector.Vec3Array):
np.testing.assert_array_equal(vec.to_array(), array)
def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_equal(rot, rigid.rotation)
assert_vectors_equal(trans, rigid.translation)
def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_close(rot, rigid.rotation)
assert_vectors_close(trans, rigid.translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
import dataclasses
def get_field_names(cls):
fields = dataclasses.fields(cls)
field_names = [f.name for f in fields]
return field_names
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from __future__ import annotations
import dataclasses
from typing import Union, List
import torch
from fastfold.utils.geometry import utils
Float = Union[float, torch.Tensor]
@dataclasses.dataclass(frozen=True)
class Vec3Array:
x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
y: torch.Tensor
z: torch.Tensor
def __post_init__(self):
if hasattr(self.x, 'dtype'):
assert self.x.dtype == self.y.dtype
assert self.x.dtype == self.z.dtype
assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
def __add__(self, other: Vec3Array) -> Vec3Array:
return Vec3Array(
self.x + other.x,
self.y + other.y,
self.z + other.z,
)
def __sub__(self, other: Vec3Array) -> Vec3Array:
return Vec3Array(
self.x - other.x,
self.y - other.y,
self.z - other.z,
)
def __mul__(self, other: Float) -> Vec3Array:
return Vec3Array(
self.x * other,
self.y * other,
self.z * other,
)
def __rmul__(self, other: Float) -> Vec3Array:
return self * other
def __truediv__(self, other: Float) -> Vec3Array:
return Vec3Array(
self.x / other,
self.y / other,
self.z / other,
)
def __neg__(self) -> Vec3Array:
return self * -1
def __pos__(self) -> Vec3Array:
return self * 1
def __getitem__(self, index) -> Vec3Array:
return Vec3Array(
self.x[index],
self.y[index],
self.z[index],
)
def __iter__(self):
return iter((self.x, self.y, self.z))
@property
def shape(self):
return self.x.shape
def map_tensor_fn(self, fn) -> Vec3Array:
return Vec3Array(
fn(self.x),
fn(self.y),
fn(self.z),
)
def cross(self, other: Vec3Array) -> Vec3Array:
"""Compute cross product between 'self' and 'other'."""
new_x = self.y * other.z - self.z * other.y
new_y = self.z * other.x - self.x * other.z
new_z = self.x * other.y - self.y * other.x
return Vec3Array(new_x, new_y, new_z)
def dot(self, other: Vec3Array) -> Float:
"""Compute dot product between 'self' and 'other'."""
return self.x * other.x + self.y * other.y + self.z * other.z
def norm(self, epsilon: float = 1e-6) -> Float:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self)
if epsilon:
norm2 = torch.clamp(norm2, min=epsilon**2)
return torch.sqrt(norm2)
def norm2(self):
return self.dot(self)
def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
"""Return unit vector with optional clipping."""
return self / self.norm(epsilon)
def clone(self) -> Vec3Array:
return Vec3Array(
self.x.clone(),
self.y.clone(),
self.z.clone(),
)
def reshape(self, new_shape) -> Vec3Array:
x = self.x.reshape(new_shape)
y = self.y.reshape(new_shape)
z = self.z.reshape(new_shape)
return Vec3Array(x, y, z)
def sum(self, dim: int) -> Vec3Array:
return Vec3Array(
torch.sum(self.x, dim=dim),
torch.sum(self.y, dim=dim),
torch.sum(self.z, dim=dim),
)
def unsqueeze(self, dim: int):
return Vec3Array(
self.x.unsqueeze(dim),
self.y.unsqueeze(dim),
self.z.unsqueeze(dim),
)
@classmethod
def zeros(cls, shape, device="cpu"):
"""Return Vec3Array corresponding to zeros of given shape."""
return cls(
torch.zeros(shape, dtype=torch.float32, device=device),
torch.zeros(shape, dtype=torch.float32, device=device),
torch.zeros(shape, dtype=torch.float32, device=device)
)
def to_tensor(self) -> torch.Tensor:
return torch.stack([self.x, self.y, self.z], dim=-1)
@classmethod
def from_array(cls, tensor):
return cls(*torch.unbind(tensor, dim=-1))
@classmethod
def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array:
return cls(
torch.cat([v.x for v in vecs], dim=dim),
torch.cat([v.y for v in vecs], dim=dim),
torch.cat([v.z for v in vecs], dim=dim),
)
def square_euclidean_distance(
vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6
) -> Float:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference = vec1 - vec2
distance = difference.dot(difference)
if epsilon:
distance = torch.maximum(distance, epsilon)
return distance
def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.dot(vector2)
def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.cross(vector2)
def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
return vector.norm(epsilon)
def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
return vector.normalized(epsilon)
def euclidean_distance(
vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6
) -> Float:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
distance = torch.sqrt(distance_sq)
return distance
def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
d: Vec3Array) -> Float:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1 = a - b
v2 = b - c
v3 = d - c
c1 = v1.cross(v2)
c2 = v3.cross(v2)
c3 = c2.cross(c1)
v2_mag = v2.norm()
return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2))
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from dataclasses import dataclass
from functools import partial
import numpy as np
import torch
from typing import Union, List
from fastfold.model.nn.triangular_multiplicative_update import is_fused_triangle_multiplication
_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class ParamType(Enum):
LinearWeight = partial( # hack: partial prevents fns from becoming methods
lambda w: w.transpose(-1, -2)
)
LinearWeightMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
)
LinearMHAOutputWeight = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearWeightMultimer = partial(
lambda w: w.unsqueeze(-1)
if len(w.shape) == 1
else w.reshape(w.shape[0], -1).transpose(-1, -2)
)
LinearBiasMultimer = partial(lambda w: w.reshape(-1))
Other = partial(lambda w: w)
def __init__(self, fn):
self.transformation = fn
@dataclass
class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other
stacked: bool = False
def _process_translations_dict(d, top_layer=True):
flat = {}
for k, v in d.items():
if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = {
(prefix + "/".join([k, k_prime])): v_prime
for k_prime, v_prime in _process_translations_dict(
v, top_layer=False
).items()
}
flat.update(sub_flat)
else:
k = "/" + k if not top_layer else k
flat[k] = v
return flat
def stacked(param_dict_list, out=None):
"""
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
"""
if out is None:
out = {}
template = param_dict_list[0]
for k, _ in template.items():
v = [d[k] for d in param_dict_list]
if type(v[0]) is dict:
out[k] = {}
stacked(v, out=out[k])
elif type(v[0]) is Param:
stacked_param = Param(
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True,
)
out[k] = stacked_param
return out
def assign(translation_dict, orig_weights):
for k, param in translation_dict.items():
with torch.no_grad():
weights = torch.as_tensor(orig_weights[k])
ref, param_type = param.param, param.param_type
if param.stacked:
weights = torch.unbind(weights, 0)
else:
weights = [weights]
ref = [ref]
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
p.copy_(w)
except:
print(k)
print(ref[0].shape)
print(weights[0].shape)
raise
def get_translation_dict(model, version):
is_multimer = "multimer" in version
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearWeightMultimer = lambda l: (
Param(l, param_type=ParamType.LinearWeightMultimer)
)
LinearBiasMultimer = lambda l: (Param(l, param_type=ParamType.LinearBiasMultimer))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
}
LayerNormParams = lambda l: {
"scale": Param(l.weight),
"offset": Param(l.bias),
}
AttentionParams = lambda att: {
"query_w": LinearWeightMHA(att.linear_q.weight),
"key_w": LinearWeightMHA(att.linear_k.weight),
"value_w": LinearWeightMHA(att.linear_v.weight),
"output_w": Param(
att.linear_o.weight,
param_type=ParamType.LinearMHAOutputWeight,
),
"output_b": LinearBias(att.linear_o.bias),
}
AttentionGatedParams = lambda att: dict(
**AttentionParams(att),
**{
"gating_w": LinearWeightMHA(att.linear_g.weight),
"gating_b": LinearBiasMHA(att.linear_g.bias),
},
)
GlobalAttentionParams = lambda att: dict(
AttentionGatedParams(att),
key_w=LinearWeight(att.linear_k.weight),
value_w=LinearWeight(att.linear_v.weight),
)
TriAttParams = lambda tri_att: {
"query_norm": LayerNormParams(tri_att.layer_norm),
"feat_2d_weights": LinearWeight(tri_att.linear.weight),
"attention": AttentionGatedParams(tri_att.mha),
}
if is_fused_triangle_multiplication():
TriMulOutParams = lambda tri_mul: {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_p),
"gate": LinearParams(tri_mul.linear_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_gate),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_p),
"gate": LinearParams(tri_mul.linear_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_gate),
}
else:
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_a_p),
"right_projection": LinearParams(tri_mul.linear_b_p),
"left_gate": LinearParams(tri_mul.linear_a_g),
"right_gate": LinearParams(tri_mul.linear_b_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_b_p),
"right_projection": LinearParams(tri_mul.linear_a_p),
"left_gate": LinearParams(tri_mul.linear_b_g),
"right_gate": LinearParams(tri_mul.linear_a_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm),
"transition1": LinearParams(pt.linear_1),
"transition2": LinearParams(pt.linear_2),
}
MSAAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": AttentionGatedParams(matt.mha),
}
MSAColAttParams = lambda matt: {
"query_norm": LayerNormParams(matt._msa_att.layer_norm_m),
"attention": AttentionGatedParams(matt._msa_att.mha),
}
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt.global_attention),
}
MSAAttPairBiasParams = lambda matt: dict(
**MSAAttParams(matt),
**{
"feat_2d_norm": LayerNormParams(matt.layer_norm_z),
"feat_2d_weights": LinearWeight(matt.linear_z.weight),
},
)
IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
# New style IPA param
# "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points),
# New style IPA param
# "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_v.weight,
),
},
"q_point_projection": PointProjectionParams(ipa.linear_q_points),
"k_point_projection": PointProjectionParams(ipa.linear_k_points),
"v_point_projection": PointProjectionParams(ipa.linear_v_points),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
TemplatePairBlockParams = lambda b: {
"triangle_attention_starting_node": TriAttParams(b.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end),
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
"pair_transition": PairTransitionParams(b.pair_transition),
}
MSATransitionParams = lambda m: {
"input_layer_norm": LayerNormParams(m.layer_norm),
"transition1": LinearParams(m.linear_1),
"transition2": LinearParams(m.linear_2),
}
OuterProductMeanParams = lambda o: {
"layer_norm_input": LayerNormParams(o.layer_norm),
"left_projection": LinearParams(o.linear_1),
"right_projection": LinearParams(o.linear_2),
"output_w": LinearWeightOPM(o.linear_out.weight),
"output_b": LinearBias(o.linear_out.bias),
}
def EvoformerBlockParams(b, is_extra_msa=False):
if is_extra_msa:
col_att_name = "msa_column_global_attention"
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
else:
col_att_name = "msa_column_attention"
msa_col_att_params = MSAColAttParams(b.msa_att_col)
d = {
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(b.msa_att_row),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_outgoing": TriMulOutParams(b.core.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node": TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.core.tri_att_end),
"pair_transition": PairTransitionParams(b.core.pair_transition),
}
return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
def FoldIterationParams(sm):
d = {
"invariant_point_attention": IPAParamsMultimer(sm.ipa)
if is_multimer
else IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2),
"transition_2": LinearParams(sm.transition.layers[0].linear_3),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"affine_update": LinearParams(sm.bb_update.linear),
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
},
}
if is_multimer:
d.pop("affine_update")
d["quat_rigid"] = {"rigid": LinearParams(sm.bb_update.linear)}
return d
############################
# translations dict overflow
############################
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks_params = stacked([TemplatePairBlockParams(b) for b in tps_blocks])
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
if not is_multimer:
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"pair_activiations": LinearParams(model.input_embedder.linear_relpos),
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_embedder.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(
model.template_embedder.template_pointwise_att.mha
),
},
"extra_msa_activations": LinearParams(model.extra_msa_embedder.linear),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams(
model.template_embedder.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_embedder.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(model.structure_module.linear_in),
"pair_layer_norm": LayerNormParams(model.structure_module.layer_norm_z),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(model.aux_heads.plddt.layer_norm),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(model.aux_heads.experimentally_resolved.linear),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
else:
temp_embedder = model.template_embedder
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"~_relative_encoding": {
"position_activations": LinearParams(
model.input_embedder.linear_relpos
),
},
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParamsMultimer(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(temp_embedder.linear_t),
},
"template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
"extra_msa_activations": LinearParams(model.extra_msa_embedder.linear),
"extra_msa_stack": ems_blocks_params,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(model.structure_module.linear_in),
"pair_layer_norm": LayerNormParams(model.structure_module.layer_norm_z),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(model.aux_heads.plddt.layer_norm),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(model.aux_heads.experimentally_resolved.linear),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
no_templ = [
"model_3",
"model_4",
"model_5",
"model_3_ptm",
"model_4_ptm",
"model_5_ptm",
]
if version in no_templ:
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
if "template_" in k:
evo_dict.pop(k)
if "_ptm" in version or is_multimer:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear)
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, version)
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
# Sanity check
keys = list(data.keys())
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
# print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}")
assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
assign(flat, data)
if is_fused_triangle_multiplication():
# (NOTE) in multimer v3, alphafold use fused tri, so need change left/right here
for b in model.template_embedder.template_pair_stack.blocks:
_change_tri_mul_in_left_right(b.tri_mul_in)
for b in model.extra_msa_stack.blocks:
_change_tri_mul_in_left_right(b.core.tri_mul_in)
for b in model.evoformer.blocks:
_change_tri_mul_in_left_right(b.core.tri_mul_in)
def _change_tri_mul_in_left_right(module):
def _change_para(para):
left_right_para = para.clone().chunk(2, dim=0)
return torch.cat((left_right_para[1], left_right_para[0]), dim=0)
with torch.no_grad():
module.linear_p.weight.copy_(_change_para(module.linear_p.weight))
module.linear_p.bias.copy_(_change_para(module.linear_p.bias))
module.linear_g.weight.copy_(_change_para(module.linear_g.weight))
module.linear_g.bias.copy_(_change_para(module.linear_g.bias))
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
from fastfold.model.fastnn.embedders import TemplateEmbedder
from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer
from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder
from fastfold.model.nn.triangular_multiplicative_update import is_fused_triangle_multiplication
def copy_layernorm(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
model_fast.bias.copy_(model_ori.bias)
def copy_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
if model_fast.use_bias:
model_fast.bias.copy_(model_ori.bias)
def copy_native_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
try:
model_fast.bias.copy_(model_ori.bias)
except:
pass
def copy_kv_linear(model_fast, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_k.weight, ori_v.weight), dim=0))
def copy_qkv_linear(model_fast, ori_q, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_q.weight, ori_k.weight, ori_v.weight), dim=0))
def copy_attention(model_fast, model_ori):
copy_qkv_linear(model_fast.to_qkv, model_ori.linear_q, model_ori.linear_k, model_ori.linear_v)
copy_linear(model_fast.gating_linear, model_ori.linear_g)
copy_linear(model_fast.o_linear, model_ori.linear_o)
try:
model_fast.gating_bias.copy_(model_ori.linear_g.bias)
except:
print("no gating_bias need copy")
def copy_left_right(model_fast, ori_left, ori_right):
model_fast.weight.copy_(torch.cat((ori_left.weight, ori_right.weight), dim=0))
model_fast.bias.copy_(torch.cat((ori_left.bias, ori_right.bias), dim=0))
def copy_transition(model_fast, model_ori):
copy_layernorm(model_fast.norm, model_ori.layer_norm)
copy_linear(model_fast.linear1, model_ori.linear_1)
copy_linear(model_fast.linear2, model_ori.linear_2)
def copy_triangle(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm_in)
copy_layernorm(model_fast.layernorm2, model_ori.layer_norm_out)
copy_linear(model_fast.output_projection, model_ori.linear_z)
model_fast.output_bias.copy_(model_ori.linear_z.bias)
if is_fused_triangle_multiplication():
copy_linear(model_fast.output_gate, model_ori.linear_gate)
copy_linear(model_fast.left_right_projection, model_ori.linear_p)
copy_linear(model_fast.left_right_gate, model_ori.linear_g)
else:
copy_linear(model_fast.output_gate, model_ori.linear_g)
copy_left_right(model_fast.left_right_projection, model_ori.linear_a_p, model_ori.linear_b_p)
copy_left_right(model_fast.left_right_gate, model_ori.linear_a_g, model_ori.linear_b_g)
def copy_triangle_att(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm)
copy_linear(model_fast.linear_b, model_ori.linear)
copy_attention(model_fast.attention, model_ori.mha)
model_fast.out_bias.copy_(model_ori.mha.linear_o.bias)
def copy_native_att(model_fast, model_ori):
copy_native_linear(model_fast.linear_q, model_ori.linear_q)
copy_native_linear(model_fast.linear_k, model_ori.linear_k)
copy_native_linear(model_fast.linear_v, model_ori.linear_v)
copy_native_linear(model_fast.linear_o, model_ori.linear_o)
if model_ori.gating:
copy_native_linear(model_fast.linear_g, model_ori.linear_g)
def copy_evoformer_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m)
copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z)
copy_attention(block_fast.msa.MSARowAttentionWithPairBias.attention,
block_ori.msa_att_row.mha)
block_fast.msa.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight)
block_fast.msa.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias)
# MSAColumnAttention
copy_layernorm(block_fast.msa.MSAColumnAttention.layernormM,
block_ori.msa_att_col._msa_att.layer_norm_m)
copy_attention(block_fast.msa.MSAColumnAttention.attention,
block_ori.msa_att_col._msa_att.mha)
# MSATransition
copy_transition(block_fast.msa.MSATransition, block_ori.core.msa_transition)
# communication
copy_layernorm(block_fast.communication.layernormM,
block_ori.core.outer_product_mean.layer_norm)
copy_linear(block_fast.communication.linear_a, block_ori.core.outer_product_mean.linear_1)
copy_linear(block_fast.communication.linear_b, block_ori.core.outer_product_mean.linear_2)
copy_linear(block_fast.communication.o_linear, block_ori.core.outer_product_mean.linear_out)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.pair.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.pair.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(block_fast.pair.TriangleAttentionStartingNode,
block_ori.core.tri_att_start)
copy_triangle_att(block_fast.pair.TriangleAttentionEndingNode, block_ori.core.tri_att_end)
copy_transition(block_fast.pair.PairTransition, block_ori.core.pair_transition)
def copy_global_attention(model_fast, model_ori):
copy_linear(model_fast.to_q, model_ori.linear_q)
copy_kv_linear(model_fast.to_kv, model_ori.linear_k, model_ori.linear_v)
copy_linear(model_fast.gating_linear, model_ori.linear_g)
copy_linear(model_fast.o_linear, model_ori.linear_o)
try:
model_fast.gating_bias.copy_(model_ori.linear_g.bias)
except:
print("no gating_bias need copy")
def copy_extra_msa_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(
block_fast.msa_stack.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m,
)
copy_layernorm(
block_fast.msa_stack.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z,
)
copy_attention(
block_fast.msa_stack.MSARowAttentionWithPairBias.attention,
block_ori.msa_att_row.mha,
)
block_fast.msa_stack.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight
)
block_fast.msa_stack.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias
)
# MSAColumnAttention
copy_layernorm(
block_fast.msa_stack.MSAColumnAttention.layernormM,
block_ori.msa_att_col.layer_norm_m,
)
copy_global_attention(
block_fast.msa_stack.MSAColumnAttention.global_attention,
block_ori.msa_att_col.global_attention,
)
# MSATransition
copy_transition(block_fast.msa_stack.MSATransition, block_ori.core.msa_transition)
# communication
comm_model = (
block_ori.core.outer_product_mean# if not block_ori.is_multimer else block_ori.outer_product_mean
)
copy_layernorm(block_fast.communication.layernormM, comm_model.layer_norm)
copy_linear(block_fast.communication.linear_a, comm_model.linear_1)
copy_linear(block_fast.communication.linear_b, comm_model.linear_2)
copy_linear(block_fast.communication.o_linear, comm_model.linear_out)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle(
block_fast.pair_stack.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle(
block_fast.pair_stack.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att(
block_fast.pair_stack.TriangleAttentionStartingNode,
block_ori.core.tri_att_start,
)
copy_triangle_att(
block_fast.pair_stack.TriangleAttentionEndingNode, block_ori.core.tri_att_end
)
copy_transition(
block_fast.pair_stack.PairTransition, block_ori.core.pair_transition
)
def copy_template_pair_stack_para(block_fast, block_ori):
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.TriangleMultiplicationOutgoing, block_ori.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.TriangleMultiplicationIncoming, block_ori.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(block_fast.TriangleAttentionStartingNode, block_ori.tri_att_start)
copy_triangle_att(block_fast.TriangleAttentionEndingNode, block_ori.tri_att_end)
copy_transition(block_fast.PairTransition, block_ori.pair_transition)
def copy_template_pair_block_para(fast_module, target_module):
with torch.no_grad():
for ori_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_template_pair_stack_para(fast_block, ori_block)
if ori_block.training == False:
fast_block.eval()
def copy_template_para(block_fast, block_ori):
# TemplateAngleEmbedder
copy_linear(block_fast.template_angle_embedder.linear_1,
block_ori.template_angle_embedder.linear_1)
copy_linear(block_fast.template_angle_embedder.linear_2,
block_ori.template_angle_embedder.linear_2)
# TemplatePairEmbedder
copy_linear(block_fast.template_pair_embedder.linear,
block_ori.template_pair_embedder.linear)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack,
block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# TemplatePointwiseAttention
copy_native_att(block_fast.template_pointwise_att.mha,
block_ori.template_pointwise_att.mha)
def copy_template_multimer_para(block_fast, block_ori):
# TemplatePairEmbedderMultimer
copy_linear(block_fast.template_pair_embedder.dgram_linear,
block_ori.template_pair_embedder.dgram_linear)
copy_linear(block_fast.template_pair_embedder.aatype_linear_1,
block_ori.template_pair_embedder.aatype_linear_1)
copy_linear(block_fast.template_pair_embedder.aatype_linear_2,
block_ori.template_pair_embedder.aatype_linear_2)
copy_layernorm(block_fast.template_pair_embedder.query_embedding_layer_norm,
block_ori.template_pair_embedder.query_embedding_layer_norm)
copy_linear(block_fast.template_pair_embedder.query_embedding_linear,
block_ori.template_pair_embedder.query_embedding_linear)
copy_linear(block_fast.template_pair_embedder.pseudo_beta_mask_linear,
block_ori.template_pair_embedder.pseudo_beta_mask_linear)
copy_linear(block_fast.template_pair_embedder.x_linear,
block_ori.template_pair_embedder.x_linear)
copy_linear(block_fast.template_pair_embedder.y_linear,
block_ori.template_pair_embedder.y_linear)
copy_linear(block_fast.template_pair_embedder.z_linear,
block_ori.template_pair_embedder.z_linear)
copy_linear(block_fast.template_pair_embedder.backbone_mask_linear,
block_ori.template_pair_embedder.backbone_mask_linear)
# TemplateSingleEmbedderMultimer
copy_linear(block_fast.template_single_embedder.template_single_embedder,
block_ori.template_single_embedder.template_single_embedder)
copy_linear(block_fast.template_single_embedder.template_projector,
block_ori.template_single_embedder.template_projector)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack,
block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# linear_t
copy_linear(block_fast.linear_t, block_ori.linear_t)
def inject_evoformer(model):
with torch.no_grad():
target_module = model.evoformer
fast_module = EvoformerStack(
c_m=target_module.blocks[0].msa_att_row.c_in,
c_z=target_module.blocks[0].msa_att_row.c_z,
c_s=target_module.linear.out_features,
no_blocks=len(target_module.blocks),
blocks_per_ckpt=target_module.blocks_per_ckpt,
clear_cache_between_blocks=target_module.clear_cache_between_blocks,
is_multimer=target_module.blocks[0].is_multimer,
)
for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_evoformer_para(fast_block, target_block)
if target_block.training == False:
fast_block.eval()
copy_linear(fast_module.linear, target_module.linear)
model.evoformer = fast_module
def inject_extramsa(model):
with torch.no_grad():
target_module = model.extra_msa_stack
fast_module = ExtraMSAStack(
c_m=target_module.blocks[0].msa_att_row.c_in,
c_z=target_module.blocks[0].msa_att_row.c_z,
no_blocks=len(target_module.blocks),
clear_cache_between_blocks=target_module.clear_cache_between_blocks,
is_multimer=target_module.blocks[0].is_multimer,
)
for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_extra_msa_para(fast_block, target_block)
if target_block.training == False:
fast_block.eval()
model.extra_msa_stack = fast_module
def inject_template(model):
with torch.no_grad():
if model.evoformer.blocks[0].is_multimer:
target_module = model.template_embedder
fast_module = TemplateEmbedderMultimer(config=model.template_embedder.config)
copy_template_multimer_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
else:
target_module = model.template_embedder
fast_module = TemplateEmbedder(config=model.template_embedder.config)
copy_template_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
def inject_embedder(model):
if model.evoformer.blocks[0].is_multimer:
return
# recycle embedder
with torch.no_grad():
target_module = model.recycling_embedder
fast_module = RecyclingEmbedder(
c_m=target_module.c_m,
c_z=target_module.c_z,
min_bin=target_module.min_bin,
max_bin=target_module.max_bin,
no_bins=target_module.no_bins,
inf=target_module.inf
)
copy_native_linear(fast_module.linear, target_module.linear)
copy_layernorm(fast_module.layer_norm_m, target_module.layer_norm_m)
copy_layernorm(fast_module.layer_norm_z, target_module.layer_norm_z)
if target_module.training == False:
fast_module.eval()
model.recycling_embedder = fast_module
# input embedder
with torch.no_grad():
target_module = model.input_embedder
fast_module = InputEmbedder(
tf_dim=target_module.tf_dim,
msa_dim=target_module.msa_dim,
c_z=target_module.c_z,
c_m=target_module.c_m,
relpos_k=target_module.relpos_k,
)
copy_linear(fast_module.linear_tf_z_i, target_module.linear_tf_z_i)
copy_linear(fast_module.linear_tf_z_j, target_module.linear_tf_z_j)
copy_linear(fast_module.linear_tf_m, target_module.linear_tf_m)
copy_linear(fast_module.linear_msa_m, target_module.linear_msa_m)
copy_linear(fast_module.linear_relpos, target_module.linear_relpos)
if target_module.training == False:
fast_module.eval()
model.input_embedder = fast_module
def inject_fastnn(model):
inject_evoformer(model)
inject_extramsa(model)
inject_template(model)
inject_embedder(model)
return model
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np
import torch
import fastfold.habana as habana
def rot_matmul(
a: torch.Tensor,
b: torch.Tensor
) -> torch.Tensor:
"""
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid AMP downcasting.
Args:
a: [*, 3, 3] left multiplicand
b: [*, 3, 3] right multiplicand
Returns:
The product ab
"""
if habana.is_habana():
if len(a.shape) == 4 and a.shape[1] == 1:
aa = a.permute(0, 1, 3, 2)
bb = b.permute(0, 1, 3, 2)
cc = bb @ aa
cc = cc.permute(0, 1, 3, 2)
return cc
elif len(a.shape) == 4 and a.shape[1] != 1:
pass
else:
cc = a @ b
return cc
row_1 = torch.stack(
[
a[..., 0, 0] * b[..., 0, 0]
+ a[..., 0, 1] * b[..., 1, 0]
+ a[..., 0, 2] * b[..., 2, 0],
a[..., 0, 0] * b[..., 0, 1]
+ a[..., 0, 1] * b[..., 1, 1]
+ a[..., 0, 2] * b[..., 2, 1],
a[..., 0, 0] * b[..., 0, 2]
+ a[..., 0, 1] * b[..., 1, 2]
+ a[..., 0, 2] * b[..., 2, 2],
],
dim=-1,
)
row_2 = torch.stack(
[
a[..., 1, 0] * b[..., 0, 0]
+ a[..., 1, 1] * b[..., 1, 0]
+ a[..., 1, 2] * b[..., 2, 0],
a[..., 1, 0] * b[..., 0, 1]
+ a[..., 1, 1] * b[..., 1, 1]
+ a[..., 1, 2] * b[..., 2, 1],
a[..., 1, 0] * b[..., 0, 2]
+ a[..., 1, 1] * b[..., 1, 2]
+ a[..., 1, 2] * b[..., 2, 2],
],
dim=-1,
)
row_3 = torch.stack(
[
a[..., 2, 0] * b[..., 0, 0]
+ a[..., 2, 1] * b[..., 1, 0]
+ a[..., 2, 2] * b[..., 2, 0],
a[..., 2, 0] * b[..., 0, 1]
+ a[..., 2, 1] * b[..., 1, 1]
+ a[..., 2, 2] * b[..., 2, 1],
a[..., 2, 0] * b[..., 0, 2]
+ a[..., 2, 1] * b[..., 1, 2]
+ a[..., 2, 2] * b[..., 2, 2],
],
dim=-1,
)
return torch.stack([row_1, row_2, row_3], dim=-2)
def rot_vec_mul(
r: torch.Tensor,
t: torch.Tensor
) -> torch.Tensor:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
to avoid AMP downcasting.
Args:
r: [*, 3, 3] rotation matrices
t: [*, 3] coordinate tensors
Returns:
[*, 3] rotated coordinates
"""
if habana.is_habana():
cont = True
if len(t.shape) == 4 and t.shape[1] == 1:
cont = False
elif len(t.shape) == 3 and t.shape[0] != r.shape[0] and t.shape[0] == 1:
cont = False
if cont:
tt = t.unsqueeze(-2)
rr = r.transpose(-2, -1)
cc = tt @ rr
cc = cc.squeeze(-2)
return cc
x = t[..., 0]
y = t[..., 1]
z = t[..., 2]
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
],
dim=-1,
)
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
return rots
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
return trans
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
quat = torch.zeros(
(*batch_dims, 4),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
with torch.no_grad():
quat[..., 0] = 1
return quat
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_QTR_MAT = np.zeros((4, 4, 3, 3))
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def rot_to_quat(
rot: torch.Tensor,
):
if(rot.shape[-2:] != (3, 3)):
raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
k = [
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
]
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
_, vectors = torch.linalg.eigh(k)
return vectors[..., -1]
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
[ 0,-1, 0, 0],
[ 0, 0,-1, 0],
[ 0, 0, 0,-1]]
_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
[ 1, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0,-1, 0]]
_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
[ 0, 0, 0,-1],
[ 1, 0, 0, 0],
[ 0, 1, 0, 0]]
_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
[ 0, 0, 1, 0],
[ 0,-1, 0, 0],
[ 1, 0, 0, 0]]
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat1[..., :, None, None] *
quat2[..., None, :, None],
dim=(-3, -2)
)
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat[..., :, None, None] *
vec[..., None, :, None],
dim=(-3, -2)
)
def invert_rot_mat(rot_mat: torch.Tensor):
return rot_mat.transpose(-1, -2)
def invert_quat(quat: torch.Tensor):
quat_prime = quat.clone()
quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
return inv
class Rotation:
"""
A 3D rotation. Depending on how the object is initialized, the
rotation is represented by either a rotation matrix or a
quaternion, though both formats are made available by helper functions.
To simplify gradient computation, the underlying format of the
rotation cannot be changed in-place. Like Rigid, the class is designed
to mimic the behavior of a torch Tensor, almost as if each Rotation
object were a tensor of rotations, in one format or another.
"""
def __init__(self,
rot_mats: Optional[torch.Tensor] = None,
quats: Optional[torch.Tensor] = None,
normalize_quats: bool = True,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
normalize_quats is not True, must be a unit quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if((rot_mats is None and quats is None) or
(rot_mats is not None and quats is not None)):
raise ValueError("Exactly one input argument must be specified")
if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
(quats is not None and quats.shape[-1] != 4)):
raise ValueError(
"Incorrectly shaped rotation matrix or quaternion"
)
# Force full-precision
if(quats is not None):
quats = quats.to(dtype=torch.float32)
if(rot_mats is not None):
rot_mats = rot_mats.to(dtype=torch.float32)
if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
self._rot_mats = rot_mats
self._quats = quats
@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rotation:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation
for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object
should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format
of the new object's rotation
Returns:
A new identity rotation
"""
if(fmt == "rot_mat"):
rot_mats = identity_rot_mats(
shape, dtype, device, requires_grad,
)
return Rotation(rot_mats=rot_mats, quats=None)
elif(fmt == "quat"):
quats = identity_quats(shape, dtype, device, requires_grad)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError(f"Invalid format: f{fmt}")
# Magic methods
def __getitem__(self, index: Any) -> Rotation:
"""
Allows torch-style indexing over the virtual shape of the rotation
object. See documentation for the shape property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if type(index) != tuple:
index = (index,)
if(self._rot_mats is not None):
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
return Rotation(rot_mats=rot_mats)
elif(self._quats is not None):
quats = self._quats[index + (slice(None),)]
return Rotation(quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __mul__(self,
right: torch.Tensor,
) -> Rotation:
"""
Pointwise left multiplication of the rotation with a tensor. Can be
used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not(isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
if(self._rot_mats is not None):
rot_mats = self._rot_mats * right[..., None, None]
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats * right[..., None]
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __rmul__(self,
left: torch.Tensor,
) -> Rotation:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
# Properties
@property
def shape(self) -> torch.Size:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s = None
if(self._quats is not None):
s = self._quats.shape[:-1]
else:
s = self._rot_mats.shape[:-2]
return s
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.dtype
elif(self._quats is not None):
return self._quats.dtype
else:
raise ValueError("Both rotations are None")
@property
def device(self) -> torch.device:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.device
elif(self._quats is not None):
return self._quats.device
else:
raise ValueError("Both rotations are None")
@property
def requires_grad(self) -> bool:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if(self._rot_mats is not None):
return self._rot_mats.requires_grad
elif(self._quats is not None):
return self._quats.requires_grad
else:
raise ValueError("Both rotations are None")
def get_rot_mats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats = self._rot_mats
if(rot_mats is None):
if(self._quats is None):
raise ValueError("Both rotations are None")
else:
rot_mats = quat_to_rot(self._quats)
return rot_mats
def get_quats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a
quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats = self._quats
if(quats is None):
if(self._rot_mats is None):
raise ValueError("Both rotations are None")
else:
quats = rot_to_quat(self._rot_mats)
return quats
def get_cur_rot(self) -> torch.Tensor:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if(self._rot_mats is not None):
return self._rot_mats
elif(self._quats is not None):
return self._quats
else:
raise ValueError("Both rotations are None")
# Rotation functions
def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
normalize_quats: bool = True
) -> Rotation:
"""
Returns a new quaternion Rotation after updating the current
object's underlying rotation with a quaternion update, formatted
as a [*, 3] tensor whose final three columns represent x, y, z such
that (1, x, y, z) is the desired (not necessarily unit) quaternion
update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats = self.get_quats()
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
return Rotation(
rot_mats=None,
quats=new_quats,
normalize_quats=normalize_quats,
)
def compose_r(self, r: Rotation) -> Rotation:
"""
Compose the rotation matrices of the current Rotation object with
those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1 = self.get_rot_mats()
r2 = r.get_rot_mats()
new_rot_mats = rot_matmul(r1, r2)
return Rotation(rot_mats=new_rot_mats, quats=None)
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
"""
Compose the quaternions of the current Rotation object with those
of another.
Depending on whether either Rotation was initialized with
quaternions, this function may call torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1 = self.get_quats()
q2 = r.get_quats()
new_quats = quat_multiply(q1, q2)
return Rotation(
rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Apply the current Rotation as a rotation matrix to a set of 3D
coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats = self.get_rot_mats()
return rot_vec_mul(rot_mats, pts)
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats = self.get_rot_mats()
inv_rot_mats = invert_rot_mat(rot_mats)
return rot_vec_mul(inv_rot_mats, pts)
def invert(self) -> Rotation:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=invert_rot_mat(self._rot_mats),
quats=None
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=invert_quat(self._quats),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
# "Tensor" stuff
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
if(self._rot_mats is not None):
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
@staticmethod
def cat(
rs: Sequence[Rotation],
dim: int,
) -> Rigid:
"""
Concatenates rotations along one of the batch dimensions. Analogous
to torch.cat().
Note that the output of this operation is always a rotation matrix,
regardless of the format of input rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be
concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats = [r.get_rot_mats() for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self,
fn: Callable[torch.Tensor, torch.Tensor]
) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
mapping over the rotation dimension(s). Can be used e.g. to sum out
a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if(self._rot_mats is not None):
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
rot_mats = torch.stack(
list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
)
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = torch.stack(
list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def cuda(self) -> Rotation:
"""
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.cuda(),
normalize_quats=False
)
else:
raise ValueError("Both rotations are None")
def to(self,
device: Optional[torch.device],
dtype: Optional[torch.dtype]
) -> Rotation:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
quats=None,
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.to(device=device, dtype=dtype),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
def detach(self) -> Rotation:
"""
Returns a copy of the Rotation whose underlying Tensor has been
detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached
from its torch graph
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.detach(),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
class Rigid:
"""
A class representing a rigid transformation. Little more than a wrapper
around two objects: a Rotation object and a [*, 3] translation
Designed to behave approximately like a single torch tensor with the
shape of the shared batch dimensions of its component parts.
"""
def __init__(self,
rots: Optional[Rotation],
trans: Optional[torch.Tensor],
):
"""
Args:
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
"""
# (we need device, dtype, etc. from at least one input)
batch_dims, dtype, device, requires_grad = None, None, None, None
if(trans is not None):
batch_dims = trans.shape[:-1]
dtype = trans.dtype
device = trans.device
requires_grad = trans.requires_grad
elif(rots is not None):
batch_dims = rots.shape
dtype = rots.dtype
device = rots.device
requires_grad = rots.requires_grad
else:
raise ValueError("At least one input argument must be specified")
if(rots is None):
rots = Rotation.identity(
batch_dims, dtype, device, requires_grad,
)
elif(trans is None):
trans = identity_trans(
batch_dims, dtype, device, requires_grad,
)
if((rots.shape != trans.shape[:-1]) or
(rots.device != trans.device)):
raise ValueError("Rots and trans incompatible")
# Force full precision. Happens to the rotations automatically.
trans = trans.to(dtype=torch.float32)
self._rots = rots
self._trans = trans
@staticmethod
def identity(
shape: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rigid:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return Rigid(
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
identity_trans(shape, dtype, device, requires_grad),
)
def __getitem__(self,
index: Any,
) -> Rigid:
"""
Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation
and the translation.
E.g.::
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6]
assert(indexed.shape == (2,))
assert(indexed.get_rots().shape == (2,))
assert(indexed.get_trans().shape == (2, 3))
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
or (3, slice(0, 1, None))
Returns:
The indexed tensor
"""
if type(index) != tuple:
index = (index,)
return Rigid(
self._rots[index],
self._trans[index + (slice(None),)],
)
def __mul__(self,
right: torch.Tensor,
) -> Rigid:
"""
Pointwise left multiplication of the transformation with a tensor.
Can be used to e.g. mask the Rigid.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not(isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
new_rots = self._rots * right
new_trans = self._trans * right[..., None]
return Rigid(new_rots, new_trans)
def __rmul__(self,
left: torch.Tensor,
) -> Rigid:
"""
Reverse pointwise multiplication of the transformation with a
tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
@property
def shape(self) -> torch.Size:
"""
Returns the shape of the shared dimensions of the rotation and
the translation.
Returns:
The shape of the transformation
"""
s = self._trans.shape[:-1]
return s
@property
def device(self) -> torch.device:
"""
Returns the device on which the Rigid's tensors are located.
Returns:
The device on which the Rigid's tensors are located
"""
return self._trans.device
def get_rots(self) -> Rotation:
"""
Getter for the rotation.
Returns:
The rotation object
"""
return self._rots
def get_trans(self) -> torch.Tensor:
"""
Getter for the translation.
Returns:
The stored translation
"""
return self._trans
def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
) -> Rigid:
"""
Composes the transformation with a quaternion update vector of
shape [*, 6], where the final 6 columns represent the x, y, and
z values of a quaternion of form (1, x, y, z) followed by a 3D
translation.
Args:
q_vec: The quaternion update vector.
Returns:
The composed transformation.
"""
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
new_rots = self._rots.compose_q_update_vec(q_vec)
trans_update = self._rots.apply(t_vec)
new_translation = self._trans + trans_update
return Rigid(new_rots, new_translation)
def compose(self,
r: Rigid,
) -> Rigid:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot = self._rots.compose_r(r._rots)
new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans)
def apply(self,
pts: torch.Tensor,
) -> torch.Tensor:
"""
Applies the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor.
Returns:
The transformed points.
"""
rotated = self._rots.apply(pts)
return rotated + self._trans
def invert_apply(self,
pts: torch.Tensor
) -> torch.Tensor:
"""
Applies the inverse of the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor
Returns:
The transformed points.
"""
pts = pts - self._trans
return self._rots.invert_apply(pts)
def invert(self) -> Rigid:
"""
Inverts the transformation.
Returns:
The inverse transformation.
"""
rot_inv = self._rots.invert()
trn_inv = rot_inv.apply(self._trans)
return Rigid(rot_inv, -1 * trn_inv)
def map_tensor_fn(self,
fn: Callable[torch.Tensor, torch.Tensor]
) -> Rigid:
"""
Apply a Tensor -> Tensor function to underlying translation and
rotation tensors, mapping over the translation/rotation dimensions
respectively.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns:
The transformed Rigid object
"""
new_rots = self._rots.map_tensor_fn(fn)
new_trans = torch.stack(
list(map(fn, torch.unbind(self._trans, dim=-1))),
dim=-1
)
return Rigid(new_rots, new_trans)
def to_tensor_4x4(self) -> torch.Tensor:
"""
Converts a transformation to a homogenous transformation tensor.
Returns:
A [*, 4, 4] homogenous transformation tensor
"""
tensor = self._trans.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self._rots.get_rot_mats()
tensor[..., :3, 3] = self._trans
tensor[..., 3, 3] = 1
return tensor
@staticmethod
def from_tensor_4x4(
t: torch.Tensor
) -> Rigid:
"""
Constructs a transformation from a homogenous transformation
tensor.
Args:
t: [*, 4, 4] homogenous transformation tensor
Returns:
T object with shape [*]
"""
if(t.shape[-2:] != (4, 4)):
raise ValueError("Incorrectly shaped input tensor")
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
trans = t[..., :3, 3]
return Rigid(rots, trans)
def to_tensor_7(self) -> torch.Tensor:
"""
Converts a transformation to a tensor with 7 final columns, four
for the quaternion followed by three for the translation.
Returns:
A [*, 7] tensor representation of the transformation
"""
tensor = self._trans.new_zeros((*self.shape, 7))
tensor[..., :4] = self._rots.get_quats()
tensor[..., 4:] = self._trans
return tensor
@staticmethod
def from_tensor_7(
t: torch.Tensor,
normalize_quats: bool = False,
) -> Rigid:
if(t.shape[-1] != 7):
raise ValueError("Incorrectly shaped input tensor")
quats, trans = t[..., :4], t[..., 4:]
rots = Rotation(
rot_mats=None,
quats=quats,
normalize_quats=normalize_quats
)
return Rigid(rots, trans)
@staticmethod
def from_3_points(
p_neg_x_axis: torch.Tensor,
origin: torch.Tensor,
p_xy_plane: torch.Tensor,
eps: float = 1e-8
) -> Rigid:
"""
Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm.
Args:
p_neg_x_axis: [*, 3] coordinates
origin: [*, 3] coordinates used as frame origins
p_xy_plane: [*, 3] coordinates
eps: Small epsilon value
Returns:
A transformation object of shape [*]
"""
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
origin = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
e1 = [c / denom for c in e1]
e2 = [
e0[1] * e1[2] - e0[2] * e1[1],
e0[2] * e1[0] - e0[0] * e1[2],
e0[0] * e1[1] - e0[1] * e1[0],
]
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin, dim=-1))
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._rots.unsqueeze(dim)
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
@staticmethod
def cat(
ts: Sequence[Rigid],
dim: int,
) -> Rigid:
"""
Concatenates transformations along a new dimension.
Args:
ts:
A list of T objects
dim:
The dimension along which the transformations should be
concatenated
Returns:
A concatenated transformation object
"""
rots = Rotation.cat([t._rots for t in ts], dim)
trans = torch.cat(
[t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
)
return Rigid(rots, trans)
def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
"""
Applies a Rotation -> Rotation function to the stored rotation
object.
Args:
fn: A function of type Rotation -> Rotation
Returns:
A transformation object with a transformed rotation.
"""
return Rigid(fn(self._rots), self._trans)
def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
"""
Applies a Tensor -> Tensor function to the stored translation.
Args:
fn:
A function of type Tensor -> Tensor to be applied to the
translation
Returns:
A transformation object with a transformed translation.
"""
return Rigid(self._rots, fn(self._trans))
def scale_translation(self, trans_scale_factor: float) -> Rigid:
"""
Scales the translation by a constant factor.
Args:
trans_scale_factor:
The constant factor
Returns:
A transformation object with a scaled translation.
"""
fn = lambda t: t * trans_scale_factor
return self.apply_trans_fn(fn)
def stop_rot_gradient(self) -> Rigid:
"""
Detaches the underlying rotation object
Returns:
A transformation object with detached rotations
"""
fn = lambda r: r.detach()
return self.apply_rot_fn(fn)
@staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
"""
Returns a transformation object from reference coordinates.
Note that this method does not take care of symmetries. If you
provide the atom positions in the non-standard way, the N atom will
end up not at [-0.527250, 1.359329, 0.0] but instead at
[-0.527250, -1.359329, 0.0]. You need to take care of such cases in
your code.
Args:
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
Returns:
A transformation object. After applying the translation and
rotation to the reference backbone, the coordinates will
approximately equal to the input coordinates.
"""
translation = -1 * ca_xyz
n_xyz = n_xyz + translation
c_xyz = c_xyz + translation
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
sin_c1 = -c_y / norm
cos_c1 = c_x / norm
zeros = sin_c1.new_zeros(sin_c1.shape)
ones = sin_c1.new_ones(sin_c1.shape)
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
c1_rots[..., 0, 0] = cos_c1
c1_rots[..., 0, 1] = -1 * sin_c1
c1_rots[..., 1, 0] = sin_c1
c1_rots[..., 1, 1] = cos_c1
c1_rots[..., 2, 2] = 1
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
sin_c2 = c_z / norm
cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1
c2_rots[..., 2, 0] = -1 * sin_c2
c2_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz)
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
sin_n = -n_z / norm
cos_n = n_y / norm
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
n_rots[..., 0, 0] = 1
n_rots[..., 1, 1] = cos_n
n_rots[..., 1, 2] = -1 * sin_n
n_rots[..., 2, 1] = sin_n
n_rots[..., 2, 2] = cos_n
rots = rot_matmul(n_rots, c_rots)
rots = rots.transpose(-1, -2)
translation = -1 * translation
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, translation)
def cuda(self) -> Rigid:
"""
Moves the transformation object to GPU memory
Returns:
A version of the transformation on GPU
"""
return Rigid(self._rots.cuda(), self._trans.cuda())
# Copyright 2023 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from Bio.SVDSuperimposer import SVDSuperimposer
import torch
def _superimpose_np(reference, coords):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[N, 3] reference array
coords:
[N, 3] array
Returns:
A tuple of [N, 3] superimposed coords and the final RMSD.
"""
sup = SVDSuperimposer()
sup.set(reference, coords)
sup.run()
return sup.get_transformed(), sup.get_rms()
def _superimpose_single(reference, coords):
reference_np = reference.detach().cpu().numpy()
coords_np = coords.detach().cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
def superimpose(reference, coords, mask):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[*, N, 3] reference tensor
coords:
[*, N, 3] tensor
mask:
[*, N] tensor
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
def select_unmasked_coords(coords, mask):
return torch.masked_select(
coords,
(mask > 0.)[..., None],
).reshape(-1, 3)
batch_dims = reference.shape[:-2]
flat_reference = reference.reshape((-1,) + reference.shape[-2:])
flat_coords = coords.reshape((-1,) + reference.shape[-2:])
flat_mask = mask.reshape((-1,) + mask.shape[-1:])
superimposed_list = []
rmsds = []
for r, c, m in zip(flat_reference, flat_coords, flat_mask):
r_unmasked_coords = select_unmasked_coords(r, m)
c_unmasked_coords = select_unmasked_coords(c, m)
superimposed, rmsd = _superimpose_single(
r_unmasked_coords,
c_unmasked_coords
)
# This is very inelegant, but idk how else to invert the masking
# procedure.
count = 0
superimposed_full_size = torch.zeros_like(r)
for i, unmasked in enumerate(m):
if(unmasked):
superimposed_full_size[i] = superimposed[count]
count += 1
superimposed_list.append(superimposed_full_size)
rmsds.append(rmsd)
superimposed_stacked = torch.stack(superimposed_list, dim=0)
rmsds_stacked = torch.stack(rmsds, dim=0)
superimposed_reshaped = superimposed_stacked.reshape(
batch_dims + coords.shape[-2:]
)
rmsds_reshaped = rmsds_stacked.reshape(
batch_dims
)
return superimposed_reshaped, rmsds_reshaped
\ No newline at end of file
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
import fastfold.habana as habana
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-4):
mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
boundaries = torch.linspace(
min_bin, max_bin, no_bins - 1, device=pts.device
)
dists = torch.sqrt(
torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
)
return torch.bucketize(dists, boundaries)
def dict_multimap(fn, dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v)
else:
# when bs = 1, returns [...] rather than [1, ...]
new_dict[k] = fn(all_v) if len(all_v) > 1 else all_v[0]
return new_dict
def one_hot(x, v_bins):
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
def batched_gather(data, inds, dim=0, no_batch_dims=0):
ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims = [
slice(None) for _ in range(len(data.shape) - no_batch_dims)
]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[ranges]
# With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type):
new_dict = {}
for k, v in dic.items():
if type(v) is dict:
new_dict[k] = dict_map(fn, v, leaf_type)
else:
new_dict[k] = tree_map(fn, v, leaf_type)
return new_dict
def tree_map(fn, tree, leaf_type):
if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type)
elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree]
elif isinstance(tree, tuple):
return tuple([tree_map(fn, x, leaf_type) for x in tree])
elif isinstance(tree, leaf_type):
return fn(tree)
else:
print(type(tree))
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if(start_edges is None):
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if(end_edges is None):
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if(len(start) == 0):
return [tuple()]
elif(len(start) == 1):
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s,e in zip(start, end):
if(s == e):
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if(divergence_idx == len(dims)):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s for s in
_get_minimal_slice_set(
start[divergence_idx + 1:],
[d - 1 for d in dims[divergence_idx + 1:]],
dims[divergence_idx + 1:],
start_edges=start_edges[divergence_idx + 1:],
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s for s in
_get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1:]],
end[divergence_idx + 1:],
dims[divergence_idx + 1:],
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
end_edges=end_edges[divergence_idx + 1:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif(start_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx]),)
)
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif(end_edges[divergence_idx]):
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if(middle_ground > 1):
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat(
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = None
for _ in range(no_chunks):
# Chunk the input
if(not low_mem):
select_chunk = (
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
)
else:
select_chunk = (
partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims)
)
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
if habana.is_habana():
import habana_frameworks.torch.core as htcore
htcore.mark_step()
return out
import os
import random
import torch
import numpy as np
def get_param_path():
# develop
if os.path.exists('/data/scratch/alphafold/alphafold/params/params_model_1.npz'):
return '/data/scratch/alphafold/alphafold/params/params_model_1.npz'
# test
return '/data/scratch/fastfold/weight.npz'
def get_data_path():
# develop
if os.path.exists('/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'):
return '/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'
# test
return '/data/scratch/fastfold/mono_batch.pkl'
def get_train_data_path():
return '/data/scratch/fastfold/std_train_batch.pkl'
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
\ No newline at end of file
# Copyright 2023 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fastfold.model.hub.loss import lddt_ca
from fastfold.common import residue_constants
from fastfold.utils.superimposition import superimpose
def drmsd(structure_1, structure_2, mask=None):
def prep_d(structure):
d = structure[..., :, None, :] - structure[..., None, :, :]
d = d ** 2
d = torch.sqrt(torch.sum(d, dim=-1))
return d
d1 = prep_d(structure_1)
d2 = prep_d(structure_2)
drmsd = d1 - d2
drmsd = drmsd ** 2
if(mask is not None):
drmsd = drmsd * (mask[..., None] * mask[..., None, :])
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)
return drmsd
def drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return drmsd(structure_1, structure_2, mask)
def gdt(p1, p2, mask, cutoffs):
n = torch.sum(mask, dim=-1)
p1 = p1.float()
p2 = p2.float()
distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))
scores = []
for c in cutoffs:
score = torch.sum((distances <= c) * mask, dim=-1) / n
score = torch.mean(score)
scores.append(score)
return sum(scores) / len(scores)
def gdt_ts(p1, p2, mask):
return gdt(p1, p2, mask, [1., 2., 4., 8.])
def gdt_ha(p1, p2, mask):
return gdt(p1, p2, mask, [0.5, 1., 2., 4.])
def compute_validation_metrics(
batch,
outputs,
superimposition_metrics=False,
):
metrics = {}
gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca(
pred_coords,
gt_coords,
all_atom_mask,
eps=1e-8,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
return metrics
from .workflow_run import batch_run
\ No newline at end of file
from .task_factory import TaskFactory
from .hhblits import HHBlitsFactory
from .hhsearch import HHSearchFactory
from .jackhmmer import JackHmmerFactory
from .hhfilter import HHfilterFactory
from .hmmsearch import HmmSearchFactory
\ No newline at end of file
from typing import List
import ray
from ray.dag.function_node import FunctionNode
from fastfold.workflow.factory import TaskFactory
import fastfold.data.tools.hhblits as ffHHBlits
class HHBlitsFactory(TaskFactory):
keywords = ['binary_path', 'databases', 'n_cpu']
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
# setup runner
runner = ffHHBlits.HHBlits(
**self.config
)
# generate function node
@ray.remote
def hhblits_node_func(after: List[FunctionNode]) -> None:
result = runner.query(fasta_path)
with open(output_path, 'w') as f:
f.write(result['a3m'])
return hhblits_node_func.bind(after)
import subprocess
import logging
from typing import List
import ray
from ray.dag.function_node import FunctionNode
from fastfold.workflow.factory import TaskFactory
class HHfilterFactory(TaskFactory):
keywords = ['binary_path']
def gen_node(self, fasta_path: str, output_path: str, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
# generate function node
@ray.remote
def hhfilter_node_func(after: List[FunctionNode]) -> None:
cmd = [
self.config.get('binary_path'),
]
if 'id' in self.config:
cmd += ['-id', str(self.config.get('id'))]
if 'cov' in self.config:
cmd += ['-cov', str(self.config.get('cov'))]
cmd += ['-i', fasta_path, '-o', output_path]
subprocess.run(cmd, shell=True)
return hhfilter_node_func.bind(after)
\ No newline at end of file
from typing import List
import inspect
import ray
from ray.dag.function_node import FunctionNode
import fastfold.data.tools.hhsearch as ffHHSearch
from fastfold.workflow.factory import TaskFactory
class HHSearchFactory(TaskFactory):
keywords = ['binary_path', 'databases', 'n_cpu']
def gen_node(self, a3m_path: str, output_path: str, atab_path: str = None, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
params = { k: self.config.get(k) for k in inspect.getfullargspec(ffHHSearch.HHSearch.__init__).kwonlyargs if self.config.get(k) }
# setup runner with a filtered config dict
runner = ffHHSearch.HHSearch(
**params
)
# generate function node
@ray.remote
def hhsearch_node_func(after: List[FunctionNode]) -> None:
with open(a3m_path, "r") as f:
a3m = f.read()
if atab_path:
hhsearch_result, atab = runner.query(a3m, gen_atab=True)
else:
hhsearch_result = runner.query(a3m)
with open(output_path, "w") as f:
f.write(hhsearch_result)
if atab_path:
with open(atab_path, "w") as f:
f.write(atab)
return hhsearch_node_func.bind(after)
from typing import List
import inspect
import ray
from ray.dag.function_node import FunctionNode
from fastfold.data.tools import hmmsearch, hmmbuild
from fastfold.data import parsers
from fastfold.workflow.factory import TaskFactory
from typing import Optional
class HmmSearchFactory(TaskFactory):
keywords = ['binary_path', 'hmmbuild_binary_path', 'database_path', 'n_cpu']
def gen_node(self, msa_sto_path: str, output_dir: Optional[str] = None, after: List[FunctionNode]=None) -> FunctionNode:
self.isReady()
params = { k: self.config.get(k) for k in inspect.getfullargspec(hmmsearch.Hmmsearch.__init__).kwonlyargs if self.config.get(k) }
# setup runner with a filtered config dict
runner = hmmsearch.Hmmsearch(
**params
)
# generate function node
@ray.remote
def hmmsearch_node_func(after: List[FunctionNode]) -> None:
with open(msa_sto_path, "r") as f:
msa_sto = f.read()
msa_sto = parsers.deduplicate_stockholm_msa(msa_sto)
msa_sto = parsers.remove_empty_columns_from_stockholm_msa(
msa_sto
)
hmmsearch_result = runner.query(msa_sto, output_dir=output_dir)
return hmmsearch_node_func.bind(after)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment