Commit df566ea4 authored by zachteed's avatar zachteed
Browse files

updated test script for matrix function

parent 67b2ec3f
...@@ -149,15 +149,14 @@ def test_act_grad(Group, device='cuda'): ...@@ -149,15 +149,14 @@ def test_act_grad(Group, device='cuda'):
def test_matrix_grad(Group, device='cuda'): def test_matrix_grad(Group, device='cuda'):
D = Group.manifold_dim D = Group.manifold_dim
X = Group.exp(5*torch.randn(1, 2, 3, D, device=device).double()) X = Group.exp(torch.randn(1, D, device=device).double())
def fn(a): def fn(a):
return (Group.exp(a) * X).matrix() return (Group.exp(a) * X).matrix()
a = torch.zeros(1, 2, 3, D, requires_grad=True, device=device).double() a = torch.zeros(1, D, requires_grad=True, device=device).double()
analytical, numerical = gradcheck(fn, [a], eps=1e-4) analytical, numerical = gradcheck(fn, [a], eps=1e-4)
assert torch.allclose(analytical[0], numerical[0], atol=1e-6)
assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
print("\t-", Group, "Passed matrix-grad test") print("\t-", Group, "Passed matrix-grad test")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment