Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
d9d943e8
Commit
d9d943e8
authored
Apr 16, 2018
by
rusty1s
Browse files
better degree impl
parent
ad330355
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
13 deletions
+15
-13
torch_spline_conv/conv.py
torch_spline_conv/conv.py
+4
-6
torch_spline_conv/utils/degree.py
torch_spline_conv/utils/degree.py
+11
-7
No files found.
torch_spline_conv/conv.py
View file @
d9d943e8
...
...
@@ -5,7 +5,7 @@ from .basis import spline_basis
from
.weighting
import
spline_weighting
from
.utils.new
import
new
from
.utils.degree
import
node_degree
from
.utils.degree
import
degree
as
node_degree
def
spline_conv
(
src
,
...
...
@@ -49,7 +49,7 @@ def spline_conv(src,
row
,
col
=
edge_index
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
n
,
e
,
m_out
=
src
.
size
(
0
),
row
.
size
(
0
),
weight
.
size
(
2
)
n
,
m_out
=
src
.
size
(
0
),
weight
.
size
(
2
)
# Weight each node.
basis
,
weight_index
=
spline_basis
(
degree
,
pseudo
,
kernel_size
,
...
...
@@ -58,14 +58,12 @@ def spline_conv(src,
# Perform the real convolution => Convert e x m_out to n x m_out features.
zero
=
new
(
src
,
n
,
m_out
).
fill_
(
0
)
row_expand
=
row
.
unsqueeze
(
-
1
).
expand
(
e
,
m_o
ut
)
row_expand
=
row
.
unsqueeze
(
-
1
).
expand
_as
(
outp
ut
)
row_expand
=
row_expand
if
torch
.
is_tensor
(
src
)
else
Variable
(
row_expand
)
output
=
zero
.
scatter_add_
(
0
,
row_expand
,
output
)
# Normalize output by node degree.
index
=
row
if
torch
.
is_tensor
(
src
)
else
Variable
(
row
)
degree
=
node_degree
(
index
,
n
,
out
=
new
(
src
))
output
/=
degree
.
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
output
/=
node_degree
(
row
,
n
,
out
=
new
(
src
)).
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
...
...
torch_spline_conv/utils/degree.py
View file @
d9d943e8
import
torch
from
torch.autograd
import
Variable
from
.new
import
new
def
node_degree
(
index
,
n
,
out
=
None
):
if
out
is
None
:
# pragma: no cover
zero
=
torch
.
zeros
(
n
)
def
degree
(
index
,
num_nodes
=
None
,
out
=
None
):
num_nodes
=
index
.
max
()
+
1
if
num_nodes
is
None
else
num_nodes
out
=
index
.
new
().
float
()
if
out
is
None
else
out
index
=
index
if
torch
.
is_tensor
(
out
)
else
Variable
(
index
)
if
torch
.
is_tensor
(
out
):
out
.
resize_
(
num_nodes
)
else
:
out
.
resize_
(
n
)
if
torch
.
is_tensor
(
out
)
else
out
.
data
.
resize_
(
n
)
zero
=
out
.
fill_
(
0
)
out
.
data
.
resize_
(
num_nodes
)
one
=
new
(
zero
,
index
.
size
(
0
)).
fill_
(
1
)
return
zero
.
scatter_add_
(
0
,
index
,
one
)
one
=
new
(
out
,
index
.
size
(
0
)).
fill_
(
1
)
return
out
.
fill_
(
0
)
.
scatter_add_
(
0
,
index
,
one
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment