Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
b5ac9f33
Commit
b5ac9f33
authored
Apr 10, 2018
by
rusty1s
Browse files
weigthing forward (cpu+gpu)
parent
d48533ea
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
137 additions
and
8 deletions
+137
-8
test/test_conv.py
test/test_conv.py
+0
-0
test/test_weighting.py
test/test_weighting.py
+45
-0
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+1
-1
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+6
-7
torch_spline_conv/conv.py
torch_spline_conv/conv.py
+0
-0
torch_spline_conv/utils/ffi.py
torch_spline_conv/utils/ffi.py
+20
-0
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+65
-0
No files found.
test/test_
spline_
conv.py
→
test/test_conv.py
View file @
b5ac9f33
File moved
test/test_weighting.py
0 → 100644
View file @
b5ac9f33
from
itertools
import
product
import
pytest
import
torch
from
torch_spline_conv.weighting
import
spline_weighting
from
.tensor
import
tensors
tests
=
[{
'src'
:
[[
1
,
2
],
[
3
,
4
]],
'weight'
:
[[[
1
],
[
2
]],
[[
3
],
[
4
]],
[[
5
],
[
6
]],
[[
7
],
[
8
]]],
'basis'
:
[[
0.5
,
0
,
0.5
,
0
],
[
0
,
0
,
0.5
,
0.5
]],
'weight_index'
:
[[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
]],
'output'
:
[
[
0.5
*
((
1
*
(
1
+
5
))
+
(
2
*
(
2
+
6
)))],
[
0.5
*
((
3
*
(
5
+
7
))
+
(
4
*
(
6
+
8
)))],
]
}]
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_spline_basis_forward_cpu
(
tensor
,
i
):
data
=
tests
[
i
]
src
=
getattr
(
torch
,
tensor
)(
data
[
'src'
])
weight
=
getattr
(
torch
,
tensor
)(
data
[
'weight'
])
basis
=
getattr
(
torch
,
tensor
)(
data
[
'basis'
])
weight_index
=
torch
.
LongTensor
(
data
[
'weight_index'
])
output
=
spline_weighting
(
src
,
weight
,
basis
,
weight_index
)
assert
output
.
tolist
()
==
data
[
'output'
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_spline_basis_forward_gpu
(
tensor
,
i
):
data
=
tests
[
i
]
src
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'src'
])
weight
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'weight'
])
basis
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'basis'
])
weight_index
=
torch
.
cuda
.
LongTensor
(
data
[
'weight_index'
])
output
=
spline_weighting
(
src
,
weight
,
basis
,
weight_index
)
assert
output
.
cpu
().
tolist
()
==
data
[
'output'
]
torch_spline_conv/__init__.py
View file @
b5ac9f33
from
.
spline_
conv
import
spline_conv
from
.conv
import
spline_conv
__version__
=
'0.1.0'
...
...
torch_spline_conv/basis.py
View file @
b5ac9f33
import
torch
from
torch.autograd
import
Function
from
.utils.ffi
import
basis_forward
as
ffi_
basis_f
orward
from
.utils.ffi
import
basis_backward
as
ffi_
basis_b
ackward
from
.utils.ffi
import
basis_forward
as
basis_f
w
from
.utils.ffi
import
basis_backward
as
basis_b
w
def
basis_forward
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
):
num_nodes
,
S
=
pseudo
.
size
(
0
),
(
degree
+
1
)
**
kernel_size
.
size
(
0
)
basis
=
pseudo
.
new
(
num_nodes
,
S
)
weight_index
=
kernel_size
.
new
(
num_nodes
,
S
)
ffi_basis_forward
(
degree
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
)
basis_fw
(
degree
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
)
return
basis
,
weight_index
def
basis_backward
(
degree
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
):
grad_pseudo
=
pseudo
.
new
(
pseudo
.
size
())
ffi_
basis_b
ackward
(
degree
,
grad_pseudo
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
basis_b
w
(
degree
,
grad_pseudo
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
return
grad_pseudo
...
...
@@ -34,8 +33,8 @@ class SplineBasis(Function):
self
.
is_open_spline
)
def
backward
(
self
,
grad_basis
,
grad_weight_index
):
pseudo
,
=
self
.
saved_tensors
grad_pseudo
=
None
pseudo
,
=
self
.
saved_tensors
if
self
.
needs_input_grad
[
0
]:
grad_pseudo
=
basis_backward
(
self
.
degree
,
grad_basis
,
pseudo
,
...
...
torch_spline_conv/
spline_
conv.py
→
torch_spline_conv/conv.py
View file @
b5ac9f33
File moved
torch_spline_conv/utils/ffi.py
View file @
b5ac9f33
...
...
@@ -28,3 +28,23 @@ def basis_backward(degree, self, grad_basis, pseudo, kernel_size,
name
=
'{}BasisBackward'
.
format
(
get_degree_str
(
degree
))
func
=
get_func
(
name
,
self
.
is_cuda
,
self
)
func
(
self
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
def
weighting_forward
(
self
,
src
,
weight
,
basis
,
weight_index
):
func
=
get_func
(
'weightingForward'
,
self
.
is_cuda
,
self
)
func
(
self
,
src
,
weight
,
basis
,
weight_index
)
def
weighting_backward_src
(
self
,
grad_output
,
weight
,
basis
,
weight_index
):
func
=
get_func
(
'weightingBackwardSrc'
,
self
.
is_cuda
,
self
)
func
(
self
,
grad_output
,
weight
,
basis
,
weight_index
)
def
weighting_backward_weight
(
self
,
grad_output
,
src
,
basis
,
weight_index
):
func
=
get_func
(
'weightingBackwardWeight'
,
self
.
is_cuda
,
self
)
func
(
self
,
grad_output
,
src
,
basis
,
weight_index
)
def
weighting_backward_basis
(
self
,
grad_output
,
src
,
weight
,
weight_index
):
func
=
get_func
(
'weightingBackwardBasis'
,
self
.
is_cuda
,
self
)
func
(
self
,
grad_output
,
src
,
weight
,
weight_index
)
torch_spline_conv/weighting.py
0 → 100644
View file @
b5ac9f33
import
torch
from
torch.autograd
import
Function
from
.utils.ffi
import
weighting_forward
as
weighting_fw
from
.utils.ffi
import
weighting_backward_src
as
weighting_bw_src
from
.utils.ffi
import
weighting_backward_weight
as
weighting_bw_weight
from
.utils.ffi
import
weighting_backward_basis
as
weighting_bw_basis
def
weighting_forward
(
src
,
weight
,
basis
,
weight_index
):
output
=
src
.
new
(
src
.
size
(
0
),
weight
.
size
(
2
))
weighting_fw
(
output
,
src
,
weight
,
basis
,
weight_index
)
return
output
def
weighting_backward_src
(
grad_output
,
weight
,
basis
,
weight_index
):
grad_src
=
grad_output
.
new
(
grad_output
.
size
(
0
),
weight
.
size
(
1
))
weight
=
weight
.
transpose
(
1
,
2
).
contiguous
()
# Coalesced memory access.
weighting_bw_src
(
grad_src
,
grad_output
,
weight
,
basis
,
weight_index
)
return
grad_src
def
weighting_backward_weight
(
grad_output
,
src
,
basis
,
weight_index
,
K
):
grad_weight
=
src
.
new
(
K
,
src
.
size
(
1
),
grad_output
.
size
(
1
))
weighting_bw_weight
(
grad_weight
,
grad_output
,
src
,
basis
,
weight_index
)
return
grad_weight
def
weighting_backward_basis
(
grad_output
,
src
,
weight
,
weight_index
):
grad_basis
=
src
.
new
(
weight_index
.
size
())
weighting_bw_basis
(
grad_basis
,
grad_output
,
src
,
weight
,
weight_index
)
return
grad_basis
class
SplineWeighting
(
Function
):
def
__init__
(
self
,
weight_index
):
super
(
SplineWeighting
,
self
).
__init__
()
self
.
weight_index
=
weight_index
def
forward
(
self
,
src
,
weight
,
basis
):
self
.
save_for_backward
(
src
,
weight
,
basis
)
return
weighting_forward
(
src
,
weight
,
basis
,
self
.
weight_index
)
def
backward
(
self
,
grad_output
):
grad_src
=
grad_weight
=
grad_basis
=
None
src
,
weight
,
basis
=
self
.
saved_tensors
if
self
.
needs_input_grad
[
0
]:
grad_src
=
weighting_backward_src
(
grad_output
,
weight
,
basis
,
self
.
weight_index
)
if
self
.
needs_input_grad
[
1
]:
grad_weight
=
weighting_backward_weight
(
grad_output
,
src
,
basis
,
self
.
weight_index
)
if
self
.
needs_input_grad
[
2
]:
grad_basis
=
weighting_backward_basis
(
grad_output
,
src
,
weight
,
self
.
weight_index
)
return
grad_src
,
grad_weight
,
grad_basis
def
spline_weighting
(
src
,
weight
,
basis
,
weight_index
):
if
torch
.
is_tensor
(
src
):
return
weighting_forward
(
src
,
weight
,
basis
,
weight_index
)
else
:
return
SplineWeighting
(
weight_index
)(
src
,
weight
,
basis
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment