Commit 676aa072 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.3 support

parent 0dd1186d
......@@ -17,7 +17,7 @@ before_install:
- export CXX="g++-4.9"
install:
- pip install numpy
- pip install -q torch -f https://download.pytorch.org/whl/nightly/cpu/torch.html
- pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install pycodestyle
- pip install flake8
- pip install codecov
......
#include <torch/extension.h>
#include "compat.h"
template <typename scalar_t> inline scalar_t linear(scalar_t v, int64_t k_mod) {
return 1 - v - k_mod + 2 * v * k_mod;
}
......@@ -34,11 +36,11 @@ template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) {
\
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_forward_##M", [&] { \
auto pseudo_data = PSEUDO.data<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.data<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.data<uint8_t>(); \
auto basis_data = basis.data<scalar_t>(); \
auto weight_index_data = weight_index.data<int64_t>(); \
auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>(); \
auto basis_data = basis.DATA_PTR<scalar_t>(); \
auto weight_index_data = weight_index.DATA_PTR<int64_t>(); \
\
int64_t k, wi, wi_offset; \
scalar_t b; \
......@@ -126,11 +128,11 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
\
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_backward_##M", [&] { \
auto grad_basis_data = GRAD_BASIS.data<scalar_t>(); \
auto pseudo_data = PSEUDO.data<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.data<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.data<uint8_t>(); \
auto grad_pseudo_data = grad_pseudo.data<scalar_t>(); \
auto grad_basis_data = GRAD_BASIS.DATA_PTR<scalar_t>(); \
auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>(); \
auto grad_pseudo_data = grad_pseudo.DATA_PTR<scalar_t>(); \
\
scalar_t g, tmp; \
\
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
#include "compat.h"
at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2);
......@@ -7,11 +9,11 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data<scalar_t>();
auto weight_data = weight.data<scalar_t>();
auto basis_data = basis.data<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>();
auto out_data = out.data<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t v;
......@@ -44,11 +46,11 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
auto grad_x = at::zeros({E, M_in}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data<scalar_t>();
auto weight_data = weight.data<scalar_t>();
auto basis_data = basis.data<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>();
auto grad_x_data = grad_x.data<scalar_t>();
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_x_data = grad_x.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
......@@ -78,11 +80,11 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_w", [&] {
auto grad_out_data = grad_out.data<scalar_t>();
auto x_data = x.data<scalar_t>();
auto basis_data = basis.data<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>();
auto grad_weight_data = grad_weight.data<scalar_t>();
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_weight_data = grad_weight.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
......@@ -110,11 +112,11 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_b", [&] {
auto grad_out_data = grad_out.data<scalar_t>();
auto x_data = x.data<scalar_t>();
auto weight_data = weight.data<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>();
auto grad_basis_data = grad_basis.data<scalar_t>();
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_basis_data = grad_basis.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
......
......@@ -2,6 +2,8 @@
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
......@@ -45,8 +47,8 @@ template <typename scalar_t> struct BasisForward {
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis), \
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.data<int64_t>(), IS_OPEN_SPLINE.data<uint8_t>(), \
basis.numel()); \
KERNEL_SIZE.DATA_PTR<int64_t>(), \
IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), basis.numel()); \
}); \
\
return std::make_tuple(basis, weight_index); \
......@@ -176,8 +178,8 @@ template <typename scalar_t> struct BasisBackward {
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_pseudo), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(GRAD_BASIS), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.data<int64_t>(), IS_OPEN_SPLINE.data<uint8_t>(), \
grad_pseudo.numel()); \
KERNEL_SIZE.DATA_PTR<int64_t>(), \
IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), grad_pseudo.numel()); \
}); \
\
return grad_pseudo; \
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
......@@ -2,21 +2,32 @@ from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
extra_compile_args = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [
CppExtension('torch_spline_conv.basis_cpu', ['cpu/basis.cpp']),
CppExtension('torch_spline_conv.weighting_cpu', ['cpu/weighting.cpp']),
CppExtension('torch_spline_conv.basis_cpu', ['cpu/basis.cpp'],
extra_compile_args=extra_compile_args),
CppExtension('torch_spline_conv.weighting_cpu', ['cpu/weighting.cpp'],
extra_compile_args=extra_compile_args),
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None:
ext_modules += [
CUDAExtension('torch_spline_conv.basis_cuda',
['cuda/basis.cpp', 'cuda/basis_kernel.cu']),
['cuda/basis.cpp', 'cuda/basis_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_spline_conv.weighting_cuda',
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu']),
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu'],
extra_compile_args=extra_compile_args),
]
__version__ = '1.1.0'
__version__ = '1.1.1'
url = 'https://github.com/rusty1s/pytorch_spline_conv'
install_requires = []
......
......@@ -32,7 +32,7 @@ def test_spline_weighting_forward(test, dtype, device):
@pytest.mark.parametrize('device', devices)
def test_spline_basis_backward(device):
def test_spline_weighting_backward(device):
pseudo = torch.rand((4, 2), dtype=torch.double, device=device)
kernel_size = tensor([5, 5], torch.long, device)
is_open_spline = tensor([1, 1], torch.uint8, device)
......
......@@ -2,6 +2,6 @@ from .basis import SplineBasis
from .weighting import SplineWeighting
from .conv import SplineConv
__version__ = '1.1.0'
__version__ = '1.1.1'
__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__']
......@@ -18,7 +18,8 @@ class SplineBasis(torch.autograd.Function):
@staticmethod
def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
ctx.save_for_backward(pseudo)
ctx.kernel_size, ctx.is_open_spline = kernel_size, is_open_spline
ctx.kernel_size = kernel_size
ctx.is_open_spline = is_open_spline
ctx.degree = degree
op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
......
......@@ -38,18 +38,9 @@ class SplineConv(object):
:rtype: :class:`Tensor`
"""
@staticmethod
def apply(x,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None):
def apply(x, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree=1, norm=True, root_weight=None, bias=None):
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
......@@ -58,8 +49,10 @@ class SplineConv(object):
n, m_out = x.size(0), weight.size(2)
# Weight each node.
data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree)
out = SplineWeighting.apply(x[col], weight, *data)
basis, weight_index = SplineBasis.apply(pseudo, kernel_size,
is_open_spline, degree)
weight_index = weight_index.detach()
out = SplineWeighting.apply(x[col], weight, basis, weight_index)
# Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment