Commit dff93289 authored by rusty1s's avatar rusty1s
Browse files

backward boilerplate

parent 0b777f0d
......@@ -4,3 +4,10 @@ void THFloatTensor_quadraticBasisForward( THFloatTensor *basis, THLongTensor *w
void THDoubleTensor_quadraticBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_cubicBasisForward( THFloatTensor *basis, THLongTensor *weightIndex, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_cubicBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_linearBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_linearBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_quadraticBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_quadraticBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_cubicBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_cubicBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
......@@ -40,4 +40,16 @@ void THTensor_(cubicBasisForward)(THTensor *basis, THLongTensor *weightIndex, TH
)
}
void THTensor_(linearBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
}
void THTensor_(quadraticBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
}
void THTensor_(cubicBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
}
#endif // TH_GENERIC_FILE
......@@ -4,3 +4,10 @@ void THCCFloatTensor_quadraticBasisForward( THCudaTensor *basis, THCudaLon
void THCCDoubleTensor_quadraticBasisForward(THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCFloatTensor_cubicBasisForward( THCudaTensor *basis, THCudaLongTensor *weightIndex, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCDoubleTensor_cubicBasisForward(THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCFloatTensor_linearBasisBackward( THCudaTensor *self, THCudaTensor *gradBasis, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCDoubleTensor_linearBasisBackward(THCudaDoubleTensor *self, THCudaDoubleTensor *gradBasis, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCFloatTensor_quadraticBasisBackward( THCudaTensor *self, THCudaTensor *gradBasis, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCDoubleTensor_quadraticBasisBackward(THCudaDoubleTensor *self, THCudaDoubleTensor *gradBasis, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCFloatTensor_cubicBasisBackward( THCudaTensor *self, THCudaTensor *gradBasis, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCDoubleTensor_cubicBasisBackward(THCudaDoubleTensor *self, THCudaDoubleTensor *gradBasis, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
......@@ -20,4 +20,19 @@ void THCCTensor_(cubicBasisForward)(THCTensor *basis, THCudaLongTensor *weightIn
THCTensor_(cubicBasisForward)(state, basis, weightIndex, pseudo, kernelSize, isOpenSpline);
}
void THCCTensor_(linearBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
}
void THCCTensor_(quadraticBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
}
void THCCTensor_(cubicBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
}
#endif // THC_GENERIC_FILE
import torch
from torch.autograd import Function
from .utils.ffi import basis_forward as ffi_basis_forward
from .utils.ffi import basis_backward as ffi_basis_backward
def basis_forward(degree, pseudo, kernel_size, is_open_spline):
......@@ -9,3 +13,32 @@ def basis_forward(degree, pseudo, kernel_size, is_open_spline):
ffi_basis_forward(degree, basis, weight_index, pseudo, kernel_size,
is_open_spline)
return basis, weight_index
def basis_backward(degree, grad_basis, pseudo, kernel_size, is_open_spline):
grad_pseudo = pseudo.new(pseudo.size())
ffi_basis_backward(degree, grad_pseudo, pseudo, kernel_size,
is_open_spline)
class Basis(Function):
def __init__(self, degree, kernel_size, is_open_spline):
super(Basis, self).__init__()
self.degree = degree
self.kernel_size = kernel_size
self.is_open_spline = is_open_spline
def forward(self, pseudo):
self.save_for_backawrd(pseudo)
return basis_forward(self.degree, pseudo, self.kernel_size,
self.is_open_spline)
def backward(self, grad_basis, grad_weight_index):
pass
def basis(degree, pseudo, kernel_size, is_open_spline):
if torch.is_tensor(pseudo):
return basis_forward(degree, pseudo, kernel_size, is_open_spline)
else:
return Basis(degree, kernel_size, is_open_spline)(pseudo)
......@@ -21,3 +21,10 @@ def basis_forward(degree, basis, weight_index, pseudo, kernel_size,
name = '{}BasisForward'.format(get_degree_str(degree))
func = get_func(name, basis.is_cuda, basis)
func(basis, weight_index, pseudo, kernel_size, is_open_spline)
def basis_backward(degree, self, grad_basis, pseudo, kernel_size,
is_open_spline):
name = '{}BasisBackward'.format(get_degree_str(degree))
func = get_func(name, self.is_cuda, self)
func(self, grad_basis, pseudo, kernel_size, is_open_spline)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment