Unverified Commit cd7b1988 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #14 from rusty1s/tracing

prepare tracing
parents d3169766 32224979
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include <cuda.h>
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__version(void) { return NULL; }
#endif
int64_t cuda_version() {
#ifdef WITH_CUDA
return CUDA_VERSION;
#else
return -1;
#endif
}
static auto registry = torch::RegisterOperators().op(
"torch_spline_conv::cuda_version", &cuda_version);
#include <Python.h>
#include <torch/script.h>
#include "cpu/weighting_cpu.h"
#ifdef WITH_CUDA
#include "cuda/weighting_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__weighting(void) { return NULL; }
#endif
torch::Tensor spline_weighting_fw(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_weighting_fw_cuda(x, weight, basis, weight_index);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_weighting_fw_cpu(x, weight, basis, weight_index);
}
}
torch::Tensor spline_weighting_bw_x(torch::Tensor grad_out,
torch::Tensor weight, torch::Tensor basis,
torch::Tensor weight_index) {
if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_weighting_bw_x_cuda(grad_out, weight, basis, weight_index);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_weighting_bw_x_cpu(grad_out, weight, basis, weight_index);
}
}
torch::Tensor spline_weighting_bw_weight(torch::Tensor grad_out,
torch::Tensor x, torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_weighting_bw_weight_cuda(grad_out, x, basis, weight_index,
kernel_size);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_weighting_bw_weight_cpu(grad_out, x, basis, weight_index,
kernel_size);
}
}
torch::Tensor spline_weighting_bw_basis(torch::Tensor grad_out, torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_weighting_bw_basis_cuda(grad_out, x, weight, weight_index);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_weighting_bw_basis_cpu(grad_out, x, weight, weight_index);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SplineWeighting : public torch::autograd::Function<SplineWeighting> {
public:
static variable_list forward(AutogradContext *ctx, Variable x,
Variable weight, Variable basis,
Variable weight_index) {
auto out = spline_weighting_fw(x, weight, basis, weight_index);
ctx->save_for_backward({x, weight, basis, weight_index});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto x = saved[0], weight = saved[1], basis = saved[2],
weight_index = saved[3];
auto grad_x = Variable();
if (torch::autograd::any_variable_requires_grad({x})) {
grad_x = spline_weighting_bw_x(grad_out, weight, basis, weight_index);
}
auto grad_weight = Variable();
if (torch::autograd::any_variable_requires_grad({weight})) {
grad_weight = spline_weighting_bw_weight(grad_out, x, basis, weight_index,
weight.size(0));
}
auto grad_basis = Variable();
if (torch::autograd::any_variable_requires_grad({basis})) {
grad_basis = spline_weighting_bw_basis(grad_out, x, weight, weight_index);
}
return {grad_x, grad_weight, grad_basis, Variable()};
}
};
torch::Tensor spline_weighting(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
x = x.contiguous();
weight = weight.contiguous();
return SplineWeighting::apply(x, weight, basis, weight_index)[0];
}
static auto registry = torch::RegisterOperators().op(
"torch_spline_conv::spline_weighting", &spline_weighting);
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor> linear_fw_cuda(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline);
std::tuple<at::Tensor, at::Tensor> quadratic_fw_cuda(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline);
std::tuple<at::Tensor, at::Tensor> cubic_fw_cuda(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline);
at::Tensor linear_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline);
at::Tensor quadratic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline);
at::Tensor cubic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline);
std::tuple<at::Tensor, at::Tensor> linear_fw(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
return linear_fw_cuda(pseudo, kernel_size, is_open_spline);
}
std::tuple<at::Tensor, at::Tensor> quadratic_fw(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
return quadratic_fw_cuda(pseudo, kernel_size, is_open_spline);
}
std::tuple<at::Tensor, at::Tensor>
cubic_fw(at::Tensor pseudo, at::Tensor kernel_size, at::Tensor is_open_spline) {
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
return cubic_fw_cuda(pseudo, kernel_size, is_open_spline);
}
at::Tensor linear_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
CHECK_CUDA(grad_basis);
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
return linear_bw_cuda(grad_basis, pseudo, kernel_size, is_open_spline);
}
at::Tensor quadratic_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
CHECK_CUDA(grad_basis);
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
return quadratic_bw_cuda(grad_basis, pseudo, kernel_size, is_open_spline);
}
at::Tensor cubic_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
CHECK_CUDA(grad_basis);
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
return cubic_bw_cuda(grad_basis, pseudo, kernel_size, is_open_spline);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_fw", &linear_fw, "Linear Basis Forward (CUDA)");
m.def("quadratic_fw", &quadratic_fw, "Quadratic Basis Forward (CUDA)");
m.def("cubic_fw", &cubic_fw, "Cubic Basis Forward (CUDA)");
m.def("linear_bw", &linear_bw, "Linear Basis Backward (CUDA)");
m.def("quadratic_bw", &quadratic_bw, "Quadratic Basis Backward (CUDA)");
m.def("cubic_bw", &cubic_bw, "Cubic Basis Backward (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t> struct BasisForward {
static inline __device__ scalar_t linear(scalar_t v, int64_t k_mod) {
return 1 - v - k_mod + 2 * v * k_mod;
}
static inline __device__ scalar_t quadratic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
}
static inline __device__ scalar_t cubic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return (1 - v) * (1 - v) * (1 - v) / 6.0;
else if (k_mod == 1)
return (3 * v * v * v - 6 * v * v + 4) / 6;
else if (k_mod == 2)
return (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6;
else
return v * v * v / 6;
}
};
#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, KERNEL_NAME) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
cudaSetDevice(PSEUDO.get_device()); \
auto E = PSEUDO.size(0); \
auto S = (int64_t)(powf(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.options()); \
auto weight_index = at::empty({E, S}, KERNEL_SIZE.options()); \
\
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_forward_##M", [&] { \
KERNEL_NAME<scalar_t><<<BLOCKS(basis.numel()), THREADS>>>( \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis), \
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.DATA_PTR<int64_t>(), \
IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), basis.numel()); \
}); \
\
return std::make_tuple(basis, weight_index); \
}()
#define BASIS_FORWARD_KERNEL(M, BASIS, WEIGHT_INDEX, PSEUDO, KERNEL_SIZE, \
IS_OPEN_SPLINE, NUMEL, CODE) \
[&] { \
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; \
const size_t stride = blockDim.x * gridDim.x; \
for (ptrdiff_t i = index; i < NUMEL; i += stride) { \
int64_t e = i / BASIS.sizes[1], s = i % BASIS.sizes[1]; \
int64_t k = s, wi = 0, wi_offset = 1; \
scalar_t b = 1; \
\
for (ptrdiff_t d = 0; d < PSEUDO.sizes[1]; d++) { \
auto k_mod = k % (M + 1); \
k /= M + 1; \
\
auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
\
wi += (((int64_t)v + k_mod) % KERNEL_SIZE[d]) * wi_offset; \
wi_offset *= KERNEL_SIZE[d]; \
\
v -= floor(v); \
v = CODE; \
b *= v; \
} \
\
BASIS.data[i] = b; \
WEIGHT_INDEX.data[i] = wi; \
} \
}()
template <typename scalar_t>
__global__ void
linear_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
BASIS_FORWARD_KERNEL(1, basis, weight_index, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::linear(v, k_mod));
}
std::tuple<at::Tensor, at::Tensor> linear_fw_cuda(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(1, pseudo, kernel_size, is_open_spline,
linear_fw_kernel);
}
template <typename scalar_t>
__global__ void
quadratic_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline,
size_t numel) {
BASIS_FORWARD_KERNEL(2, basis, weight_index, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::quadratic(v, k_mod));
}
std::tuple<at::Tensor, at::Tensor>
quadratic_fw_cuda(at::Tensor pseudo, at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(2, pseudo, kernel_size, is_open_spline,
quadratic_fw_kernel);
}
template <typename scalar_t>
__global__ void
cubic_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
BASIS_FORWARD_KERNEL(3, basis, weight_index, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::cubic(v, k_mod));
}
std::tuple<at::Tensor, at::Tensor> cubic_fw_cuda(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(3, pseudo, kernel_size, is_open_spline, cubic_fw_kernel);
}
template <typename scalar_t> struct BasisBackward {
static inline __device__ scalar_t linear(scalar_t v, int64_t k_mod) {
return 2 * k_mod - 1;
}
static inline __device__ scalar_t quadratic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return v - 1;
else if (k_mod == 1)
return -2 * v + 1;
else
return v;
}
static inline __device__ scalar_t cubic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return (-v * v + 2 * v - 1) / 2;
else if (k_mod == 1)
return (3 * v * v - 4 * v) / 2;
else if (k_mod == 2)
return (-3 * v * v + 2 * v + 1) / 2;
else
return v * v / 2;
}
};
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
KERNEL_NAME) \
[&]() -> at::Tensor { \
cudaSetDevice(GRAD_BASIS.get_device()); \
auto E = PSEUDO.size(0); \
auto D = PSEUDO.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.options()); \
\
AT_DISPATCH_FLOATING_TYPES( \
GRAD_BASIS.scalar_type(), "basis_backward_##M", [&] { \
KERNEL_NAME<scalar_t><<<BLOCKS(grad_pseudo.numel()), THREADS>>>( \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_pseudo), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(GRAD_BASIS), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.DATA_PTR<int64_t>(), \
IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), grad_pseudo.numel()); \
}); \
\
return grad_pseudo; \
}()
#define BASIS_BACKWARD_KERNEL(M, GRAD_PSEUDO, GRAD_BASIS, PSEUDO, KERNEL_SIZE, \
IS_OPEN_SPLINE, NUMEL, CODE, GRAD_CODE) \
[&] { \
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; \
const size_t stride = blockDim.x * gridDim.x; \
for (ptrdiff_t i = index; i < NUMEL; i += stride) { \
int64_t e = i / GRAD_PSEUDO.sizes[1], d = i % GRAD_PSEUDO.sizes[1]; \
scalar_t g = 0, tmp; \
\
for (ptrdiff_t s = 0; s < GRAD_BASIS.sizes[1]; s++) { \
auto k_mod = (s / (int64_t)(powf(M + 1, d) + 0.5)) % (M + 1); \
auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
v -= floor(v); \
v = GRAD_CODE; \
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < GRAD_PSEUDO.sizes[1]; d_it++) { \
auto d_new = d_it - (d >= d_it); \
k_mod = (s / (int64_t)(powf(M + 1, d_new) + 0.5)) % (M + 1); \
v = PSEUDO.data[e * pseudo.strides[0] + d_new * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d_new] - M * IS_OPEN_SPLINE[d_new]; \
v -= floor(v); \
v = CODE; \
tmp *= v; \
} \
g += tmp * \
GRAD_BASIS \
.data[e * GRAD_BASIS.strides[0] + s * GRAD_BASIS.strides[1]]; \
} \
g *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
GRAD_PSEUDO.data[i] = g; \
} \
}()
template <typename scalar_t>
__global__ void
linear_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
BASIS_BACKWARD_KERNEL(1, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::linear(v, k_mod),
BasisBackward<scalar_t>::linear(v, k_mod));
}
at::Tensor linear_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(1, grad_basis, pseudo, kernel_size, is_open_spline,
linear_bw_kernel);
}
template <typename scalar_t>
__global__ void
quadratic_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline,
size_t numel) {
BASIS_BACKWARD_KERNEL(2, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::quadratic(v, k_mod),
BasisBackward<scalar_t>::quadratic(v, k_mod));
}
at::Tensor quadratic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_BACKWARD(2, grad_basis, pseudo, kernel_size, is_open_spline,
quadratic_bw_kernel);
}
template <typename scalar_t>
__global__ void
cubic_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
BASIS_BACKWARD_KERNEL(3, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline, numel,
BasisForward<scalar_t>::cubic(v, k_mod),
BasisBackward<scalar_t>::cubic(v, k_mod));
}
at::Tensor cubic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(3, grad_basis, pseudo, kernel_size, is_open_spline,
cubic_bw_kernel);
}
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index);
at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index);
at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor basis, at::Tensor weight_index,
int64_t K);
at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor weight, at::Tensor weight_index);
at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
return weighting_fw_cuda(x, weight, basis, weight_index);
}
at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
CHECK_CUDA(grad_out);
CHECK_CUDA(weight);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
return weighting_bw_x_cuda(grad_out, weight, basis, weight_index);
}
at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
at::Tensor weight_index, int64_t K) {
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
return weighting_bw_w_cuda(grad_out, x, basis, weight_index, K);
}
at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
at::Tensor weight_index) {
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(weight_index);
return weighting_bw_b_cuda(grad_out, x, weight, weight_index);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("weighting_fw", &weighting_fw, "Weighting Forward (CUDA)");
m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CUDA)");
m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward Weight (CUDA)");
m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward Basis (CUDA)");
}
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void
weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / out.sizes[1], m_out = i % out.sizes[1];
auto S = basis.sizes[1];
scalar_t v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis.data[e * S + s];
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < x.sizes[1]; m_in++) {
auto tmp =
weight.data[wi * weight.strides[0] + m_in * weight.strides[1] +
m_out * weight.strides[2]];
tmp *= b * x.data[e * x.strides[0] + m_in * x.strides[1]];
v += tmp;
}
}
out.data[i] = v;
}
}
at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
cudaSetDevice(x.get_device());
auto E = x.size(0), M_out = weight.size(2);
auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "weighting_fw", [&] {
weighting_fw_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
out.numel());
});
return out;
}
template <typename scalar_t>
__global__ void weighting_bw_x_kernel(
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / grad_x.sizes[1], m_in = i % grad_x.sizes[1];
auto S = basis.sizes[1];
scalar_t v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis.data[e * S + s];
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_out = 0; m_out < grad_out.sizes[1]; m_out++) {
auto tmp =
weight.data[wi * weight.strides[0] + m_out * weight.strides[1] +
m_in * weight.strides[2]];
tmp *= b *
grad_out
.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
v += tmp;
}
}
grad_x.data[i] = v;
}
}
at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0), M_in = weight.size(1);
auto grad_x = at::empty({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous();
AT_DISPATCH_FLOATING_TYPES(grad_x.scalar_type(), "weighting_bw_x", [&] {
weighting_bw_x_kernel<scalar_t><<<BLOCKS(grad_x.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
grad_x.numel());
});
return grad_x;
}
template <typename scalar_t>
__global__ void weighting_bw_w_kernel(
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_weight,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / grad_out.sizes[1], m_out = i % grad_out.sizes[1];
int64_t S = basis.sizes[1], M_in = x.sizes[1], M_out = grad_out.sizes[1];
auto g =
grad_out.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis.data[e * S + s];
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x.data[e * x.strides[0] + m_in * x.strides[1]];
atomicAdd(&grad_weight.data[wi * M_in * M_out + m_in * M_out + m_out],
v);
}
}
}
}
at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor basis, at::Tensor weight_index,
int64_t K) {
cudaSetDevice(grad_out.get_device());
auto M_in = x.size(1), M_out = grad_out.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_w", [&] {
weighting_bw_w_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_weight),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
grad_out.numel());
});
return grad_weight;
}
template <typename scalar_t>
__global__ void weighting_bw_b_kernel(
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / grad_out.sizes[1], m_out = i % grad_out.sizes[1];
auto S = grad_basis.sizes[1];
auto g =
grad_out.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
for (ptrdiff_t s = 0; s < S; s++) {
scalar_t v = 0;
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < x.sizes[1]; m_in++) {
auto w = weight.data[wi * weight.strides[0] + m_in * weight.strides[1] +
m_out * weight.strides[2]];
v += g * w * x.data[e * x.strides[0] + m_in * x.strides[1]];
}
atomicAdd(&grad_basis.data[e * S + s], v);
}
}
}
at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor weight, at::Tensor weight_index) {
cudaSetDevice(grad_out.get_device());
auto E = x.size(0), S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_b", [&] {
weighting_bw_b_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_basis),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
grad_out.numel());
});
return grad_basis;
}
#!/bin/bash
if [ "${TRAVIS_OS_NAME}" = "linux" ]; then
wget -nv https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
chmod +x miniconda.sh
./miniconda.sh -b
PATH=/home/travis/miniconda3/bin:${PATH}
fi
if [ "${TRAVIS_OS_NAME}" = "osx" ]; then
wget -nv https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh
chmod +x miniconda.sh
./miniconda.sh -b
PATH=/Users/travis/miniconda3/bin:${PATH}
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ]; then
choco install openssl.light
choco install miniconda3
PATH=/c/tools/miniconda3/Scripts:$PATH
fi
conda update --yes conda
conda create --yes -n test python="${PYTHON_VERSION}"
#!/bin/bash
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cpu" ]; then
export TOOLKIT=cpuonly
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cu92" ]; then
export CUDA_SHORT=9.2
export CUDA=9.2.148-1
export UBUNTU_VERSION=ubuntu1604
export CUBLAS=cuda-cublas-dev-9-2
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cu100" ]; then
export CUDA_SHORT=10.0
export CUDA=10.0.130-1
export UBUNTU_VERSION=ubuntu1804
export CUBLAS=cuda-cublas-dev-10-0
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cu101" ]; then
export IDX=cu101
export CUDA_SHORT=10.1
export CUDA=10.1.105-1
export UBUNTU_VERSION=ubuntu1804
export CUBLAS=libcublas-dev
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cpu" ]; then
export TOOLKIT=cpuonly
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cu92" ]; then
export CUDA_SHORT=9.2
export CUDA_URL=https://developer.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod2/local_installers2
export CUDA_FILE=cuda_${CUDA_SHORT}.148_win10
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cu100" ]; then
export CUDA_SHORT=10.0
export CUDA_URL=https://developer.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod/local_installers
export CUDA_FILE=cuda_${CUDA_SHORT}.130_411.31_win10
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cu101" ]; then
export CUDA_SHORT=10.1
export CUDA_URL=https://developer.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod/local_installers
export CUDA_FILE=cuda_${CUDA_SHORT}.105_418.96_win10.exe
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "osx" ] && [ "$IDX" = "cpu" ]; then
export TOOLKIT=""
fi
if [ "${IDX}" = "cpu" ]; then
export FORCE_CPU=1
else
export FORCE_CUDA=1
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "${IDX}" != "cpu" ]; then
INSTALLER=cuda-repo-${UBUNTU_VERSION}_${CUDA}_amd64.deb
wget -nv "http://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/${INSTALLER}"
sudo dpkg -i "${INSTALLER}"
wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/7fa2af80.pub"
sudo apt-key add 7fa2af80.pub
sudo apt update -qq
sudo apt install -y "cuda-core-${CUDA_SHORT/./-}" "cuda-cudart-dev-${CUDA_SHORT/./-}" "${CUBLAS}" "cuda-cusparse-dev-${CUDA_SHORT/./-}"
sudo apt clean
CUDA_HOME=/usr/local/cuda-${CUDA_SHORT}
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
PATH=${CUDA_HOME}/bin:${PATH}
nvcc --version
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${IDX}" != "cpu" ]; then
wget -nv "${CUDA_URL}/${CUDA_FILE}"
PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT}\" -Wait -NoNewWindow"
CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v${CUDA_SHORT}
PATH=${CUDA_HOME}/bin:$PATH
PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH
nvcc --version
fi
# Fix Cuda9.2 on Windows: https://github.com/pytorch/pytorch/issues/6109
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${IDX}" = "cu92" ]; then
sed -i.bak -e '129,141d' "${CUDA_HOME}/include/crt/host_config.h"
fi
import sys
import os
import os.path as osp
import glob
import shutil
idx = sys.argv[1]
assert idx in ['cpu', 'cu92', 'cu100', 'cu101']
dist_dir = osp.join(osp.dirname(osp.abspath(__file__)), '..', 'dist')
wheels = glob.glob(osp.join('dist', '**', '*.whl'), recursive=True)
for wheel in wheels:
if idx in wheel:
continue
paths = wheel.split(osp.sep)
names = paths[-1].split('-')
name = '-'.join(names[:-4] + ['latest+' + idx] + names[-3:])
shutil.copyfile(wheel, osp.join(*paths[:-1], name))
name = '-'.join(names[:-4] + [names[-4] + '+' + idx] + names[-3:])
os.rename(wheel, osp.join(*paths[:-1], name))
#!/bin/bash
# Fix "member may not be initialized" error on Windows: https://github.com/pytorch/pytorch/issues/27958
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${IDX}" != "cpu" ]; then
sed -i.bak -e 's/constexpr/const/g' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/torch/csrc/jit/script/module.h
sed -i.bak -e 's/constexpr/const/g' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/torch/csrc/jit/argument_spec.h
sed -i.bak -e 's/return \*(this->value)/return \*((type\*)this->value)/g' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/pybind11/cast.h
fi
import os
import os.path as osp
import glob
from setuptools import setup, find_packages from setuptools import setup, find_packages
from sys import argv
import torch import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
TORCH_MAJOR = int(torch.__version__.split('.')[0]) WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
TORCH_MINOR = int(torch.__version__.split('.')[1]) if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True
if os.getenv('FORCE_CPU', '0') == '1':
WITH_CUDA = False
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
Extension = CppExtension
define_macros = []
extra_compile_args = {'cxx': []}
if WITH_CUDA:
Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files:
name = main.split(os.sep)[-1][:-4]
extra_compile_args = [] sources = [main]
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [ path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
CppExtension('torch_spline_conv.basis_cpu', ['cpu/basis.cpp'], if osp.exists(path):
extra_compile_args=extra_compile_args), sources += [path]
CppExtension('torch_spline_conv.weighting_cpu', ['cpu/weighting.cpp'],
extra_compile_args=extra_compile_args),
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
GPU = True path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
for arg in argv: if WITH_CUDA and osp.exists(path):
if arg == '--cpu': sources += [path]
GPU = False
argv.remove(arg)
if CUDA_HOME is not None and GPU: extension = Extension(
ext_modules += [ 'torch_spline_conv._' + name,
CUDAExtension('torch_spline_conv.basis_cuda', sources,
['cuda/basis.cpp', 'cuda/basis_kernel.cu'], include_dirs=[extensions_dir],
extra_compile_args=extra_compile_args), define_macros=define_macros,
CUDAExtension('torch_spline_conv.weighting_cuda', extra_compile_args=extra_compile_args,
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu'], )
extra_compile_args=extra_compile_args), extensions += [extension]
]
return extensions
__version__ = '1.1.1'
url = 'https://github.com/rusty1s/pytorch_spline_conv'
install_requires = [] install_requires = []
setup_requires = ['pytest-runner'] setup_requires = ['pytest-runner']
...@@ -43,23 +63,26 @@ tests_require = ['pytest', 'pytest-cov'] ...@@ -43,23 +63,26 @@ tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_spline_conv', name='torch_spline_conv',
version=__version__, version='1.2.0',
description=('Implementation of the Spline-Based Convolution Operator of '
'SplineCNN in PyTorch'),
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url=url, url='https://github.com/rusty1s/pytorch_spline_conv',
download_url='{}/archive/{}.tar.gz'.format(url, __version__), description=('Implementation of the Spline-Based Convolution Operator of '
'SplineCNN in PyTorch'),
keywords=[ keywords=[
'pytorch', 'pytorch',
'geometric-deep-learning', 'geometric-deep-learning',
'graph-neural-networks', 'graph-neural-networks',
'spline-cnn', 'spline-cnn',
], ],
license='MIT',
python_requires='>=3.6',
install_requires=install_requires, install_requires=install_requires,
setup_requires=setup_requires, setup_requires=setup_requires,
tests_require=tests_require, tests_require=tests_require,
ext_modules=ext_modules, ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass=cmdclass, cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
},
packages=find_packages(), packages=find_packages(),
) )
...@@ -2,7 +2,7 @@ from itertools import product ...@@ -2,7 +2,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_spline_conv.basis import SplineBasis from torch_spline_conv import spline_basis
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
...@@ -34,7 +34,7 @@ def test_spline_basis_forward(test, dtype, device): ...@@ -34,7 +34,7 @@ def test_spline_basis_forward(test, dtype, device):
is_open_spline = tensor(test['is_open_spline'], torch.uint8, device) is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
degree = 1 degree = 1
op = SplineBasis.apply basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline,
basis, weight_index = op(pseudo, kernel_size, is_open_spline, degree) degree)
assert basis.tolist() == test['basis'] assert basis.tolist() == test['basis']
assert weight_index.tolist() == test['weight_index'] assert weight_index.tolist() == test['weight_index']
...@@ -3,11 +3,12 @@ from itertools import product ...@@ -3,11 +3,12 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_spline_conv import SplineConv from torch_spline_conv import spline_conv
from torch_spline_conv.basis import implemented_degrees as degrees
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
degrees = [1, 2, 3]
tests = [{ tests = [{
'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]], 'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]], 'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]],
...@@ -51,12 +52,12 @@ def test_spline_conv_forward(test, dtype, device): ...@@ -51,12 +52,12 @@ def test_spline_conv_forward(test, dtype, device):
root_weight = tensor(test['root_weight'], dtype, device) root_weight = tensor(test['root_weight'], dtype, device)
bias = tensor(test['bias'], dtype, device) bias = tensor(test['bias'], dtype, device)
out = SplineConv.apply(x, edge_index, pseudo, weight, kernel_size, out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, True, root_weight, bias) is_open_spline, 1, True, root_weight, bias)
assert out.tolist() == test['expected'] assert out.tolist() == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices)) @pytest.mark.parametrize('degree,device', product(degrees, devices))
def test_spline_basis_backward(degree, device): def test_spline_basis_backward(degree, device):
x = torch.rand((3, 2), dtype=torch.double, device=device) x = torch.rand((3, 2), dtype=torch.double, device=device)
x.requires_grad_() x.requires_grad_()
...@@ -74,4 +75,4 @@ def test_spline_basis_backward(degree, device): ...@@ -74,4 +75,4 @@ def test_spline_basis_backward(degree, device):
data = (x, edge_index, pseudo, weight, kernel_size, is_open_spline, degree, data = (x, edge_index, pseudo, weight, kernel_size, is_open_spline, degree,
True, root_weight, bias) True, root_weight, bias)
assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True assert gradcheck(spline_conv, data, eps=1e-6, atol=1e-4) is True
...@@ -3,8 +3,7 @@ from itertools import product ...@@ -3,8 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_spline_conv.weighting import SplineWeighting from torch_spline_conv import spline_basis, spline_weighting
from torch_spline_conv.basis import SplineBasis
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
...@@ -27,7 +26,7 @@ def test_spline_weighting_forward(test, dtype, device): ...@@ -27,7 +26,7 @@ def test_spline_weighting_forward(test, dtype, device):
basis = tensor(test['basis'], dtype, device) basis = tensor(test['basis'], dtype, device)
weight_index = tensor(test['weight_index'], torch.long, device) weight_index = tensor(test['weight_index'], torch.long, device)
out = SplineWeighting.apply(x, weight, basis, weight_index) out = spline_weighting(x, weight, basis, weight_index)
assert out.tolist() == test['expected'] assert out.tolist() == test['expected']
...@@ -38,8 +37,8 @@ def test_spline_weighting_backward(device): ...@@ -38,8 +37,8 @@ def test_spline_weighting_backward(device):
is_open_spline = tensor([1, 1], torch.uint8, device) is_open_spline = tensor([1, 1], torch.uint8, device)
degree = 1 degree = 1
op = SplineBasis.apply basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline,
basis, weight_index = op(pseudo, kernel_size, is_open_spline, degree) degree)
basis.requires_grad_() basis.requires_grad_()
x = torch.rand((4, 2), dtype=torch.double, device=device) x = torch.rand((4, 2), dtype=torch.double, device=device)
...@@ -48,4 +47,4 @@ def test_spline_weighting_backward(device): ...@@ -48,4 +47,4 @@ def test_spline_weighting_backward(device):
weight.requires_grad_() weight.requires_grad_()
data = (x, weight, basis, weight_index) data = (x, weight, basis, weight_index)
assert gradcheck(SplineWeighting.apply, data, eps=1e-6, atol=1e-4) is True assert gradcheck(spline_weighting, data, eps=1e-6, atol=1e-4) is True
...@@ -4,7 +4,7 @@ dtypes = [torch.float, torch.double] ...@@ -4,7 +4,7 @@ dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))] devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
def tensor(x, dtype, device): def tensor(x, dtype, device):
......
from .basis import SplineBasis import importlib
from .weighting import SplineWeighting import os.path as osp
from .conv import SplineConv
__version__ = '1.1.1' import torch
__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__'] __version__ = '1.2.0'
expected_torch_version = (1, 4)
try:
for library in ['_version', '_basis', '_weighting']:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
except OSError as e:
major, minor = [int(x) for x in torch.__version__.split('.')[:2]]
t_major, t_minor = expected_torch_version
if major != t_major or (major == t_major and minor != t_minor):
raise RuntimeError(
f'Expected PyTorch version {t_major}.{t_minor} but found '
f'version {major}.{minor}.')
raise OSError(e)
if torch.version.cuda is not None: # pragma: no cover
cuda_version = torch.ops.torch_spline_conv.cuda_version()
if cuda_version == -1:
major = minor = 0
elif cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else:
major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]
if t_major != major or t_minor != minor:
raise RuntimeError(
f'Detected that PyTorch and torch_spline_conv were compiled with '
f'different CUDA versions. PyTorch has CUDA version '
f'{t_major}.{t_minor} and torch_spline_conv has CUDA version '
f'{major}.{minor}. Please reinstall the torch_spline_conv that '
f'matches your PyTorch install.')
from .basis import spline_basis # noqa
from .weighting import spline_weighting # noqa
from .conv import spline_conv # noqa
__all__ = [
'spline_basis',
'spline_weighting',
'spline_conv',
'__version__',
]
import torch from typing import Tuple
import torch_spline_conv.basis_cpu
if torch.cuda.is_available():
import torch_spline_conv.basis_cuda
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
def get_func(name, tensor):
if tensor.is_cuda:
return getattr(torch_spline_conv.basis_cuda, name)
else:
return getattr(torch_spline_conv.basis_cpu, name)
import torch
class SplineBasis(torch.autograd.Function):
@staticmethod
def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
ctx.save_for_backward(pseudo)
ctx.kernel_size = kernel_size
ctx.is_open_spline = is_open_spline
ctx.degree = degree
op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
basis, weight_index = op(pseudo, kernel_size, is_open_spline)
return basis, weight_index
@staticmethod
def backward(ctx, grad_basis, grad_weight_index):
pseudo, = ctx.saved_tensors
kernel_size, is_open_spline = ctx.kernel_size, ctx.is_open_spline
degree = ctx.degree
grad_pseudo = None
if ctx.needs_input_grad[0]:
op = get_func('{}_bw'.format(implemented_degrees[degree]), pseudo)
grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
return grad_pseudo, None, None, None @torch.jit.script
def spline_basis(pseudo: torch.Tensor, kernel_size: torch.Tensor,
is_open_spline: torch.Tensor,
degree: int) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_spline_conv.spline_basis(pseudo, kernel_size,
is_open_spline, degree)
import torch from typing import Optional
from .basis import SplineBasis import torch
from .weighting import SplineWeighting
from .utils.degree import degree as node_degree from .basis import spline_basis
from .weighting import spline_weighting
class SplineConv(object): @torch.jit.script
def spline_conv(x: torch.Tensor, edge_index: torch.Tensor,
pseudo: torch.Tensor, weight: torch.Tensor,
kernel_size: torch.Tensor, is_open_spline: torch.Tensor,
degree: int = 1, norm: bool = True,
root_weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""Applies the spline-based convolution operator :math:`(f \star g)(i) = r"""Applies the spline-based convolution operator :math:`(f \star g)(i) =
\frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)} \frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph. f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
...@@ -38,37 +44,35 @@ class SplineConv(object): ...@@ -38,37 +44,35 @@ class SplineConv(object):
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
@staticmethod
def apply(x, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree=1, norm=True, root_weight=None, bias=None):
x = x.unsqueeze(-1) if x.dim() == 1 else x x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index[0], edge_index[1]
N, E, M_out = x.size(0), row.size(0), weight.size(2)
row, col = edge_index # Weight each node.
n, m_out = x.size(0), weight.size(2) basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline,
degree)
# Weight each node. out = spline_weighting(x[col], weight, basis, weight_index)
basis, weight_index = SplineBasis.apply(pseudo, kernel_size,
is_open_spline, degree)
weight_index = weight_index.detach()
out = SplineWeighting.apply(x[col], weight, basis, weight_index)
# Convert e x m_out to n x m_out features. # Convert E x M_out to N x M_out features.
row_expand = row.unsqueeze(-1).expand_as(out) row_expanded = row.unsqueeze(-1).expand_as(out)
out = x.new_zeros((n, m_out)).scatter_add_(0, row_expand, out) out = x.new_zeros((N, M_out)).scatter_add_(0, row_expanded, out)
# Normalize out by node degree (if wished). # Normalize out by node degree (if wished).
if norm: if norm:
deg = node_degree(row, n, out.dtype, out.device) ones = torch.ones(E, dtype=x.dtype, device=x.device)
out = out / deg.unsqueeze(-1).clamp(min=1) deg = out.new_zeros(N).scatter_add_(0, row, ones)
out = out / deg.unsqueeze(-1).clamp_(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
out = out + torch.mm(x, root_weight) out = out + torch.matmul(x, root_weight)
# Add bias (if wished). # Add bias (if wished).
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment