Commit 67b2ec3f authored by zachteed's avatar zachteed
Browse files

ToMatrix gradient implemented

parent 355a5174
...@@ -166,11 +166,17 @@ class LieGroup: ...@@ -166,11 +166,17 @@ class LieGroup:
elif p.shape[-1] == 4: elif p.shape[-1] == 4:
return self.apply_op(Act4, self.data, p) return self.apply_op(Act4, self.data, p)
# def matrix(self):
# """ convert element to 4x4 matrix """
# input_shape = self.data.shape
# mat = ToMatrix.apply(self.group_id, self.data.reshape(-1, self.embedded_dim))
# return mat.view(input_shape[:-1] + (4,4))
def matrix(self): def matrix(self):
""" convert element to 4x4 matrix """ """ convert element to 4x4 matrix """
input_shape = self.data.shape I = torch.eye(4, dtype=self.dtype, device=self.device)
mat = ToMatrix.apply(self.group_id, self.data.reshape(-1, self.embedded_dim)) I = I.view([1] * (len(self.data.shape) - 1) + [4, 4])
return mat.view(input_shape[:-1] + (4,4)) return self.__class__(self.data[...,None,:]).act(I).transpose(-1,-2)
def detach(self): def detach(self):
return self.__class__(self.data.detach()) return self.__class__(self.data.detach())
......
...@@ -147,6 +147,21 @@ def test_act_grad(Group, device='cuda'): ...@@ -147,6 +147,21 @@ def test_act_grad(Group, device='cuda'):
print("\t-", Group, "Passed act-grad test") print("\t-", Group, "Passed act-grad test")
def test_matrix_grad(Group, device='cuda'):
D = Group.manifold_dim
X = Group.exp(5*torch.randn(1, 2, 3, D, device=device).double())
def fn(a):
return (Group.exp(a) * X).matrix()
a = torch.zeros(1, 2, 3, D, requires_grad=True, device=device).double()
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
print("\t-", Group, "Passed matrix-grad test")
def scale(device='cuda'): def scale(device='cuda'):
def fn(a, s): def fn(a, s):
...@@ -210,6 +225,7 @@ if __name__ == '__main__': ...@@ -210,6 +225,7 @@ if __name__ == '__main__':
test_adj_grad(Group, device='cpu') test_adj_grad(Group, device='cpu')
test_adjT_grad(Group, device='cpu') test_adjT_grad(Group, device='cpu')
test_act_grad(Group, device='cpu') test_act_grad(Group, device='cpu')
test_matrix_grad(Group, device='cpu')
print("Testing lietorch forward pass (GPU) ...") print("Testing lietorch forward pass (GPU) ...")
...@@ -231,5 +247,6 @@ if __name__ == '__main__': ...@@ -231,5 +247,6 @@ if __name__ == '__main__':
test_adj_grad(Group, device='cuda') test_adj_grad(Group, device='cuda')
test_adjT_grad(Group, device='cuda') test_adjT_grad(Group, device='cuda')
test_act_grad(Group, device='cuda') test_act_grad(Group, device='cuda')
test_matrix_grad(Group, device='cuda')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment