# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch # According to DeepMind, this prevents rotation compositions from being # computed on low-precision tensor cores. I'm personally skeptical that it # makes a difference, but to get as close as possible to their outputs, I'm # adding it. def rot_matmul(a, b): e = ... row_1 = torch.stack([ a[e,0,0]*b[e,0,0] + a[e,0,1]*b[e,1,0] + a[e,0,2]*b[e,2,0], a[e,0,0]*b[e,0,1] + a[e,0,1]*b[e,1,1] + a[e,0,2]*b[e,2,1], a[e,0,0]*b[e,0,2] + a[e,0,1]*b[e,1,2] + a[e,0,2]*b[e,2,2], ], dim=-1) row_2 = torch.stack([ a[e,1,0]*b[e,0,0] + a[e,1,1]*b[e,1,0] + a[e,1,2]*b[e,2,0], a[e,1,0]*b[e,0,1] + a[e,1,1]*b[e,1,1] + a[e,1,2]*b[e,2,1], a[e,1,0]*b[e,0,2] + a[e,1,1]*b[e,1,2] + a[e,1,2]*b[e,2,2], ], dim=-1) row_3 = torch.stack([ a[e,2,0]*b[e,0,0] + a[e,2,1]*b[e,1,0] + a[e,2,2]*b[e,2,0], a[e,2,0]*b[e,0,1] + a[e,2,1]*b[e,1,1] + a[e,2,2]*b[e,2,1], a[e,2,0]*b[e,0,2] + a[e,2,1]*b[e,1,2] + a[e,2,2]*b[e,2,2], ], dim=-1) return torch.stack([row_1, row_2, row_3], dim=-2) def rot_vec_mul(r, t): x = t[..., 0] y = t[..., 1] z = t[..., 2] return torch.stack([ r[..., 0, 0]*x + r[..., 0, 1]*y + r[..., 0, 2]*z, r[..., 1, 0]*x + r[..., 1, 1]*y + r[..., 1, 2]*z, r[..., 2, 0]*x + r[..., 2, 1]*y + r[..., 2, 2]*z, ], dim=-1) class T: def __init__(self, rots, trans): self.rots = rots self.trans = trans if(self.rots is None and self.trans is None): raise ValueError("Only one of rots and trans can be None") elif(self.rots is None): self.rots = T.identity_rot( self.trans.shape[:-1], self.trans.dtype, self.trans.device ) elif(self.trans is None): self.trans = T.identity_trans( self.rots.shape[:-2], self.rots.dtype, self.rots.device ) if(self.rots.shape[-2:] != (3, 3) or self.trans.shape[-1] != 3 or self.rots.shape[:-2] != self.trans.shape[:-1]): raise ValueError("Incorrectly shaped input") def __getitem__(self, index): if(type(index) != tuple): index = (index,) return T( self.rots[index + (slice(None), slice(None))], self.trans[index + (slice(None),)] ) def __eq__(self, obj): return ( torch.all(self.rots == obj.rots) and torch.all(self.trans == obj.trans) ) def __mul__(self, right): rots = self.rots * right[..., None, None] trans = self.trans * right[..., None] return T(rots, trans) def __rmul__(self, left): return self.__mul__(left) @property def shape(self): s = self.rots.shape[:-2] return s if len(s) > 0 else torch.Size([1]) def get_trans(self): return self.trans def get_rots(self): return self.rots def compose(self, t): rot_1, trn_1 = self.rots, self.trans rot_2, trn_2 = t.rots, t.trans rot = rot_matmul(rot_1, rot_2) trn = rot_vec_mul(rot_1, trn_2) + trn_1 return T(rot, trn) def apply(self, pts): r, t = self.rots, self.trans rotated = rot_vec_mul(r, pts) return rotated + t def invert_apply(self, pts): r, t = self.rots, self.trans pts = pts - t return rot_vec_mul(r.transpose(-1, -2), pts) def invert(self): rot_inv = self.rots.transpose(-1, -2) trn_inv = rot_vec_mul(rot_inv, self.trans) return T(rot_inv, -1 * trn_inv) def unsqueeze(self, dim): if(dim >= len(self.shape)): raise ValueError("Invalid dimension") rots = self.rots.unsqueeze(dim if dim >= 0 else dim - 2) trans = self.trans.unsqueeze(dim if dim >= 0 else dim - 1) return T(rots, trans) @staticmethod def identity_rot(shape, dtype, device, requires_grad=False): rots = torch.eye( 3, dtype=dtype, device=device, requires_grad=requires_grad ) rots = rots.view(*((1,) * len(shape)), 3, 3) rots = rots.expand(*shape, -1, -1) return rots @staticmethod def identity_trans(shape, dtype, device, requires_grad=False): trans = torch.zeros( (*shape, 3), dtype=dtype, device=device, requires_grad=requires_grad ) return trans @staticmethod def identity(shape, dtype, device, requires_grad=False): return T( T.identity_rot(shape, dtype, device, requires_grad), T.identity_trans(shape, dtype, device, requires_grad), ) @staticmethod def from_4x4(t): rots = t[..., :3, :3] trans = t[..., :3, 3] return T(rots, trans) def to_4x4(self): tensor = torch.zeros((*self.shape, 4, 4), device=self.rots.device) tensor[..., :3, :3] = self.rots tensor[..., :3, 3] = self.trans tensor[..., 3, 3] = 1 return tensor @staticmethod def from_tensor(t): return T.from_4x4(t) @staticmethod def from_3_points(p_neg_x_axis, origin, p_xy_plane, eps=1e-8): v1 = origin - p_neg_x_axis v2 = p_xy_plane - origin e1 = v1 / torch.sqrt(torch.sum(v1 ** 2, dim=-1) + eps)[..., None] u2 = v2 - e1 * (torch.einsum('...i,...i->...', v2, e1)[..., None]) e2 = u2 / torch.sqrt(torch.sum(u2 ** 2, dim=-1) + eps)[..., None] e3 = torch.cross(e1, e2, dim=-1) rots = torch.cat( ( e1.unsqueeze(-1), e2.unsqueeze(-1), e3.unsqueeze(-1), ), dim=-1, ) return T(rots, origin) @staticmethod def concat(ts, dim): rots = torch.cat( [t.rots for t in ts], dim=dim if dim >= 0 else dim - 2 ) trans = torch.cat( [t.trans for t in ts], dim=dim if dim >= 0 else dim - 1 ) return T(rots, trans) def map_tensor_fn(self, fn): """ Apply a function that takes a tensor as its only argument to the rotations and translations, treating the final two/one dimension(s), respectively, as batch dimensions. E.g.: Given t, an instance of T of shape [N, M], this function can be used to sum out the second dimension thereof as follows: t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) The resulting object has rotations of shape [N, 3, 3] and translations of shape [N, 3] """ rots = self.rots.view(*self.rots.shape[:-2], 9) rots = torch.stack(list(map(fn, torch.unbind(rots, -1))), dim=-1) rots = rots.view(*rots.shape[:-1], 3, 3) trans = torch.stack(list(map(fn, torch.unbind(self.trans, -1))), dim=-1) return T(rots, trans) def stop_rot_gradient(self): return T(self.rots.detach(), self.trans) def scale_translation(self, factor): return T(self.rots, self.trans * factor) @staticmethod def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): translation = -1 * c_xyz n_xyz = n_xyz + translation c_xyz = c_xyz + translation c_x, c_y, c_z = [c_xyz[...,i] for i in range(3)] norm = torch.sqrt(eps + c_x**2 + c_y**2) sin_c1 = -c_y / norm cos_c1 = c_x / norm zeros = sin_c1.new_zeros(sin_c1.shape) ones = sin_c1.new_ones(sin_c1.shape) c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) c1_rots[..., 0, 0] = cos_c1 c1_rots[..., 0, 1] = -1 * sin_c1 c1_rots[..., 1, 0] = sin_c1 c1_rots[..., 1, 1] = cos_c1 c1_rots[..., 2, 2] = 1 norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2) sin_c2 = c_z / norm cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) c2_rots[..., 0, 0] = cos_c2 c2_rots[..., 0, 2] = sin_c2 c2_rots[..., 1, 1] = 1 c1_rots[..., 2, 0] = -1 * sin_c2 c1_rots[..., 2, 2] = cos_c2 c_rots = rot_matmul(c2_rot_matrix, c1_rot_matrix) n_xyz = rot_vec_mul(c_rots, n_xyz) _, n_y, n_z = [n_xyz[..., i] for i in range(3)] norm = torch.sqrt(eps + n_y**2 + n_z**2) sin_n = -n_z / norm cos_n = n_y / norm n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) n_rots[..., 0, 0] = 1 n_rots[..., 1, 1] = cos_n n_rots[..., 1, 2] = -1 * sin_n n_rots[..., 2, 1] = sin_n n_rots[..., 2, 2] = cos_n rots = rot_matmul(n_rots, c_rots) rots = rots.transpose(-1, -2) translation = -1 * translation return T(rots, translation) _quat_elements = ['a', 'b', 'c', 'd'] _qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] _qtr_ind_dict = {key:ind for ind, key in enumerate(_qtr_keys)} def _to_mat(pairs): mat = torch.zeros((4, 4)) for pair in pairs: key, value = pair ind = _qtr_ind_dict[key] mat[ind // 4][ind % 4] = value return mat _qtr_mat = torch.zeros((4, 4, 3, 3)) _qtr_mat[..., 0, 0] = _to_mat([('aa', 1), ('bb', 1), ('cc', -1), ('dd', -1)]) _qtr_mat[..., 0, 1] = _to_mat([('bc', 2), ('ad', -2)]) _qtr_mat[..., 0, 2] = _to_mat([('bd', 2), ('ac', 2)]) _qtr_mat[..., 1, 0] = _to_mat([('bc', 2), ('ad', 2)]) _qtr_mat[..., 1, 1] = _to_mat([('aa', 1), ('bb', -1), ('cc', 1), ('dd', -1)]) _qtr_mat[..., 1, 2] = _to_mat([('cd', 2), ('ab', -2)]) _qtr_mat[..., 2, 0] = _to_mat([('bd', 2), ('ac', -2)]) _qtr_mat[..., 2, 1] = _to_mat([('cd', 2), ('ab', 2)]) _qtr_mat[..., 2, 2] = _to_mat([('aa', 1), ('bb', -1), ('cc', -1), ('dd', 1)]) def quat_to_rot( quat # [*, 4] ): # [*, 4, 4] quat = quat[..., None] * quat[..., None, :] # [*, 4, 4, 3, 3] shaped_qtr_mat = _qtr_mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3)) quat = quat[..., None, None] * shaped_qtr_mat.to(quat.device) # [*, 3, 3] return torch.sum(quat, dim=(-3, -4))