Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
8418614e
Commit
8418614e
authored
Mar 11, 2018
by
rusty1s
Browse files
prepare for backward to pseudo
parent
12a47ebc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
50 deletions
+45
-50
test/test_spline_conv.py
test/test_spline_conv.py
+6
-6
torch_spline_conv/functions/spline_conv.py
torch_spline_conv/functions/spline_conv.py
+21
-33
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+18
-11
No files found.
test/test_spline_conv.py
View file @
8418614e
...
...
@@ -39,7 +39,7 @@ def test_spline_conv_cpu(tensor):
assert
output
.
tolist
()
==
expected_output
x
,
weight
=
Variable
(
x
),
Variable
(
weight
)
x
,
weight
,
pseudo
=
Variable
(
x
),
Variable
(
weight
)
,
Variable
(
pseudo
)
root_weight
,
bias
=
Variable
(
root_weight
),
Variable
(
bias
)
output
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
...
...
@@ -48,16 +48,16 @@ def test_spline_conv_cpu(tensor):
def
test_spline_weighting_backward_cpu
():
pseudo
=
[[
0.25
,
0.125
],
[
0.25
,
0.375
],
[
0.75
,
0.625
],
[
0.75
,
0.875
]]
pseudo
=
torch
.
DoubleTensor
(
pseudo
)
kernel_size
=
torch
.
LongTensor
([
5
,
5
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
1
])
basis
,
index
=
s
pline
_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
,
25
)
op
=
S
pline
Weighting
(
kernel_size
,
is_open_spline
,
1
)
op
=
SplineWeighting
(
basis
,
index
)
x
=
torch
.
DoubleTensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]])
x
=
Variable
(
x
,
requires_grad
=
True
)
pseudo
=
[[
0.25
,
0.125
],
[
0.25
,
0.375
],
[
0.75
,
0.625
],
[
0.75
,
0.875
]]
# pseudo = Variable(torch.DoubleTensor(pseudo), requires_grad=True)
pseudo
=
Variable
(
torch
.
DoubleTensor
(
pseudo
))
weight
=
torch
.
DoubleTensor
(
25
,
2
,
4
).
uniform_
(
-
1
,
1
)
weight
=
Variable
(
weight
,
requires_grad
=
True
)
assert
gradcheck
(
op
,
(
x
,
weight
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
assert
gradcheck
(
op
,
(
x
,
pseudo
,
weight
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
torch_spline_conv/functions/spline_conv.py
View file @
8418614e
...
...
@@ -2,39 +2,7 @@ import torch
from
torch.autograd
import
Variable
as
Var
from
.degree
import
node_degree
from
.utils
import
spline_basis
,
spline_weighting
def
_spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
=
1
):
n
,
e
=
x
.
size
(
0
),
edge_index
.
size
(
1
)
K
,
m_in
,
m_out
=
weight
.
size
()
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
# Get features for every target node => |E| x M_in
output
=
x
[
edge_index
[
1
]]
# Get B-spline basis products and weight indices for each edge.
basis
,
weight_index
=
spline_basis
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
,
K
)
# Weight gathered features based on B-spline basis and trainable weights.
output
=
spline_weighting
(
output
,
weight
,
basis
,
weight_index
)
# Perform the real convolution => Convert |E| x M_out to N x M_out output.
row
=
edge_index
[
0
].
unsqueeze
(
-
1
).
expand
(
e
,
m_out
)
row
=
row
if
torch
.
is_tensor
(
x
)
else
Var
(
row
)
zero
=
x
.
new
(
n
,
m_out
)
if
torch
.
is_tensor
(
x
)
else
Var
(
x
.
data
.
new
(
n
,
m_out
))
output
=
zero
.
fill_
(
0
).
scatter_add_
(
0
,
row
,
output
)
return
output
from
.utils
import
spline_weighting
def
spline_conv
(
x
,
...
...
@@ -67,3 +35,23 @@ def spline_conv(x,
output
+=
bias
return
output
def
_spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
):
n
,
e
,
m_out
=
x
.
size
(
0
),
edge_index
.
size
(
1
),
weight
.
size
(
2
)
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
# Weight gathered features based on B-spline bases and trainable weights.
output
=
spline_weighting
(
x
[
edge_index
[
1
]],
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
)
# Perform the real convolution => Convert e x m_out to n x m_out features.
row
=
edge_index
[
0
].
unsqueeze
(
-
1
).
expand
(
e
,
m_out
)
row
=
row
if
torch
.
is_tensor
(
x
)
else
Var
(
row
)
zero
=
x
.
new
(
n
,
m_out
)
if
torch
.
is_tensor
(
x
)
else
Var
(
x
.
data
.
new
(
n
,
m_out
))
output
=
zero
.
fill_
(
0
).
scatter_add_
(
0
,
row
,
output
)
return
output
torch_spline_conv/functions/utils.py
View file @
8418614e
...
...
@@ -37,34 +37,41 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
# pragma: no cover
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
# grad_weight computation via `atomic_add` => Initialize with zeros.
grad_weight
=
x
.
new
(
weight
.
size
()).
fill_
(
0
)
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
func
=
get_func
(
'weighting_backward'
,
x
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
return
grad_input
,
grad_weight
class
SplineWeighting
(
Function
):
def
__init__
(
self
,
basis
,
weight_index
):
def
__init__
(
self
,
kernel_size
,
is_open_spline
,
degree
):
super
(
SplineWeighting
,
self
).
__init__
()
self
.
basis
=
basis
self
.
weight_index
=
weight_index
self
.
kernel_size
=
kernel_size
self
.
is_open_spline
=
is_open_spline
self
.
degree
=
degree
def
forward
(
self
,
x
,
weight
):
def
forward
(
self
,
x
,
pseudo
,
weight
):
self
.
save_for_backward
(
x
,
weight
)
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
K
=
weight
.
size
(
0
)
basis
,
weight_index
=
spline_basis
(
self
.
degree
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
,
K
)
self
.
basis
,
self
.
weight_index
=
basis
,
weight_index
return
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
def
backward
(
self
,
grad_output
):
# pragma: no cover
x
,
weight
=
self
.
saved_tensors
basis
,
weight_index
=
self
.
basis
,
self
.
weight
_
in
dex
return
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
)
grad_input
,
grad_weight
=
spline_
weightin
g_backward
(
grad_output
,
x
,
weight
,
self
.
basis
,
self
.
weight_index
)
return
grad_input
,
None
,
grad_weight
def
spline_weighting
(
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting
(
x
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
):
if
torch
.
is_tensor
(
x
):
basis
,
weight_index
=
spline_basis
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
,
weight
.
size
(
0
))
return
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
else
:
return
SplineWeighting
(
basis
,
weight_index
)(
x
,
weight
)
op
=
SplineWeighting
(
kernel_size
,
is_open_spline
,
degree
)
return
op
(
x
,
pseudo
,
weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment