from __future__ import annotations from typing import Tuple, Any, Sequence, Callable, Optional, Iterable import numpy as np import torch def zero_translation( batch_dims: Tuple[int], dtype: Optional[torch.dtype] = torch.float, device: Optional[torch.device] = torch.device("cpu"), requires_grad: bool = False, ) -> torch.Tensor: trans = torch.zeros( (*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad ) return trans # pylint: disable=bad-whitespace _QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) _QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr _QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii _QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj _QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk _QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij _QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik _QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk _QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir _QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr _QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr _QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9) _QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT) _QUAT_MULTIPLY = np.zeros((4, 4, 4)) _QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]] _QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]] _QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]] _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]] _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] _QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC) class Rotation: def __init__( self, mat: torch.Tensor, ): if mat.shape[-2:] != (3, 3): raise ValueError(f"incorrect rotation shape: {mat.shape}") self._mat = mat @staticmethod def identity( shape, dtype: Optional[torch.dtype] = torch.float, device: Optional[torch.device] = torch.device("cpu"), requires_grad: bool = False, ) -> Rotation: mat = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad) mat = mat.view(*((1,) * len(shape)), 3, 3) mat = mat.expand(*shape, -1, -1) return Rotation(mat) @staticmethod def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return (a.float() @ b.float()).type(a.dtype) @staticmethod def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype) def __getitem__(self, index: Any) -> Rotation: if not isinstance(index, tuple): index = (index,) return Rotation(mat=self._mat[index + (slice(None), slice(None))]) def __mul__(self, right: Any) -> Rotation: if isinstance(right, (int, float)): return Rotation(mat=self._mat * right) elif isinstance(right, torch.Tensor): return Rotation(mat=self._mat * right[..., None, None]) else: raise TypeError( f"multiplicand must be a tensor or a number, got {type(right)}." ) def __rmul__(self, left: Any) -> Rotation: return self.__mul__(left) def __matmul__(self, other: Rotation) -> Rotation: new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat) return Rotation(mat=new_mat) @property def _inv_mat(self): return self._mat.transpose(-1, -2) @property def rot_mat(self) -> torch.Tensor: return self._mat def invert(self) -> Rotation: return Rotation(mat=self._inv_mat) def apply(self, pts: torch.Tensor) -> torch.Tensor: return Rotation.mat_mul_vec(self._mat, pts) def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: return Rotation.mat_mul_vec(self._inv_mat, pts) # inherit tensor behaviors @property def shape(self) -> torch.Size: s = self._mat.shape[:-2] return s @property def dtype(self) -> torch.dtype: return self._mat.dtype @property def device(self) -> torch.device: return self._mat.device @property def requires_grad(self) -> bool: return self._mat.requires_grad def unsqueeze(self, dim: int) -> Rotation: if dim >= len(self.shape): raise ValueError("Invalid dimension") rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2) return Rotation(mat=rot_mats) def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation: mat = self._mat.view(self._mat.shape[:-2] + (9,)) mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1) mat = mat.view(mat.shape[:-1] + (3, 3)) return Rotation(mat=mat) @staticmethod def cat(rs: Sequence[Rotation], dim: int) -> Rotation: rot_mats = [r.rot_mat for r in rs] rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) return Rotation(mat=rot_mats) def cuda(self) -> Rotation: return Rotation(mat=self._mat.cuda()) def to( self, device: Optional[torch.device], dtype: Optional[torch.dtype] ) -> Rotation: return Rotation(mat=self._mat.to(device=device, dtype=dtype)) def type(self, dtype: Optional[torch.dtype]) -> Rotation: return Rotation(mat=self._mat.type(dtype)) def detach(self) -> Rotation: return Rotation(mat=self._mat.detach()) class Frame: def __init__( self, rotation: Optional[Rotation], translation: Optional[torch.Tensor], ): if rotation is None and translation is None: rotation = Rotation.identity((0,)) translation = zero_translation((0,)) elif translation is None: translation = zero_translation( rotation.shape, rotation.dtype, rotation.device, rotation.requires_grad ) elif rotation is None: rotation = Rotation.identity( translation.shape[:-1], translation.dtype, translation.device, translation.requires_grad, ) if (rotation.shape != translation.shape[:-1]) or ( rotation.device != translation.device ): raise ValueError("RotationMatrix and translation incompatible") self._r = rotation self._t = translation @staticmethod def identity( shape: Iterable[int], dtype: Optional[torch.dtype] = torch.float, device: Optional[torch.device] = torch.device("cpu"), requires_grad: bool = False, ) -> Frame: return Frame( Rotation.identity(shape, dtype, device, requires_grad), zero_translation(shape, dtype, device, requires_grad), ) def __getitem__( self, index: Any, ) -> Frame: if type(index) != tuple: index = (index,) return Frame( self._r[index], self._t[index + (slice(None),)], ) def __mul__( self, right: torch.Tensor, ) -> Frame: if not (isinstance(right, torch.Tensor)): raise TypeError("The other multiplicand must be a Tensor") new_rots = self._r * right new_trans = self._t * right[..., None] return Frame(new_rots, new_trans) def __rmul__( self, left: torch.Tensor, ) -> Frame: return self.__mul__(left) @property def shape(self) -> torch.Size: s = self._t.shape[:-1] return s @property def device(self) -> torch.device: return self._t.device def get_rots(self) -> Rotation: return self._r def get_trans(self) -> torch.Tensor: return self._t def compose( self, other: Frame, ) -> Frame: new_rot = self._r @ other._r new_trans = self._r.apply(other._t) + self._t return Frame(new_rot, new_trans) def apply( self, pts: torch.Tensor, ) -> torch.Tensor: rotated = self._r.apply(pts) return rotated + self._t def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: pts = pts - self._t return self._r.invert_apply(pts) def invert(self) -> Frame: rot_inv = self._r.invert() trn_inv = rot_inv.apply(self._t) return Frame(rot_inv, -1 * trn_inv) def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Frame: new_rots = self._r.map_tensor_fn(fn) new_trans = torch.stack(list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1) return Frame(new_rots, new_trans) def to_tensor_4x4(self) -> torch.Tensor: tensor = self._t.new_zeros((*self.shape, 4, 4)) tensor[..., :3, :3] = self._r.rot_mat tensor[..., :3, 3] = self._t tensor[..., 3, 3] = 1 return tensor @staticmethod def from_tensor_4x4(t: torch.Tensor) -> Frame: if t.shape[-2:] != (4, 4): raise ValueError("Incorrectly shaped input tensor") rots = Rotation(mat=t[..., :3, :3]) trans = t[..., :3, 3] return Frame(rots, trans) @staticmethod def from_3_points( p_neg_x_axis: torch.Tensor, origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-8, ) -> Frame: p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) origin = torch.unbind(origin, dim=-1) p_xy_plane = torch.unbind(p_xy_plane, dim=-1) e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] denom = torch.sqrt(sum((c * c for c in e0)) + eps) e0 = [c / denom for c in e0] dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] denom = torch.sqrt(sum((c * c for c in e1)) + eps) e1 = [c / denom for c in e1] e2 = [ e0[1] * e1[2] - e0[2] * e1[1], e0[2] * e1[0] - e0[0] * e1[2], e0[0] * e1[1] - e0[1] * e1[0], ] rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) rots = rots.reshape(rots.shape[:-1] + (3, 3)) rot_obj = Rotation(mat=rots) return Frame(rot_obj, torch.stack(origin, dim=-1)) def unsqueeze( self, dim: int, ) -> Frame: if dim >= len(self.shape): raise ValueError("Invalid dimension") rots = self._r.unsqueeze(dim) trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1) return Frame(rots, trans) @staticmethod def cat( Ts: Sequence[Frame], dim: int, ) -> Frame: rots = Rotation.cat([T._r for T in Ts], dim) trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1) return Frame(rots, trans) def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame: return Frame(fn(self._r), self._t) def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Frame: return Frame(self._r, fn(self._t)) def scale_translation(self, trans_scale_factor: float) -> Frame: fn = lambda t: t * trans_scale_factor return self.apply_trans_fn(fn) def stop_rot_gradient(self) -> Frame: fn = lambda r: r.detach() return self.apply_rot_fn(fn) @staticmethod def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): input_dtype = ca_xyz.dtype n_xyz = n_xyz.float() ca_xyz = ca_xyz.float() c_xyz = c_xyz.float() n_xyz = n_xyz - ca_xyz c_xyz = c_xyz - ca_xyz c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)] norm = torch.sqrt(eps + c_x**2 + c_y**2) sin_c1 = -c_y / norm cos_c1 = c_x / norm c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) c1_rots[..., 0, 0] = cos_c1 c1_rots[..., 0, 1] = -1 * sin_c1 c1_rots[..., 1, 0] = sin_c1 c1_rots[..., 1, 1] = cos_c1 c1_rots[..., 2, 2] = 1 norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2) sin_c2 = d_pair / norm cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) c2_rots[..., 0, 0] = cos_c2 c2_rots[..., 0, 2] = sin_c2 c2_rots[..., 1, 1] = 1 c2_rots[..., 2, 0] = -1 * sin_c2 c2_rots[..., 2, 2] = cos_c2 c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots) n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz) _, n_y, n_z = [n_xyz[..., i] for i in range(3)] norm = torch.sqrt(eps + n_y**2 + n_z**2) sin_n = -n_z / norm cos_n = n_y / norm n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) n_rots[..., 0, 0] = 1 n_rots[..., 1, 1] = cos_n n_rots[..., 1, 2] = -1 * sin_n n_rots[..., 2, 1] = sin_n n_rots[..., 2, 2] = cos_n rots = Rotation.mat_mul_mat(n_rots, c_rots) rots = rots.transpose(-1, -2) rot_obj = Rotation(mat=rots.type(input_dtype)) return Frame(rot_obj, ca_xyz.type(input_dtype)) def cuda(self) -> Frame: return Frame(self._r.cuda(), self._t.cuda()) @property def dtype(self) -> torch.dtype: assert self._r.dtype == self._t.dtype return self._r.dtype def type(self, dtype) -> Frame: return Frame(self._r.type(dtype), self._t.type(dtype)) class Quaternion: def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor): if quaternion.shape[-1] != 4: raise ValueError(f"incorrect quaternion shape: {quaternion.shape}") self._q = quaternion self._t = translation @staticmethod def identity( shape: Iterable[int], dtype: Optional[torch.dtype] = torch.float, device: Optional[torch.device] = torch.device("cpu"), requires_grad: bool = False, ) -> Quaternion: trans = zero_translation(shape, dtype, device, requires_grad) quats = torch.zeros( (*shape, 4), dtype=dtype, device=device, requires_grad=requires_grad ) with torch.no_grad(): quats[..., 0] = 1 return Quaternion(quats, trans) def get_quats(self): return self._q def get_trans(self): return self._t def get_rot_mats(self): quats = self.get_quats() rot_mats = Quaternion.quat_to_rot(quats) return rot_mats @staticmethod def quat_to_rot(normalized_quat): global _QUAT_TO_ROT_tensor dtype = normalized_quat.dtype normalized_quat = normalized_quat.float() if _QUAT_TO_ROT_tensor.device != normalized_quat.device: _QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to(normalized_quat.device) rot_tensor = torch.sum( _QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None] * normalized_quat[..., None, :, None], dim=(-3, -2), ) rot_tensor = rot_tensor.type(dtype) rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3) return rot_tensor @staticmethod def normalize_quat(quats): dtype = quats.dtype quats = quats.float() quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) quats = quats.type(dtype) return quats @staticmethod def quat_multiply_by_vec(quat, vec): dtype = quat.dtype quat = quat.float() vec = vec.float() global _QUAT_MULTIPLY_BY_VEC_tensor if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device: _QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to(quat.device) mat = _QUAT_MULTIPLY_BY_VEC_tensor reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) return torch.sum( reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2), ).type(dtype) def compose_q_update_vec( self, q_update_vec: torch.Tensor, normalize_quats: bool = True ) -> torch.Tensor: quats = self.get_quats() new_quats = quats + Quaternion.quat_multiply_by_vec(quats, q_update_vec) if normalize_quats: new_quats = Quaternion.normalize_quat(new_quats) return new_quats def compose_update_vec( self, update_vec: torch.Tensor, pre_rot_mat: Rotation, ) -> Quaternion: q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:] new_quats = self.compose_q_update_vec(q_vec) trans_update = pre_rot_mat.apply(t_vec) new_trans = self._t + trans_update return Quaternion(new_quats, new_trans) def stop_rot_gradient(self) -> Quaternion: return Quaternion(self._q.detach(), self._t)