Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
d9100e71
Commit
d9100e71
authored
Mar 10, 2018
by
rusty1s
Browse files
bugfixes
parent
ffcc4df7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
18 deletions
+17
-18
torch_spline_conv/functions/degree.py
torch_spline_conv/functions/degree.py
+4
-4
torch_spline_conv/functions/spline_conv.py
torch_spline_conv/functions/spline_conv.py
+9
-9
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+4
-5
No files found.
torch_spline_conv/functions/degree.py
View file @
d9100e71
import
torch
import
torch
def
node_degree
(
index
,
out
=
None
):
def
node_degree
(
edge_
index
,
n
,
out
=
None
):
one
=
torch
.
ones
(
index
.
size
(
1
),
out
)
zero
=
torch
.
zeros
(
n
,
out
=
out
)
zero
=
torch
.
zeros
(
index
.
size
(
1
),
out
)
one
=
torch
.
ones
(
edge_
index
.
size
(
1
),
out
=
zero
.
new
()
)
return
zero
.
scatter_add_
(
0
,
index
[
0
],
one
)
return
zero
.
scatter_add_
(
0
,
edge_
index
[
0
],
one
)
torch_spline_conv/functions/spline_conv.py
View file @
d9100e71
import
torch
import
torch
# from torch.autograd import Variable as Var
from
.degree
import
node_degree
from
.degree
import
node_degree
from
.utils
import
spline_basis
,
spline_weighting
from
.utils
import
spline_basis
,
spline_weighting
def
spline_conv
(
x
,
def
spline_conv
(
x
,
index
,
edge_
index
,
pseudo
,
pseudo
,
weight
,
weight
,
kernel_size
,
kernel_size
,
...
@@ -15,27 +14,28 @@ def spline_conv(x,
...
@@ -15,27 +14,28 @@ def spline_conv(x,
degree
=
1
,
degree
=
1
,
bias
=
None
):
bias
=
None
):
n
,
e
=
x
.
size
(
0
),
edge_index
.
size
(
1
)
K
,
m_in
,
m_out
=
weight
.
size
()
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
# Get features for every target node => |E| x M_in
# Get features for every target node => |E| x M_in
output
=
x
[
index
[
1
]]
output
=
x
[
edge_
index
[
1
]]
# Get B-spline basis products and weight indices for each edge.
# Get B-spline basis products and weight indices for each edge.
basis
,
weight_index
=
spline_basis
(
degree
,
pseudo
,
kernel_size
,
basis
,
weight_index
=
spline_basis
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
,
weight
.
size
(
0
)
)
is_open_spline
,
K
)
# Weight gathered features based on B-spline basis and trainable weights.
# Weight gathered features based on B-spline basis and trainable weights.
output
=
spline_weighting
(
output
,
weight
,
basis
,
weight_index
)
output
=
spline_weighting
(
output
,
weight
,
basis
,
weight_index
)
# Perform the real convolution => Convert |E| x M_out to N x M_out output.
# Perform the real convolution => Convert |E| x M_out to N x M_out output.
row
=
index
[
0
].
unsqueeze
(
-
1
).
expand
(
-
1
,
output
.
size
(
1
))
row
=
edge_index
[
0
].
unsqueeze
(
-
1
).
expand
(
e
,
m_out
)
# zero = x if torch.is_tensor(x) else x.data
zero
=
x
.
new
(
n
,
m_out
).
fill_
(
0
)
zero
=
x
.
new
(
row
.
size
()).
fill_
(
0
)
# row, zero = row, zero if torch.is_tensor(x) else Var(row), Var(zero)
output
=
zero
.
scatter_add_
(
0
,
row
,
output
)
output
=
zero
.
scatter_add_
(
0
,
row
,
output
)
# Normalize output by node degree.
# Normalize output by node degree.
output
/=
node_degree
(
index
,
out
=
x
.
new
()).
unsqueeze
(
-
1
).
clamp_
(
min
=
1
)
output
/=
node_degree
(
edge_
index
,
n
,
out
=
x
.
new
()).
clamp_
(
min
=
1
)
# Weight root node separately (if wished).
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
if
root_weight
is
not
None
:
...
...
torch_spline_conv/functions/utils.py
View file @
d9100e71
...
@@ -20,9 +20,8 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
...
@@ -20,9 +20,8 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
weight_index
=
kernel_size
.
new
(
pseudo
.
size
(
0
),
s
)
weight_index
=
kernel_size
.
new
(
pseudo
.
size
(
0
),
s
)
degree
=
implemented_degrees
.
get
(
degree
)
degree
=
implemented_degrees
.
get
(
degree
)
if
degree
is
None
:
assert
degree
is
not
None
,
(
raise
NotImplementedError
(
'Basis computation not implemented for '
'Basis computation not implemented for specified B-spline degree'
)
'specified B-spline degree'
)
func
=
get_func
(
'basis_{}'
.
format
(
degree
),
pseudo
)
func
=
get_func
(
'basis_{}'
.
format
(
degree
),
pseudo
)
func
(
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
K
)
func
(
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
K
)
...
@@ -31,7 +30,7 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
...
@@ -31,7 +30,7 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
def
spline_weighting_fw
(
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_fw
(
x
,
weight
,
basis
,
weight_index
):
output
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
2
))
output
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
2
))
func
=
get_func
(
'
spline_
weighting_fw'
,
x
)
func
=
get_func
(
'weighting_fw'
,
x
)
func
(
output
,
x
,
weight
,
basis
,
weight_index
)
func
(
output
,
x
,
weight
,
basis
,
weight_index
)
return
output
return
output
...
@@ -39,7 +38,7 @@ def spline_weighting_fw(x, weight, basis, weight_index):
...
@@ -39,7 +38,7 @@ def spline_weighting_fw(x, weight, basis, weight_index):
def
spline_weighting_bw
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_bw
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
grad_weight
=
x
.
new
(
weight
)
grad_weight
=
x
.
new
(
weight
)
func
=
get_func
(
'
spline_
weighting_bw'
,
x
)
func
=
get_func
(
'weighting_bw'
,
x
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
return
grad_input
,
grad_weight
return
grad_input
,
grad_weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment