Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
03f4110b
Commit
03f4110b
authored
Apr 09, 2018
by
rusty1s
Browse files
gradcheck
parent
dff93289
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
15 deletions
+36
-15
test/test_basis.py
test/test_basis.py
+21
-7
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+15
-8
No files found.
test/test_basis.py
View file @
03f4110b
...
...
@@ -2,18 +2,20 @@ from itertools import product
import
pytest
import
torch
from
torch_spline_conv.basis
import
basis_forward
from
torch.autograd
import
Variable
,
gradcheck
from
torch_spline_conv.basis
import
spline_basis
,
SplineBasis
from
torch_spline_conv.utils.ffi
import
implemented_degrees
from
.tensor
import
tensors
tests
=
[{
'pseudo'
:
[
0
,
0.0625
,
0.25
,
0.75
,
0.9375
,
1
],
'pseudo'
:
[
[
0
]
,
[
0.0625
]
,
[
0.25
]
,
[
0.75
]
,
[
0.9375
]
,
[
1
]
],
'kernel_size'
:
[
5
],
'is_open_spline'
:
[
1
],
'basis'
:
[[
1
,
0
],
[
0.75
,
0.25
],
[
1
,
0
],
[
1
,
0
],
[
0.25
,
0.75
],
[
1
,
0
]],
'weight_index'
:
[[
0
,
1
],
[
0
,
1
],
[
1
,
2
],
[
3
,
4
],
[
3
,
4
],
[
4
,
0
]],
},
{
'pseudo'
:
[
0
,
0.0625
,
0.25
,
0.75
,
0.9375
,
1
],
'pseudo'
:
[
[
0
]
,
[
0.0625
]
,
[
0.25
]
,
[
0.75
]
,
[
0.9375
]
,
[
1
]
],
'kernel_size'
:
[
4
],
'is_open_spline'
:
[
0
],
'basis'
:
[[
1
,
0
],
[
0.75
,
0.25
],
[
1
,
0
],
[
1
,
0
],
[
0.25
,
0.75
],
[
1
,
0
]],
...
...
@@ -28,27 +30,39 @@ tests = [{
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_
basis_forward
_cpu
(
tensor
,
i
):
def
test_
spline_basis
_cpu
(
tensor
,
i
):
data
=
tests
[
i
]
pseudo
=
getattr
(
torch
,
tensor
)(
data
[
'pseudo'
])
kernel_size
=
torch
.
LongTensor
(
data
[
'kernel_size'
])
is_open_spline
=
torch
.
ByteTensor
(
data
[
'is_open_spline'
])
basis
,
weight_index
=
basis_forward
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
basis
,
weight_index
=
spline_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
assert
basis
.
tolist
()
==
data
[
'basis'
]
assert
weight_index
.
tolist
()
==
data
[
'weight_index'
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_
basis_forward
_gpu
(
tensor
,
i
):
# pragma: no cover
def
test_
spline_basis
_gpu
(
tensor
,
i
):
# pragma: no cover
data
=
tests
[
i
]
pseudo
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'pseudo'
])
kernel_size
=
torch
.
cuda
.
LongTensor
(
data
[
'kernel_size'
])
is_open_spline
=
torch
.
cuda
.
ByteTensor
(
data
[
'is_open_spline'
])
basis
,
weight_index
=
basis_forward
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
basis
,
weight_index
=
spline_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
assert
basis
.
cpu
().
tolist
()
==
data
[
'basis'
]
assert
weight_index
.
cpu
().
tolist
()
==
data
[
'weight_index'
]
def
test_spline_basis_grad_cpu
():
degree
=
1
kernel_size
=
torch
.
LongTensor
([
5
,
5
,
5
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
0
,
1
])
op
=
SplineBasis
(
degree
,
kernel_size
,
is_open_spline
)
pseudo
=
torch
.
DoubleTensor
(
4
,
3
).
uniform_
(
0
,
1
)
pseudo
=
Variable
(
pseudo
,
requires_grad
=
True
)
assert
gradcheck
(
op
,
(
pseudo
,
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
torch_spline_conv/basis.py
View file @
03f4110b
...
...
@@ -6,7 +6,6 @@ from .utils.ffi import basis_backward as ffi_basis_backward
def
basis_forward
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
):
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
num_nodes
,
S
=
pseudo
.
size
(
0
),
(
degree
+
1
)
**
kernel_size
.
size
(
0
)
basis
=
pseudo
.
new
(
num_nodes
,
S
)
weight_index
=
kernel_size
.
new
(
num_nodes
,
S
)
...
...
@@ -17,28 +16,36 @@ def basis_forward(degree, pseudo, kernel_size, is_open_spline):
def
basis_backward
(
degree
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
):
grad_pseudo
=
pseudo
.
new
(
pseudo
.
size
())
ffi_basis_backward
(
degree
,
grad_pseudo
,
pseudo
,
kernel_size
,
ffi_basis_backward
(
degree
,
grad_pseudo
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
return
grad_pseudo
class
Basis
(
Function
):
class
Spline
Basis
(
Function
):
def
__init__
(
self
,
degree
,
kernel_size
,
is_open_spline
):
super
(
Basis
,
self
).
__init__
()
super
(
Spline
Basis
,
self
).
__init__
()
self
.
degree
=
degree
self
.
kernel_size
=
kernel_size
self
.
is_open_spline
=
is_open_spline
def
forward
(
self
,
pseudo
):
self
.
save_for_back
a
wrd
(
pseudo
)
self
.
save_for_backw
a
rd
(
pseudo
)
return
basis_forward
(
self
.
degree
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
)
def
backward
(
self
,
grad_basis
,
grad_weight_index
):
pass
pseudo
,
=
self
.
saved_tensors
grad_pseudo
=
None
if
self
.
needs_input_grad
[
0
]:
grad_pseudo
=
basis_backward
(
self
.
degree
,
grad_basis
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
)
def
basis
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
):
return
grad_pseudo
def
spline_basis
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
):
if
torch
.
is_tensor
(
pseudo
):
return
basis_forward
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
)
else
:
return
Basis
(
degree
,
kernel_size
,
is_open_spline
)(
pseudo
)
return
Spline
Basis
(
degree
,
kernel_size
,
is_open_spline
)(
pseudo
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment