Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
42542bff
Commit
42542bff
authored
Mar 11, 2018
by
rusty1s
Browse files
test bw function
parent
40f5b757
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
4 deletions
+27
-4
test/benchmark.py
test/benchmark.py
+22
-1
torch_spline_conv/functions/spline_conv.py
torch_spline_conv/functions/spline_conv.py
+3
-0
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+1
-1
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+1
-2
No files found.
test/benchmark.py
View file @
42542bff
import
torch
from
torch.autograd
import
Variable
,
gradcheck
from
torch_spline_conv
import
spline_conv
from
torch_spline_conv.functions.utils
import
SplineWeighting
,
spline_basis
x
=
torch
.
Tensor
([[
9
,
10
],
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]])
index
=
torch
.
LongTensor
([[
0
,
0
,
0
,
0
],
[
1
,
2
,
3
,
4
]])
pseudo
=
[[
0.25
,
0.125
],
[
0.25
,
0.375
],
[
0.75
,
0.625
],
[
0.75
,
0.875
]]
pseudo
=
torch
.
Tensor
(
pseudo
)
weight
=
torch
.
arange
(
0.5
,
0.5
*
25
,
step
=
0.5
).
view
(
12
,
2
,
1
)
# print(weight[:, 0].squeeze())
kernel_size
=
torch
.
LongTensor
([
3
,
4
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
0
])
root_weight
=
torch
.
arange
(
12.5
,
13.5
,
step
=
0.5
).
view
(
2
,
1
)
...
...
@@ -28,3 +29,23 @@ expected_output = [
[
12.5
*
5
+
13
*
6
],
[
12.5
*
7
+
13
*
8
],
]
print
(
output
.
tolist
(),
expected_output
)
x
=
Variable
(
x
,
requires_grad
=
True
)
weight
=
Variable
(
weight
,
requires_grad
=
True
)
root_weight
=
Variable
(
root_weight
,
requires_grad
=
True
)
output
=
spline_conv
(
x
,
index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
root_weight
)
print
(
output
.
data
.
tolist
())
x
,
pseudo
,
weight
=
x
.
data
.
double
(),
pseudo
.
double
(),
weight
.
data
.
double
()
x
=
x
[
index
[
1
]]
x
=
Variable
(
x
,
requires_grad
=
True
)
weight
=
Variable
(
weight
,
requires_grad
=
True
)
basis
,
weight_index
=
spline_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
,
weight
.
size
(
0
))
op
=
SplineWeighting
(
basis
,
weight_index
)
test
=
gradcheck
(
op
,
(
x
,
weight
),
eps
=
1e-6
,
atol
=
1e-4
)
print
(
test
)
torch_spline_conv/functions/spline_conv.py
View file @
42542bff
...
...
@@ -15,6 +15,9 @@ def spline_conv(x,
degree
=
1
,
bias
=
None
):
print
(
'TODO: Degree of 0'
)
print
(
'TODO: Kernel size of 1'
)
n
,
e
=
x
.
size
(
0
),
edge_index
.
size
(
1
)
K
,
m_in
,
m_out
=
weight
.
size
()
...
...
torch_spline_conv/functions/utils.py
View file @
42542bff
...
...
@@ -37,7 +37,7 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
)).
fill_
(
0
)
grad_weight
=
x
.
new
(
weight
).
fill_
(
0
)
grad_weight
=
x
.
new
(
weight
.
size
()
).
fill_
(
0
)
func
=
get_func
(
'weighting_backward'
,
x
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
return
grad_input
,
grad_weight
...
...
torch_spline_conv/src/generic/cpu.c
View file @
42542bff
...
...
@@ -75,7 +75,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
void
spline_
(
weighting_backward
)(
THTensor
*
grad_input
,
THTensor
*
grad_weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
real
*
grad_weight_data
=
grad_weight
->
storage
->
data
+
grad_weight
->
storageOffset
;
int64_t
M_out
=
THTensor_
(
size
)(
grad_
in
put
,
1
);
int64_t
M_out
=
THTensor_
(
size
)(
grad_
out
put
,
1
);
int64_t
M_in
=
THTensor_
(
size
)(
input
,
1
);
int64_t
S
=
THLongTensor_size
(
weight_index
,
1
);
int64_t
m_out
,
m_in
,
s
,
i
,
w_idx
;
real
g
,
b
;
...
...
@@ -90,7 +90,6 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
w_idx
=
i
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
;
grad_input_data
[
m_in
]
+=
b
*
g
*
*
(
weight_data
+
w_idx
);
grad_weight_data
[
w_idx
]
+=
b
*
g
*
*
(
input_data
+
m_in
*
input_stride
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment