Commit d7a83c01 authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent 67904212
...@@ -48,7 +48,6 @@ def test_spline_conv_cpu(tensor): ...@@ -48,7 +48,6 @@ def test_spline_conv_cpu(tensor):
def test_spline_weighting_backward_cpu(): def test_spline_weighting_backward_cpu():
return
kernel_size = torch.LongTensor([5, 5]) kernel_size = torch.LongTensor([5, 5])
is_open_spline = torch.ByteTensor([1, 1]) is_open_spline = torch.ByteTensor([1, 1])
op = SplineWeighting(kernel_size, is_open_spline, 1) op = SplineWeighting(kernel_size, is_open_spline, 1)
......
...@@ -40,17 +40,17 @@ def spline_weighting_backward_input(grad_output, weight, basis, weight_index): ...@@ -40,17 +40,17 @@ def spline_weighting_backward_input(grad_output, weight, basis, weight_index):
return grad_input return grad_input
# pragma: no cover
def spline_weighting_backward_weight(grad_output, x, basis, weight_index, K):
grad_weight = x.new(K, x.size(1), grad_output.size(1)).fill_(0)
func = get_func('weighting_backward_weight', x)
func(grad_weight, grad_output, x, basis, weight_index)
return grad_weight
# pragma: no cover # pragma: no cover
def spline_weighting_backward_basis(grad_output, x, weight, weight_index): def spline_weighting_backward_basis(grad_output, x, weight, weight_index):
grad_basis = x.new(weight_index.size()) grad_basis = x.new(weight_index.size())
func = get_func('weighting_backward_basis', x) func = get_func('weighting_backward_basis', x)
func(grad_basis, grad_output, x, weight, weight_index) func(grad_basis, grad_output, x, weight, weight_index)
return grad_basis return grad_basis
# pragma: no cover
def spline_weighting_backward_weight(grad_output, x, basis, weight_index, K):
grad_weight = x.new(K, x.size(1), grad_output.size(1)).fill_(0)
func = get_func('weighting_backward_weight', x)
func(grad_weight, grad_output, x, basis, weight_index)
return grad_weight
...@@ -3,8 +3,8 @@ from torch.autograd import Function ...@@ -3,8 +3,8 @@ from torch.autograd import Function
from .ffi import (spline_basis_forward, spline_weighting_forward, from .ffi import (spline_basis_forward, spline_weighting_forward,
spline_weighting_backward_input, spline_weighting_backward_input,
spline_weighting_backward_weight, spline_weighting_backward_basis,
spline_weighting_backward_basis) spline_weighting_backward_weight)
class SplineWeighting(Function): class SplineWeighting(Function):
...@@ -20,17 +20,31 @@ class SplineWeighting(Function): ...@@ -20,17 +20,31 @@ class SplineWeighting(Function):
self.degree, pseudo, self.kernel_size, self.is_open_spline, K) self.degree, pseudo, self.kernel_size, self.is_open_spline, K)
output = spline_weighting_forward(x, weight, basis, weight_index) output = spline_weighting_forward(x, weight, basis, weight_index)
# self.save_for_backward(x, weight) self.save_for_backward(x, weight)
# self.basis, self.weight_index = basis, weight_index self.basis, self.weight_index = basis, weight_index
return output return output
def backward(self, grad_output): # pragma: no cover def backward(self, grad_output): # pragma: no cover
pass x, weight = self.saved_tensors
# x, weight = self.saved_tensors basis, weight_index = self.basis, self.weight_index
# grad_input, grad_weight = spline_weighting_backward( grad_input, grad_pseudo, grad_weight = None, None, None
# grad_output, x, weight, self.basis, self.weight_index)
# return grad_input, None, grad_weight if self.needs_input_grad[0]:
grad_input = spline_weighting_backward_input(
grad_output, weight, basis, weight_index)
if self.needs_input_grad[1]:
grad_basis = spline_weighting_backward_basis(
grad_output, x, weight, weight_index)
print('pseudo needs grad')
if self.needs_input_grad[2]:
K = weight.size(0)
grad_weight = spline_weighting_backward_weight(
grad_output, x, basis, weight_index, K)
return grad_input, grad_pseudo, grad_weight
def spline_weighting(x, pseudo, weight, kernel_size, is_open_spline, degree): def spline_weighting(x, pseudo, weight, kernel_size, is_open_spline, degree):
......
...@@ -13,8 +13,9 @@ void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *inp ...@@ -13,8 +13,9 @@ void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *inp
void spline_weighting_backward_input_Float(THFloatTensor *grad_input, THFloatTensor *grad_output, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_input_Float(THFloatTensor *grad_input, THFloatTensor *grad_output, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_input_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_output, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_input_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_output, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_basis_Float(THFloatTensor *grad_basis, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THLongTensor *weight_index);
void spline_weighting_backward_basis_Double(THDoubleTensor *grad_basis, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THLongTensor *weight_index);
void spline_weighting_backward_weight_Float(THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_weight_Float(THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_weight_Double(THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_weight_Double(THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_basis_Float(THFloatTensor *grad_basis, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THLongTensor *weight_index);
void spline_weighting_backward_basis_Double(THDoubleTensor *grad_basis, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THLongTensor *weight_index);
...@@ -49,7 +49,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei ...@@ -49,7 +49,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_output, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_output, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real b; real *weight_data = weight->storage->data + weight->storageOffset; real b;
SPLINE_WEIGHTING_BACKWARD(grad_input, grad_output, basis, weight_index, THTensor_(size)(grad_input, 1), THTensor_(size)(grad_output, 1), THLongTensor_size(weight_index, 1), SPLINE_WEIGHTING_BACKWARD(grad_input, grad_output, basis, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1),
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
value = 0; value = 0;
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
...@@ -64,35 +64,36 @@ void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_outp ...@@ -64,35 +64,36 @@ void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_outp
) )
} }
void spline_(weighting_backward_weight)(THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_backward_basis)(THTensor *grad_basis, THTensor *grad_output, THTensor *input, THTensor *weight, THLongTensor *weight_index) {
real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset; real b; real *weight_data = weight->storage->data + weight->storageOffset;
SPLINE_WEIGHTING_BACKWARD(grad_output, input, basis, weight_index, THTensor_(size)(grad_output, 1), THTensor_(size)(input, 1), THLongTensor_size(weight_index, 1), SPLINE_WEIGHTING_BACKWARD(grad_basis, grad_output, input, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) { for (m_out = 0; m_out < M_out; m_out++) {
value = *(grad_output_data + m_out * grad_output_stride);
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride); w_idx = *(weight_index_data + s * weight_index_stride); value = 0;
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
grad_weight_data[w_idx * M_in * M_out + m_in * M_out + m_out] += b * value * *(input_data + m_in * input_stride); value += *(input_data + m_in * input_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out);
} }
grad_basis_data[s] += value * *(grad_output_data + m_out * grad_output_stride);
} }
} }
) )
} }
void spline_(weighting_backward_basis)(THTensor *grad_basis, THTensor *grad_output, THTensor *input, THTensor *weight, THLongTensor *weight_index) { void spline_(weighting_backward_weight)(THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset; real b;
SPLINE_WEIGHTING_BACKWARD(grad_basis, grad_output, input, weight_index, THTensor_(size)(grad_output, 1), THTensor_(size)(input, 1), THLongTensor_size(weight_index, 1), SPLINE_WEIGHTING_BACKWARD(grad_output, input, basis, weight_index, THTensor_(size)(input, 1), THTensor_(size)(grad_output, 1), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) { for (m_out = 0; m_out < M_out; m_out++) {
value = *(grad_output_data + m_out * grad_output_stride);
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
w_idx = *(weight_index_data + s * weight_index_stride); value = 0; b = *(basis_data + s * basis_stride);
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
value += *(input_data + m_in * input_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out); grad_weight_data[w_idx * M_in * M_out + m_in * M_out + m_out] += b * value * *(input_data + m_in * input_stride);
} }
grad_basis_data[s] += value * *(grad_output_data + m_out * grad_output_stride);
} }
} }
) )
} }
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment