Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
c4b33b49
Commit
c4b33b49
authored
Apr 11, 2018
by
rusty1s
Browse files
added conv impl
parent
0440e1f4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
2 deletions
+52
-2
torch_spline_conv/conv.py
torch_spline_conv/conv.py
+37
-2
torch_spline_conv/utils/degree.py
torch_spline_conv/utils/degree.py
+9
-0
torch_spline_conv/utils/new.py
torch_spline_conv/utils/new.py
+6
-0
No files found.
torch_spline_conv/conv.py
View file @
c4b33b49
#
import torch
import
torch
from
.basis
import
spline_basis
from
.weighting
import
spline_weighting
def
spline_conv
(
x
,
from
.utils.new
import
new
from
.utils.degree
import
node_degree
def
spline_conv
(
src
,
edge_index
,
pseudo
,
weight
,
...
...
@@ -10,4 +16,33 @@ def spline_conv(x,
degree
=
1
,
root_weight
=
None
,
bias
=
None
):
src
=
src
.
unsqueeze
(
-
1
)
if
src
.
dim
()
==
1
else
src
row
,
col
=
edge_index
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
n
,
e
,
m_out
=
src
.
size
(
0
),
row
.
size
(
0
),
weight
.
size
(
2
)
# Weight each node.
basis
,
weight_index
=
spline_basis
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
)
output
=
spline_weighting
(
src
[
col
],
weight
,
basis
,
weight_index
)
# Perform the real convolution => Convert e x m_out to n x m_out features.
zero
=
new
(
src
,
n
,
m_out
).
fill_
(
0
)
row_expand
=
row
.
unsqueeze
(
-
1
).
expand
(
e
,
m_out
)
output
=
zero
.
scatter_add_
(
0
,
row_expand
,
output
)
# Normalize output by node degree.
degree
=
node_degree
(
row
,
n
,
out
=
new
(
src
))
output
/=
degree
.
unsqueeze
(
-
1
).
clamp_
(
min
=
1
)
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
output
+=
torch
.
mm
(
src
,
root_weight
)
# Add bias (if wished).
if
bias
is
not
None
:
output
+=
bias
return
output
torch_spline_conv/utils/degree.py
0 → 100644
View file @
c4b33b49
import
torch
from
.new
import
new
def
node_degree
(
index
,
num_nodes
,
out
=
None
):
zero
=
torch
.
zeros
(
num_nodes
,
out
=
out
)
one
=
torch
.
ones
(
index
,
out
=
new
(
zero
))
return
zero
.
scatter_add_
(
0
,
index
,
one
)
torch_spline_conv/utils/new.py
0 → 100644
View file @
c4b33b49
import
torch
from
torch.autograd
import
Variable
def
new
(
x
,
*
sizes
):
return
x
.
new
(
sizes
)
if
torch
.
is_tensor
(
x
)
else
Variable
(
x
.
data
.
new
(
sizes
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment