Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
2cde9023
Commit
2cde9023
authored
Feb 14, 2018
by
Jan Eric Lenssen
Browse files
complied to coding standards
parent
a8109737
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
60 additions
and
52 deletions
+60
-52
compute_spline_basis.py
compute_spline_basis.py
+15
-14
edgewise_spline_weighting.py
edgewise_spline_weighting.py
+1
-1
edgewise_spline_weighting_gpu.py
edgewise_spline_weighting_gpu.py
+28
-28
spline.py
spline.py
+2
-1
spline_cubic_gpu_test.py
spline_cubic_gpu_test.py
+4
-2
spline_linear_gpu_test.py
spline_linear_gpu_test.py
+4
-2
spline_quadratic_gpu_test.py
spline_quadratic_gpu_test.py
+6
-4
No files found.
compute_spline_basis.py
View file @
2cde9023
...
...
@@ -144,11 +144,11 @@ const long* kernel_size, const long* is_open_spline, int num_threads) {
}
'''
def
get_basis_kernel
(
k_max
,
K
,
dim
,
degree
):
if
degree
==
3
:
def
get_basis_kernel
(
k_max
,
K
,
dim
,
degree
):
if
degree
==
3
:
_spline_kernel
=
_spline_kernel_cubic
elif
degree
==
2
:
elif
degree
==
2
:
_spline_kernel
=
_spline_kernel_quadratic
else
:
_spline_kernel
=
_spline_kernel_linear
...
...
@@ -164,12 +164,13 @@ def get_basis_kernel(k_max,K,dim,degree):
K
=
K
)
return
f
def
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
K
,
basis_kernel
):
assert
input
.
is_cuda
and
kernel_size
.
is_cuda
and
is_open_spline
.
is_cuda
input
=
input
.
unsqueeze
(
1
)
if
len
(
input
.
size
())
<
2
else
input
num_edges
,
dim
=
input
.
size
()
k_max
=
2
**
dim
k_max
=
2
**
dim
amount
=
input
.
new
(
num_edges
,
k_max
)
index
=
input
.
new
(
num_edges
,
k_max
).
long
()
...
...
@@ -177,15 +178,15 @@ def compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel):
with
torch
.
cuda
.
device_of
(
input
):
basis_kernel
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
amount
.
data_ptr
(),
index
.
data_ptr
(),
kernel_size
.
data_ptr
(),
is_open_spline
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
amount
.
data_ptr
(),
index
.
data_ptr
(),
kernel_size
.
data_ptr
(),
is_open_spline
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
amount
,
index
edgewise_spline_weighting.py
View file @
2cde9023
...
...
@@ -8,6 +8,6 @@ def edgewise_spline_weighting(input, weight, amount, index, k_fw, k_bw):
if
input
.
is_cuda
:
K
,
M_in
,
M_out
=
weight
.
size
()
return
EdgewiseSplineWeightingGPU
(
amount
,
index
,
K
,
M_in
,
M_out
,
k_fw
,
k_bw
)(
input
,
weight
)
,
k_fw
,
k_bw
)(
input
,
weight
)
else
:
raise
NotImplementedError
edgewise_spline_weighting_gpu.py
View file @
2cde9023
...
...
@@ -95,7 +95,8 @@ const long* index, int num_threads) {
}
'''
def
get_forward_kernel
(
M_in
,
M_out
,
k_max
):
def
get_forward_kernel
(
M_in
,
M_out
,
k_max
):
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
with
torch
.
cuda
.
device_of
(
cuda_tensor
):
f_fw
=
load_kernel
(
...
...
@@ -107,7 +108,8 @@ def get_forward_kernel(M_in,M_out,k_max):
k_max
=
k_max
)
return
f_fw
def
get_backward_kernel
(
M_in
,
M_out
,
k_max
,
K
):
def
get_backward_kernel
(
M_in
,
M_out
,
k_max
,
K
):
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
with
torch
.
cuda
.
device_of
(
cuda_tensor
):
f_bw
=
load_kernel
(
...
...
@@ -133,36 +135,33 @@ class EdgewiseSplineWeightingGPU(Function):
self
.
f_fw
=
k_fw
self
.
f_bw
=
k_bw
def
forward
(
self
,
input
,
weight
):
assert
input
.
is_cuda
and
weight
.
is_cuda
self
.
save_for_backward
(
input
,
weight
)
output
=
input
.
new
(
input
.
size
(
0
),
self
.
M_out
)
num_threads
=
output
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
self
.
f_fw
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
weight
.
data_ptr
(),
output
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
weight
.
data_ptr
(),
output
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
output
def
backward
(
self
,
grad_output
):
input
,
weight
=
self
.
saved_tensors
grad_input
=
grad_output
.
new
(
input
.
size
(
0
),
self
.
M_in
).
fill_
(
0
)
grad_weight
=
grad_output
.
new
(
self
.
K
,
self
.
M_in
,
self
.
M_out
).
fill_
(
0
)
...
...
@@ -170,17 +169,18 @@ class EdgewiseSplineWeightingGPU(Function):
with
torch
.
cuda
.
device_of
(
grad_output
):
self
.
f_bw
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
grad_output
.
data_ptr
(),
grad_input
.
data_ptr
(),
grad_weight
.
data_ptr
(),
input
.
data_ptr
(),
weight
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
grad_output
.
data_ptr
(),
grad_input
.
data_ptr
(),
grad_weight
.
data_ptr
(),
input
.
data_ptr
(),
weight
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
grad_input
,
grad_weight
spline.py
View file @
2cde9023
...
...
@@ -6,6 +6,7 @@ if torch.cuda.is_available():
def
spline
(
input
,
kernel_size
,
is_open_spline
,
K
,
degree
,
basis_kernel
):
if
input
.
is_cuda
:
return
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
K
,
basis_kernel
)
return
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
K
,
basis_kernel
)
else
:
raise
NotImplementedError
()
spline_cubic_gpu_test.py
View file @
2cde9023
...
...
@@ -18,7 +18,8 @@ class SplineQuadraticGPUTest(unittest.TestCase):
K
=
7
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
3
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
7
,
basis_kernel
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
7
,
basis_kernel
)
a2
=
[
[
0.1667
,
0.6667
,
0.1667
,
0
],
...
...
@@ -44,7 +45,8 @@ class SplineQuadraticGPUTest(unittest.TestCase):
K
=
4
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
3
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a2
=
[
[
0.1667
,
0.6667
,
0.1667
,
0
],
...
...
spline_linear_gpu_test.py
View file @
2cde9023
...
...
@@ -18,7 +18,8 @@ class SplineLinearGPUTest(unittest.TestCase):
K
=
5
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
1
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
5
,
basis_kernel
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
5
,
basis_kernel
)
a2
=
[[
0
,
1
],
[
0.2
,
0.8
],
[
0
,
1
],
[
0
,
1
],
[
0
,
1
],
[
0.8
,
0.2
],
[
0
,
1
]]
i2
=
[[
1
,
0
],
[
1
,
0
],
[
2
,
1
],
[
3
,
2
],
[
4
,
3
],
[
4
,
3
],
[
0
,
4
]]
...
...
@@ -35,7 +36,8 @@ class SplineLinearGPUTest(unittest.TestCase):
K
=
4
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
1
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a2
=
[[
0
,
1
],
[
0.2
,
0.8
],
[
0
,
1
],
[
0
,
1
],
[
0
,
1
],
[
0.8
,
0.2
],
[
0
,
1
]]
i2
=
[[
1
,
0
],
[
1
,
0
],
[
2
,
1
],
[
3
,
2
],
[
0
,
3
],
[
0
,
3
],
[
1
,
0
]]
...
...
spline_quadratic_gpu_test.py
View file @
2cde9023
...
...
@@ -16,10 +16,11 @@ class SplineQuadraticGPUTest(unittest.TestCase):
is_open_spline
=
torch
.
cuda
.
LongTensor
([
1
])
k_max
=
3
K
=
6
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
2
)
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
2
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
6
,
basis_kernel
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
6
,
basis_kernel
)
a2
=
[[
0.5
,
0.5
,
0
],
[
0.32
,
0.66
,
0.02
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.02
,
0.66
,
0.32
],
[
0.5
,
0.5
,
0
]]
...
...
@@ -38,7 +39,8 @@ class SplineQuadraticGPUTest(unittest.TestCase):
K
=
4
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
2
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a2
=
[[
0.5
,
0.5
,
0
],
[
0.32
,
0.66
,
0.02
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.02
,
0.66
,
0.32
],
[
0.5
,
0.5
,
0
]]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment