Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
2de0f11f
Commit
2de0f11f
authored
Mar 07, 2018
by
Jan Eric Lenssen
Browse files
bugfixes adj gradient
parent
9c208e8e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
108 additions
and
33 deletions
+108
-33
spline_conv.py
spline_conv.py
+1
-0
spline_conv_gpu.py
spline_conv_gpu.py
+34
-22
spline_conv_test.py
spline_conv_test.py
+73
-11
No files found.
spline_conv.py
View file @
2de0f11f
...
@@ -29,6 +29,7 @@ def spline_conv(
...
@@ -29,6 +29,7 @@ def spline_conv(
output
=
input
[
col
]
output
=
input
[
col
]
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
if
output
.
is_cuda
:
if
output
.
is_cuda
:
output
=
SplineConvGPU
(
kernel_size
,
is_open_spline
,
K
,
degree
,
output
=
SplineConvGPU
(
kernel_size
,
is_open_spline
,
K
,
degree
,
basis_kernel
,
basis_backward_kernel
,
basis_kernel
,
basis_backward_kernel
,
...
...
spline_conv_gpu.py
View file @
2de0f11f
...
@@ -144,28 +144,29 @@ const ${Dtype}* amount, const long* index, int num_threads) {
...
@@ -144,28 +144,29 @@ const ${Dtype}* amount, const long* index, int num_threads) {
// Calculate B-spline basis tensor product gradient
// Calculate B-spline basis tensor product gradient
adj_g += g * f * w;
adj_g += g * f * w;
}
}
atomicAdd(&(grad_amount[e_idx*${k_max} +k_idx]), adj_g);
atomicAdd(&(grad_amount[e_idx*${k_max} +
k_idx]), adj_g);
}
}
}
}
}
}
'''
'''
def
get_weighting_forward_kernel
(
M_in
,
M_out
,
k_max
):
def
get_weighting_forward_kernel
(
M_in
,
M_out
,
k_max
,
dtype
=
'float'
):
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
kernel
=
_edgewise_spline_weighting_forward_kernel
kernel
=
_edgewise_spline_weighting_forward_kernel
with
torch
.
cuda
.
device_of
(
cuda_tensor
):
with
torch
.
cuda
.
device_of
(
cuda_tensor
):
f_fw
=
load_kernel
(
f_fw
=
load_kernel
(
'edgewise_spline_weighting_forward_kernel'
,
'edgewise_spline_weighting_forward_kernel'
,
kernel
,
kernel
,
Dtype
=
'float'
,
Dtype
=
dtype
,
M_in
=
M_in
,
M_in
=
M_in
,
M_out
=
M_out
,
M_out
=
M_out
,
k_max
=
k_max
)
k_max
=
k_max
)
return
f_fw
return
f_fw
def
get_weighting_backward_kernel
(
M_in
,
M_out
,
k_max
,
K
,
bp_to_adj
=
False
):
def
get_weighting_backward_kernel
(
M_in
,
M_out
,
k_max
,
K
,
bp_to_adj
=
False
,
dtype
=
'float'
):
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
cuda_tensor
=
torch
.
FloatTensor
([
1
]).
cuda
()
if
bp_to_adj
:
if
bp_to_adj
:
kernel
=
_edgewise_spline_weighting_backward_kernel_bp2adj
kernel
=
_edgewise_spline_weighting_backward_kernel_bp2adj
...
@@ -175,7 +176,7 @@ def get_weighting_backward_kernel(M_in, M_out, k_max, K, bp_to_adj=False):
...
@@ -175,7 +176,7 @@ def get_weighting_backward_kernel(M_in, M_out, k_max, K, bp_to_adj=False):
f_bw
=
load_kernel
(
f_bw
=
load_kernel
(
'edgewise_spline_weighting_backward_kernel'
,
'edgewise_spline_weighting_backward_kernel'
,
kernel
,
kernel
,
Dtype
=
'float'
,
Dtype
=
dtype
,
M_in
=
M_in
,
M_in
=
M_in
,
M_out
=
M_out
,
M_out
=
M_out
,
k_max
=
k_max
,
k_max
=
k_max
,
...
@@ -341,27 +342,42 @@ int num_threads) {
...
@@ -341,27 +342,42 @@ int num_threads) {
${Dtype} grad_out = 0.0;
${Dtype} grad_out = 0.0;
int quotient = (int)pow(2.0,(double)d_idx);
int quotient = (int)pow(2.0,(double)d_idx);
value = input[e_idx * ${dim} + d_idx];
value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value);
for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
k_idx_mod = (k_idx/quotient) % 2;
k_idx_mod = (k_idx/quotient) % 2;
value = input[e_idx * ${dim} + d_idx];
value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value);
${Dtype} residual = (1 - k_idx_mod) * (frac - 1) + k_idx_mod * frac;
${Dtype} residual = (1 - k_idx_mod) * (frac - 1) + k_idx_mod * frac;
int a_idx = e_idx*${k_max} + k_idx;
int a_idx = e_idx*${k_max} + k_idx;
grad_out += grad_amount[a_idx]*amount[a_idx]/residual;
grad_out += grad_amount[a_idx]*amount[a_idx]/residual;
}
}
grad_adj[
e_idx*${dim} + d_idx] = grad_out
;
grad_adj[
idx] = grad_out*(kernel_size[d_idx] - is_open_spline[d_idx])
;
}
}
}
}
/*
${Dtype} a = -(1 - k_idx_mod) + k_idx_mod;
for (int d_it = 0; d_it < ${dim}; d_it++) {
if(d_it!=d_idx)
{
value = input[e_idx * ${dim} + d_it];
value *= kernel_size[d_it] - is_open_spline[d_it];
frac = value - floor(value);
a *= (1 - k_idx_mod) * (1 - frac) + k_idx_mod * frac;
}
}
grad_out += a*grad_amount[a_idx];
*/
'''
'''
def
get_basis_kernel
(
k_max
,
K
,
dim
,
degree
):
def
get_basis_kernel
(
k_max
,
K
,
dim
,
degree
,
dtype
=
'float'
):
if
degree
==
3
:
if
degree
==
3
:
_spline_kernel
=
_spline_kernel_cubic
_spline_kernel
=
_spline_kernel_cubic
elif
degree
==
2
:
elif
degree
==
2
:
...
@@ -374,14 +390,14 @@ def get_basis_kernel(k_max, K, dim, degree):
...
@@ -374,14 +390,14 @@ def get_basis_kernel(k_max, K, dim, degree):
f
=
load_kernel
(
f
=
load_kernel
(
'spline_kernel'
,
'spline_kernel'
,
_spline_kernel
,
_spline_kernel
,
Dtype
=
'float'
,
Dtype
=
dtype
,
k_max
=
k_max
,
k_max
=
k_max
,
dim
=
dim
,
dim
=
dim
,
K
=
K
)
K
=
K
)
return
f
return
f
def
get_basis_backward_kernel
(
k_max
,
K
,
dim
,
degree
):
def
get_basis_backward_kernel
(
k_max
,
K
,
dim
,
degree
,
dtype
=
'float'
):
if
degree
==
3
:
if
degree
==
3
:
raise
NotImplementedError
raise
NotImplementedError
elif
degree
==
2
:
elif
degree
==
2
:
...
@@ -394,7 +410,7 @@ def get_basis_backward_kernel(k_max, K, dim, degree):
...
@@ -394,7 +410,7 @@ def get_basis_backward_kernel(k_max, K, dim, degree):
f
=
load_kernel
(
f
=
load_kernel
(
'spline_kernel'
,
'spline_kernel'
,
_spline_kernel
,
_spline_kernel
,
Dtype
=
'float'
,
Dtype
=
dtype
,
k_max
=
k_max
,
k_max
=
k_max
,
dim
=
dim
,
dim
=
dim
,
K
=
K
)
K
=
K
)
...
@@ -431,11 +447,10 @@ class SplineConvGPU(Function):
...
@@ -431,11 +447,10 @@ class SplineConvGPU(Function):
self
.
save_for_backward
(
input
,
weight
)
self
.
save_for_backward
(
input
,
weight
)
num_edges
,
dim
=
adj_values
.
size
()
num_edges
,
dim
=
adj_values
.
size
()
k_max
=
2
**
dim
k_max
=
(
self
.
degree
+
1
)
**
dim
amount
=
adj_values
.
new
(
num_edges
,
k_max
)
amount
=
adj_values
.
new
(
num_edges
,
k_max
)
index
=
adj_values
.
new
(
num_edges
,
k_max
).
long
()
index
=
adj_values
.
new
(
num_edges
,
k_max
).
long
()
num_threads
=
amount
.
numel
()
num_threads
=
amount
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
with
torch
.
cuda
.
device_of
(
input
):
self
.
f_basis_fw
(
self
.
f_basis_fw
(
block
=
(
cuda_num_threads
,
1
,
1
),
block
=
(
cuda_num_threads
,
1
,
1
),
...
@@ -452,8 +467,8 @@ class SplineConvGPU(Function):
...
@@ -452,8 +467,8 @@ class SplineConvGPU(Function):
# Weight features
# Weight features
output
=
input
.
new
(
input
.
size
(
0
),
self
.
M_out
)
output
=
input
.
new
(
input
.
size
(
0
),
self
.
M_out
)
num_threads
=
output
.
numel
()
num_threads
=
output
.
numel
()
with
torch
.
cuda
.
device_of
(
input
):
with
torch
.
cuda
.
device_of
(
input
):
self
.
f_weighting_fw
(
self
.
f_weighting_fw
(
block
=
(
cuda_num_threads
,
1
,
1
),
block
=
(
cuda_num_threads
,
1
,
1
),
...
@@ -468,15 +483,12 @@ class SplineConvGPU(Function):
...
@@ -468,15 +483,12 @@ class SplineConvGPU(Function):
],
],
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
stream
=
Stream
(
ptr
=
torch
.
cuda
.
current_stream
().
cuda_stream
))
self
.
amount
=
amount
self
.
amount
=
amount
self
.
index
=
index
self
.
index
=
index
return
output
return
output
def
backward
(
self
,
grad_output
):
def
backward
(
self
,
grad_output
):
print
(
'grad_output:'
,
grad_output
.
min
(),
grad_output
.
max
())
grad_input
=
grad_output
.
new
(
grad_output
.
size
(
0
),
self
.
M_in
).
fill_
(
0
)
grad_input
=
grad_output
.
new
(
grad_output
.
size
(
0
),
self
.
M_in
).
fill_
(
0
)
grad_weight
=
grad_output
.
new
(
self
.
K
,
self
.
M_in
,
self
.
M_out
).
fill_
(
0
)
grad_weight
=
grad_output
.
new
(
self
.
K
,
self
.
M_in
,
self
.
M_out
).
fill_
(
0
)
num_threads
=
grad_output
.
numel
()
num_threads
=
grad_output
.
numel
()
...
@@ -488,7 +500,6 @@ class SplineConvGPU(Function):
...
@@ -488,7 +500,6 @@ class SplineConvGPU(Function):
index
=
self
.
index
index
=
self
.
index
grad_amount
=
grad_output
.
new
(
amount
.
size
(
0
),
grad_amount
=
grad_output
.
new
(
amount
.
size
(
0
),
amount
.
size
(
1
)).
fill_
(
0
)
amount
.
size
(
1
)).
fill_
(
0
)
with
torch
.
cuda
.
device_of
(
grad_output
):
with
torch
.
cuda
.
device_of
(
grad_output
):
self
.
f_weighting_bw
(
self
.
f_weighting_bw
(
block
=
(
cuda_num_threads
,
1
,
1
),
block
=
(
cuda_num_threads
,
1
,
1
),
...
@@ -529,6 +540,7 @@ class SplineConvGPU(Function):
...
@@ -529,6 +540,7 @@ class SplineConvGPU(Function):
#print('grad_weight:',grad_weight[:,:,-1].min(), grad_weight[:,:,-1].max())
#print('grad_weight:',grad_weight[:,:,-1].min(), grad_weight[:,:,-1].max())
#print('grad_amount:',grad_amount.min(), grad_amount.max())
#print('grad_amount:',grad_amount.min(), grad_amount.max())
#print('grad_adj:',grad_adj.min(), grad_adj.max())
#print('grad_adj:',grad_adj.min(), grad_adj.max())
return
grad_input
,
grad_weight
,
grad_adj
return
grad_input
,
grad_weight
,
grad_adj
else
:
else
:
...
...
spline_conv_test.py
View file @
2de0f11f
...
@@ -11,11 +11,12 @@ from .spline_conv_gpu import get_basis_kernel,get_basis_backward_kernel, \
...
@@ -11,11 +11,12 @@ from .spline_conv_gpu import get_basis_kernel,get_basis_backward_kernel, \
class
SplineConvTest
(
unittest
.
TestCase
):
class
SplineConvTest
(
unittest
.
TestCase
):
'''
@unittest.skipIf(not torch.cuda.is_available(), 'no GPU')
@unittest.skipIf(not torch.cuda.is_available(), 'no GPU')
def test_forward_gpu(self):
def test_forward_gpu(self):
edges = torch.LongTensor([[0, 0, 0, 0], [1, 2, 3, 4]])
edges = torch.LongTensor([[0, 0, 0, 0], [1, 2, 3, 4]])
values = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]]
values = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]]
values
=
torch
.
FloatTensor
(
values
)
values = torch.FloatTensor(values)
.double()
adj = {'indices': edges.cuda(), 'values': Variable(values.cuda()),
adj = {'indices': edges.cuda(), 'values': Variable(values.cuda()),
'size': torch.Size([5, 5, 2])}
'size': torch.Size([5, 5, 2])}
...
@@ -23,11 +24,12 @@ class SplineConvTest(unittest.TestCase):
...
@@ -23,11 +24,12 @@ class SplineConvTest(unittest.TestCase):
kernel_size = torch.cuda.LongTensor([3, 4])
kernel_size = torch.cuda.LongTensor([3, 4])
is_open_spline = torch.cuda.LongTensor([1, 0])
is_open_spline = torch.cuda.LongTensor([1, 0])
input
=
torch
.
FloatTensor
([[
9
,
10
],
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]])
input = torch.FloatTensor([[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]]).double()
weight
=
torch
.
arange
(
0.5
,
0.5
*
27
,
step
=
0.5
).
view
(
13
,
2
,
1
)
weight = torch.arange(0.5, 0.5 * 27, step=0.5).view(13, 2, 1).double()
input, weight = input.cuda(), weight.cuda()
input, weight = input.cuda(), weight.cuda()
input, weight = Variable(input), Variable(weight)
input, weight = Variable(input), Variable(weight)
row, col = adj['indices']
output = input[col]
K = 12
K = 12
in_features = 2
in_features = 2
out_features = 1
out_features = 1
...
@@ -43,9 +45,29 @@ class SplineConvTest(unittest.TestCase):
...
@@ -43,9 +45,29 @@ class SplineConvTest(unittest.TestCase):
basis_bw_k = get_basis_backward_kernel(k_max, K, dim, degree)
basis_bw_k = get_basis_backward_kernel(k_max, K, dim, degree)
output
=
spline_conv
(
#output = spline_conv(
adj
,
input
,
weight
,
kernel_size
,
is_open_spline
,
K
,
fw_k
,
bw_k
,
# adj, input, weight, kernel_size, is_open_spline, K, fw_k, bw_k,
basis_fw_k
,
basis_bw_k
,
bp_to_adj
=
True
)
# basis_fw_k, basis_bw_k,bp_to_adj=True)
values = adj['values']
output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_fw_k, basis_bw_k, fw_k, bw_k, bp_to_adj=True)
\
(output, weight, values)
zero = output.data.new(adj['size'][1], output.size(1)).fill_(0.0)
zero = Variable(zero) if not torch.is_tensor(output) else zero
r = row.view(-1, 1).expand(row.size(0), output.size(1))
output = zero.scatter_add_(0, Variable(r), output)
# Weighten root node features by multiplying with root weight.
output += torch.mm(input, weight[-1])
# Normalize output by degree.
ones = values.data.new(values.size(0)).fill_(1)
zero = values.data.new(output.size(0)).fill_(0)
degree = zero.scatter_add_(0, row, ones)
degree = torch.clamp(degree, min=1)
output = output / Variable(degree.view(-1, 1))
expected_output = [
expected_output = [
[(12.5 * 9 + 13 * 10 + 266) / 4],
[(12.5 * 9 + 13 * 10 + 266) / 4],
...
@@ -56,14 +78,16 @@ class SplineConvTest(unittest.TestCase):
...
@@ -56,14 +78,16 @@ class SplineConvTest(unittest.TestCase):
]
]
assert_almost_equal(output.cpu().data.numpy(), expected_output, 1)
assert_almost_equal(output.cpu().data.numpy(), expected_output, 1)
@unittest.skipIf(not torch.cuda.is_available(), 'no GPU')
@unittest.skipIf(not torch.cuda.is_available(), 'no GPU')
def test_backward(self):
def test_backward(self):
kernel_size = torch.cuda.LongTensor([3, 4])
kernel_size = torch.cuda.LongTensor([3, 4])
is_open_spline
=
torch
.
cuda
.
LongTensor
([
1
,
0
])
is_open_spline = torch.cuda.LongTensor([1,
1
])
input = torch.randn(4, 2).double().cuda()
input = torch.randn(4, 2).double().cuda()
weight = torch.randn(12, 2, 1).double().cuda()
weight = torch.randn(12, 2, 1).double().cuda()
values
=
torch
.
randn
(
4
,
2
).
double
().
cuda
()
values = torch.FloatTensor(4, 2).uniform_(0, 1).double().cuda()
print(values)
input = Variable(input, requires_grad=True)
input = Variable(input, requires_grad=True)
weight = Variable(weight, requires_grad=True)
weight = Variable(weight, requires_grad=True)
values = Variable(values, requires_grad=True)
values = Variable(values, requires_grad=True)
...
@@ -84,7 +108,45 @@ class SplineConvTest(unittest.TestCase):
...
@@ -84,7 +108,45 @@ class SplineConvTest(unittest.TestCase):
op = SplineConvGPU(kernel_size, is_open_spline, K, degree,
op = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_fw_k, basis_bw_k, fw_k, bw_k, bp_to_adj=True)
basis_fw_k, basis_bw_k, fw_k, bw_k, bp_to_adj=True)
print(op(input, weight, values))
#test = gradcheck(op, (input, weight, values), eps=1e-6, atol=1e-4)
#self.assertTrue(test)
'''
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
'no GPU'
)
def
test_backward
(
self
):
input
=
torch
.
randn
(
4
,
2
).
double
().
cuda
()
weight
=
torch
.
randn
(
9
,
2
,
1
).
double
().
cuda
()
values
=
torch
.
FloatTensor
(
4
,
2
).
uniform_
(
0
,
1
).
double
().
cuda
()
print
(
values
)
input
=
Variable
(
input
,
requires_grad
=
True
)
weight
=
Variable
(
weight
,
requires_grad
=
True
)
values
=
Variable
(
values
,
requires_grad
=
True
)
K
=
9
in_features
=
2
out_features
=
1
degree
=
1
dim
=
2
k_max
=
(
degree
+
1
)
**
dim
kernel_size
=
torch
.
cuda
.
LongTensor
([
3
,
3
])
is_open_spline
=
torch
.
cuda
.
LongTensor
([
1
,
0
])
fw_k
=
get_weighting_forward_kernel
(
in_features
,
out_features
,
k_max
,
dtype
=
'double'
)
bw_k
=
get_weighting_backward_kernel
(
in_features
,
out_features
,
k_max
,
K
,
True
,
dtype
=
'double'
)
basis_fw_k
=
get_basis_kernel
(
k_max
,
K
,
dim
,
degree
,
dtype
=
'double'
)
basis_bw_k
=
get_basis_backward_kernel
(
k_max
,
K
,
dim
,
degree
,
dtype
=
'double'
)
op
=
SplineConvGPU
(
kernel_size
,
is_open_spline
,
K
,
degree
,
basis_fw_k
,
basis_bw_k
,
fw_k
,
bw_k
,
bp_to_adj
=
True
)
#print(op(input, weight, values))
test
=
gradcheck
(
op
,
(
input
,
weight
,
values
),
eps
=
1e-6
,
atol
=
1e-4
)
test
=
gradcheck
(
op
,
(
input
,
weight
,
values
),
eps
=
1e-6
,
atol
=
1e-4
)
print
(
test
)
self
.
assertTrue
(
test
)
self
.
assertTrue
(
test
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment