Commit 3fbeeabc authored by rusty1s's avatar rusty1s
Browse files

added backward

parent 5487e31a
......@@ -28,17 +28,17 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
return basis, weight_index
def spline_weighting_fw(x, weight, basis, weight_index):
def spline_weighting_forward(x, weight, basis, weight_index):
output = x.new(x.size(0), weight.size(2))
func = get_func('weighting_fw', x)
func = get_func('weighting_forward', x)
func(output, x, weight, basis, weight_index)
return output
def spline_weighting_bw(grad_output, x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1))
grad_weight = x.new(weight)
func = get_func('weighting_bw', x)
func = get_func('weighting_backward', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight
......@@ -52,16 +52,17 @@ class SplineWeighting(Function):
def forward(self, x, weight):
self.save_for_backward(x, weight)
basis, weight_index = self.basis, self.weight_index
return spline_weighting_fw(x, weight, basis, weight_index)
return spline_weighting_forward(x, weight, basis, weight_index)
def backward(self, grad_output):
x, weight = self.saved_tensors
basis, weight_index = self.basis, self.weight_index
return spline_weighting_bw(grad_output, x, weight, basis, weight_index)
return spline_weighting_backward(grad_output, x, weight, basis,
weight_index)
def spline_weighting(x, weight, basis, weight_index):
if torch.is_tensor(x):
return spline_weighting_fw(x, weight, basis, weight_index)
return spline_weighting_forward(x, weight, basis, weight_index)
else:
return SplineWeighting(basis, weight_index)(x, weight)
......@@ -72,3 +72,86 @@
} \
THFree(TH_TENSOR_DIM_APPLY_counter); \
}
#define TH_TENSOR_DIM_APPLY5(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, TENSOR4, TYPE5, TENSOR5, DIMENSION, CODE) { \
TYPE1 *TENSOR1##_data = NULL; \
int64_t TENSOR1##_stride = 0, TENSOR1##_size = 0; \
TYPE2 *TENSOR2##_data = NULL; \
int64_t TENSOR2##_stride = 0, TENSOR2##_size = 0; \
TYPE3 *TENSOR3##_data = NULL; \
int64_t TENSOR3##_stride = 0, TENSOR3##_size = 0; \
TYPE4 *TENSOR4##_data = NULL; \
int64_t TENSOR4##_stride = 0, TENSOR4##_size = 0; \
TYPE5 *TENSOR5##_data = NULL; \
int64_t TENSOR5##_stride = 0, TENSOR5##_size = 0; \
\
int64_t *TH_TENSOR_DIM_APPLY_counter = NULL; \
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
int TH_TENSOR_DIM_APPLY_i; \
\
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR1->nDimension)); \
\
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) { \
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
} \
\
TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \
TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \
TENSOR1##_size = TENSOR1->size[DIMENSION]; \
\
TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \
TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \
TENSOR2##_size = TENSOR2->size[DIMENSION]; \
\
TENSOR3##_data = (TENSOR3)->storage->data+(TENSOR3)->storageOffset; \
TENSOR3##_stride = (TENSOR3)->stride[DIMENSION]; \
TENSOR3##_size = TENSOR3->size[DIMENSION]; \
\
TENSOR4##_data = (TENSOR4)->storage->data+(TENSOR4)->storageOffset; \
TENSOR4##_stride = (TENSOR4)->stride[DIMENSION]; \
TENSOR4##_size = TENSOR4->size[DIMENSION]; \
\
TENSOR5##_data = (TENSOR5)->storage->data+(TENSOR5)->storageOffset; \
TENSOR5##_stride = (TENSOR5)->stride[DIMENSION]; \
TENSOR5##_size = TENSOR5->size[DIMENSION]; \
\
while (!TH_TENSOR_DIM_APPLY_hasFinished) { \
CODE \
\
if (TENSOR1->nDimension == 1) break; \
\
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) { \
if (TH_TENSOR_DIM_APPLY_i == DIMENSION) { \
if (TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) { \
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
break; \
} \
continue; \
} \
\
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \
TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR3##_data += TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR4##_data += TENSOR4->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR5##_data += TENSOR5->stride[TH_TENSOR_DIM_APPLY_i]; \
\
if (TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) { \
if (TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) { \
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
break; \
} \
else { \
TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR3##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR4##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR4->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR5##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR5->stride[TH_TENSOR_DIM_APPLY_i]; \
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
} \
} \
else break; \
} \
} \
THFree(TH_TENSOR_DIM_APPLY_counter); \
}
#include <TH/TH.h>
#include "THTensorDimApply4.h"
#include "THTensorDimApply.h"
#define spline_(NAME) TH_CONCAT_4(spline_, NAME, _, Real)
......
......@@ -7,8 +7,8 @@ void spline_basis_quadratic_Double(THDoubleTensor *basis, THLongTensor *weight_i
void spline_basis_cubic_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_basis_cubic_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_weighting_fw_Float(THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_fw_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_forward_Float(THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_bw_Float(THFloatTensor *grad_input, THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_bw_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_Float(THFloatTensor *grad_input, THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
......@@ -50,7 +50,7 @@ void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor
)
}
void spline_(weighting_fw)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset;
int64_t M_out = THTensor_(size)(output, 1);
int64_t M_in = THTensor_(size)(input, 1);
......@@ -72,7 +72,9 @@ void spline_(weighting_fw)(THTensor *output, THTensor *input, THTensor *weight,
)
}
void spline_(weighting_bw)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1,
)
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment