Commit cba43ac9 authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent d1b3f976
......@@ -6,6 +6,16 @@ matrix:
- python: 2.7
- python: 3.5
- python: 3.6
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- gcc-4.9
- g++-4.9
before_install:
- export CC="gcc-4.9"
- export CXX="g++-4.9"
install:
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl; fi
......
......@@ -177,7 +177,7 @@ template <typename scalar_t> struct BasisBackward {
}); \
\
return grad_pseudo; \
}
}()
#define BASIS_BACKWARD_KERNEL(M, GRAD_PSEUDO, GRAD_BASIS, PSEUDO, KERNEL_SIZE, \
IS_OPEN_SPLINE, NUMEL, CODE, GRAD_CODE) \
......@@ -188,18 +188,58 @@ template <typename scalar_t> struct BasisBackward {
} \
}()
template <typename scalar_t>
__global__ void
linear_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
BASIS_BACKWARD_KERNEL(1, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::linear(v, k_mod),
BasisBackward<scalar_t>::linear(v, k_mod));
}
at::Tensor linear_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return grad_basis;
return BASIS_BACKWARD(1, grad_basis, pseudo, kernel_size, is_open_spline,
linear_bw_kernel);
}
template <typename scalar_t>
__global__ void
quadratic_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline,
size_t numel) {
BASIS_BACKWARD_KERNEL(2, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::quadratic(v, k_mod),
BasisBackward<scalar_t>::quadratic(v, k_mod));
}
at::Tensor quadratic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return grad_basis;
return BASIS_BACKWARD(2, grad_basis, pseudo, kernel_size, is_open_spline,
quadratic_bw_kernel);
}
template <typename scalar_t>
__global__ void
cubic_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
BASIS_BACKWARD_KERNEL(3, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::cubic(v, k_mod),
BasisBackward<scalar_t>::cubic(v, k_mod));
}
at::Tensor cubic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return grad_basis;
return BASIS_BACKWARD(3, grad_basis, pseudo, kernel_size, is_open_spline,
cubic_bw_kernel);
}
......@@ -8,8 +8,6 @@ from torch_spline_conv.basis import implemented_degrees as degrees
from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
tests = [{
'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]],
......
......@@ -8,8 +8,6 @@ from torch_spline_conv.basis import SplineBasis
from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
tests = [{
'x': [[1, 2], [3, 4]],
'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment