Commit 676aa072 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.3 support

parent 0dd1186d
...@@ -17,7 +17,7 @@ before_install: ...@@ -17,7 +17,7 @@ before_install:
- export CXX="g++-4.9" - export CXX="g++-4.9"
install: install:
- pip install numpy - pip install numpy
- pip install -q torch -f https://download.pytorch.org/whl/nightly/cpu/torch.html - pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install pycodestyle - pip install pycodestyle
- pip install flake8 - pip install flake8
- pip install codecov - pip install codecov
......
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
template <typename scalar_t> inline scalar_t linear(scalar_t v, int64_t k_mod) { template <typename scalar_t> inline scalar_t linear(scalar_t v, int64_t k_mod) {
return 1 - v - k_mod + 2 * v * k_mod; return 1 - v - k_mod + 2 * v * k_mod;
} }
...@@ -34,11 +36,11 @@ template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) { ...@@ -34,11 +36,11 @@ template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) {
\ \
AT_DISPATCH_FLOATING_TYPES( \ AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_forward_##M", [&] { \ PSEUDO.scalar_type(), "basis_forward_##M", [&] { \
auto pseudo_data = PSEUDO.data<scalar_t>(); \ auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.data<int64_t>(); \ auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.data<uint8_t>(); \ auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>(); \
auto basis_data = basis.data<scalar_t>(); \ auto basis_data = basis.DATA_PTR<scalar_t>(); \
auto weight_index_data = weight_index.data<int64_t>(); \ auto weight_index_data = weight_index.DATA_PTR<int64_t>(); \
\ \
int64_t k, wi, wi_offset; \ int64_t k, wi, wi_offset; \
scalar_t b; \ scalar_t b; \
...@@ -126,11 +128,11 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) { ...@@ -126,11 +128,11 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
\ \
AT_DISPATCH_FLOATING_TYPES( \ AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_backward_##M", [&] { \ PSEUDO.scalar_type(), "basis_backward_##M", [&] { \
auto grad_basis_data = GRAD_BASIS.data<scalar_t>(); \ auto grad_basis_data = GRAD_BASIS.DATA_PTR<scalar_t>(); \
auto pseudo_data = PSEUDO.data<scalar_t>(); \ auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.data<int64_t>(); \ auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.data<uint8_t>(); \ auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>(); \
auto grad_pseudo_data = grad_pseudo.data<scalar_t>(); \ auto grad_pseudo_data = grad_pseudo.DATA_PTR<scalar_t>(); \
\ \
scalar_t g, tmp; \ scalar_t g, tmp; \
\ \
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis, at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) { at::Tensor weight_index) {
auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2); auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2);
...@@ -7,11 +9,11 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis, ...@@ -7,11 +9,11 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
auto out = at::empty({E, M_out}, x.options()); auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "weighting_fw", [&] { AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data<scalar_t>(); auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.data<scalar_t>(); auto weight_data = weight.DATA_PTR<scalar_t>();
auto basis_data = basis.data<scalar_t>(); auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>(); auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto out_data = out.data<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
scalar_t v; scalar_t v;
...@@ -44,11 +46,11 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight, ...@@ -44,11 +46,11 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
auto grad_x = at::zeros({E, M_in}, grad_out.options()); auto grad_x = at::zeros({E, M_in}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data<scalar_t>(); auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto weight_data = weight.data<scalar_t>(); auto weight_data = weight.DATA_PTR<scalar_t>();
auto basis_data = basis.data<scalar_t>(); auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>(); auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_x_data = grad_x.data<scalar_t>(); auto grad_x_data = grad_x.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) { for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) { for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
...@@ -78,11 +80,11 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis, ...@@ -78,11 +80,11 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options()); auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_w", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_w", [&] {
auto grad_out_data = grad_out.data<scalar_t>(); auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.data<scalar_t>(); auto x_data = x.DATA_PTR<scalar_t>();
auto basis_data = basis.data<scalar_t>(); auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>(); auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_weight_data = grad_weight.data<scalar_t>(); auto grad_weight_data = grad_weight.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) { for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) { for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
...@@ -110,11 +112,11 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight, ...@@ -110,11 +112,11 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
auto grad_basis = at::zeros({E, S}, grad_out.options()); auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_b", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_b", [&] {
auto grad_out_data = grad_out.data<scalar_t>(); auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.data<scalar_t>(); auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.data<scalar_t>(); auto weight_data = weight.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>(); auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_basis_data = grad_basis.data<scalar_t>(); auto grad_basis_data = grad_basis.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) { for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) { for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
...@@ -45,8 +47,8 @@ template <typename scalar_t> struct BasisForward { ...@@ -45,8 +47,8 @@ template <typename scalar_t> struct BasisForward {
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis), \ at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis), \
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index), \ at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \ at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.data<int64_t>(), IS_OPEN_SPLINE.data<uint8_t>(), \ KERNEL_SIZE.DATA_PTR<int64_t>(), \
basis.numel()); \ IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), basis.numel()); \
}); \ }); \
\ \
return std::make_tuple(basis, weight_index); \ return std::make_tuple(basis, weight_index); \
...@@ -176,8 +178,8 @@ template <typename scalar_t> struct BasisBackward { ...@@ -176,8 +178,8 @@ template <typename scalar_t> struct BasisBackward {
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_pseudo), \ at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_pseudo), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(GRAD_BASIS), \ at::cuda::detail::getTensorInfo<scalar_t, int64_t>(GRAD_BASIS), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \ at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.data<int64_t>(), IS_OPEN_SPLINE.data<uint8_t>(), \ KERNEL_SIZE.DATA_PTR<int64_t>(), \
grad_pseudo.numel()); \ IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), grad_pseudo.numel()); \
}); \ }); \
\ \
return grad_pseudo; \ return grad_pseudo; \
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
...@@ -2,21 +2,32 @@ from setuptools import setup, find_packages ...@@ -2,21 +2,32 @@ from setuptools import setup, find_packages
import torch import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
extra_compile_args = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [ ext_modules = [
CppExtension('torch_spline_conv.basis_cpu', ['cpu/basis.cpp']), CppExtension('torch_spline_conv.basis_cpu', ['cpu/basis.cpp'],
CppExtension('torch_spline_conv.weighting_cpu', ['cpu/weighting.cpp']), extra_compile_args=extra_compile_args),
CppExtension('torch_spline_conv.weighting_cpu', ['cpu/weighting.cpp'],
extra_compile_args=extra_compile_args),
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None: if CUDA_HOME is not None:
ext_modules += [ ext_modules += [
CUDAExtension('torch_spline_conv.basis_cuda', CUDAExtension('torch_spline_conv.basis_cuda',
['cuda/basis.cpp', 'cuda/basis_kernel.cu']), ['cuda/basis.cpp', 'cuda/basis_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_spline_conv.weighting_cuda', CUDAExtension('torch_spline_conv.weighting_cuda',
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu']), ['cuda/weighting.cpp', 'cuda/weighting_kernel.cu'],
extra_compile_args=extra_compile_args),
] ]
__version__ = '1.1.0' __version__ = '1.1.1'
url = 'https://github.com/rusty1s/pytorch_spline_conv' url = 'https://github.com/rusty1s/pytorch_spline_conv'
install_requires = [] install_requires = []
......
...@@ -32,7 +32,7 @@ def test_spline_weighting_forward(test, dtype, device): ...@@ -32,7 +32,7 @@ def test_spline_weighting_forward(test, dtype, device):
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_spline_basis_backward(device): def test_spline_weighting_backward(device):
pseudo = torch.rand((4, 2), dtype=torch.double, device=device) pseudo = torch.rand((4, 2), dtype=torch.double, device=device)
kernel_size = tensor([5, 5], torch.long, device) kernel_size = tensor([5, 5], torch.long, device)
is_open_spline = tensor([1, 1], torch.uint8, device) is_open_spline = tensor([1, 1], torch.uint8, device)
......
...@@ -2,6 +2,6 @@ from .basis import SplineBasis ...@@ -2,6 +2,6 @@ from .basis import SplineBasis
from .weighting import SplineWeighting from .weighting import SplineWeighting
from .conv import SplineConv from .conv import SplineConv
__version__ = '1.1.0' __version__ = '1.1.1'
__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__'] __all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__']
...@@ -18,7 +18,8 @@ class SplineBasis(torch.autograd.Function): ...@@ -18,7 +18,8 @@ class SplineBasis(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, pseudo, kernel_size, is_open_spline, degree): def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
ctx.save_for_backward(pseudo) ctx.save_for_backward(pseudo)
ctx.kernel_size, ctx.is_open_spline = kernel_size, is_open_spline ctx.kernel_size = kernel_size
ctx.is_open_spline = is_open_spline
ctx.degree = degree ctx.degree = degree
op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo) op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
......
...@@ -38,18 +38,9 @@ class SplineConv(object): ...@@ -38,18 +38,9 @@ class SplineConv(object):
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
@staticmethod @staticmethod
def apply(x, def apply(x, edge_index, pseudo, weight, kernel_size, is_open_spline,
edge_index, degree=1, norm=True, root_weight=None, bias=None):
pseudo,
weight,
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None):
x = x.unsqueeze(-1) if x.dim() == 1 else x x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
...@@ -58,8 +49,10 @@ class SplineConv(object): ...@@ -58,8 +49,10 @@ class SplineConv(object):
n, m_out = x.size(0), weight.size(2) n, m_out = x.size(0), weight.size(2)
# Weight each node. # Weight each node.
data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree) basis, weight_index = SplineBasis.apply(pseudo, kernel_size,
out = SplineWeighting.apply(x[col], weight, *data) is_open_spline, degree)
weight_index = weight_index.detach()
out = SplineWeighting.apply(x[col], weight, basis, weight_index)
# Convert e x m_out to n x m_out features. # Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(out) row_expand = row.unsqueeze(-1).expand_as(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment