#include "THCBasis.h" #include "THCBasisForward.cuh" #include "THCBasisBackward.cuh" template __global__ void linearBasisForwardKernel(TensorInfo basis, TensorInfoweightIndex, TensorInfo pseudo, int64_t *kernelSize, uint8_t *isOpenSpline, ptrdiff_t n) { THC_TENSOR_BASIS_FORWARD_KERNEL(1, basis, weightIndex, pseudo, kernelSize, isOpenSpline, n, BasisForward::linear(v, kMod)) } template __global__ void quadraticBasisForwardKernel(TensorInfo basis, TensorInfoweightIndex, TensorInfo pseudo, int64_t *kernelSize, uint8_t *isOpenSpline, ptrdiff_t n) { THC_TENSOR_BASIS_FORWARD_KERNEL(2, basis, weightIndex, pseudo, kernelSize, isOpenSpline, n, BasisForward::quadratic(v, kMod)) } template __global__ void cubicBasisForwardKernel(TensorInfo basis, TensorInfoweightIndex, TensorInfo pseudo, int64_t *kernelSize, uint8_t *isOpenSpline, ptrdiff_t n) { THC_TENSOR_BASIS_FORWARD_KERNEL(3, basis, weightIndex, pseudo, kernelSize, isOpenSpline, n, BasisForward::cubic(v, kMod)) } template __global__ void linearBasisBackwardKernel(TensorInfo self, TensorInfogradBasis, TensorInfo pseudo, int64_t *kernelSize, uint8_t *isOpenSpline, ptrdiff_t n) { THC_TENSOR_BASIS_BACKWARD_KERNEL(1, self, gradBasis, pseudo, kernelSize, isOpenSpline, n, BasisForward::linear(v, kMod), BasisBackward::linear(v, kMod)) } template __global__ void quadraticBasisBackwardKernel(TensorInfo self, TensorInfogradBasis, TensorInfo pseudo, int64_t *kernelSize, uint8_t *isOpenSpline, ptrdiff_t n) { THC_TENSOR_BASIS_BACKWARD_KERNEL(2, self, gradBasis, pseudo, kernelSize, isOpenSpline, n, BasisForward::quadratic(v, kMod), BasisBackward::quadratic(v, kMod)) } template __global__ void cubicBasisBackwardKernel(TensorInfo self, TensorInfogradBasis, TensorInfo pseudo, int64_t *kernelSize, uint8_t *isOpenSpline, ptrdiff_t n) { THC_TENSOR_BASIS_BACKWARD_KERNEL(3, self, gradBasis, pseudo, kernelSize, isOpenSpline, n, BasisForward::cubic(v, kMod), BasisBackward::cubic(v, kMod)) } #include "generic/THCBasis.cu" #include "THC/THCGenerateFloatTypes.h"