Commit 0fa9ce8f authored by zachteed's avatar zachteed
Browse files

FromVec broadcasting, increase torch version

parent 82b02233
Pipeline #2047 canceled with stages
......@@ -25,7 +25,7 @@ Zachary Teed and Jia Deng, CVPR 2021
### Requirements:
* Cuda >= 10.1 (with nvcc compiler)
* PyTorch >= 1.7
* PyTorch >= 1.8
We recommend installing within a virtual enviornment. Make sure you clone using the `--recursive` flag. If you are using Anaconda, the following command can be used to install all dependencies
```
......
......@@ -83,7 +83,7 @@ class FromVec(torch.autograd.Function):
def backward(cls, ctx, grad):
inputs = ctx.saved_tensors
J = lietorch_backends.projector(ctx.group_id, *inputs)
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J))
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2)
class ToVec(torch.autograd.Function):
""" convert group object to vector """
......@@ -98,5 +98,5 @@ class ToVec(torch.autograd.Function):
def backward(cls, ctx, grad):
inputs = ctx.saved_tensors
J = lietorch_backends.projector(ctx.group_id, *inputs)
return None, torch.matmul(grad.unsqueeze(-2), J)
return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2)
......@@ -218,7 +218,7 @@ def test_fromvec_grad(Group, device='cuda', tol=1e-6):
return Group.InitFromVec(a).vec()
D = Group.embedded_dim
a = torch.randn(1, D, requires_grad=True, device=device).double()
a = torch.randn(1, 2, D, requires_grad=True, device=device).double()
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment