Commit 91887c3b authored by zachteed's avatar zachteed
Browse files

added differentiable conversions (InitFromVec, ToVec, ToMatrix), bump version

parent df566ea4
...@@ -25,7 +25,7 @@ Zachary Teed and Jia Deng, CVPR 2021 ...@@ -25,7 +25,7 @@ Zachary Teed and Jia Deng, CVPR 2021
### Requirements: ### Requirements:
* Cuda >= 10.1 (with nvcc compiler) * Cuda >= 10.1 (with nvcc compiler)
* PyTorch >= 1.6 * PyTorch >= 1.7
We recommend installing within a virtual enviornment. Make sure you clone using the `--recursive` flag. If you are using Anaconda, the following command can be used to install all dependencies We recommend installing within a virtual enviornment. Make sure you clone using the `--recursive` flag. If you are using Anaconda, the following command can be used to install all dependencies
``` ```
...@@ -42,7 +42,7 @@ To run the examples, you will need OpenCV and Open3D. Depending on your operatin ...@@ -42,7 +42,7 @@ To run the examples, you will need OpenCV and Open3D. Depending on your operatin
pip install opencv-python open3d pip install opencv-python open3d
``` ```
### Installing: ### Installing (from source)
Clone the repo using the `--recursive` flag and install using `setup.py` (may take up to 10 minutes) Clone the repo using the `--recursive` flag and install using `setup.py` (may take up to 10 minutes)
``` ```
...@@ -51,6 +51,14 @@ python setup.py install ...@@ -51,6 +51,14 @@ python setup.py install
./run_tests.sh ./run_tests.sh
``` ```
### Installing (pip)
You can install the library directly using pip
```bash
pip install git+https://github.com/princeton-vl/lietorch.git
```
## Overview ## Overview
LieTorch currently supports the 3D transformation groups. LieTorch currently supports the 3D transformation groups.
...@@ -62,7 +70,7 @@ LieTorch currently supports the 3D transformation groups. ...@@ -62,7 +70,7 @@ LieTorch currently supports the 3D transformation groups.
| SE3 | 6 | rotation + translation | | SE3 | 6 | rotation + translation |
| Sim3 | 7 | rotation + translation + scaling | | Sim3 | 7 | rotation + translation + scaling |
Each group supports the following operations: Each group supports the following differentiable operations:
| Operation | Map | Description | | Operation | Map | Description |
| -------| --------| ------------- | | -------| --------| ------------- |
...@@ -72,11 +80,16 @@ Each group supports the following operations: ...@@ -72,11 +80,16 @@ Each group supports the following operations:
| mul | G x G -> G | group multiplication | | mul | G x G -> G | group multiplication |
| adj | G x g -> g | adjoint | | adj | G x g -> g | adjoint |
| adjT | G x g*-> g* | dual adjoint | | adjT | G x g*-> g* | dual adjoint |
| act | G x R3 -> R3 | action on point (set) | | act | G x R^3 -> R^3 | action on point (set) |
| act4 | G x P3 -> P3 | action on homogeneous point (set) | | act4 | G x P^3 -> P^3 | action on homogeneous point (set) |
| matrix | G -> R^{4x4} | convert to 4x4 matrix
| vec | G -> R^D | map to Euclidean embedding vector |
| InitFromVec | R^D -> G | initialize group from Euclidean embedding
   
#### Simple Example: ### Simple Example:
Compute the angles between all pairs of rotation matrices Compute the angles between all pairs of rotation matrices
```python ```python
...@@ -86,10 +99,10 @@ from lietorch import SO3 ...@@ -86,10 +99,10 @@ from lietorch import SO3
phi = torch.randn(8000, 3, device='cuda', requires_grad=True) phi = torch.randn(8000, 3, device='cuda', requires_grad=True)
R = SO3.exp(phi) R = SO3.exp(phi)
# relative rotation matrix, SO3 ^ {100 x 100} # relative rotation matrix, SO3 ^ {8000 x 8000}
dR = R[:,None].inv() * R[None,:] dR = R[:,None].inv() * R[None,:]
# 100x100 matrix of angles # 8000x8000 matrix of angles
ang = dR.log().norm(dim=-1) ang = dR.log().norm(dim=-1)
# backpropogation in tangent space # backpropogation in tangent space
...@@ -97,6 +110,27 @@ loss = ang.sum() ...@@ -97,6 +110,27 @@ loss = ang.sum()
loss.backward() loss.backward()
``` ```
### Converting between Groups Elements and Euclidean Embeddings
We provide differentiable `FromVec` and `ToVec` functions which can be used to convert between LieGroup elements and their vector embeddings. Additional, the `.matrix` function returns a 4x4 transformation matrix.
```python
# random quaternion
q = torch.randn(1, 4, requires_grad=True)
q = q / q.norm(dim=-1, keepdim=True)
# create SO3 object from quaternion (differentiable w.r.t q)
R = SO3.InitFromVec(q)
# 4x4 transformation matrix (differentiable w.r.t R)
T = R.matrix()
# map back to quaterion (differentiable w.r.t R)
q = R.vec()
```
## Examples ## Examples
We provide real use cases in the examples directory We provide real use cases in the examples directory
1. Pose Graph Optimization 1. Pose Graph Optimization
......
...@@ -2,6 +2,8 @@ import lietorch_backends ...@@ -2,6 +2,8 @@ import lietorch_backends
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
class GroupOp(torch.autograd.Function): class GroupOp(torch.autograd.Function):
""" group operation base class """ """ group operation base class """
...@@ -22,6 +24,7 @@ class GroupOp(torch.autograd.Function): ...@@ -22,6 +24,7 @@ class GroupOp(torch.autograd.Function):
grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs) grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs)
return (None, ) + tuple(grad_inputs) return (None, ) + tuple(grad_inputs)
class Exp(GroupOp): class Exp(GroupOp):
""" exponential map """ """ exponential map """
forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward
...@@ -63,21 +66,37 @@ class ToMatrix(GroupOp): ...@@ -63,21 +66,37 @@ class ToMatrix(GroupOp):
forward_op, backward_op = lietorch_backends.as_matrix, None forward_op, backward_op = lietorch_backends.as_matrix, None
class ExtractTranslation(torch.autograd.Function):
""" group operation base class """
@staticmethod
def forward(ctx, data):
ctx.save_for_backward(data)
return data[...,:3]
@staticmethod ### conversion operations to/from Euclidean embeddings ###
def backward(ctx, dt):
data, = ctx.saved_tensors class FromVec(torch.autograd.Function):
t = data[...,:3] """ convert vector into group object """
@classmethod
def forward(cls, ctx, group_id, *inputs):
ctx.group_id = group_id
ctx.save_for_backward(*inputs)
return inputs[0]
diff_tau_phi = torch.zeros_like(data) @classmethod
diff_tau_phi[...,0:3] = dt def backward(cls, ctx, grad):
diff_tau_phi[...,3:6] = torch.cross(t, dt) inputs = ctx.saved_tensors
J = lietorch_backends.projector(ctx.group_id, *inputs)
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J))
class ToVec(torch.autograd.Function):
""" convert group object to vector """
@classmethod
def forward(cls, ctx, group_id, *inputs):
ctx.group_id = group_id
ctx.save_for_backward(*inputs)
return inputs[0]
@classmethod
def backward(cls, ctx, grad):
inputs = ctx.saved_tensors
J = lietorch_backends.projector(ctx.group_id, *inputs)
return None, torch.matmul(grad.unsqueeze(-2), J)
return diff_tau_phi
\ No newline at end of file
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import numpy as np import numpy as np
# group operations implemented in cuda # group operations implemented in cuda
from .group_ops import Exp, Log, Inv, Mul, Adj, AdjT, Jinv, Act3, Act4, ToMatrix, ExtractTranslation from .group_ops import Exp, Log, Inv, Mul, Adj, AdjT, Jinv, Act3, Act4, ToMatrix, ToVec, FromVec
from .broadcasting import broadcast_inputs from .broadcasting import broadcast_inputs
...@@ -70,6 +70,9 @@ class LieGroup: ...@@ -70,6 +70,9 @@ class LieGroup:
def dtype(self): def dtype(self):
return self.data.dtype return self.data.dtype
def vec(self):
return self.apply_op(ToVec, self.data)
@property @property
def tangent_shape(self): def tangent_shape(self):
return self.data.shape[:-1] + (self.manifold_dim,) return self.data.shape[:-1] + (self.manifold_dim,)
...@@ -100,6 +103,10 @@ class LieGroup: ...@@ -100,6 +103,10 @@ class LieGroup:
def IdentityLike(cls, G): def IdentityLike(cls, G):
return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype) return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype)
@classmethod
def InitFromVec(cls, data):
return cls(cls.apply_op(FromVec, data))
@classmethod @classmethod
def Random(cls, *batch_shape, sigma=1.0, **kwargs): def Random(cls, *batch_shape, sigma=1.0, **kwargs):
""" Construct random element with batch_shape by random sampling in tangent space""" """ Construct random element with batch_shape by random sampling in tangent space"""
...@@ -127,6 +134,10 @@ class LieGroup: ...@@ -127,6 +134,10 @@ class LieGroup:
""" exponential map: x -> X """ """ exponential map: x -> X """
return cls(cls.apply_op(Exp, x)) return cls(cls.apply_op(Exp, x))
def quaternion(self):
""" extract quaternion """
return self.apply_op(Quat, self.data)
def log(self): def log(self):
""" logarithm map """ """ logarithm map """
return self.apply_op(Log, self.data) return self.apply_op(Log, self.data)
...@@ -166,18 +177,18 @@ class LieGroup: ...@@ -166,18 +177,18 @@ class LieGroup:
elif p.shape[-1] == 4: elif p.shape[-1] == 4:
return self.apply_op(Act4, self.data, p) return self.apply_op(Act4, self.data, p)
# def matrix(self):
# """ convert element to 4x4 matrix """
# input_shape = self.data.shape
# mat = ToMatrix.apply(self.group_id, self.data.reshape(-1, self.embedded_dim))
# return mat.view(input_shape[:-1] + (4,4))
def matrix(self): def matrix(self):
""" convert element to 4x4 matrix """ """ convert element to 4x4 matrix """
I = torch.eye(4, dtype=self.dtype, device=self.device) I = torch.eye(4, dtype=self.dtype, device=self.device)
I = I.view([1] * (len(self.data.shape) - 1) + [4, 4]) I = I.view([1] * (len(self.data.shape) - 1) + [4, 4])
return self.__class__(self.data[...,None,:]).act(I).transpose(-1,-2) return self.__class__(self.data[...,None,:]).act(I).transpose(-1,-2)
def translation(self):
""" extract translation component """
p = torch.as_tensor([0.0, 0.0, 0.0, 1.0], dtype=self.dtype, device=self.device)
p = p.view([1] * (len(self.data.shape) - 1) + [4,])
return self.apply_op(Act4, self.data, p)
def detach(self): def detach(self):
return self.__class__(self.data.detach()) return self.__class__(self.data.detach())
...@@ -272,9 +283,6 @@ class SE3(LieGroup): ...@@ -272,9 +283,6 @@ class SE3(LieGroup):
t = t * s.unsqueeze(-1) t = t * s.unsqueeze(-1)
return SE3(torch.cat([t, q], dim=-1)) return SE3(torch.cat([t, q], dim=-1))
def translation(self):
return ExtractTranslation.apply(self.data)
class Sim3(LieGroup): class Sim3(LieGroup):
group_name = 'Sim3' group_name = 'Sim3'
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
switch (_st) { \ switch (_st) { \
PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Double, double, __VA_ARGS__) \ PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Double, double, __VA_ARGS__) \
PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Float, float, __VA_ARGS__) \ PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Float, float, __VA_ARGS__) \
default: break; \
} \ } \
}() }()
......
...@@ -32,7 +32,14 @@ std::vector<torch::Tensor> act_backward_cpu(int, torch::Tensor, torch::Tensor, t ...@@ -32,7 +32,14 @@ std::vector<torch::Tensor> act_backward_cpu(int, torch::Tensor, torch::Tensor, t
torch::Tensor act4_forward_cpu(int, torch::Tensor, torch::Tensor); torch::Tensor act4_forward_cpu(int, torch::Tensor, torch::Tensor);
std::vector<torch::Tensor> act4_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor); std::vector<torch::Tensor> act4_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
// conversion operations
// std::vector<torch::Tensor> to_vec_backward_cpu(int, torch::Tensor, torch::Tensor);
// std::vector<torch::Tensor> from_vec_backward_cpu(int, torch::Tensor, torch::Tensor);
// utility operations // utility operations
torch::Tensor orthogonal_projector_cpu(int, torch::Tensor);
torch::Tensor as_matrix_forward_cpu(int, torch::Tensor); torch::Tensor as_matrix_forward_cpu(int, torch::Tensor);
torch::Tensor jleft_forward_cpu(int, torch::Tensor, torch::Tensor); torch::Tensor jleft_forward_cpu(int, torch::Tensor, torch::Tensor);
......
...@@ -34,7 +34,12 @@ std::vector<torch::Tensor> act_backward_gpu(int, torch::Tensor, torch::Tensor, t ...@@ -34,7 +34,12 @@ std::vector<torch::Tensor> act_backward_gpu(int, torch::Tensor, torch::Tensor, t
torch::Tensor act4_forward_gpu(int, torch::Tensor, torch::Tensor); torch::Tensor act4_forward_gpu(int, torch::Tensor, torch::Tensor);
std::vector<torch::Tensor> act4_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor); std::vector<torch::Tensor> act4_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
// conversion operations
// std::vector<torch::Tensor> to_vec_backward_gpu(int, torch::Tensor, torch::Tensor);
// std::vector<torch::Tensor> from_vec_backward_gpu(int, torch::Tensor, torch::Tensor);
// utility operators // utility operators
torch::Tensor orthogonal_projector_gpu(int, torch::Tensor);
torch::Tensor as_matrix_forward_gpu(int, torch::Tensor); torch::Tensor as_matrix_forward_gpu(int, torch::Tensor);
......
...@@ -84,6 +84,23 @@ class RxSO3 { ...@@ -84,6 +84,23 @@ class RxSO3 {
return T; return T;
} }
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,5,5> orthogonal_projector() const {
// jacobian action on a point
Eigen::Matrix<Scalar,5,5> J = Eigen::Matrix<Scalar,5,5>::Zero();
J.template block<3,3>(0,0) = 0.5 * (
unit_quaternion.w() * Matrix3::Identity() +
SO3<Scalar>::hat(-unit_quaternion.vec())
);
J.template block<1,3>(3,0) = 0.5 * (-unit_quaternion.vec());
// scale
J(4,3) = scale;
return J;
}
EIGEN_DEVICE_FUNC Transformation Rotation() const { EIGEN_DEVICE_FUNC Transformation Rotation() const {
return unit_quaternion.toRotationMatrix(); return unit_quaternion.toRotationMatrix();
} }
......
...@@ -111,6 +111,16 @@ class SE3 { ...@@ -111,6 +111,16 @@ class SE3 {
return ad; return ad;
} }
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,7,7> orthogonal_projector() const {
// jacobian action on a point
Eigen::Matrix<Scalar,7,7> J = Eigen::Matrix<Scalar,7,7>::Zero();
J.template block<3,3>(0,0) = Matrix3::Identity();
J.template block<3,3>(0,3) = SO3<Scalar>::hat(-translation);
J.template block<4,4>(3,3) = so3.orthogonal_projector();
return J;
}
EIGEN_DEVICE_FUNC Tangent Log() const { EIGEN_DEVICE_FUNC Tangent Log() const {
Vector3 phi = so3.Log(); Vector3 phi = so3.Log();
Matrix3 Vinv = SO3<Scalar>::left_jacobian_inverse(phi); Matrix3 Vinv = SO3<Scalar>::left_jacobian_inverse(phi);
...@@ -164,7 +174,6 @@ class SE3 { ...@@ -164,7 +174,6 @@ class SE3 {
EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& tau_phi) { EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& tau_phi) {
// left jacobian // left jacobian
Vector3 tau = tau_phi.template segment<3>(0);
Vector3 phi = tau_phi.template segment<3>(3); Vector3 phi = tau_phi.template segment<3>(3);
Matrix3 J = SO3<Scalar>::left_jacobian(phi); Matrix3 J = SO3<Scalar>::left_jacobian(phi);
Matrix3 Q = SE3<Scalar>::calcQ(tau_phi); Matrix3 Q = SE3<Scalar>::calcQ(tau_phi);
...@@ -207,6 +216,9 @@ class SE3 { ...@@ -207,6 +216,9 @@ class SE3 {
return J; return J;
} }
private: private:
SO3<Scalar> so3; SO3<Scalar> so3;
Vector3 translation; Vector3 translation;
......
...@@ -76,6 +76,16 @@ class Sim3 { ...@@ -76,6 +76,16 @@ class Sim3 {
return T; return T;
} }
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,8,8> orthogonal_projector() const {
// jacobian action on a point
Eigen::Matrix<Scalar,8,8> J = Eigen::Matrix<Scalar,8,8>::Zero();
J.template block<3,3>(0,0) = Matrix3::Identity();
J.template block<3,3>(0,3) = SO3<Scalar>::hat(-translation);
J.template block<3,1>(0,6) = translation;
J.template block<5,5>(3,3) = rxso3.orthogonal_projector();
return J;
}
EIGEN_DEVICE_FUNC Adjoint Adj() const { EIGEN_DEVICE_FUNC Adjoint Adj() const {
Adjoint Ad = Adjoint::Identity(); Adjoint Ad = Adjoint::Identity();
Matrix3 sR = rxso3.Matrix(); Matrix3 sR = rxso3.Matrix();
......
...@@ -78,6 +78,18 @@ class SO3 { ...@@ -78,6 +78,18 @@ class SO3 {
return T; return T;
} }
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,4,4> orthogonal_projector() const {
// jacobian action on a point
Eigen::Matrix<Scalar,4,4> J = Eigen::Matrix<Scalar,4,4>::Zero();
J.template block<3,3>(0,0) = 0.5 * (
unit_quaternion.w() * Matrix3::Identity() +
SO3<Scalar>::hat(-unit_quaternion.vec())
);
J.template block<1,3>(3,0) = 0.5 * (-unit_quaternion.vec());
return J;
}
EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const { EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
return Adj() * a; return Adj() * a;
} }
...@@ -174,7 +186,7 @@ class SO3 { ...@@ -174,7 +186,7 @@ class SO3 {
Scalar(1.0/6.0) - Scalar(1.0/120.0) * theta2 : Scalar(1.0/6.0) - Scalar(1.0/120.0) * theta2 :
(theta - sin(theta)) / (theta2 * theta); (theta - sin(theta)) / (theta2 * theta);
return I + coef1 * Phi + coef2 * Phi * Phi; return I + coef1 * Phi + coef2 * Phi2;
} }
EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& phi) { EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& phi) {
......
...@@ -161,50 +161,96 @@ def test_matrix_grad(Group, device='cuda'): ...@@ -161,50 +161,96 @@ def test_matrix_grad(Group, device='cuda'):
print("\t-", Group, "Passed matrix-grad test") print("\t-", Group, "Passed matrix-grad test")
def scale(device='cuda'): def extract_translation_grad(Group, device='cuda'):
""" prototype function """
def fn(a, s):
X = SE3.exp(a)
X.scale(s)
return X.log()
s = torch.rand(1, requires_grad=True, device=device).double() D = Group.manifold_dim
a = torch.randn(1, 6, requires_grad=True, device=device).double() X = Group.exp(5*torch.randn(1,D, device=device).double())
analytical, numerical = gradcheck(fn, [a, s], eps=1e-3) def fn(a):
print(analytical[1]) return (Group.exp(a)*X).translation()
print(numerical[1])
assert torch.allclose(analytical[0], numerical[0], atol=1e-8) a = torch.zeros(1, D, requires_grad=True, device=device).double()
assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
print("\t-", "Passed se3-to-sim3 test") analytical, numerical = gradcheck(fn, [a], eps=1e-4)
assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
print("\t-", Group, "Passed translation grad test")
def extract_translation(Group, device='cuda'): def test_vec_grad(Group, device='cuda', tol=1e-6):
""" prototype function """
D = Group.manifold_dim D = Group.manifold_dim
X = Group.exp(5*torch.randn(1,D, device=device).double()) X = Group.exp(5*torch.randn(1,D, device=device).double())
def fn(a): def fn(a):
return (Group.exp(a)*X).translation() return (Group.exp(a)*X).vec()
a = torch.zeros(1, D, requires_grad=True, device=device).double() a = torch.zeros(1, D, requires_grad=True, device=device).double()
analytical, numerical = gradcheck(fn, [a], eps=1e-4) analytical, numerical = gradcheck(fn, [a], eps=1e-4)
print(analytical[0]) assert torch.allclose(analytical[0], numerical[0], atol=tol)
print(numerical[0]) print("\t-", Group, "Passed tovec grad test")
def test_fromvec_grad(Group, device='cuda', tol=1e-6):
def fn(a):
if Group == SO3:
a = a / a.norm(dim=-1, keepdim=True)
elif Group == RxSO3:
q, s = a.split([4, 1], dim=-1)
q = q / q.norm(dim=-1, keepdim=True)
a = torch.cat([q, s.exp()], dim=-1)
elif Group == SE3:
t, q = a.split([3, 4], dim=-1)
q = q / q.norm(dim=-1, keepdim=True)
a = torch.cat([t, q], dim=-1)
elif Group == Sim3:
t, q, s = a.split([3, 4, 1], dim=-1)
q = q / q.norm(dim=-1, keepdim=True)
a = torch.cat([t, q, s.exp()], dim=-1)
return Group.InitFromVec(a).vec()
D = Group.embedded_dim
a = torch.randn(1, D, requires_grad=True, device=device).double()
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
assert torch.allclose(analytical[0], numerical[0], atol=tol)
print("\t-", Group, "Passed fromvec grad test")
def scale(device='cuda'):
def fn(a, s):
X = SE3.exp(a)
X.scale(s)
return X.log()
s = torch.rand(1, requires_grad=True, device=device).double()
a = torch.randn(1, 6, requires_grad=True, device=device).double()
analytical, numerical = gradcheck(fn, [a, s], eps=1e-3)
print(analytical[1])
print(numerical[1])
assert torch.allclose(analytical[0], numerical[0], atol=1e-8) assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
print("\t-", Group, "Passed translation test") assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
print("\t-", "Passed se3-to-sim3 test")
if __name__ == '__main__': if __name__ == '__main__':
print("Testing lietorch forward pass (CPU) ...") print("Testing lietorch forward pass (CPU) ...")
for Group in [SO3, RxSO3, SE3, Sim3]: for Group in [SO3, RxSO3, SE3, Sim3]:
test_exp_log(Group, device='cpu') test_exp_log(Group, device='cpu')
...@@ -225,7 +271,9 @@ if __name__ == '__main__': ...@@ -225,7 +271,9 @@ if __name__ == '__main__':
test_adjT_grad(Group, device='cpu') test_adjT_grad(Group, device='cpu')
test_act_grad(Group, device='cpu') test_act_grad(Group, device='cpu')
test_matrix_grad(Group, device='cpu') test_matrix_grad(Group, device='cpu')
extract_translation_grad(Group, device='cpu')
test_vec_grad(Group, device='cpu')
test_fromvec_grad(Group, device='cpu')
print("Testing lietorch forward pass (GPU) ...") print("Testing lietorch forward pass (GPU) ...")
for Group in [SO3, RxSO3, SE3, Sim3]: for Group in [SO3, RxSO3, SE3, Sim3]:
...@@ -247,5 +295,8 @@ if __name__ == '__main__': ...@@ -247,5 +295,8 @@ if __name__ == '__main__':
test_adjT_grad(Group, device='cuda') test_adjT_grad(Group, device='cuda')
test_act_grad(Group, device='cuda') test_act_grad(Group, device='cuda')
test_matrix_grad(Group, device='cuda') test_matrix_grad(Group, device='cuda')
extract_translation_grad(Group, device='cuda')
test_vec_grad(Group, device='cuda')
test_fromvec_grad(Group, device='cuda')
...@@ -23,6 +23,8 @@ torch::Tensor expm(int group_index, torch::Tensor a) { ...@@ -23,6 +23,8 @@ torch::Tensor expm(int group_index, torch::Tensor a) {
} else if (a.device().type() == torch::DeviceType::CUDA) { } else if (a.device().type() == torch::DeviceType::CUDA) {
return exp_forward_gpu(group_index, a); return exp_forward_gpu(group_index, a);
} }
return a;
} }
std::vector<torch::Tensor> expm_backward(int group_index, torch::Tensor grad, torch::Tensor a) { std::vector<torch::Tensor> expm_backward(int group_index, torch::Tensor grad, torch::Tensor a) {
...@@ -34,6 +36,8 @@ std::vector<torch::Tensor> expm_backward(int group_index, torch::Tensor grad, to ...@@ -34,6 +36,8 @@ std::vector<torch::Tensor> expm_backward(int group_index, torch::Tensor grad, to
} else if (a.device().type() == torch::DeviceType::CUDA) { } else if (a.device().type() == torch::DeviceType::CUDA) {
return exp_backward_gpu(group_index, grad, a); return exp_backward_gpu(group_index, grad, a);
} }
return {};
} }
torch::Tensor logm(int group_index, torch::Tensor X) { torch::Tensor logm(int group_index, torch::Tensor X) {
...@@ -44,6 +48,8 @@ torch::Tensor logm(int group_index, torch::Tensor X) { ...@@ -44,6 +48,8 @@ torch::Tensor logm(int group_index, torch::Tensor X) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return log_forward_gpu(group_index, X); return log_forward_gpu(group_index, X);
} }
return X;
} }
std::vector<torch::Tensor> logm_backward(int group_index, torch::Tensor grad, torch::Tensor X) { std::vector<torch::Tensor> logm_backward(int group_index, torch::Tensor grad, torch::Tensor X) {
...@@ -56,6 +62,8 @@ std::vector<torch::Tensor> logm_backward(int group_index, torch::Tensor grad, to ...@@ -56,6 +62,8 @@ std::vector<torch::Tensor> logm_backward(int group_index, torch::Tensor grad, to
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return log_backward_gpu(group_index, grad, X); return log_backward_gpu(group_index, grad, X);
} }
return {};
} }
torch::Tensor inv(int group_index, torch::Tensor X) { torch::Tensor inv(int group_index, torch::Tensor X) {
...@@ -66,6 +74,8 @@ torch::Tensor inv(int group_index, torch::Tensor X) { ...@@ -66,6 +74,8 @@ torch::Tensor inv(int group_index, torch::Tensor X) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return inv_forward_gpu(group_index, X); return inv_forward_gpu(group_index, X);
} }
return X;
} }
std::vector<torch::Tensor> inv_backward(int group_index, torch::Tensor grad, torch::Tensor X) { std::vector<torch::Tensor> inv_backward(int group_index, torch::Tensor grad, torch::Tensor X) {
...@@ -78,6 +88,8 @@ std::vector<torch::Tensor> inv_backward(int group_index, torch::Tensor grad, tor ...@@ -78,6 +88,8 @@ std::vector<torch::Tensor> inv_backward(int group_index, torch::Tensor grad, tor
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return inv_backward_gpu(group_index, grad, X); return inv_backward_gpu(group_index, grad, X);
} }
return {};
} }
// Binary operations // Binary operations
...@@ -92,6 +104,8 @@ torch::Tensor mul(int group_index, torch::Tensor X, torch::Tensor Y) { ...@@ -92,6 +104,8 @@ torch::Tensor mul(int group_index, torch::Tensor X, torch::Tensor Y) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return mul_forward_gpu(group_index, X, Y); return mul_forward_gpu(group_index, X, Y);
} }
return X;
} }
std::vector<torch::Tensor> mul_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor Y) { std::vector<torch::Tensor> mul_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor Y) {
...@@ -105,6 +119,8 @@ std::vector<torch::Tensor> mul_backward(int group_index, torch::Tensor grad, tor ...@@ -105,6 +119,8 @@ std::vector<torch::Tensor> mul_backward(int group_index, torch::Tensor grad, tor
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return mul_backward_gpu(group_index, grad, X, Y); return mul_backward_gpu(group_index, grad, X, Y);
} }
return {};
} }
torch::Tensor adj(int group_index, torch::Tensor X, torch::Tensor a) { torch::Tensor adj(int group_index, torch::Tensor X, torch::Tensor a) {
...@@ -117,6 +133,8 @@ torch::Tensor adj(int group_index, torch::Tensor X, torch::Tensor a) { ...@@ -117,6 +133,8 @@ torch::Tensor adj(int group_index, torch::Tensor X, torch::Tensor a) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return adj_forward_gpu(group_index, X, a); return adj_forward_gpu(group_index, X, a);
} }
return X;
} }
std::vector<torch::Tensor> adj_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor a) { std::vector<torch::Tensor> adj_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
...@@ -130,6 +148,8 @@ std::vector<torch::Tensor> adj_backward(int group_index, torch::Tensor grad, tor ...@@ -130,6 +148,8 @@ std::vector<torch::Tensor> adj_backward(int group_index, torch::Tensor grad, tor
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return adj_backward_gpu(group_index, grad, X, a); return adj_backward_gpu(group_index, grad, X, a);
} }
return {};
} }
torch::Tensor adjT(int group_index, torch::Tensor X, torch::Tensor a) { torch::Tensor adjT(int group_index, torch::Tensor X, torch::Tensor a) {
...@@ -142,6 +162,8 @@ torch::Tensor adjT(int group_index, torch::Tensor X, torch::Tensor a) { ...@@ -142,6 +162,8 @@ torch::Tensor adjT(int group_index, torch::Tensor X, torch::Tensor a) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return adjT_forward_gpu(group_index, X, a); return adjT_forward_gpu(group_index, X, a);
} }
return X;
} }
std::vector<torch::Tensor> adjT_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor a) { std::vector<torch::Tensor> adjT_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
...@@ -155,6 +177,8 @@ std::vector<torch::Tensor> adjT_backward(int group_index, torch::Tensor grad, to ...@@ -155,6 +177,8 @@ std::vector<torch::Tensor> adjT_backward(int group_index, torch::Tensor grad, to
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return adjT_backward_gpu(group_index, grad, X, a); return adjT_backward_gpu(group_index, grad, X, a);
} }
return {};
} }
...@@ -168,6 +192,8 @@ torch::Tensor act(int group_index, torch::Tensor X, torch::Tensor p) { ...@@ -168,6 +192,8 @@ torch::Tensor act(int group_index, torch::Tensor X, torch::Tensor p) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return act_forward_gpu(group_index, X, p); return act_forward_gpu(group_index, X, p);
} }
return X;
} }
std::vector<torch::Tensor> act_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor p) { std::vector<torch::Tensor> act_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
...@@ -181,6 +207,8 @@ std::vector<torch::Tensor> act_backward(int group_index, torch::Tensor grad, tor ...@@ -181,6 +207,8 @@ std::vector<torch::Tensor> act_backward(int group_index, torch::Tensor grad, tor
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return act_backward_gpu(group_index, grad, X, p); return act_backward_gpu(group_index, grad, X, p);
} }
return {};
} }
torch::Tensor act4(int group_index, torch::Tensor X, torch::Tensor p) { torch::Tensor act4(int group_index, torch::Tensor X, torch::Tensor p) {
...@@ -193,6 +221,8 @@ torch::Tensor act4(int group_index, torch::Tensor X, torch::Tensor p) { ...@@ -193,6 +221,8 @@ torch::Tensor act4(int group_index, torch::Tensor X, torch::Tensor p) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return act4_forward_gpu(group_index, X, p); return act4_forward_gpu(group_index, X, p);
} }
return X;
} }
std::vector<torch::Tensor> act4_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor p) { std::vector<torch::Tensor> act4_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
...@@ -206,8 +236,25 @@ std::vector<torch::Tensor> act4_backward(int group_index, torch::Tensor grad, to ...@@ -206,8 +236,25 @@ std::vector<torch::Tensor> act4_backward(int group_index, torch::Tensor grad, to
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return act4_backward_gpu(group_index, grad, X, p); return act4_backward_gpu(group_index, grad, X, p);
} }
return {};
}
torch::Tensor projector(int group_index, torch::Tensor X) {
CHECK_CONTIGUOUS(X);
if (X.device().type() == torch::DeviceType::CPU) {
return orthogonal_projector_cpu(group_index, X);
} else if (X.device().type() == torch::DeviceType::CUDA) {
return orthogonal_projector_gpu(group_index, X);
}
return X;
} }
torch::Tensor as_matrix(int group_index, torch::Tensor X) { torch::Tensor as_matrix(int group_index, torch::Tensor X) {
CHECK_CONTIGUOUS(X); CHECK_CONTIGUOUS(X);
...@@ -217,6 +264,8 @@ torch::Tensor as_matrix(int group_index, torch::Tensor X) { ...@@ -217,6 +264,8 @@ torch::Tensor as_matrix(int group_index, torch::Tensor X) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return as_matrix_forward_gpu(group_index, X); return as_matrix_forward_gpu(group_index, X);
} }
return X;
} }
torch::Tensor Jinv(int group_index, torch::Tensor X, torch::Tensor a) { torch::Tensor Jinv(int group_index, torch::Tensor X, torch::Tensor a) {
...@@ -229,6 +278,8 @@ torch::Tensor Jinv(int group_index, torch::Tensor X, torch::Tensor a) { ...@@ -229,6 +278,8 @@ torch::Tensor Jinv(int group_index, torch::Tensor X, torch::Tensor a) {
} else if (X.device().type() == torch::DeviceType::CUDA) { } else if (X.device().type() == torch::DeviceType::CUDA) {
return jleft_forward_gpu(group_index, X, a); return jleft_forward_gpu(group_index, X, a);
} }
return a;
} }
// {exp, log, inv, mul, adj, adjT, act, act4} forward/backward bindings // {exp, log, inv, mul, adj, adjT, act, act4} forward/backward bindings
...@@ -259,6 +310,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -259,6 +310,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// functions with no gradient // functions with no gradient
m.def("as_matrix", &as_matrix, "convert to matrix"); m.def("as_matrix", &as_matrix, "convert to matrix");
m.def("projector", &projector, "orthogonal projection matrix");
m.def("Jinv", &Jinv, "left inverse jacobian operator"); m.def("Jinv", &Jinv, "left inverse jacobian operator");
}; };
......
...@@ -31,7 +31,6 @@ void exp_backward_kernel(const scalar_t* grad, const scalar_t* a_ptr, scalar_t* ...@@ -31,7 +31,6 @@ void exp_backward_kernel(const scalar_t* grad, const scalar_t* a_ptr, scalar_t*
// exponential map backward kernel // exponential map backward kernel
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>; using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
using Grad = Eigen::Matrix<scalar_t,1,Group::K>; using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) { at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
for (int64_t i=start; i<end; i++) { for (int64_t i=start; i<end; i++) {
...@@ -46,7 +45,6 @@ template <typename Group, typename scalar_t> ...@@ -46,7 +45,6 @@ template <typename Group, typename scalar_t>
void log_forward_kernel(const scalar_t* X_ptr, scalar_t* a_ptr, int batch_size) { void log_forward_kernel(const scalar_t* X_ptr, scalar_t* a_ptr, int batch_size) {
// logarithm map forward kernel // logarithm map forward kernel
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>; using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) { at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
for (int64_t i=start; i<end; i++) { for (int64_t i=start; i<end; i++) {
...@@ -61,7 +59,6 @@ void log_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* ...@@ -61,7 +59,6 @@ void log_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t*
// logarithm map backward kernel // logarithm map backward kernel
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>; using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
using Grad = Eigen::Matrix<scalar_t,1,Group::K>; using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) { at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
for (int64_t i=start; i<end; i++) { for (int64_t i=start; i<end; i++) {
...@@ -91,7 +88,6 @@ void inv_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t * ...@@ -91,7 +88,6 @@ void inv_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t *
// group inverse backward kernel // group inverse backward kernel
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>; using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
using Grad = Eigen::Matrix<scalar_t,1,Group::K>; using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) { at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
for (int64_t i=start; i<end; i++) { for (int64_t i=start; i<end; i++) {
...@@ -241,6 +237,38 @@ void act_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scal ...@@ -241,6 +237,38 @@ void act_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scal
}); });
} }
// template <typename Group, typename scalar_t>
// void tovec_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int batch_size) {
// // group inverse forward kernel
// using Data = Eigen::Matrix<scalar_t,Group::N,1>;
// using Grad = Eigen::Matrix<scalar_t,1,Group::N>;
// at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
// for (int64_t i=start; i<end; i++) {
// Group X(X_ptr + i*Group::N);
// Grad g(grad + i*Group::N);
// Eigen::Map<Grad>(dX + i*Group::N) = g * X.vec_jacobian();
// }
// });
// }
// template <typename Group, typename scalar_t>
// void fromvec_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int batch_size) {
// // group inverse forward kernel
// using Data = Eigen::Matrix<scalar_t,Group::N,1>;
// using Grad = Eigen::Matrix<scalar_t,1,Group::N>;
// at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
// for (int64_t i=start; i<end; i++) {
// Group X(X_ptr + i*Group::N);
// Grad g(grad + i*Group::N);
// Eigen::Map<Grad>(dX + i*Group::N) = g * X.vec_jacobian();
// }
// });
// }
template <typename Group, typename scalar_t> template <typename Group, typename scalar_t>
void act4_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int batch_size) { void act4_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int batch_size) {
// action on homogeneous point forward kernel // action on homogeneous point forward kernel
...@@ -295,6 +323,18 @@ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, int batch_ ...@@ -295,6 +323,18 @@ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, int batch_
}); });
} }
template <typename Group, typename scalar_t>
void orthogonal_projector_kernel(const scalar_t* X_ptr, scalar_t* P_ptr, int batch_size) {
// group inverse forward kernel
using Proj = Eigen::Matrix<scalar_t,Group::N,Group::N,Eigen::RowMajor>;
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
for (int64_t i=start; i<end; i++) {
Group X(X_ptr + i*Group::N);
Eigen::Map<Proj>(P_ptr + i*Group::N*Group::N) = X.orthogonal_projector();
}
});
}
template <typename Group, typename scalar_t> template <typename Group, typename scalar_t>
void jleft_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int batch_size) { void jleft_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int batch_size) {
// left-jacobian inverse action // left-jacobian inverse action
...@@ -586,6 +626,21 @@ torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) { ...@@ -586,6 +626,21 @@ torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) {
return T4x4; return T4x4;
} }
torch::Tensor orthogonal_projector_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor P;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] {
P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
orthogonal_projector_kernel<group_t, scalar_t>(X.data_ptr<scalar_t>(), P.data_ptr<scalar_t>(), batch_size);
}));
return P;
}
torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) { torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options()); torch::Tensor b = torch::zeros(a.sizes(), a.options());
......
...@@ -215,7 +215,7 @@ __global__ void act_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, ...@@ -215,7 +215,7 @@ __global__ void act_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr,
Point p(p_ptr + i*3); Point p(p_ptr + i*3);
PointGrad dq(grad + i*3); PointGrad dq(grad + i*3);
Eigen::Map<PointGrad>(dp + i*3) = dq * X.Matrix().block<3,3>(0,0); Eigen::Map<PointGrad>(dp + i*3) = dq * X.Matrix4x4().block<3,3>(0,0);
Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act_jacobian(X*p); Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act_jacobian(X*p);
} }
} }
...@@ -235,7 +235,6 @@ __global__ void act4_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr ...@@ -235,7 +235,6 @@ __global__ void act4_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr
} }
} }
template <typename Group, typename scalar_t> template <typename Group, typename scalar_t>
__global__ void act4_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int num_threads) { __global__ void act4_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int num_threads) {
// adjoint backward kernel // adjoint backward kernel
...@@ -256,10 +255,9 @@ __global__ void act4_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr ...@@ -256,10 +255,9 @@ __global__ void act4_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr
} }
} }
template <typename Group, typename scalar_t> template <typename Group, typename scalar_t>
__global__ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, int num_threads) { __global__ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, int num_threads) {
// group inverse forward kernel // convert to 4x4 matrix representation
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>; using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
using Data = Eigen::Matrix<scalar_t,Group::N,1>; using Data = Eigen::Matrix<scalar_t,Group::N,1>;
using Matrix4 = Eigen::Matrix<scalar_t,4,4,Eigen::RowMajor>; using Matrix4 = Eigen::Matrix<scalar_t,4,4,Eigen::RowMajor>;
...@@ -270,6 +268,17 @@ __global__ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, ...@@ -270,6 +268,17 @@ __global__ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr,
} }
} }
template <typename Group, typename scalar_t>
__global__ void orthogonal_projector_kernel(const scalar_t* X_ptr, scalar_t* P_ptr, int num_threads) {
// orthogonal projection matrix
using Proj = Eigen::Matrix<scalar_t,Group::N,Group::N,Eigen::RowMajor>;
GPU_1D_KERNEL_LOOP(i, num_threads) {
Group X(X_ptr + i*Group::N);
Eigen::Map<Proj>(P_ptr + i*Group::N*Group::N) = X.orthogonal_projector();
}
}
template <typename Group, typename scalar_t> template <typename Group, typename scalar_t>
__global__ void jleft_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int num_threads) { __global__ void jleft_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int num_threads) {
// left jacobian inverse action // left jacobian inverse action
...@@ -560,6 +569,22 @@ torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) { ...@@ -560,6 +569,22 @@ torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) {
} }
torch::Tensor orthogonal_projector_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor P;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] {
P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
orthogonal_projector_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
P.data_ptr<scalar_t>(),
batch_size);
}));
return P;
}
torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) { torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options()); torch::Tensor b = torch::zeros(a.sizes(), a.options());
......
...@@ -5,11 +5,10 @@ import os.path as osp ...@@ -5,11 +5,10 @@ import os.path as osp
ROOT = osp.dirname(osp.abspath(__file__)) ROOT = osp.dirname(osp.abspath(__file__))
print(ROOT)
setup( setup(
name='lietorch', name='lietorch',
version='0.1', version='0.2',
description='Lie Groups for PyTorch', description='Lie Groups for PyTorch',
author='teedrz', author='teedrz',
packages=['lietorch'], packages=['lietorch'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment