Commit 6100aa77 authored by rusty1s's avatar rusty1s
Browse files

define for spline basis backward

parent 3faacaf3
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from torch.autograd import Variable, gradcheck from torch.autograd import Variable, gradcheck
from torch_spline_conv import spline_conv from torch_spline_conv import spline_conv
from torch_spline_conv.functions.spline_weighting import SplineWeighting from torch_spline_conv.functions.spline_weighting import SplineWeighting
from torch_spline_conv.functions.ffi import implemented_degrees # from torch_spline_conv.functions.ffi import implemented_degrees
from .utils import tensors, Tensor from .utils import tensors, Tensor
...@@ -49,17 +49,16 @@ def test_spline_conv_cpu(tensor): ...@@ -49,17 +49,16 @@ def test_spline_conv_cpu(tensor):
def test_spline_weighting_backward_cpu(): def test_spline_weighting_backward_cpu():
# for degree in implemented_degrees.keys(): for degree in [1]:
degree = list(implemented_degrees.keys())[0] kernel_size = torch.LongTensor([5, 5])
kernel_size = torch.LongTensor([5, 5]) is_open_spline = torch.ByteTensor([1, 1])
is_open_spline = torch.ByteTensor([1, 1]) op = SplineWeighting(kernel_size, is_open_spline, degree)
op = SplineWeighting(kernel_size, is_open_spline, degree)
x = torch.DoubleTensor(4, 2).uniform_(-1, 1) x = torch.DoubleTensor(4, 2).uniform_(-1, 1)
x = Variable(x, requires_grad=True) x = Variable(x)
pseudo = torch.DoubleTensor(4, 2).uniform_(0, 1) pseudo = torch.DoubleTensor(4, 2).uniform_(0, 1)
pseudo = Variable(torch.DoubleTensor(pseudo), requires_grad=True) pseudo = Variable(torch.DoubleTensor(pseudo), requires_grad=True)
weight = torch.DoubleTensor(25, 2, 4).uniform_(-1, 1) weight = torch.DoubleTensor(25, 2, 4).uniform_(-1, 1)
weight = Variable(weight, requires_grad=True) weight = Variable(weight)
assert gradcheck(op, (x, pseudo, weight), eps=1e-6, atol=1e-4) is True assert gradcheck(op, (x, pseudo, weight), eps=1e-6, atol=1e-4) is True
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
#define SPLINE_BASIS_FORWARD(M, basis, weight_index, pseudo, kernel_size, is_open_spline, K, CODE) { \ #define SPLINE_BASIS_FORWARD(M, basis, weight_index, pseudo, kernel_size, is_open_spline, K, CODE) { \
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \ int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \ uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \
int64_t D = THTensor_(size)(pseudo, 1); \
int64_t S = THLongTensor_size(weight_index, 1); \ int64_t S = THLongTensor_size(weight_index, 1); \
int64_t s, d, k, k_mod, i, offset; real value, b; \ int64_t D = THTensor_(size)(pseudo, 1); \
int64_t s, d, k, k_mod, i, offset; real b, value; \
\ \
TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, \ TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, \
for (s = 0; s < S; s++) { \ for (s = 0; s < S; s++) { \
...@@ -29,6 +29,38 @@ ...@@ -29,6 +29,38 @@
}) \ }) \
} }
#define SPLINE_BASIS_BACKWARD(M, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline, EVAL_CODE, GRAD_CODE) { \
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \
int64_t D = THTensor_(size)(pseudo, 1); \
int64_t S = THTensor_(size)(grad_basis, 1); \
int64_t d, s, d_it, quotient, k_mod; real g_out, g, value;\
\
TH_TENSOR_DIM_APPLY3(real, grad_pseudo, real, grad_basis, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, \
for (d = 0; d < D; d++) { \
g_out = 0; \
quotient = pow(M + 1, d); \
for (s = 0; s < S; s++) { \
k_mod = (s / quotient) % (M + 1); \
GRAD_CODE \
g = value; \
\
for (d_it = 0; d_it < D; d_it++) { \
if (d_it != d) { \
k_mod = (s / (int64_t) pow(M + 1, d_it)) % (M + 1); \
value = *(pseudo_data + d_it * pseudo_stride) * (kernel_size_data[d_it] - M * is_open_spline_data[d_it]); \
value -= floor(value); \
EVAL_CODE \
g *= value; \
} \
} \
g_out += g * *(grad_basis_data + s * grad_basis_stride); \
} \
grad_pseudo_data[d * grad_pseudo_stride] = g_out * (kernel_size_data[d] - M * is_open_spline_data[d]); \
} \
) \
}
#define SPLINE_WEIGHTING(TENSOR1, TENSOR2, TENSOR3, weight_index, M_IN, M_OUT, M_S, CODE) { \ #define SPLINE_WEIGHTING(TENSOR1, TENSOR2, TENSOR3, weight_index, M_IN, M_OUT, M_S, CODE) { \
int64_t M_in = M_IN; int64_t M_out = M_OUT; int64_t S = M_S; \ int64_t M_in = M_IN; int64_t M_out = M_OUT; int64_t S = M_S; \
int64_t m_in, m_out, s, w_idx; real value; \ int64_t m_in, m_out, s, w_idx; real value; \
......
...@@ -10,7 +10,7 @@ void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index, ...@@ -10,7 +10,7 @@ void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index,
void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) { void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS_FORWARD(2, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS_FORWARD(2, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) value = 0.5 * (1 - value) * (1 - value); if (k_mod == 0) value = 0.5 * value * value - value + 0.5;
else if (k_mod == 1) value = -value * value + value + 0.5; else if (k_mod == 1) value = -value * value + value + 0.5;
else value = 0.5 * value * value; else value = 0.5 * value * value;
) )
...@@ -26,39 +26,31 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T ...@@ -26,39 +26,31 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T
} }
void spline_(linear_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) { void spline_(linear_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; SPLINE_BASIS_BACKWARD(1, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; value = (1 - k_mod) * (1 - value) + k_mod * value;
int64_t D = THTensor_(size)(pseudo, 1); ,
int64_t S = THTensor_(size)(grad_basis, 1); value = -1 + k_mod + k_mod;
int64_t s, d, d_it;
TH_TENSOR_DIM_APPLY3(real, grad_pseudo, real, grad_basis, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (d = 0; d < D; d++) {
real g_out = 0;
int64_t quotient = (int64_t) pow(2, d);
for (s = 0; s < S; s++) {
int64_t k_mod = (s/quotient) % 2;
real a = -(1 - k_mod) + k_mod;
for (d_it = 0; d_it < D; d_it++) {
if (d_it != d) {
k_mod = (s/((int64_t) pow(2, d_it))) % 2;
real value = *(pseudo_data + d_it * pseudo_stride) * (kernel_size_data[d_it] - is_open_spline_data[d_it]);
value -= floor(value);
a *= (1 - k_mod) * (1 - value) + k_mod * value;
}
}
g_out += a * *(grad_basis_data + s * grad_basis_stride);
}
grad_pseudo_data[d * grad_pseudo_stride] = g_out * (kernel_size_data[d] - is_open_spline_data[d]);
}
) )
} }
void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) { void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(2, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
if (k_mod == 0) value = 0.5 * value * value - value + 0.5;
else if (k_mod == 1) value = -value * value + value + 0.5;
else value = 0.5 * value * value;
,
if (k_mod == 0) value = 2 * value - 1;
else if (k_mod == 1) value = -2 * value + 1;
else value = value;
)
} }
void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) { void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(3, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
value = (1 - k_mod) * (1 - value) + k_mod * value;
,
value = -(1 - k_mod) + k_mod;
)
} }
void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment