Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
8383b476
Commit
8383b476
authored
Mar 05, 2018
by
Jan Eric Lenssen
Browse files
merged master into bp_to_u
parents
f622968c
4df5512a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
322 additions
and
254 deletions
+322
-254
compute_spline_basis.py
compute_spline_basis.py
+192
-0
edgewise_spline_weighting.py
edgewise_spline_weighting.py
+4
-2
edgewise_spline_weighting_gpu.py
edgewise_spline_weighting_gpu.py
+67
-56
spline.py
spline.py
+4
-12
spline_conv.py
spline_conv.py
+13
-2
spline_cubic_gpu_test.py
spline_cubic_gpu_test.py
+14
-5
spline_linear_gpu.py
spline_linear_gpu.py
+0
-84
spline_linear_gpu_test.py
spline_linear_gpu_test.py
+14
-5
spline_quadratic_gpu.py
spline_quadratic_gpu.py
+0
-84
spline_quadratic_gpu_test.py
spline_quadratic_gpu_test.py
+14
-4
No files found.
spline_cubic_gpu
.py
→
compute_spline_basis
.py
View file @
8383b476
...
...
@@ -3,13 +3,105 @@ import torch
from
....utils.cuda
import
(
cuda_num_threads
,
Stream
,
Dtype
,
load_kernel
,
kernel_loop
,
get_blocks
)
_spline_kernel
=
kernel_loop
+
'''
_spline_kernel
_linear
=
kernel_loop
+
'''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline) {
const long* kernel_size, const long* is_open_spline
, int num_threads
) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
CUDA_KERNEL_LOOP(idx, num_threads) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
int K = ${K};
int k_idx_mod;
int bot;
int top;
${Dtype} value;
${Dtype} frac;
${Dtype} a = 1.0;
long i = 0;
for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
K /= kernel_size[d_idx];
k_idx_mod = k_idx % 2;
k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx];
value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value);
a *= (1 - k_idx_mod) * frac + k_idx_mod * (1 - frac);
bot = int(floor(value));
top = (bot + 1) % kernel_size[d_idx];
bot %= kernel_size[d_idx];
i += (k_idx_mod * bot + (1 - k_idx_mod) * top) * K;
}
amount[idx] = a;
index[idx] = i;
}
}
'''
_spline_kernel_quadratic
=
kernel_loop
+
'''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline, int num_threads) {
CUDA_KERNEL_LOOP(idx, num_threads) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
int K = ${K};
int k_idx_mod;
int pos;
${Dtype} value;
${Dtype} frac;
${Dtype} a = 1.0;
long i = 0;
for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
K /= kernel_size[d_idx];
k_idx_mod = k_idx % 3;
k_idx /= 3;
value = input[e_idx * ${dim} + d_idx] *
(kernel_size[d_idx] - (2 * is_open_spline[d_idx]));
frac = value - floor(value);
if (k_idx_mod == 0) a *= 0.5 * (1- frac) * (1-frac);
else if (k_idx_mod == 1) a *= -frac * frac + frac + 0.5;
else a *= 0.5 * frac * frac;
pos = int(floor(value)) + k_idx_mod;
pos %= kernel_size[d_idx];
i += pos * K;
}
amount[idx] = a;
index[idx] = i;
}
}
'''
_spline_kernel_cubic
=
kernel_loop
+
'''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline, int num_threads) {
CUDA_KERNEL_LOOP(idx, num_threads}) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
...
...
@@ -53,35 +145,48 @@ const long* kernel_size, const long* is_open_spline) {
'''
def
spline_cubic_gpu
(
input
,
kernel_size
,
is_open_spline
,
K
):
def
get_basis_kernel
(
k_max
,
K
,
dim
,
degree
):
if
degree
==
3
:
_spline_kernel
=
_spline_kernel_cubic
elif
degree
==
2
:
_spline_kernel
=
_spline_kernel_quadratic
else
:
_spline_kernel
=
_spline_kernel_linear
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
with
torch
.
cuda
.
device_of
(
cuda_tensor
):
f
=
load_kernel
(
'spline_kernel'
,
_spline_kernel
,
Dtype
=
'float'
,
k_max
=
k_max
,
dim
=
dim
,
K
=
K
)
return
f
def
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
K
,
basis_kernel
):
assert
input
.
is_cuda
and
kernel_size
.
is_cuda
and
is_open_spline
.
is_cuda
input
=
input
.
unsqueeze
(
1
)
if
len
(
input
.
size
())
<
2
else
input
num_edges
,
dim
=
input
.
size
()
k_max
=
4
**
dim
k_max
=
2
**
dim
amount
=
input
.
new
(
num_edges
,
k_max
)
index
=
input
.
new
(
num_edges
,
k_max
).
long
()
num_threads
=
amount
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
f
=
load_kernel
(
'spline_kernel'
,
_spline_kernel
,
Dtype
=
Dtype
(
input
),
num_threads
=
num_threads
,
k_max
=
k_max
,
dim
=
dim
,
K
=
K
)
f
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
amount
.
data_ptr
(),
index
.
data_ptr
(),
kernel_size
.
data_ptr
(),
is_open_spline
.
data_ptr
()
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
basis_kernel
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
amount
.
data_ptr
(),
index
.
data_ptr
(),
kernel_size
.
data_ptr
(),
is_open_spline
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
amount
,
index
edgewise_spline_weighting.py
View file @
8383b476
...
...
@@ -4,8 +4,10 @@ if torch.cuda.is_available():
from
.edgewise_spline_weighting_gpu
import
EdgewiseSplineWeightingGPU
def
edgewise_spline_weighting
(
input
,
weight
,
amount
,
index
):
def
edgewise_spline_weighting
(
input
,
weight
,
amount
,
index
,
k_fw
,
k_bw
):
if
input
.
is_cuda
:
return
EdgewiseSplineWeightingGPU
(
amount
,
index
)(
input
,
weight
)
K
,
M_in
,
M_out
=
weight
.
size
()
return
EdgewiseSplineWeightingGPU
(
amount
,
index
,
K
,
M_in
,
M_out
,
k_fw
,
k_bw
)(
input
,
weight
)
else
:
raise
NotImplementedError
edgewise_spline_weighting_gpu.py
View file @
8383b476
import
torch
from
torch.autograd
import
Function
from
....utils.cuda
import
(
cuda_num_threads
,
Stream
,
Dtype
,
load_kernel
,
kernel_loop
,
get_blocks
)
from
....utils.cuda
import
(
cuda_num_threads
,
Stream
,
load_kernel
,
kernel_loop
,
get_blocks
)
_edgewise_spline_weighting_forward_kernel
=
kernel_loop
+
'''
extern "C"
__global__ void edgewise_spline_weighting_forward_kernel(
const ${Dtype}* input, const ${Dtype}* weight, ${Dtype}* output,
const ${Dtype}* amount, const long* index) {
const ${Dtype}* amount, const long* index
, int num_threads
) {
CUDA_KERNEL_LOOP(idx,
${
num_threads
}
) {
CUDA_KERNEL_LOOP(idx, num_threads) {
const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out};
...
...
@@ -50,9 +50,9 @@ extern "C"
__global__ void edgewise_spline_weighting_backward_kernel(
const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight,
const ${Dtype}* input, const ${Dtype}* weight, const ${Dtype}* amount,
const long* index) {
const long* index
, int num_threads
) {
CUDA_KERNEL_LOOP(idx,
${
num_threads
}
) {
CUDA_KERNEL_LOOP(idx, num_threads) {
const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out};
...
...
@@ -86,7 +86,7 @@ const long* index) {
// Calculate weight gradient.
f = input[e_idx * ${M_in} + m_in_idx];
w_grad = f * b * g
rad_output[e_idx * ${M_out} + m_out_idx]
;
w_grad = f * b * g;
atomicAdd(&(grad_weight[w_idx]), w_grad);
// Not so efficient either, but not avoidable.
}
...
...
@@ -96,78 +96,89 @@ const long* index) {
'''
def
get_forward_kernel
(
M_in
,
M_out
,
k_max
):
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
with
torch
.
cuda
.
device_of
(
cuda_tensor
):
f_fw
=
load_kernel
(
'edgewise_spline_weighting_forward_kernel'
,
_edgewise_spline_weighting_forward_kernel
,
Dtype
=
'float'
,
M_in
=
M_in
,
M_out
=
M_out
,
k_max
=
k_max
)
return
f_fw
def
get_backward_kernel
(
M_in
,
M_out
,
k_max
,
K
):
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
with
torch
.
cuda
.
device_of
(
cuda_tensor
):
f_bw
=
load_kernel
(
'edgewise_spline_weighting_backward_kernel'
,
_edgewise_spline_weighting_backward_kernel
,
Dtype
=
'float'
,
M_in
=
M_in
,
M_out
=
M_out
,
k_max
=
k_max
,
K
=
K
)
return
f_bw
class
EdgewiseSplineWeightingGPU
(
Function
):
def
__init__
(
self
,
amount
,
index
):
def
__init__
(
self
,
amount
,
index
,
K
,
M_in
,
M_out
,
k_fw
,
k_bw
):
super
(
EdgewiseSplineWeightingGPU
,
self
).
__init__
()
assert
amount
.
is_cuda
and
index
.
is_cuda
self
.
amount
=
amount
self
.
index
=
index
self
.
M_in
=
M_in
self
.
M_out
=
M_out
self
.
K
=
K
self
.
f_fw
=
k_fw
self
.
f_bw
=
k_bw
def
forward
(
self
,
input
,
weight
):
assert
input
.
is_cuda
and
weight
.
is_cuda
self
.
save_for_backward
(
input
,
weight
)
_
,
M_in
,
M_out
=
weight
.
size
()
k_max
=
self
.
amount
.
size
(
1
)
output
=
input
.
new
(
input
.
size
(
0
),
M_out
)
output
=
input
.
new
(
input
.
size
(
0
),
self
.
M_out
)
num_threads
=
output
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
f
=
load_kernel
(
'edgewise_spline_weighting_forward_kernel'
,
_edgewise_spline_weighting_forward_kernel
,
Dtype
=
Dtype
(
input
),
num_threads
=
num_threads
,
M_in
=
M_in
,
M_out
=
M_out
,
k_max
=
k_max
)
f
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
weight
.
data_ptr
(),
output
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
()
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
self
.
f_fw
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
weight
.
data_ptr
(),
output
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
output
def
backward
(
self
,
grad_output
):
input
,
weight
=
self
.
saved_tensors
K
,
M_in
,
M_out
=
weight
.
size
()
k_max
=
self
.
amount
.
size
(
1
)
grad_input
=
grad_output
.
new
(
input
.
size
(
0
),
M_in
).
fill_
(
0
)
grad_weight
=
grad_output
.
new
(
K
,
M_in
,
M_out
).
fill_
(
0
)
grad_input
=
grad_output
.
new
(
input
.
size
(
0
),
self
.
M_in
).
fill_
(
0
)
grad_weight
=
grad_output
.
new
(
self
.
K
,
self
.
M_in
,
self
.
M_out
).
fill_
(
0
)
num_threads
=
grad_output
.
numel
()
with
torch
.
cuda
.
device_of
(
grad_output
):
f
=
load_kernel
(
'edgewise_spline_weighting_backward_kernel'
,
_edgewise_spline_weighting_backward_kernel
,
Dtype
=
Dtype
(
input
),
num_threads
=
num_threads
,
M_in
=
M_in
,
M_out
=
M_out
,
k_max
=
k_max
,
K
=
K
)
f
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
grad_output
.
data_ptr
(),
grad_input
.
data_ptr
(),
grad_weight
.
data_ptr
(),
input
.
data_ptr
(),
weight
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
()
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
self
.
f_bw
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
grad_output
.
data_ptr
(),
grad_input
.
data_ptr
(),
grad_weight
.
data_ptr
(),
input
.
data_ptr
(),
weight
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
(),
num_threads
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
grad_input
,
grad_weight
spline.py
View file @
8383b476
import
torch
if
torch
.
cuda
.
is_available
():
from
.spline_linear_gpu
import
spline_linear_gpu
from
.spline_quadratic_gpu
import
spline_quadratic_gpu
from
.spline_cubic_gpu
import
spline_cubic_gpu
from
.compute_spline_basis
import
compute_spline_basis
def
spline
(
input
,
kernel_size
,
is_open_spline
,
K
,
degree
):
def
spline
(
input
,
kernel_size
,
is_open_spline
,
K
,
degree
,
basis_kernel
):
if
input
.
is_cuda
:
if
degree
==
1
:
return
spline_linear_gpu
(
input
,
kernel_size
,
is_open_spline
,
K
)
if
degree
==
2
:
return
spline_quadratic_gpu
(
input
,
kernel_size
,
is_open_spline
,
K
)
if
degree
==
3
:
return
spline_cubic_gpu
(
input
,
kernel_size
,
is_open_spline
,
K
)
else
:
raise
NotImplementedError
()
return
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
K
,
basis_kernel
)
else
:
raise
NotImplementedError
()
spline_conv.py
View file @
8383b476
import
torch
from
torch.autograd
import
Variable
import
time
from
.spline
import
spline
...
...
@@ -13,17 +14,27 @@ def spline_conv(
kernel_size
,
is_open_spline
,
K
,
forward_kernel
,
backward_kernel
,
basis_kernel
,
degree
=
1
,
bias
=
None
):
if
input
.
dim
()
==
1
:
input
=
input
.
unsqueeze
(
1
)
values
=
adj
.
_values
()
row
,
col
=
adj
.
_indices
()
# Get features for every end vertex with shape [|E| x M_in].
output
=
input
[
col
]
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
amount
,
index
=
spline
(
values
,
kernel_size
,
is_open_spline
,
K
,
degree
)
output
=
edgewise_spline_weighting
(
output
,
weight
[:
-
1
],
amount
,
index
)
amount
,
index
=
spline
(
values
,
kernel_size
,
is_open_spline
,
K
,
degree
,
basis_kernel
)
output
=
edgewise_spline_weighting
(
output
,
weight
[:
-
1
],
amount
,
index
,
forward_kernel
,
backward_kernel
)
# Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to
# [n x M_out] feature matrix.
...
...
spline_cubic_gpu_test.py
View file @
8383b476
...
...
@@ -4,7 +4,8 @@ import torch
from
numpy.testing
import
assert_equal
,
assert_almost_equal
if
torch
.
cuda
.
is_available
():
from
.spline_cubic_gpu
import
spline_cubic_gpu
from
.compute_spline_basis
import
compute_spline_basis
from
.compute_spline_basis
import
get_basis_kernel
class
SplineQuadraticGPUTest
(
unittest
.
TestCase
):
...
...
@@ -13,8 +14,12 @@ class SplineQuadraticGPUTest(unittest.TestCase):
input
=
torch
.
cuda
.
FloatTensor
([
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
])
kernel_size
=
torch
.
cuda
.
LongTensor
([
7
])
is_open_spline
=
torch
.
cuda
.
LongTensor
([
1
])
a1
,
i1
=
spline_cubic_gpu
(
input
,
kernel_size
,
is_open_spline
,
7
)
k_max
=
4
K
=
7
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
3
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
7
,
basis_kernel
)
a2
=
[
[
0.1667
,
0.6667
,
0.1667
,
0
],
...
...
@@ -36,8 +41,12 @@ class SplineQuadraticGPUTest(unittest.TestCase):
input
=
torch
.
cuda
.
FloatTensor
([
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
])
kernel_size
=
torch
.
cuda
.
LongTensor
([
4
])
is_open_spline
=
torch
.
cuda
.
LongTensor
([
0
])
a1
,
i1
=
spline_cubic_gpu
(
input
,
kernel_size
,
is_open_spline
,
4
)
k_max
=
4
K
=
4
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
3
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a2
=
[
[
0.1667
,
0.6667
,
0.1667
,
0
],
...
...
spline_linear_gpu.py
deleted
100644 → 0
View file @
f622968c
import
torch
from
....utils.cuda
import
(
cuda_num_threads
,
Stream
,
Dtype
,
load_kernel
,
kernel_loop
,
get_blocks
)
_spline_kernel
=
kernel_loop
+
'''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
int K = ${K};
int k_idx_mod;
int bot;
int top;
${Dtype} value;
${Dtype} frac;
${Dtype} a = 1.0;
long i = 0;
for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
K /= kernel_size[d_idx];
k_idx_mod = k_idx % 2;
k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx] *
(kernel_size[d_idx] - is_open_spline[d_idx]);
frac = value - floor(value);
a *= (1 - k_idx_mod) * frac + k_idx_mod * (1 - frac);
bot = int(floor(value));
top = (bot + 1) % kernel_size[d_idx];
bot %= kernel_size[d_idx];
i += (k_idx_mod * bot + (1 - k_idx_mod) * top) * K;
}
amount[idx] = a;
index[idx] = i;
}
}
'''
def
spline_linear_gpu
(
input
,
kernel_size
,
is_open_spline
,
K
):
assert
input
.
is_cuda
and
kernel_size
.
is_cuda
and
is_open_spline
.
is_cuda
input
=
input
.
unsqueeze
(
1
)
if
len
(
input
.
size
())
<
2
else
input
num_edges
,
dim
=
input
.
size
()
k_max
=
2
**
dim
amount
=
input
.
new
(
num_edges
,
k_max
)
index
=
input
.
new
(
num_edges
,
k_max
).
long
()
num_threads
=
amount
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
f
=
load_kernel
(
'spline_kernel'
,
_spline_kernel
,
Dtype
=
Dtype
(
input
),
num_threads
=
num_threads
,
k_max
=
k_max
,
dim
=
dim
,
K
=
K
)
f
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
amount
.
data_ptr
(),
index
.
data_ptr
(),
kernel_size
.
data_ptr
(),
is_open_spline
.
data_ptr
()
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
amount
,
index
spline_linear_gpu_test.py
View file @
8383b476
...
...
@@ -4,7 +4,8 @@ import torch
from
numpy.testing
import
assert_equal
,
assert_almost_equal
if
torch
.
cuda
.
is_available
():
from
.spline_linear_gpu
import
spline_linear_gpu
from
.compute_spline_basis
import
compute_spline_basis
from
.compute_spline_basis
import
get_basis_kernel
class
SplineLinearGPUTest
(
unittest
.
TestCase
):
...
...
@@ -13,8 +14,12 @@ class SplineLinearGPUTest(unittest.TestCase):
input
=
torch
.
cuda
.
FloatTensor
([
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
])
kernel_size
=
torch
.
cuda
.
LongTensor
([
5
])
is_open_spline
=
torch
.
cuda
.
LongTensor
([
1
])
a1
,
i1
=
spline_linear_gpu
(
input
,
kernel_size
,
is_open_spline
,
5
)
k_max
=
2
K
=
5
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
1
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
5
,
basis_kernel
)
a2
=
[[
0
,
1
],
[
0.2
,
0.8
],
[
0
,
1
],
[
0
,
1
],
[
0
,
1
],
[
0.8
,
0.2
],
[
0
,
1
]]
i2
=
[[
1
,
0
],
[
1
,
0
],
[
2
,
1
],
[
3
,
2
],
[
4
,
3
],
[
4
,
3
],
[
0
,
4
]]
...
...
@@ -27,8 +32,12 @@ class SplineLinearGPUTest(unittest.TestCase):
input
=
torch
.
cuda
.
FloatTensor
([
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
])
kernel_size
=
torch
.
cuda
.
LongTensor
([
4
])
is_open_spline
=
torch
.
cuda
.
LongTensor
([
0
])
a1
,
i1
=
spline_linear_gpu
(
input
,
kernel_size
,
is_open_spline
,
4
)
k_max
=
2
K
=
4
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
1
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a2
=
[[
0
,
1
],
[
0.2
,
0.8
],
[
0
,
1
],
[
0
,
1
],
[
0
,
1
],
[
0.8
,
0.2
],
[
0
,
1
]]
i2
=
[[
1
,
0
],
[
1
,
0
],
[
2
,
1
],
[
3
,
2
],
[
0
,
3
],
[
0
,
3
],
[
1
,
0
]]
...
...
spline_quadratic_gpu.py
deleted
100644 → 0
View file @
f622968c
import
torch
from
....utils.cuda
import
(
cuda_num_threads
,
Stream
,
Dtype
,
load_kernel
,
kernel_loop
,
get_blocks
)
_spline_kernel
=
kernel_loop
+
'''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
int K = ${K};
int k_idx_mod;
int pos;
${Dtype} value;
${Dtype} frac;
${Dtype} a = 1.0;
long i = 0;
for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
K /= kernel_size[d_idx];
k_idx_mod = k_idx % 3;
k_idx /= 3;
value = input[e_idx * ${dim} + d_idx] *
(kernel_size[d_idx] - (2 * is_open_spline[d_idx]));
frac = value - floor(value);
if (k_idx_mod == 0) a *= 0.5 * (1- frac) * (1-frac);
else if (k_idx_mod == 1) a *= -frac * frac + frac + 0.5;
else a *= 0.5 * frac * frac;
pos = int(floor(value)) + k_idx_mod;
pos %= kernel_size[d_idx];
i += pos * K;
}
amount[idx] = a;
index[idx] = i;
}
}
'''
def
spline_quadratic_gpu
(
input
,
kernel_size
,
is_open_spline
,
K
):
assert
input
.
is_cuda
and
kernel_size
.
is_cuda
and
is_open_spline
.
is_cuda
input
=
input
.
unsqueeze
(
1
)
if
len
(
input
.
size
())
<
2
else
input
num_edges
,
dim
=
input
.
size
()
k_max
=
3
**
dim
amount
=
input
.
new
(
num_edges
,
k_max
)
index
=
input
.
new
(
num_edges
,
k_max
).
long
()
num_threads
=
amount
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
f
=
load_kernel
(
'spline_kernel'
,
_spline_kernel
,
Dtype
=
Dtype
(
input
),
num_threads
=
num_threads
,
k_max
=
k_max
,
dim
=
dim
,
K
=
K
)
f
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
amount
.
data_ptr
(),
index
.
data_ptr
(),
kernel_size
.
data_ptr
(),
is_open_spline
.
data_ptr
()
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
amount
,
index
spline_quadratic_gpu_test.py
View file @
8383b476
...
...
@@ -4,7 +4,8 @@ import torch
from
numpy.testing
import
assert_equal
,
assert_almost_equal
if
torch
.
cuda
.
is_available
():
from
.spline_quadratic_gpu
import
spline_quadratic_gpu
from
.compute_spline_basis
import
compute_spline_basis
from
.compute_spline_basis
import
get_basis_kernel
class
SplineQuadraticGPUTest
(
unittest
.
TestCase
):
...
...
@@ -13,8 +14,13 @@ class SplineQuadraticGPUTest(unittest.TestCase):
input
=
torch
.
cuda
.
FloatTensor
([
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
])
kernel_size
=
torch
.
cuda
.
LongTensor
([
6
])
is_open_spline
=
torch
.
cuda
.
LongTensor
([
1
])
k_max
=
3
K
=
6
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
2
)
a1
,
i1
=
spline_quadratic_gpu
(
input
,
kernel_size
,
is_open_spline
,
6
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
6
,
basis_kernel
)
a2
=
[[
0.5
,
0.5
,
0
],
[
0.32
,
0.66
,
0.02
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.02
,
0.66
,
0.32
],
[
0.5
,
0.5
,
0
]]
...
...
@@ -29,8 +35,12 @@ class SplineQuadraticGPUTest(unittest.TestCase):
input
=
torch
.
cuda
.
FloatTensor
([
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
])
kernel_size
=
torch
.
cuda
.
LongTensor
([
4
])
is_open_spline
=
torch
.
cuda
.
LongTensor
([
0
])
a1
,
i1
=
spline_quadratic_gpu
(
input
,
kernel_size
,
is_open_spline
,
4
)
k_max
=
3
K
=
4
dim
=
1
basis_kernel
=
get_basis_kernel
(
k_max
,
K
,
dim
,
2
)
a1
,
i1
=
compute_spline_basis
(
input
,
kernel_size
,
is_open_spline
,
4
,
basis_kernel
)
a2
=
[[
0.5
,
0.5
,
0
],
[
0.32
,
0.66
,
0.02
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.02
,
0.66
,
0.32
],
[
0.5
,
0.5
,
0
]]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment