Commit 8d6acb03 authored by rusty1s's avatar rusty1s
Browse files

clean up and bugfixes

parent b46459f4
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCCBasis.c"
#else
void THCCTensor_(linearBasisForward)(THCTensor *basis, THCudaLongTensor *weightIndex,
THCTensor *pseudo, THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
THCTensor_(linearBasisForward)(state, basis, weightIndex, pseudo, kernelSize, isOpenSpline);
}
void THCCTensor_(quadraticBasisForward)(THCTensor *basis, THCudaLongTensor *weightIndex,
THCTensor *pseudo, THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
THCTensor_(quadraticBasisForward)(state, basis, weightIndex, pseudo, kernelSize, isOpenSpline);
}
void THCCTensor_(cubicBasisForward)(THCTensor *basis, THCudaLongTensor *weightIndex,
THCTensor *pseudo, THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
THCTensor_(cubicBasisForward)(state, basis, weightIndex, pseudo, kernelSize, isOpenSpline);
}
void THCCTensor_(linearBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
THCTensor_(linearBasisBackward)(state, self, gradBasis, pseudo, kernelSize, isOpenSpline);
}
void THCCTensor_(quadraticBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
THCTensor_(quadraticBasisBackward)(state, self, gradBasis, pseudo, kernelSize, isOpenSpline);
}
void THCCTensor_(cubicBasisBackward)(THCTensor *self, THCTensor *gradBasis, THCTensor *pseudo,
THCudaLongTensor *kernelSize,
THCudaByteTensor *isOpenSpline) {
THCTensor_(cubicBasisBackward)(state, self, gradBasis, pseudo, kernelSize, isOpenSpline);
}
#endif // THC_GENERIC_FILE
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCCWeighting.c"
#else
void THCCTensor_(weightingForward)(THCTensor *self, THCTensor *src, THCTensor *weight,
THCTensor *basis, THCudaLongTensor *weightIndex) {
THCTensor_(weightingForward)(state, self, src, weight, basis, weightIndex);
}
void THCCTensor_(weightingBackwardSrc)(THCTensor *self, THCTensor *gradOutput, THCTensor *weight,
THCTensor *basis, THCudaLongTensor *weightIndex) {
THCTensor_(weightingBackwardSrc)(state, self, gradOutput, weight, basis, weightIndex);
}
void THCCTensor_(weightingBackwardWeight)(THCTensor *self, THCTensor *gradOutput, THCTensor *src,
THCTensor *basis, THCudaLongTensor *weightIndex) {
THCTensor_(weightingBackwardWeight)(state, self, gradOutput, src, basis, weightIndex);
}
void THCCTensor_(weightingBackwardBasis)(THCTensor *self, THCTensor *gradOutput, THCTensor *src,
THCTensor *weight, THCudaLongTensor *weightIndex) {
THCTensor_(weightingBackwardBasis)(state, self, gradOutput, src, weight, weightIndex);
}
#endif // THC_GENERIC_FILE
......@@ -193,7 +193,7 @@ template <typename scalar_t> struct BasisBackward {
auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
v -= floor(v); \
v = CODE; \
v = GRAD_CODE; \
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < GRAD_PSEUDO.sizes[1]; d_it++) { \
......@@ -202,7 +202,7 @@ template <typename scalar_t> struct BasisBackward {
v = PSEUDO.data[e * pseudo.strides[0] + d_new * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d_new] - M * IS_OPEN_SPLINE[d_new]; \
v -= floor(v); \
v = GRAD_CODE; \
v = CODE; \
tmp *= v; \
} \
g += tmp * \
......
......@@ -7,7 +7,7 @@ from torch_spline_conv import SplineConv
from torch_spline_conv.basis import implemented_degrees as degrees
from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
devices = [torch.device('cuda')]
tests = [{
'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
......
......@@ -7,7 +7,6 @@ from torch_spline_conv.weighting import SplineWeighting
from torch_spline_conv.basis import SplineBasis
from .utils import dtypes, devices, tensor
devices = [torch.device('cuda')]
tests = [{
'x': [[1, 2], [3, 4]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment