Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
6100aa77
Commit
6100aa77
authored
Mar 13, 2018
by
rusty1s
Browse files
define for spline basis backward
parent
3faacaf3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
42 deletions
+65
-42
test/test_spline_conv.py
test/test_spline_conv.py
+12
-13
torch_spline_conv/src/cpu.c
torch_spline_conv/src/cpu.c
+34
-2
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+19
-27
No files found.
test/test_spline_conv.py
View file @
6100aa77
...
...
@@ -3,7 +3,7 @@ import torch
from
torch.autograd
import
Variable
,
gradcheck
from
torch_spline_conv
import
spline_conv
from
torch_spline_conv.functions.spline_weighting
import
SplineWeighting
from
torch_spline_conv.functions.ffi
import
implemented_degrees
#
from torch_spline_conv.functions.ffi import implemented_degrees
from
.utils
import
tensors
,
Tensor
...
...
@@ -49,17 +49,16 @@ def test_spline_conv_cpu(tensor):
def
test_spline_weighting_backward_cpu
():
# for degree in implemented_degrees.keys():
degree
=
list
(
implemented_degrees
.
keys
())[
0
]
kernel_size
=
torch
.
LongTensor
([
5
,
5
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
1
])
op
=
SplineWeighting
(
kernel_size
,
is_open_spline
,
degree
)
for
degree
in
[
1
]:
kernel_size
=
torch
.
LongTensor
([
5
,
5
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
1
])
op
=
SplineWeighting
(
kernel_size
,
is_open_spline
,
degree
)
x
=
torch
.
DoubleTensor
(
4
,
2
).
uniform_
(
-
1
,
1
)
x
=
Variable
(
x
,
requires_grad
=
True
)
pseudo
=
torch
.
DoubleTensor
(
4
,
2
).
uniform_
(
0
,
1
)
pseudo
=
Variable
(
torch
.
DoubleTensor
(
pseudo
),
requires_grad
=
True
)
weight
=
torch
.
DoubleTensor
(
25
,
2
,
4
).
uniform_
(
-
1
,
1
)
weight
=
Variable
(
weight
,
requires_grad
=
True
)
x
=
torch
.
DoubleTensor
(
4
,
2
).
uniform_
(
-
1
,
1
)
x
=
Variable
(
x
)
pseudo
=
torch
.
DoubleTensor
(
4
,
2
).
uniform_
(
0
,
1
)
pseudo
=
Variable
(
torch
.
DoubleTensor
(
pseudo
),
requires_grad
=
True
)
weight
=
torch
.
DoubleTensor
(
25
,
2
,
4
).
uniform_
(
-
1
,
1
)
weight
=
Variable
(
weight
)
assert
gradcheck
(
op
,
(
x
,
pseudo
,
weight
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
assert
gradcheck
(
op
,
(
x
,
pseudo
,
weight
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
torch_spline_conv/src/cpu.c
View file @
6100aa77
...
...
@@ -7,9 +7,9 @@
#define SPLINE_BASIS_FORWARD(M, basis, weight_index, pseudo, kernel_size, is_open_spline, K, CODE) { \
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \
int64_t D = THTensor_(size)(pseudo, 1); \
int64_t S = THLongTensor_size(weight_index, 1); \
int64_t s, d, k, k_mod, i, offset; real value, b; \
int64_t D = THTensor_(size)(pseudo, 1); \
int64_t s, d, k, k_mod, i, offset; real b, value; \
\
TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, \
for (s = 0; s < S; s++) { \
...
...
@@ -29,6 +29,38 @@
}) \
}
#define SPLINE_BASIS_BACKWARD(M, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline, EVAL_CODE, GRAD_CODE) { \
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \
int64_t D = THTensor_(size)(pseudo, 1); \
int64_t S = THTensor_(size)(grad_basis, 1); \
int64_t d, s, d_it, quotient, k_mod; real g_out, g, value;\
\
TH_TENSOR_DIM_APPLY3(real, grad_pseudo, real, grad_basis, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, \
for (d = 0; d < D; d++) { \
g_out = 0; \
quotient = pow(M + 1, d); \
for (s = 0; s < S; s++) { \
k_mod = (s / quotient) % (M + 1); \
GRAD_CODE \
g = value; \
\
for (d_it = 0; d_it < D; d_it++) { \
if (d_it != d) { \
k_mod = (s / (int64_t) pow(M + 1, d_it)) % (M + 1); \
value = *(pseudo_data + d_it * pseudo_stride) * (kernel_size_data[d_it] - M * is_open_spline_data[d_it]); \
value -= floor(value); \
EVAL_CODE \
g *= value; \
} \
} \
g_out += g * *(grad_basis_data + s * grad_basis_stride); \
} \
grad_pseudo_data[d * grad_pseudo_stride] = g_out * (kernel_size_data[d] - M * is_open_spline_data[d]); \
} \
) \
}
#define SPLINE_WEIGHTING(TENSOR1, TENSOR2, TENSOR3, weight_index, M_IN, M_OUT, M_S, CODE) { \
int64_t M_in = M_IN; int64_t M_out = M_OUT; int64_t S = M_S; \
int64_t m_in, m_out, s, w_idx; real value; \
...
...
torch_spline_conv/src/generic/cpu.c
View file @
6100aa77
...
...
@@ -10,7 +10,7 @@ void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index,
void
spline_
(
quadratic_basis_forward
)(
THTensor
*
basis
,
THLongTensor
*
weight_index
,
THTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
)
{
SPLINE_BASIS_FORWARD
(
2
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
K
,
if
(
k_mod
==
0
)
value
=
0
.
5
*
(
1
-
value
)
*
(
1
-
value
)
;
if
(
k_mod
==
0
)
value
=
0
.
5
*
value
*
value
-
value
+
0
.
5
;
else
if
(
k_mod
==
1
)
value
=
-
value
*
value
+
value
+
0
.
5
;
else
value
=
0
.
5
*
value
*
value
;
)
...
...
@@ -26,39 +26,31 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T
}
void
spline_
(
linear_basis_backward
)(
THTensor
*
grad_pseudo
,
THTensor
*
grad_basis
,
THTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
)
{
int64_t
*
kernel_size_data
=
kernel_size
->
storage
->
data
+
kernel_size
->
storageOffset
;
uint8_t
*
is_open_spline_data
=
is_open_spline
->
storage
->
data
+
is_open_spline
->
storageOffset
;
int64_t
D
=
THTensor_
(
size
)(
pseudo
,
1
);
int64_t
S
=
THTensor_
(
size
)(
grad_basis
,
1
);
int64_t
s
,
d
,
d_it
;
TH_TENSOR_DIM_APPLY3
(
real
,
grad_pseudo
,
real
,
grad_basis
,
real
,
pseudo
,
1
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
d
=
0
;
d
<
D
;
d
++
)
{
real
g_out
=
0
;
int64_t
quotient
=
(
int64_t
)
pow
(
2
,
d
);
for
(
s
=
0
;
s
<
S
;
s
++
)
{
int64_t
k_mod
=
(
s
/
quotient
)
%
2
;
real
a
=
-
(
1
-
k_mod
)
+
k_mod
;
for
(
d_it
=
0
;
d_it
<
D
;
d_it
++
)
{
if
(
d_it
!=
d
)
{
k_mod
=
(
s
/
((
int64_t
)
pow
(
2
,
d_it
)))
%
2
;
real
value
=
*
(
pseudo_data
+
d_it
*
pseudo_stride
)
*
(
kernel_size_data
[
d_it
]
-
is_open_spline_data
[
d_it
]);
value
-=
floor
(
value
);
a
*=
(
1
-
k_mod
)
*
(
1
-
value
)
+
k_mod
*
value
;
}
}
g_out
+=
a
*
*
(
grad_basis_data
+
s
*
grad_basis_stride
);
}
grad_pseudo_data
[
d
*
grad_pseudo_stride
]
=
g_out
*
(
kernel_size_data
[
d
]
-
is_open_spline_data
[
d
]);
}
SPLINE_BASIS_BACKWARD
(
1
,
grad_pseudo
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
,
value
=
(
1
-
k_mod
)
*
(
1
-
value
)
+
k_mod
*
value
;
,
value
=
-
1
+
k_mod
+
k_mod
;
)
}
void
spline_
(
quadratic_basis_backward
)(
THTensor
*
grad_pseudo
,
THTensor
*
grad_basis
,
THTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
)
{
SPLINE_BASIS_BACKWARD
(
2
,
grad_pseudo
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
,
if
(
k_mod
==
0
)
value
=
0
.
5
*
value
*
value
-
value
+
0
.
5
;
else
if
(
k_mod
==
1
)
value
=
-
value
*
value
+
value
+
0
.
5
;
else
value
=
0
.
5
*
value
*
value
;
,
if
(
k_mod
==
0
)
value
=
2
*
value
-
1
;
else
if
(
k_mod
==
1
)
value
=
-
2
*
value
+
1
;
else
value
=
value
;
)
}
void
spline_
(
cubic_basis_backward
)(
THTensor
*
grad_pseudo
,
THTensor
*
grad_basis
,
THTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
)
{
SPLINE_BASIS_BACKWARD
(
3
,
grad_pseudo
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
,
value
=
(
1
-
k_mod
)
*
(
1
-
value
)
+
k_mod
*
value
;
,
value
=
-
(
1
-
k_mod
)
+
k_mod
;
)
}
void
spline_
(
weighting_forward
)(
THTensor
*
output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment