Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
lietorch
Commits
df566ea4
Commit
df566ea4
authored
Jul 18, 2021
by
zachteed
Browse files
updated test script for matrix function
parent
67b2ec3f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
4 deletions
+3
-4
lietorch/run_tests.py
lietorch/run_tests.py
+3
-4
No files found.
lietorch/run_tests.py
View file @
df566ea4
...
@@ -149,15 +149,14 @@ def test_act_grad(Group, device='cuda'):
...
@@ -149,15 +149,14 @@ def test_act_grad(Group, device='cuda'):
def
test_matrix_grad
(
Group
,
device
=
'cuda'
):
def
test_matrix_grad
(
Group
,
device
=
'cuda'
):
D
=
Group
.
manifold_dim
D
=
Group
.
manifold_dim
X
=
Group
.
exp
(
5
*
torch
.
randn
(
1
,
2
,
3
,
D
,
device
=
device
).
double
())
X
=
Group
.
exp
(
torch
.
randn
(
1
,
D
,
device
=
device
).
double
())
def
fn
(
a
):
def
fn
(
a
):
return
(
Group
.
exp
(
a
)
*
X
).
matrix
()
return
(
Group
.
exp
(
a
)
*
X
).
matrix
()
a
=
torch
.
zeros
(
1
,
2
,
3
,
D
,
requires_grad
=
True
,
device
=
device
).
double
()
a
=
torch
.
zeros
(
1
,
D
,
requires_grad
=
True
,
device
=
device
).
double
()
analytical
,
numerical
=
gradcheck
(
fn
,
[
a
],
eps
=
1e-4
)
analytical
,
numerical
=
gradcheck
(
fn
,
[
a
],
eps
=
1e-4
)
assert
torch
.
allclose
(
analytical
[
0
],
numerical
[
0
],
atol
=
1e-6
)
assert
torch
.
allclose
(
analytical
[
0
],
numerical
[
0
],
atol
=
1e-8
)
print
(
"
\t
-"
,
Group
,
"Passed matrix-grad test"
)
print
(
"
\t
-"
,
Group
,
"Passed matrix-grad test"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment