Commit 3faacaf3 authored by rusty1s's avatar rusty1s
Browse files

added pseudo backwards for linear

parent 26327cf5
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from torch.autograd import Variable, gradcheck from torch.autograd import Variable, gradcheck
from torch_spline_conv import spline_conv from torch_spline_conv import spline_conv
from torch_spline_conv.functions.spline_weighting import SplineWeighting from torch_spline_conv.functions.spline_weighting import SplineWeighting
from torch_spline_conv.functions.ffi import implemented_degrees
from .utils import tensors, Tensor from .utils import tensors, Tensor
...@@ -48,15 +49,16 @@ def test_spline_conv_cpu(tensor): ...@@ -48,15 +49,16 @@ def test_spline_conv_cpu(tensor):
def test_spline_weighting_backward_cpu(): def test_spline_weighting_backward_cpu():
# for degree in implemented_degrees.keys():
degree = list(implemented_degrees.keys())[0]
kernel_size = torch.LongTensor([5, 5]) kernel_size = torch.LongTensor([5, 5])
is_open_spline = torch.ByteTensor([1, 1]) is_open_spline = torch.ByteTensor([1, 1])
op = SplineWeighting(kernel_size, is_open_spline, 1) op = SplineWeighting(kernel_size, is_open_spline, degree)
x = torch.DoubleTensor([[1, 2], [3, 4], [5, 6], [7, 8]]) x = torch.DoubleTensor(4, 2).uniform_(-1, 1)
x = Variable(x, requires_grad=True) x = Variable(x, requires_grad=True)
pseudo = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]] pseudo = torch.DoubleTensor(4, 2).uniform_(0, 1)
# pseudo = Variable(torch.DoubleTensor(pseudo), requires_grad=True) pseudo = Variable(torch.DoubleTensor(pseudo), requires_grad=True)
pseudo = Variable(torch.DoubleTensor(pseudo))
weight = torch.DoubleTensor(25, 2, 4).uniform_(-1, 1) weight = torch.DoubleTensor(25, 2, 4).uniform_(-1, 1)
weight = Variable(weight, requires_grad=True) weight = Variable(weight, requires_grad=True)
......
...@@ -51,7 +51,7 @@ def spline_weighting_backward_input(grad_output, weight, basis, ...@@ -51,7 +51,7 @@ def spline_weighting_backward_input(grad_output, weight, basis,
def spline_weighting_backward_basis(grad_output, x, weight, def spline_weighting_backward_basis(grad_output, x, weight,
weight_index): # pragma: no cover weight_index): # pragma: no cover
grad_basis = x.new(weight_index.size()) grad_basis = x.new(weight_index.size()).fill_(0)
func = get_func('weighting_backward_basis', x) func = get_func('weighting_backward_basis', x)
func(grad_basis, grad_output, x, weight, weight_index) func(grad_basis, grad_output, x, weight, weight_index)
return grad_basis return grad_basis
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#define spline_(NAME) TH_CONCAT_4(spline_, NAME, _, Real) #define spline_(NAME) TH_CONCAT_4(spline_, NAME, _, Real)
#define SPLINE_BASIS(M, basis, weight_index, pseudo, kernel_size, is_open_spline, K, CODE) { \ #define SPLINE_BASIS_FORWARD(M, basis, weight_index, pseudo, kernel_size, is_open_spline, K, CODE) { \
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \ int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \ uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \
int64_t D = THTensor_(size)(pseudo, 1); \ int64_t D = THTensor_(size)(pseudo, 1); \
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
}) \ }) \
} }
#define SPLINE_WEIGHTING_BACKWARD(TENSOR1, TENSOR2, TENSOR3, weight_index, M_IN, M_OUT, M_S, CODE) { \ #define SPLINE_WEIGHTING(TENSOR1, TENSOR2, TENSOR3, weight_index, M_IN, M_OUT, M_S, CODE) { \
int64_t M_in = M_IN; int64_t M_out = M_OUT; int64_t S = M_S; \ int64_t M_in = M_IN; int64_t M_out = M_OUT; int64_t S = M_S; \
int64_t m_in, m_out, s, w_idx; real value; \ int64_t m_in, m_out, s, w_idx; real value; \
TH_TENSOR_DIM_APPLY4(real, TENSOR1, real, TENSOR2, real, TENSOR3, int64_t, weight_index, 1, CODE) \ TH_TENSOR_DIM_APPLY4(real, TENSOR1, real, TENSOR2, real, TENSOR3, int64_t, weight_index, 1, CODE) \
......
...@@ -5,12 +5,12 @@ void spline_quadratic_basis_forward_Double(THDoubleTensor *basis, THLongTensor * ...@@ -5,12 +5,12 @@ void spline_quadratic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *
void spline_cubic_basis_forward_Float( THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_cubic_basis_forward_Float( THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_cubic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_cubic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_linear_basis_backward_Float( THFloatTensor *grad_pseudo, THLongTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline); void spline_linear_basis_backward_Float( THFloatTensor *grad_pseudo, THFloatTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_linear_basis_backward_Double(THDoubleTensor *grad_pseudo, THLongTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline); void spline_linear_basis_backward_Double(THDoubleTensor *grad_pseudo, THDoubleTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_quadratic_basis_backward_Float( THFloatTensor *grad_pseudo, THLongTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline); void spline_quadratic_basis_backward_Float( THFloatTensor *grad_pseudo, THFloatTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_quadratic_basis_backward_Double(THDoubleTensor *grad_pseudo, THLongTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline); void spline_quadratic_basis_backward_Double(THDoubleTensor *grad_pseudo, THDoubleTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_cubic_basis_backward_Float( THFloatTensor *grad_pseudo, THLongTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline); void spline_cubic_basis_backward_Float( THFloatTensor *grad_pseudo, THFloatTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_cubic_basis_backward_Double(THDoubleTensor *grad_pseudo, THLongTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline); void spline_cubic_basis_backward_Double(THDoubleTensor *grad_pseudo, THDoubleTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_weighting_forward_Float( THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index); void spline_weighting_forward_Float( THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
......
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
#else #else
void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) { void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS(1, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS_FORWARD(1, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
value = (1 - k_mod) * (1 - value) + k_mod * value; value = (1 - k_mod) * (1 - value) + k_mod * value;
) )
} }
void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) { void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS(2, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS_FORWARD(2, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) value = 0.5 * (1 - value) * (1 - value); if (k_mod == 0) value = 0.5 * (1 - value) * (1 - value);
else if (k_mod == 1) value = -value * value + value + 0.5; else if (k_mod == 1) value = -value * value + value + 0.5;
else value = 0.5 * value * value; else value = 0.5 * value * value;
...@@ -17,7 +17,7 @@ void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_inde ...@@ -17,7 +17,7 @@ void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_inde
} }
void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) { void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS(3, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS_FORWARD(3, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) value = (1 - value) * (1 - value) * (1 - value) / 6.0; if (k_mod == 0) value = (1 - value) * (1 - value) * (1 - value) / 6.0;
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6.0; else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6.0;
else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6.0; else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6.0;
...@@ -25,23 +25,45 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T ...@@ -25,23 +25,45 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T
) )
} }
void spline_(linear_basis_backward)(THTensor *grad_pseudo, THLongTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) { void spline_(linear_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset;
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset;
int64_t D = THTensor_(size)(pseudo, 1);
int64_t S = THTensor_(size)(grad_basis, 1);
int64_t s, d, d_it;
TH_TENSOR_DIM_APPLY3(real, grad_pseudo, real, grad_basis, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (d = 0; d < D; d++) {
real g_out = 0;
int64_t quotient = (int64_t) pow(2, d);
for (s = 0; s < S; s++) {
int64_t k_mod = (s/quotient) % 2;
real a = -(1 - k_mod) + k_mod;
for (d_it = 0; d_it < D; d_it++) {
if (d_it != d) {
k_mod = (s/((int64_t) pow(2, d_it))) % 2;
real value = *(pseudo_data + d_it * pseudo_stride) * (kernel_size_data[d_it] - is_open_spline_data[d_it]);
value -= floor(value);
a *= (1 - k_mod) * (1 - value) + k_mod * value;
}
}
g_out += a * *(grad_basis_data + s * grad_basis_stride);
}
grad_pseudo_data[d * grad_pseudo_stride] = g_out * (kernel_size_data[d] - is_open_spline_data[d]);
}
)
} }
void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THLongTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) { void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
} }
void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THLongTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) { void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
} }
void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real *weight_data = weight->storage->data + weight->storageOffset; real b;
int64_t M_out = THTensor_(size)(output, 1); SPLINE_WEIGHTING(output, input, basis, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1),
int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, w_idx; real b, value;
TH_TENSOR_DIM_APPLY4(real, output, real, input, real, basis, int64_t, weight_index, 1,
for (m_out = 0; m_out < M_out; m_out++) { for (m_out = 0; m_out < M_out; m_out++) {
value = 0; value = 0;
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
...@@ -58,7 +80,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei ...@@ -58,7 +80,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_output, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_output, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real b; real *weight_data = weight->storage->data + weight->storageOffset; real b;
SPLINE_WEIGHTING_BACKWARD(grad_input, grad_output, basis, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1), SPLINE_WEIGHTING(grad_input, grad_output, basis, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1),
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
value = 0; value = 0;
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
...@@ -68,21 +90,21 @@ void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_outp ...@@ -68,21 +90,21 @@ void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_outp
value += b * *(grad_output_data + m_out * grad_output_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out); value += b * *(grad_output_data + m_out * grad_output_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out);
} }
} }
grad_input_data[m_in] = value; grad_input_data[m_in * grad_input_stride] = value;
} }
) )
} }
void spline_(weighting_backward_basis)(THTensor *grad_basis, THTensor *grad_output, THTensor *input, THTensor *weight, THLongTensor *weight_index) { void spline_(weighting_backward_basis)(THTensor *grad_basis, THTensor *grad_output, THTensor *input, THTensor *weight, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real *weight_data = weight->storage->data + weight->storageOffset;
SPLINE_WEIGHTING_BACKWARD(grad_basis, grad_output, input, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1), SPLINE_WEIGHTING(grad_basis, grad_output, input, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) { for (m_out = 0; m_out < M_out; m_out++) {
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
w_idx = *(weight_index_data + s * weight_index_stride); value = 0; w_idx = *(weight_index_data + s * weight_index_stride); value = 0;
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
value += *(input_data + m_in * input_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out); value += *(input_data + m_in * input_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out);
} }
grad_basis_data[s] += value * *(grad_output_data + m_out * grad_output_stride); grad_basis_data[s * grad_basis_stride] += value * *(grad_output_data + m_out * grad_output_stride);
} }
} }
) )
...@@ -90,7 +112,7 @@ void spline_(weighting_backward_basis)(THTensor *grad_basis, THTensor *grad_outp ...@@ -90,7 +112,7 @@ void spline_(weighting_backward_basis)(THTensor *grad_basis, THTensor *grad_outp
void spline_(weighting_backward_weight)(THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_backward_weight)(THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *basis, THLongTensor *weight_index) {
real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset; real b; real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset; real b;
SPLINE_WEIGHTING_BACKWARD(grad_output, input, basis, weight_index, THTensor_(size)(input, 1), THTensor_(size)(grad_output, 1), THLongTensor_size(weight_index, 1), SPLINE_WEIGHTING(grad_output, input, basis, weight_index, THTensor_(size)(input, 1), THTensor_(size)(grad_output, 1), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) { for (m_out = 0; m_out < M_out; m_out++) {
value = *(grad_output_data + m_out * grad_output_stride); value = *(grad_output_data + m_out * grad_output_stride);
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
...@@ -104,5 +126,4 @@ void spline_(weighting_backward_weight)(THTensor *grad_weight, THTensor *grad_ou ...@@ -104,5 +126,4 @@ void spline_(weighting_backward_weight)(THTensor *grad_weight, THTensor *grad_ou
) )
} }
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment