Commit 3a07cc5e authored by rusty1s's avatar rusty1s
Browse files

added python autograd function

parent eb8f32d7
...@@ -29,12 +29,19 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K): ...@@ -29,12 +29,19 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
return basis, weight_index return basis, weight_index
def spline_weighting_forward(x, weight, basis, weight_index): def spline_weighting_fw(x, weight, basis, weight_index):
pass output = x.new(x.size(0), weight.size(2))
func = get_func('spline_weighting_fw', x)
func(output, x, weight, basis, weight_index)
return output
def spline_weighting_backward(x, weight, basis, weight_index): def spline_weighting_bw(grad_output, x, weight, basis, weight_index):
pass grad_input = x.new(x.size(0), weight.size(1))
grad_weight = x.new(weight)
func = get_func('spline_weighting_bw', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight
class SplineWeighting(Function): class SplineWeighting(Function):
...@@ -44,14 +51,18 @@ class SplineWeighting(Function): ...@@ -44,14 +51,18 @@ class SplineWeighting(Function):
self.weight_index = weight_index self.weight_index = weight_index
def forward(self, x, weight): def forward(self, x, weight):
pass self.save_for_backward(x, weight)
basis, weight_index = self.basis, self.weight_index
return spline_weighting_fw(x, weight, basis, weight_index)
def backward(self, grad_output): def backward(self, grad_output):
pass x, weight = self.saved_tensors
basis, weight_index = self.basis, self.weight_index
return spline_weighting_bw(grad_output, x, weight, basis, weight_index)
def spline_weighting(x, weight, basis, weight_index): def spline_weighting(x, weight, basis, weight_index):
if torch.is_tensor(x): if torch.is_tensor(x):
return spline_weighting_forward(x, weight, basis, weight_index) return spline_weighting_fw(x, weight, basis, weight_index)
else: else:
return SplineWeighting(basis, weight_index)(x, weight) return SplineWeighting(basis, weight_index)(x, weight)
...@@ -7,8 +7,8 @@ void spline_basis_quadratic_Double(THDoubleTensor *basis, THLongTensor *weight_i ...@@ -7,8 +7,8 @@ void spline_basis_quadratic_Double(THDoubleTensor *basis, THLongTensor *weight_i
void spline_basis_cubic_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_basis_cubic_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_basis_cubic_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_basis_cubic_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_edgewise_forward_Float(THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index); void spline_edgewise_fw_Float(THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_edgewise_forward_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index); void spline_edgewise_fw_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_edgewise_backward_Float(THFloatTensor *grad_input, THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index); void spline_weighting_bw_Float(THFloatTensor *grad_input, THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_edgewise_backward_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_bw_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
...@@ -50,10 +50,10 @@ void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor ...@@ -50,10 +50,10 @@ void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor
) )
} }
void spline_(edgewise_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_fw)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
} }
void spline_(edgewise_backward)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_bw)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
} }
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment