Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
f622968c
Commit
f622968c
authored
Nov 28, 2017
by
Jan Eric Lenssen
Browse files
started backprop to adj
parent
c16689ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
358 additions
and
1 deletion
+358
-1
__init__.py
__init__.py
+2
-1
edgewise_spline_weighting_bp2adj_gpu.py
edgewise_spline_weighting_bp2adj_gpu.py
+309
-0
spline_conv_bp2adj.py
spline_conv_bp2adj.py
+47
-0
No files found.
__init__.py
View file @
f622968c
from
.spline_conv
import
spline_conv
from
.spline_conv_bp2adj
import
spline_conv_bp2adj
__all__
=
[
'spline_conv'
]
__all__
=
[
'spline_conv'
,
'spline_conv_bp2adj'
]
edgewise_spline_weighting_bp2adj_gpu.py
0 → 100644
View file @
f622968c
import
torch
from
torch.autograd
import
Function
from
....utils.cuda
import
(
cuda_num_threads
,
Stream
,
Dtype
,
load_kernel
,
kernel_loop
,
get_blocks
)
_edgewise_spline_weighting_forward_kernel
=
kernel_loop
+
'''
extern "C"
__global__ void edgewise_spline_weighting_forward_kernel(
const ${Dtype}* input, const ${Dtype}* weight, ${Dtype}* output,
const ${Dtype}* amount, const long* index) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out};
${Dtype} result = 0.0;
${Dtype} w;
${Dtype} f;
int k;
${Dtype} b;
long c;
long w_idx;
for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
k = e_idx * ${k_max} + k_idx;
b = amount[k];
c = index[k];
for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
w_idx = c * ${M_out} * ${M_in} +
m_in_idx * ${M_out} +
m_out_idx;
w = weight[w_idx];
f = input[e_idx * ${M_in} + m_in_idx];
result += b * w * f;
}
}
output[idx] = result;
}
}
'''
_edgewise_spline_weighting_backward_kernel
=
kernel_loop
+
'''
extern "C"
__global__ void edgewise_spline_weighting_backward_kernel(
const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight,
${Dtype}* grad_b, const ${Dtype}* input, const ${Dtype}* weight,
const ${Dtype}* amount, const long* index) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out};
${Dtype} w;
${Dtype} g;
${Dtype} f;
${Dtype} w_grad;
${Dtype} b_grad;
int k;
${Dtype} b;
long c;
long w_idx;
for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
k = e_idx * ${k_max} + k_idx;
b = amount[k];
c = index[k];
for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
w_idx = c * ${M_out} * ${M_in} +
m_in_idx * ${M_out} +
m_out_idx;
w = weight[w_idx];
// Calculate input gradient.
g = grad_output[e_idx * ${M_out} + m_out_idx];
atomicAdd(&(grad_input[e_idx * ${M_in} + m_in_idx]), b * w * g);
// This is inefficient: `reduce_sum` shouldn't be done like this.
// Looping over `M_out` would be better to avoid the `atomicAdd`.
// Calculate weight gradient.
f = input[e_idx * ${M_in} + m_in_idx];
w_grad = f * b * grad_output[e_idx * ${M_out} + m_out_idx];
atomicAdd(&(grad_weight[w_idx]), w_grad);
// Not so efficient either, but not avoidable.
//Calculate gradient of B
b_grad = f * w * g;
atomicAdd(&(grad_b[k]), b_grad);
}
}
}
}
'''
_spline_kernel
=
kernel_loop
+
'''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
int K = ${K};
int k_idx_mod;
int bot;
int top;
${Dtype} value;
${Dtype} frac;
${Dtype} a = 1.0;
long i = 0;
for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
K /= kernel_size[d_idx];
k_idx_mod = k_idx % 2;
k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx] *
(kernel_size[d_idx] - is_open_spline[d_idx]);
frac = value - floor(value);
a *= (1 - k_idx_mod) * frac + k_idx_mod * (1 - frac);
bot = int(floor(value));
top = (bot + 1) % kernel_size[d_idx];
bot %= kernel_size[d_idx];
i += (k_idx_mod * bot + (1 - k_idx_mod) * top) * K;
}
amount[idx] = a;
index[idx] = i;
}
}
'''
_spline_kernel_backwards
=
kernel_loop
+
'''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
int K = ${K};
int k_idx_mod;
int bot;
int top;
${Dtype} value;
${Dtype} frac;
${Dtype} a = 1.0;
long i = 0;
for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
K /= kernel_size[d_idx];
k_idx_mod = k_idx % 2;
k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx] *
(kernel_size[d_idx] - is_open_spline[d_idx]);
frac = value - floor(value);
a *= (1 - k_idx_mod) * frac + k_idx_mod * (1 - frac);
bot = int(floor(value));
top = (bot + 1) % kernel_size[d_idx];
bot %= kernel_size[d_idx];
i += (k_idx_mod * bot + (1 - k_idx_mod) * top) * K;
}
amount[idx] = a;
index[idx] = i;
}
}
'''
class
EdgewiseSplineWeightingGPU
(
Function
):
def
__init__
(
self
,
kernel_size
,
is_open_spline
,
K
):
super
(
EdgewiseSplineWeightingGPU
,
self
).
__init__
()
assert
kernel_size
.
is_cuda
and
is_open_spline
.
is_cuda
self
.
kernel_size
=
kernel_size
self
.
is_open_spline
=
is_open_spline
self
.
K
=
K
def
forward
(
self
,
input
,
values
,
weight
):
assert
input
.
is_cuda
and
weight
.
is_cuda
self
.
save_for_backward
(
input
,
weight
)
values
=
values
.
unsqueeze
(
1
)
if
len
(
values
.
size
())
<
2
else
values
num_edges
,
dim
=
values
.
size
()
k_max
=
2
**
dim
amount
=
values
.
new
(
num_edges
,
k_max
)
index
=
values
.
new
(
num_edges
,
k_max
).
long
()
num_threads
=
amount
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
f
=
load_kernel
(
'spline_kernel'
,
_spline_kernel
,
Dtype
=
Dtype
(
input
),
num_threads
=
num_threads
,
k_max
=
k_max
,
dim
=
dim
,
K
=
self
.
K
)
f
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
values
.
data_ptr
(),
amount
.
data_ptr
(),
index
.
data_ptr
(),
self
.
kernel_size
.
data_ptr
(),
self
.
is_open_spline
.
data_ptr
()
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
self
.
amount
=
amount
self
.
index
=
index
_
,
M_in
,
M_out
=
weight
.
size
()
k_max
=
self
.
amount
.
size
(
1
)
output
=
input
.
new
(
input
.
size
(
0
),
M_out
)
num_threads
=
output
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
f
=
load_kernel
(
'edgewise_spline_weighting_forward_kernel'
,
_edgewise_spline_weighting_forward_kernel
,
Dtype
=
Dtype
(
input
),
num_threads
=
num_threads
,
M_in
=
M_in
,
M_out
=
M_out
,
k_max
=
k_max
)
f
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
input
.
data_ptr
(),
weight
.
data_ptr
(),
output
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
()
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
return
output
def
backward
(
self
,
grad_output
):
input
,
weight
=
self
.
saved_tensors
K
,
M_in
,
M_out
=
weight
.
size
()
E
,
k_max
=
self
.
amount
.
size
()
grad_input
=
grad_output
.
new
(
input
.
size
(
0
),
M_in
).
fill_
(
0
)
grad_weight
=
grad_output
.
new
(
K
,
M_in
,
M_out
).
fill_
(
0
)
grad_b
=
grad_output
.
new
(
E
,
k_max
).
fill_
(
0
)
num_threads
=
grad_output
.
numel
()
with
torch
.
cuda
.
device_of
(
grad_output
):
f
=
load_kernel
(
'edgewise_spline_weighting_backward_kernel'
,
_edgewise_spline_weighting_backward_kernel
,
Dtype
=
Dtype
(
input
),
num_threads
=
num_threads
,
M_in
=
M_in
,
M_out
=
M_out
,
k_max
=
k_max
,
K
=
K
)
f
(
block
=
(
cuda_num_threads
,
1
,
1
),
grid
=
(
get_blocks
(
num_threads
),
1
,
1
),
args
=
[
grad_output
.
data_ptr
(),
grad_input
.
data_ptr
(),
grad_weight
.
data_ptr
(),
grad_b
.
data_ptr
(),
input
.
data_ptr
(),
weight
.
data_ptr
(),
self
.
amount
.
data_ptr
(),
self
.
index
.
data_ptr
()
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
grad_b
=
grad_output
.
new
(
E
,
k_max
).
fill_
(
0
)
return
grad_input
,
grad_weight
,
grad_u
spline_conv_bp2adj.py
0 → 100644
View file @
f622968c
import
torch
from
torch.autograd
import
Variable
from
.spline
import
spline
from
.edgewise_spline_weighting_bp2adj_gpu
import
EdgewiseSplineWeightingGPU
def
spline_conv_bp2adj
(
adj
,
# Tensor
input
,
# Variable
weight
,
# Variable
kernel_size
,
is_open_spline
,
K
,
degree
=
1
,
bias
=
None
):
values
=
adj
.
_values
()
row
,
col
=
adj
.
_indices
()
# Get features for every end vertex with shape [|E| x M_in].
output
=
input
[
col
]
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
output
=
EdgewiseSplineWeightingGPU
(
kernel_size
,
is_open_spline
,
K
)(
output
,
weight
[:
-
1
],
values
)
# Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to
# [n x M_out] feature matrix.
zero
=
output
.
data
.
new
(
adj
.
size
(
1
),
output
.
size
(
1
)).
fill_
(
0.0
)
zero
=
Variable
(
zero
)
if
not
torch
.
is_tensor
(
output
)
else
zero
r
=
row
.
view
(
-
1
,
1
).
expand
(
row
.
size
(
0
),
output
.
size
(
1
))
output
=
zero
.
scatter_add_
(
0
,
Variable
(
r
),
output
)
# Weighten root node features by multiplying with root weight.
output
+=
torch
.
mm
(
input
,
weight
[
-
1
])
# Normalize output by degree.
ones
=
values
.
new
(
values
.
size
(
0
)).
fill_
(
1
)
zero
=
values
.
new
(
output
.
size
(
0
)).
fill_
(
0
)
degree
=
zero
.
scatter_add_
(
0
,
row
,
ones
)
degree
=
torch
.
clamp
(
degree
,
min
=
1
)
output
=
output
/
Variable
(
degree
.
view
(
-
1
,
1
))
if
bias
is
not
None
:
output
+=
bias
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment