Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
5e006b95
Commit
5e006b95
authored
Mar 11, 2018
by
rusty1s
Browse files
added bias test
parent
f8ca386a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
15 deletions
+13
-15
test/test_spline_conv.py
test/test_spline_conv.py
+10
-10
torch_spline_conv/functions/spline_conv.py
torch_spline_conv/functions/spline_conv.py
+0
-3
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+3
-2
No files found.
test/test_spline_conv.py
View file @
5e006b95
...
@@ -17,9 +17,10 @@ def test_spline_conv_cpu(tensor):
...
@@ -17,9 +17,10 @@ def test_spline_conv_cpu(tensor):
kernel_size
=
torch
.
LongTensor
([
3
,
4
])
kernel_size
=
torch
.
LongTensor
([
3
,
4
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
0
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
0
])
root_weight
=
torch
.
arange
(
12.5
,
13.5
,
step
=
0.5
,
out
=
x
.
new
()).
view
(
2
,
1
)
root_weight
=
torch
.
arange
(
12.5
,
13.5
,
step
=
0.5
,
out
=
x
.
new
()).
view
(
2
,
1
)
bias
=
Tensor
(
tensor
,
[
1
])
output
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
output
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
root_weight
)
is_open_spline
,
root_weight
,
1
,
bias
)
edgewise_output
=
[
edgewise_output
=
[
1
*
0.25
*
(
0.5
+
1.5
+
4.5
+
5.5
)
+
2
*
0.25
*
(
1
+
2
+
5
+
6
),
1
*
0.25
*
(
0.5
+
1.5
+
4.5
+
5.5
)
+
2
*
0.25
*
(
1
+
2
+
5
+
6
),
...
@@ -29,21 +30,20 @@ def test_spline_conv_cpu(tensor):
...
@@ -29,21 +30,20 @@ def test_spline_conv_cpu(tensor):
]
]
expected_output
=
[
expected_output
=
[
[
12.5
*
9
+
13
*
10
+
sum
(
edgewise_output
)
/
4
],
[
1
+
12.5
*
9
+
13
*
10
+
sum
(
edgewise_output
)
/
4
],
[
12.5
*
1
+
13
*
2
],
[
1
+
12.5
*
1
+
13
*
2
],
[
12.5
*
3
+
13
*
4
],
[
1
+
12.5
*
3
+
13
*
4
],
[
12.5
*
5
+
13
*
6
],
[
1
+
12.5
*
5
+
13
*
6
],
[
12.5
*
7
+
13
*
8
],
[
1
+
12.5
*
7
+
13
*
8
],
]
]
assert
output
.
tolist
()
==
expected_output
assert
output
.
tolist
()
==
expected_output
x
=
Variable
(
x
,
requires_grad
=
True
)
x
,
weight
=
Variable
(
x
),
Variable
(
weight
)
weight
=
Variable
(
weight
,
requires_grad
=
True
)
root_weight
,
bias
=
Variable
(
root_weight
),
Variable
(
bias
)
root_weight
=
Variable
(
root_weight
,
requires_grad
=
True
)
output
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
output
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
root_weight
)
is_open_spline
,
root_weight
,
1
,
bias
)
assert
output
.
data
.
tolist
()
==
expected_output
assert
output
.
data
.
tolist
()
==
expected_output
...
...
torch_spline_conv/functions/spline_conv.py
View file @
5e006b95
...
@@ -15,9 +15,6 @@ def spline_conv(x,
...
@@ -15,9 +15,6 @@ def spline_conv(x,
degree
=
1
,
degree
=
1
,
bias
=
None
):
bias
=
None
):
# TODO: degree of 0
# TODO: kernel size of 1
n
,
e
=
x
.
size
(
0
),
edge_index
.
size
(
1
)
n
,
e
=
x
.
size
(
0
),
edge_index
.
size
(
1
)
K
,
m_in
,
m_out
=
weight
.
size
()
K
,
m_in
,
m_out
=
weight
.
size
()
...
...
torch_spline_conv/functions/utils.py
View file @
5e006b95
...
@@ -35,7 +35,8 @@ def spline_weighting_forward(x, weight, basis, weight_index):
...
@@ -35,7 +35,8 @@ def spline_weighting_forward(x, weight, basis, weight_index):
return
output
return
output
def
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
# pragma: no cover
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
# grad_weight computation via `atomic_add` => Initialize with zeros.
# grad_weight computation via `atomic_add` => Initialize with zeros.
grad_weight
=
x
.
new
(
weight
.
size
()).
fill_
(
0
)
grad_weight
=
x
.
new
(
weight
.
size
()).
fill_
(
0
)
...
@@ -55,7 +56,7 @@ class SplineWeighting(Function):
...
@@ -55,7 +56,7 @@ class SplineWeighting(Function):
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
return
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
return
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
def
backward
(
self
,
grad_output
):
def
backward
(
self
,
grad_output
):
# pragma: no cover
x
,
weight
=
self
.
saved_tensors
x
,
weight
=
self
.
saved_tensors
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
return
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
return
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment