Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
e6632e3f
Commit
e6632e3f
authored
Apr 26, 2018
by
rusty1s
Browse files
to function
parent
6729e5b9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
38 deletions
+43
-38
README.md
README.md
+7
-6
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+2
-2
torch_spline_conv/conv.py
torch_spline_conv/conv.py
+34
-30
No files found.
README.md
View file @
e6632e3f
...
@@ -32,10 +32,11 @@ pip install cffi torch-spline-conv
...
@@ -32,10 +32,11 @@ pip install cffi torch-spline-conv
## Usage
## Usage
```
python
```
python
from
torch_spline_conv
import
s
pline
_c
onv
from
torch_spline_conv
import
S
pline
C
onv
output
=
spline_conv
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
output
=
SplineConv
.
apply
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
=
1
,
root_weight
=
None
,
bias
=
None
)
is_open_spline
,
degree
=
1
,
root_weight
=
None
,
bias
=
None
)
```
```
Applies the spline-based convolution operator
Applies the spline-based convolution operator
...
@@ -70,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
...
@@ -70,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
```
python
```
python
import
torch
import
torch
from
torch_spline_conv
import
s
pline
_c
onv
from
torch_spline_conv
import
S
pline
C
onv
src
=
torch
.
Tensor
(
4
,
2
)
# 4 nodes with 2 features each
src
=
torch
.
Tensor
(
4
,
2
)
# 4 nodes with 2 features each
edge_index
=
torch
.
LongTensor
([[
0
,
1
,
1
,
2
,
2
,
3
],
[
1
,
0
,
2
,
1
,
3
,
2
]])
# 6 edges
edge_index
=
torch
.
LongTensor
([[
0
,
1
,
1
,
2
,
2
,
3
],
[
1
,
0
,
2
,
1
,
3
,
2
]])
# 6 edges
...
@@ -82,8 +83,8 @@ degree = torch.tensor(1) # B-spline degree of 1
...
@@ -82,8 +83,8 @@ degree = torch.tensor(1) # B-spline degree of 1
root_weight
=
torch
.
Tensor
(
2
,
4
)
# separately weight root nodes
root_weight
=
torch
.
Tensor
(
2
,
4
)
# separately weight root nodes
bias
=
None
# do not apply an additional bias
bias
=
None
# do not apply an additional bias
output
=
s
pline
_c
onv
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
output
=
S
pline
C
onv
.
apply
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
,
root_weight
,
bias
)
is_open_spline
,
degree
,
root_weight
,
bias
)
print
(
output
.
size
())
print
(
output
.
size
())
torch
.
Size
([
4
,
4
])
# 4 nodes with 4 features each
torch
.
Size
([
4
,
4
])
# 4 nodes with 4 features each
...
...
torch_spline_conv/__init__.py
View file @
e6632e3f
from
.conv
import
s
pline
_c
onv
from
.conv
import
S
pline
C
onv
__version__
=
'0.1.0'
__version__
=
'0.1.0'
__all__
=
[
'
s
pline
_c
onv'
,
'__version__'
]
__all__
=
[
'
S
pline
C
onv'
,
'__version__'
]
torch_spline_conv/conv.py
View file @
e6632e3f
import
torch
import
torch
from
torch.autograd
import
Function
from
.basis
import
SplineBasis
from
.basis
import
SplineBasis
from
.weighting
import
SplineWeighting
from
.weighting
import
SplineWeighting
...
@@ -6,15 +7,7 @@ from .weighting import SplineWeighting
...
@@ -6,15 +7,7 @@ from .weighting import SplineWeighting
from
.utils.degree
import
degree
as
node_degree
from
.utils.degree
import
degree
as
node_degree
def
spline_conv
(
src
,
class
SplineConv
(
Function
):
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
,
root_weight
=
None
,
bias
=
None
):
"""Applies the spline-based convolution operator :math:`(f \star g)(i) =
"""Applies the spline-based convolution operator :math:`(f \star g)(i) =
\f
rac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
\f
rac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
...
@@ -45,31 +38,42 @@ def spline_conv(src,
...
@@ -45,31 +38,42 @@ def spline_conv(src,
:rtype: :class:`Tensor`
:rtype: :class:`Tensor`
"""
"""
src
=
src
.
unsqueeze
(
-
1
)
if
src
.
dim
()
==
1
else
src
@
staticmethod
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
def
forward
(
ctx
,
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
,
root_weight
=
None
,
bias
=
None
):
src
=
src
.
unsqueeze
(
-
1
)
if
src
.
dim
()
==
1
else
src
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
row
,
col
=
edge_index
row
,
col
=
edge_index
n
,
m_out
=
src
.
size
(
0
),
weight
.
size
(
2
)
n
,
m_out
=
src
.
size
(
0
),
weight
.
size
(
2
)
# Weight each node.
# Weight each node.
basis
,
weight_index
=
SplineBasis
.
apply
(
degree
,
pseudo
,
kernel_size
,
b
,
wi
=
SplineBasis
.
apply
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
)
is_open_spline
)
output
=
SplineWeighting
.
apply
(
src
[
col
],
weight
,
b
,
wi
)
output
=
SplineWeighting
.
apply
(
src
[
col
],
weight
,
basis
,
weight_index
)
# Perform the real convolution =>
Convert e x m_out to n x m_out features.
#
Convert e x m_out to n x m_out features.
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
output
)
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
output
)
output
=
src
.
new_zeros
((
n
,
m_out
)).
scatter_add_
(
0
,
row_expand
,
output
)
output
=
src
.
new_zeros
((
n
,
m_out
)).
scatter_add_
(
0
,
row_expand
,
output
)
# Normalize output by node degree.
# Normalize output by node degree.
deg
=
node_degree
(
row
,
n
,
out
=
src
.
new_empty
(()))
deg
=
node_degree
(
row
,
n
,
out
=
src
.
new_empty
(()))
output
/=
deg
.
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
output
/=
deg
.
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
# Weight root node separately (if wished).
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
if
root_weight
is
not
None
:
output
+=
torch
.
mm
(
src
,
root_weight
)
output
+=
torch
.
mm
(
src
,
root_weight
)
# Add bias (if wished).
# Add bias (if wished).
if
bias
is
not
None
:
if
bias
is
not
None
:
output
+=
bias
output
+=
bias
return
output
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment