Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
781766e4
Commit
781766e4
authored
Apr 28, 2018
by
rusty1s
Browse files
rename
parent
74199575
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
41 deletions
+41
-41
README.md
README.md
+4
-4
test/test_conv.py
test/test_conv.py
+4
-4
test/test_weighting.py
test/test_weighting.py
+3
-3
torch_spline_conv/conv.py
torch_spline_conv/conv.py
+9
-9
torch_spline_conv/utils/ffi.py
torch_spline_conv/utils/ffi.py
+6
-6
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+15
-15
No files found.
README.md
View file @
781766e4
...
@@ -34,7 +34,7 @@ pip install cffi torch-spline-conv
...
@@ -34,7 +34,7 @@ pip install cffi torch-spline-conv
```
python
```
python
from
torch_spline_conv
import
SplineConv
from
torch_spline_conv
import
SplineConv
out
put
=
SplineConv
.
apply
(
src
,
out
=
SplineConv
.
apply
(
src
,
edge_index
,
edge_index
,
pseudo
,
pseudo
,
weight
,
weight
,
...
@@ -71,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
...
@@ -71,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
### Returns
### Returns
*
**out
put
**
*(Tensor)*
-
Outp
ut node features of shape
`(number_of_nodes x out_channels)`
.
*
**out**
*(Tensor)*
-
o
ut node features of shape
`(number_of_nodes x out_channels)`
.
### Example
### Example
...
@@ -89,10 +89,10 @@ degree = 1 # B-spline degree of 1
...
@@ -89,10 +89,10 @@ degree = 1 # B-spline degree of 1
root_weight
=
torch
.
Tensor
(
2
,
4
)
# separately weight root nodes
root_weight
=
torch
.
Tensor
(
2
,
4
)
# separately weight root nodes
bias
=
None
# do not apply an additional bias
bias
=
None
# do not apply an additional bias
out
put
=
SplineConv
.
apply
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
out
=
SplineConv
.
apply
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
,
root_weight
,
bias
)
is_open_spline
,
degree
,
root_weight
,
bias
)
print
(
out
put
.
size
())
print
(
out
.
size
())
torch
.
Size
([
4
,
4
])
# 4 nodes with 4 features each
torch
.
Size
([
4
,
4
])
# 4 nodes with 4 features each
```
```
...
...
test/test_conv.py
View file @
781766e4
...
@@ -30,7 +30,7 @@ tests = [{
...
@@ -30,7 +30,7 @@ tests = [{
'is_open_spline'
:
[
1
,
0
],
'is_open_spline'
:
[
1
,
0
],
'root_weight'
:
[[
12.5
],
[
13
]],
'root_weight'
:
[[
12.5
],
[
13
]],
'bias'
:
[
1
],
'bias'
:
[
1
],
'
output
'
:
[
'
expected
'
:
[
[
1
+
12.5
*
9
+
13
*
10
+
(
8.5
+
40.5
+
107.5
+
101.5
)
/
4
],
[
1
+
12.5
*
9
+
13
*
10
+
(
8.5
+
40.5
+
107.5
+
101.5
)
/
4
],
[
1
+
12.5
*
1
+
13
*
2
],
[
1
+
12.5
*
1
+
13
*
2
],
[
1
+
12.5
*
3
+
13
*
4
],
[
1
+
12.5
*
3
+
13
*
4
],
...
@@ -51,9 +51,9 @@ def test_spline_conv_forward(test, dtype, device):
...
@@ -51,9 +51,9 @@ def test_spline_conv_forward(test, dtype, device):
root_weight
=
tensor
(
test
[
'root_weight'
],
dtype
,
device
)
root_weight
=
tensor
(
test
[
'root_weight'
],
dtype
,
device
)
bias
=
tensor
(
test
[
'bias'
],
dtype
,
device
)
bias
=
tensor
(
test
[
'bias'
],
dtype
,
device
)
out
put
=
SplineConv
.
apply
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
out
=
SplineConv
.
apply
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
1
,
root_weight
,
bias
)
is_open_spline
,
1
,
root_weight
,
bias
)
assert
out
put
.
tolist
()
==
test
[
'
output
'
]
assert
out
.
tolist
()
==
test
[
'
expected
'
]
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
.
keys
(),
devices
))
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
.
keys
(),
devices
))
...
...
test/test_weighting.py
View file @
781766e4
...
@@ -13,7 +13,7 @@ tests = [{
...
@@ -13,7 +13,7 @@ tests = [{
'weight'
:
[[[
1
],
[
2
]],
[[
3
],
[
4
]],
[[
5
],
[
6
]],
[[
7
],
[
8
]]],
'weight'
:
[[[
1
],
[
2
]],
[[
3
],
[
4
]],
[[
5
],
[
6
]],
[[
7
],
[
8
]]],
'basis'
:
[[
0.5
,
0
,
0.5
,
0
],
[
0
,
0
,
0.5
,
0.5
]],
'basis'
:
[[
0.5
,
0
,
0.5
,
0
],
[
0
,
0
,
0.5
,
0.5
]],
'weight_index'
:
[[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
]],
'weight_index'
:
[[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
]],
'
output
'
:
[
'
expected
'
:
[
[
0.5
*
((
1
*
(
1
+
5
))
+
(
2
*
(
2
+
6
)))],
[
0.5
*
((
1
*
(
1
+
5
))
+
(
2
*
(
2
+
6
)))],
[
0.5
*
((
3
*
(
5
+
7
))
+
(
4
*
(
6
+
8
)))],
[
0.5
*
((
3
*
(
5
+
7
))
+
(
4
*
(
6
+
8
)))],
]
]
...
@@ -27,8 +27,8 @@ def test_spline_weighting_forward(test, dtype, device):
...
@@ -27,8 +27,8 @@ def test_spline_weighting_forward(test, dtype, device):
basis
=
tensor
(
test
[
'basis'
],
dtype
,
device
)
basis
=
tensor
(
test
[
'basis'
],
dtype
,
device
)
weight_index
=
tensor
(
test
[
'weight_index'
],
torch
.
long
,
device
)
weight_index
=
tensor
(
test
[
'weight_index'
],
torch
.
long
,
device
)
out
put
=
SplineWeighting
.
apply
(
src
,
weight
,
basis
,
weight_index
)
out
=
SplineWeighting
.
apply
(
src
,
weight
,
basis
,
weight_index
)
assert
out
put
.
tolist
()
==
test
[
'
output
'
]
assert
out
.
tolist
()
==
test
[
'
expected
'
]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
...
...
torch_spline_conv/conv.py
View file @
781766e4
...
@@ -56,22 +56,22 @@ class SplineConv(object):
...
@@ -56,22 +56,22 @@ class SplineConv(object):
# Weight each node.
# Weight each node.
data
=
SplineBasis
.
apply
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
data
=
SplineBasis
.
apply
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
out
put
=
SplineWeighting
.
apply
(
src
[
col
],
weight
,
*
data
)
out
=
SplineWeighting
.
apply
(
src
[
col
],
weight
,
*
data
)
# Convert e x m_out to n x m_out features.
# Convert e x m_out to n x m_out features.
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
out
put
)
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
out
)
out
put
=
src
.
new_zeros
((
n
,
m_out
)).
scatter_add_
(
0
,
row_expand
,
out
put
)
out
=
src
.
new_zeros
((
n
,
m_out
)).
scatter_add_
(
0
,
row_expand
,
out
)
# Normalize out
put
by node degree.
# Normalize out by node degree.
deg
=
node_degree
(
row
,
n
,
out
put
.
dtype
,
out
put
.
device
)
deg
=
node_degree
(
row
,
n
,
out
.
dtype
,
out
.
device
)
out
put
/=
deg
.
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
out
/=
deg
.
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
# Weight root node separately (if wished).
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
if
root_weight
is
not
None
:
out
put
+=
torch
.
mm
(
src
,
root_weight
)
out
+=
torch
.
mm
(
src
,
root_weight
)
# Add bias (if wished).
# Add bias (if wished).
if
bias
is
not
None
:
if
bias
is
not
None
:
out
put
+=
bias
out
+=
bias
return
out
put
return
out
torch_spline_conv/utils/ffi.py
View file @
781766e4
...
@@ -33,16 +33,16 @@ def fw_weighting(self, src, weight, basis, weight_index):
...
@@ -33,16 +33,16 @@ def fw_weighting(self, src, weight, basis, weight_index):
func
(
self
,
src
,
weight
,
basis
,
weight_index
)
func
(
self
,
src
,
weight
,
basis
,
weight_index
)
def
bw_weighting_src
(
self
,
grad_out
put
,
weight
,
basis
,
weight_index
):
def
bw_weighting_src
(
self
,
grad_out
,
weight
,
basis
,
weight_index
):
func
=
get_func
(
'weightingBackwardSrc'
,
self
)
func
=
get_func
(
'weightingBackwardSrc'
,
self
)
func
(
self
,
grad_out
put
,
weight
,
basis
,
weight_index
)
func
(
self
,
grad_out
,
weight
,
basis
,
weight_index
)
def
bw_weighting_weight
(
self
,
grad_out
put
,
src
,
basis
,
weight_index
):
def
bw_weighting_weight
(
self
,
grad_out
,
src
,
basis
,
weight_index
):
func
=
get_func
(
'weightingBackwardWeight'
,
self
)
func
=
get_func
(
'weightingBackwardWeight'
,
self
)
func
(
self
,
grad_out
put
,
src
,
basis
,
weight_index
)
func
(
self
,
grad_out
,
src
,
basis
,
weight_index
)
def
bw_weighting_basis
(
self
,
grad_out
put
,
src
,
weight
,
weight_index
):
def
bw_weighting_basis
(
self
,
grad_out
,
src
,
weight
,
weight_index
):
func
=
get_func
(
'weightingBackwardBasis'
,
self
)
func
=
get_func
(
'weightingBackwardBasis'
,
self
)
func
(
self
,
grad_out
put
,
src
,
weight
,
weight_index
)
func
(
self
,
grad_out
,
src
,
weight
,
weight_index
)
torch_spline_conv/weighting.py
View file @
781766e4
...
@@ -5,26 +5,26 @@ from .utils.ffi import bw_weighting_weight, bw_weighting_basis
...
@@ -5,26 +5,26 @@ from .utils.ffi import bw_weighting_weight, bw_weighting_basis
def
fw
(
src
,
weight
,
basis
,
weight_index
):
def
fw
(
src
,
weight
,
basis
,
weight_index
):
out
put
=
src
.
new_empty
((
src
.
size
(
0
),
weight
.
size
(
2
)))
out
=
src
.
new_empty
((
src
.
size
(
0
),
weight
.
size
(
2
)))
fw_weighting
(
out
put
,
src
,
weight
,
basis
,
weight_index
)
fw_weighting
(
out
,
src
,
weight
,
basis
,
weight_index
)
return
out
put
return
out
def
bw_src
(
grad_out
put
,
weight
,
basis
,
weight_index
):
def
bw_src
(
grad_out
,
weight
,
basis
,
weight_index
):
grad_src
=
grad_out
put
.
new_empty
((
grad_out
put
.
size
(
0
),
weight
.
size
(
1
)))
grad_src
=
grad_out
.
new_empty
((
grad_out
.
size
(
0
),
weight
.
size
(
1
)))
bw_weighting_src
(
grad_src
,
grad_out
put
,
weight
,
basis
,
weight_index
)
bw_weighting_src
(
grad_src
,
grad_out
,
weight
,
basis
,
weight_index
)
return
grad_src
return
grad_src
def
bw_weight
(
grad_out
put
,
src
,
basis
,
weight_index
,
K
):
def
bw_weight
(
grad_out
,
src
,
basis
,
weight_index
,
K
):
grad_weight
=
src
.
new_empty
((
K
,
src
.
size
(
1
),
grad_out
put
.
size
(
1
)))
grad_weight
=
src
.
new_empty
((
K
,
src
.
size
(
1
),
grad_out
.
size
(
1
)))
bw_weighting_weight
(
grad_weight
,
grad_out
put
,
src
,
basis
,
weight_index
)
bw_weighting_weight
(
grad_weight
,
grad_out
,
src
,
basis
,
weight_index
)
return
grad_weight
return
grad_weight
def
bw_basis
(
grad_out
put
,
src
,
weight
,
weight_index
):
def
bw_basis
(
grad_out
,
src
,
weight
,
weight_index
):
grad_basis
=
src
.
new_empty
(
weight_index
.
size
())
grad_basis
=
src
.
new_empty
(
weight_index
.
size
())
bw_weighting_basis
(
grad_basis
,
grad_out
put
,
src
,
weight
,
weight_index
)
bw_weighting_basis
(
grad_basis
,
grad_out
,
src
,
weight
,
weight_index
)
return
grad_basis
return
grad_basis
...
@@ -35,18 +35,18 @@ class SplineWeighting(Function):
...
@@ -35,18 +35,18 @@ class SplineWeighting(Function):
return
fw
(
src
,
weight
,
basis
,
weight_index
)
return
fw
(
src
,
weight
,
basis
,
weight_index
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
put
):
# pragma: no cover
def
backward
(
ctx
,
grad_out
):
# pragma: no cover
grad_src
=
grad_weight
=
grad_basis
=
None
grad_src
=
grad_weight
=
grad_basis
=
None
src
,
weight
,
basis
,
weight_index
=
ctx
.
saved_tensors
src
,
weight
,
basis
,
weight_index
=
ctx
.
saved_tensors
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
grad_src
=
bw_src
(
grad_out
put
,
weight
,
basis
,
weight_index
)
grad_src
=
bw_src
(
grad_out
,
weight
,
basis
,
weight_index
)
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
K
=
weight
.
size
(
0
)
K
=
weight
.
size
(
0
)
grad_weight
=
bw_weight
(
grad_out
put
,
src
,
basis
,
weight_index
,
K
)
grad_weight
=
bw_weight
(
grad_out
,
src
,
basis
,
weight_index
,
K
)
if
ctx
.
needs_input_grad
[
2
]:
if
ctx
.
needs_input_grad
[
2
]:
grad_basis
=
bw_basis
(
grad_out
put
,
src
,
weight
,
weight_index
)
grad_basis
=
bw_basis
(
grad_out
,
src
,
weight
,
weight_index
)
return
grad_src
,
grad_weight
,
grad_basis
,
None
return
grad_src
,
grad_weight
,
grad_basis
,
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment