Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
lietorch
Commits
67b2ec3f
Commit
67b2ec3f
authored
Jul 18, 2021
by
zachteed
Browse files
ToMatrix gradient implemented
parent
355a5174
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
3 deletions
+26
-3
lietorch/groups.py
lietorch/groups.py
+9
-3
lietorch/run_tests.py
lietorch/run_tests.py
+17
-0
No files found.
lietorch/groups.py
View file @
67b2ec3f
...
@@ -166,11 +166,17 @@ class LieGroup:
...
@@ -166,11 +166,17 @@ class LieGroup:
elif
p
.
shape
[
-
1
]
==
4
:
elif
p
.
shape
[
-
1
]
==
4
:
return
self
.
apply_op
(
Act4
,
self
.
data
,
p
)
return
self
.
apply_op
(
Act4
,
self
.
data
,
p
)
# def matrix(self):
# """ convert element to 4x4 matrix """
# input_shape = self.data.shape
# mat = ToMatrix.apply(self.group_id, self.data.reshape(-1, self.embedded_dim))
# return mat.view(input_shape[:-1] + (4,4))
def
matrix
(
self
):
def
matrix
(
self
):
""" convert element to 4x4 matrix """
""" convert element to 4x4 matrix """
input_shape
=
self
.
data
.
shape
I
=
torch
.
eye
(
4
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
mat
=
ToMatrix
.
apply
(
self
.
group_id
,
self
.
data
.
re
shape
(
-
1
,
self
.
embedded_dim
)
)
I
=
I
.
view
([
1
]
*
(
len
(
self
.
data
.
shape
)
-
1
)
+
[
4
,
4
]
)
return
mat
.
view
(
input_shape
[:
-
1
]
+
(
4
,
4
)
)
return
self
.
__class__
(
self
.
data
[...,
None
,:]).
act
(
I
).
transpose
(
-
1
,
-
2
)
def
detach
(
self
):
def
detach
(
self
):
return
self
.
__class__
(
self
.
data
.
detach
())
return
self
.
__class__
(
self
.
data
.
detach
())
...
...
lietorch/run_tests.py
View file @
67b2ec3f
...
@@ -147,6 +147,21 @@ def test_act_grad(Group, device='cuda'):
...
@@ -147,6 +147,21 @@ def test_act_grad(Group, device='cuda'):
print
(
"
\t
-"
,
Group
,
"Passed act-grad test"
)
print
(
"
\t
-"
,
Group
,
"Passed act-grad test"
)
def
test_matrix_grad
(
Group
,
device
=
'cuda'
):
D
=
Group
.
manifold_dim
X
=
Group
.
exp
(
5
*
torch
.
randn
(
1
,
2
,
3
,
D
,
device
=
device
).
double
())
def
fn
(
a
):
return
(
Group
.
exp
(
a
)
*
X
).
matrix
()
a
=
torch
.
zeros
(
1
,
2
,
3
,
D
,
requires_grad
=
True
,
device
=
device
).
double
()
analytical
,
numerical
=
gradcheck
(
fn
,
[
a
],
eps
=
1e-4
)
assert
torch
.
allclose
(
analytical
[
0
],
numerical
[
0
],
atol
=
1e-8
)
print
(
"
\t
-"
,
Group
,
"Passed matrix-grad test"
)
def
scale
(
device
=
'cuda'
):
def
scale
(
device
=
'cuda'
):
def
fn
(
a
,
s
):
def
fn
(
a
,
s
):
...
@@ -210,6 +225,7 @@ if __name__ == '__main__':
...
@@ -210,6 +225,7 @@ if __name__ == '__main__':
test_adj_grad
(
Group
,
device
=
'cpu'
)
test_adj_grad
(
Group
,
device
=
'cpu'
)
test_adjT_grad
(
Group
,
device
=
'cpu'
)
test_adjT_grad
(
Group
,
device
=
'cpu'
)
test_act_grad
(
Group
,
device
=
'cpu'
)
test_act_grad
(
Group
,
device
=
'cpu'
)
test_matrix_grad
(
Group
,
device
=
'cpu'
)
print
(
"Testing lietorch forward pass (GPU) ..."
)
print
(
"Testing lietorch forward pass (GPU) ..."
)
...
@@ -231,5 +247,6 @@ if __name__ == '__main__':
...
@@ -231,5 +247,6 @@ if __name__ == '__main__':
test_adj_grad
(
Group
,
device
=
'cuda'
)
test_adj_grad
(
Group
,
device
=
'cuda'
)
test_adjT_grad
(
Group
,
device
=
'cuda'
)
test_adjT_grad
(
Group
,
device
=
'cuda'
)
test_act_grad
(
Group
,
device
=
'cuda'
)
test_act_grad
(
Group
,
device
=
'cuda'
)
test_matrix_grad
(
Group
,
device
=
'cuda'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment