"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3e4c5707c3e6e0e363ef93a6a60bad7245f05e46"
Commit dff93289 authored by rusty1s's avatar rusty1s
Browse files

backward boilerplate

parent 0b777f0d
...@@ -4,3 +4,10 @@ void THFloatTensor_quadraticBasisForward( THFloatTensor *basis, THLongTensor *w ...@@ -4,3 +4,10 @@ void THFloatTensor_quadraticBasisForward( THFloatTensor *basis, THLongTensor *w
void THDoubleTensor_quadraticBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline); void THDoubleTensor_quadraticBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_cubicBasisForward( THFloatTensor *basis, THLongTensor *weightIndex, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline); void THFloatTensor_cubicBasisForward( THFloatTensor *basis, THLongTensor *weightIndex, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_cubicBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline); void THDoubleTensor_cubicBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_linearBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_linearBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_quadraticBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_quadraticBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_cubicBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_cubicBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
...@@ -40,4 +40,16 @@ void THTensor_(cubicBasisForward)(THTensor *basis, THLongTensor *weightIndex, TH ...@@ -40,4 +40,16 @@ void THTensor_(cubicBasisForward)(THTensor *basis, THLongTensor *weightIndex, TH
) )
} }
void THTensor_(linearBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
}
void THTensor_(quadraticBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
}
void THTensor_(cubicBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
}
#endif // TH_GENERIC_FILE #endif // TH_GENERIC_FILE
...@@ -4,3 +4,10 @@ void THCCFloatTensor_quadraticBasisForward( THCudaTensor *basis, THCudaLon ...@@ -4,3 +4,10 @@ void THCCFloatTensor_quadraticBasisForward( THCudaTensor *basis, THCudaLon
void THCCDoubleTensor_quadraticBasisForward(THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline); void THCCDoubleTensor_quadraticBasisForward(THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCFloatTensor_cubicBasisForward( THCudaTensor *basis, THCudaLongTensor *weightIndex, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline); void THCCFloatTensor_cubicBasisForward( THCudaTensor *basis, THCudaLongTensor *weightIndex, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCDoubleTensor_cubicBasisForward(THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline); void THCCDoubleTensor_cubicBasisForward(THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCFloatTensor_linearBasisBackward( THCudaTensor *self, THCudaTensor *gradBasis, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCDoubleTensor_linearBasisBackward(THCudaDoubleTensor *self, THCudaDoubleTensor *gradBasis, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCFloatTensor_quadraticBasisBackward( THCudaTensor *self, THCudaTensor *gradBasis, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCDoubleTensor_quadraticBasisBackward(THCudaDoubleTensor *self, THCudaDoubleTensor *gradBasis, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCFloatTensor_cubicBasisBackward( THCudaTensor *self, THCudaTensor *gradBasis, THCudaTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
void THCCDoubleTensor_cubicBasisBackward(THCudaDoubleTensor *self, THCudaDoubleTensor *gradBasis, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernelSize, THCudaByteTensor *isOpenSpline);
...@@ -20,4 +20,19 @@ void THCCTensor_(cubicBasisForward)(THCTensor *basis, THCudaLongTensor *weightIn ...@@ -20,4 +20,19 @@ void THCCTensor_(cubicBasisForward)(THCTensor *basis, THCudaLongTensor *weightIn
THCTensor_(cubicBasisForward)(state, basis, weightIndex, pseudo, kernelSize, isOpenSpline); THCTensor_(cubicBasisForward)(state, basis, weightIndex, pseudo, kernelSize, isOpenSpline);
} }
void THCCTensor_(linearBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
}
void THCCTensor_(quadraticBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
}
void THCCTensor_(cubicBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
}
#endif // THC_GENERIC_FILE #endif // THC_GENERIC_FILE
import torch
from torch.autograd import Function
from .utils.ffi import basis_forward as ffi_basis_forward from .utils.ffi import basis_forward as ffi_basis_forward
from .utils.ffi import basis_backward as ffi_basis_backward
def basis_forward(degree, pseudo, kernel_size, is_open_spline): def basis_forward(degree, pseudo, kernel_size, is_open_spline):
...@@ -9,3 +13,32 @@ def basis_forward(degree, pseudo, kernel_size, is_open_spline): ...@@ -9,3 +13,32 @@ def basis_forward(degree, pseudo, kernel_size, is_open_spline):
ffi_basis_forward(degree, basis, weight_index, pseudo, kernel_size, ffi_basis_forward(degree, basis, weight_index, pseudo, kernel_size,
is_open_spline) is_open_spline)
return basis, weight_index return basis, weight_index
def basis_backward(degree, grad_basis, pseudo, kernel_size, is_open_spline):
grad_pseudo = pseudo.new(pseudo.size())
ffi_basis_backward(degree, grad_pseudo, pseudo, kernel_size,
is_open_spline)
class Basis(Function):
def __init__(self, degree, kernel_size, is_open_spline):
super(Basis, self).__init__()
self.degree = degree
self.kernel_size = kernel_size
self.is_open_spline = is_open_spline
def forward(self, pseudo):
self.save_for_backawrd(pseudo)
return basis_forward(self.degree, pseudo, self.kernel_size,
self.is_open_spline)
def backward(self, grad_basis, grad_weight_index):
pass
def basis(degree, pseudo, kernel_size, is_open_spline):
if torch.is_tensor(pseudo):
return basis_forward(degree, pseudo, kernel_size, is_open_spline)
else:
return Basis(degree, kernel_size, is_open_spline)(pseudo)
...@@ -21,3 +21,10 @@ def basis_forward(degree, basis, weight_index, pseudo, kernel_size, ...@@ -21,3 +21,10 @@ def basis_forward(degree, basis, weight_index, pseudo, kernel_size,
name = '{}BasisForward'.format(get_degree_str(degree)) name = '{}BasisForward'.format(get_degree_str(degree))
func = get_func(name, basis.is_cuda, basis) func = get_func(name, basis.is_cuda, basis)
func(basis, weight_index, pseudo, kernel_size, is_open_spline) func(basis, weight_index, pseudo, kernel_size, is_open_spline)
def basis_backward(degree, self, grad_basis, pseudo, kernel_size,
is_open_spline):
name = '{}BasisBackward'.format(get_degree_str(degree))
func = get_func(name, self.is_cuda, self)
func(self, grad_basis, pseudo, kernel_size, is_open_spline)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment