Commit cba43ac9 authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent d1b3f976
...@@ -6,6 +6,16 @@ matrix: ...@@ -6,6 +6,16 @@ matrix:
- python: 2.7 - python: 2.7
- python: 3.5 - python: 3.5
- python: 3.6 - python: 3.6
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- gcc-4.9
- g++-4.9
before_install:
- export CC="gcc-4.9"
- export CXX="g++-4.9"
install: install:
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl; fi
......
...@@ -177,7 +177,7 @@ template <typename scalar_t> struct BasisBackward { ...@@ -177,7 +177,7 @@ template <typename scalar_t> struct BasisBackward {
}); \ }); \
\ \
return grad_pseudo; \ return grad_pseudo; \
} }()
#define BASIS_BACKWARD_KERNEL(M, GRAD_PSEUDO, GRAD_BASIS, PSEUDO, KERNEL_SIZE, \ #define BASIS_BACKWARD_KERNEL(M, GRAD_PSEUDO, GRAD_BASIS, PSEUDO, KERNEL_SIZE, \
IS_OPEN_SPLINE, NUMEL, CODE, GRAD_CODE) \ IS_OPEN_SPLINE, NUMEL, CODE, GRAD_CODE) \
...@@ -188,18 +188,58 @@ template <typename scalar_t> struct BasisBackward { ...@@ -188,18 +188,58 @@ template <typename scalar_t> struct BasisBackward {
} \ } \
}() }()
template <typename scalar_t>
__global__ void
linear_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
BASIS_BACKWARD_KERNEL(1, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::linear(v, k_mod),
BasisBackward<scalar_t>::linear(v, k_mod));
}
at::Tensor linear_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo, at::Tensor linear_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) { at::Tensor kernel_size, at::Tensor is_open_spline) {
return grad_basis; return BASIS_BACKWARD(1, grad_basis, pseudo, kernel_size, is_open_spline,
linear_bw_kernel);
}
template <typename scalar_t>
__global__ void
quadratic_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline,
size_t numel) {
BASIS_BACKWARD_KERNEL(2, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::quadratic(v, k_mod),
BasisBackward<scalar_t>::quadratic(v, k_mod));
} }
at::Tensor quadratic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo, at::Tensor quadratic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor kernel_size,
at::Tensor is_open_spline) { at::Tensor is_open_spline) {
return grad_basis; return BASIS_BACKWARD(2, grad_basis, pseudo, kernel_size, is_open_spline,
quadratic_bw_kernel);
}
template <typename scalar_t>
__global__ void
cubic_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
BASIS_BACKWARD_KERNEL(3, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::cubic(v, k_mod),
BasisBackward<scalar_t>::cubic(v, k_mod));
} }
at::Tensor cubic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo, at::Tensor cubic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) { at::Tensor kernel_size, at::Tensor is_open_spline) {
return grad_basis; return BASIS_BACKWARD(3, grad_basis, pseudo, kernel_size, is_open_spline,
cubic_bw_kernel);
} }
...@@ -8,8 +8,6 @@ from torch_spline_conv.basis import implemented_degrees as degrees ...@@ -8,8 +8,6 @@ from torch_spline_conv.basis import implemented_degrees as degrees
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
tests = [{ tests = [{
'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]], 'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]], 'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]],
......
...@@ -8,8 +8,6 @@ from torch_spline_conv.basis import SplineBasis ...@@ -8,8 +8,6 @@ from torch_spline_conv.basis import SplineBasis
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
tests = [{ tests = [{
'x': [[1, 2], [3, 4]], 'x': [[1, 2], [3, 4]],
'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]], 'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment