Commit b46459f4 authored by rusty1s's avatar rusty1s
Browse files

cuda kernels

parent 57f5a26e
...@@ -18,7 +18,7 @@ before_install: ...@@ -18,7 +18,7 @@ before_install:
- export CXX="g++-4.9" - export CXX="g++-4.9"
install: install:
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cpu/torch.4.1-cp35-cp35m-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl; fi
- pip install pycodestyle - pip install pycodestyle
- pip install flake8 - pip install flake8
......
...@@ -142,12 +142,11 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) { ...@@ -142,12 +142,11 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
tmp = v; \ tmp = v; \
\ \
for (ptrdiff_t d_it = 1; d_it < D; d_it++) { \ for (ptrdiff_t d_it = 1; d_it < D; d_it++) { \
auto d_other = d_it - (d >= d_it); \ auto d_new = d_it - (d >= d_it); \
k_mod = (s / (int64_t)(pow(M + 1, d_other) + 0.5)) % (M + 1); \ k_mod = (s / (int64_t)(pow(M + 1, d_new) + 0.5)) % (M + 1); \
v = pseudo_data[e * pseudo.stride(0) + \ v = pseudo_data[e * pseudo.stride(0) + \
d_other * pseudo.stride(1)]; \ d_new * pseudo.stride(1)]; \
v *= kernel_size_data[d_other] - \ v *= kernel_size_data[d_new] - M * is_open_spline_data[d_new]; \
M * is_open_spline_data[d_other]; \
v -= floor(v); \ v -= floor(v); \
v = FUNC<scalar_t>(v, k_mod); \ v = FUNC<scalar_t>(v, k_mod); \
tmp *= v; \ tmp *= v; \
......
...@@ -142,6 +142,6 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight, ...@@ -142,6 +142,6 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("weighting_fw", &weighting_fw, "Weighting Forward (CPU)"); m.def("weighting_fw", &weighting_fw, "Weighting Forward (CPU)");
m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CPU)"); m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CPU)");
m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward W (CPU)"); m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward Weight (CPU)");
m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward B (CPU)"); m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward Basis (CPU)");
} }
#pragma once
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomicAdd(double *address, double val) {
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
}
#endif
...@@ -185,6 +185,32 @@ template <typename scalar_t> struct BasisBackward { ...@@ -185,6 +185,32 @@ template <typename scalar_t> struct BasisBackward {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; \ const size_t index = blockIdx.x * blockDim.x + threadIdx.x; \
const size_t stride = blockDim.x * gridDim.x; \ const size_t stride = blockDim.x * gridDim.x; \
for (ptrdiff_t i = index; i < NUMEL; i += stride) { \ for (ptrdiff_t i = index; i < NUMEL; i += stride) { \
int64_t e = i / GRAD_PSEUDO.sizes[1], d = i % GRAD_PSEUDO.sizes[1]; \
scalar_t g = 0, tmp; \
\
for (ptrdiff_t s = 0; s < GRAD_BASIS.sizes[1]; s++) { \
auto k_mod = (s / (int64_t)(pow(M + 1, d) + 0.5)) % (M + 1); \
auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
v -= floor(v); \
v = CODE; \
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < GRAD_PSEUDO.sizes[1]; d_it++) { \
auto d_new = d_it - (d >= d_it); \
k_mod = (s / (int64_t)(pow(M + 1, d_new) + 0.5)) % (M + 1); \
v = PSEUDO.data[e * pseudo.strides[0] + d_new * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d_new] - M * IS_OPEN_SPLINE[d_new]; \
v -= floor(v); \
v = GRAD_CODE; \
tmp *= v; \
} \
g += tmp * \
GRAD_BASIS \
.data[e * GRAD_BASIS.strides[0] + s * GRAD_BASIS.strides[1]]; \
} \
g *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
GRAD_PSEUDO.data[e * GRAD_PSEUDO.sizes[1] + d] = g; \
} \ } \
}() }()
......
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index);
at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index);
at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor basis, at::Tensor weight_index,
int64_t K);
at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor weight, at::Tensor weight_index);
at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
return weighting_fw_cuda(x, weight, basis, weight_index);
}
at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
CHECK_CUDA(grad_out);
CHECK_CUDA(weight);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
return weighting_bw_x_cuda(grad_out, weight, basis, weight_index);
}
at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
at::Tensor weight_index, int64_t K) {
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
return weighting_bw_w_cuda(grad_out, x, basis, weight_index, K);
}
at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
at::Tensor weight_index) {
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(weight_index);
return weighting_bw_b_cuda(grad_out, x, weight, weight_index);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("weighting_fw", &weighting_fw, "Weighting Forward (CUDA)");
m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CUDA)");
m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward Weight (CUDA)");
m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward Basis (CUDA)");
}
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "atomics.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void
weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / out.sizes[1], m_out = i % out.sizes[1];
auto S = basis.sizes[1];
scalar_t v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis.data[e * S + s];
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < x.sizes[1]; m_in++) {
auto tmp =
weight.data[wi * weight.strides[0] + m_in * weight.strides[1] +
m_out * weight.strides[2]];
tmp *= b * x.data[e * x.strides[0] + m_in * x.strides[1]];
v += tmp;
}
}
out.data[e * out.sizes[1] + m_out] = v;
}
}
at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
auto E = x.size(0), M_out = weight.size(2);
auto out = at::empty({E, M_out}, x.type());
AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] {
weighting_fw_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
out.numel());
});
return out;
}
template <typename scalar_t>
__global__ void weighting_bw_x_kernel(
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / grad_x.sizes[1], m_in = i % grad_x.sizes[1];
auto S = basis.sizes[1];
scalar_t v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis.data[e * S + s];
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_out = 0; m_out < grad_out.sizes[1]; m_out++) {
auto tmp =
weight.data[wi * weight.strides[0] + m_out * weight.strides[1] +
m_in * weight.strides[2]];
tmp *= b *
grad_out
.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
v += tmp;
}
}
grad_x.data[e * grad_x.sizes[1] + m_in] = v;
}
}
at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = weight.size(1);
auto grad_x = at::empty({E, M_in}, grad_out.type());
weight = weight.transpose(1, 2).contiguous();
AT_DISPATCH_FLOATING_TYPES(grad_x.type(), "weighting_bw_x", [&] {
weighting_bw_x_kernel<scalar_t><<<BLOCKS(grad_x.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
grad_x.numel());
});
return grad_x;
}
template <typename scalar_t>
__global__ void weighting_bw_w_kernel(
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_weight,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / grad_out.sizes[1], m_out = i % grad_out.sizes[1];
int64_t S = basis.sizes[1], M_in = x.sizes[1], M_out = grad_out.sizes[1];
auto g =
grad_out.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis.data[e * S + s];
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x.data[e * x.strides[0] + m_in * x.strides[1]];
atomicAdd(&grad_weight.data[wi * M_in * M_out + m_in * M_out + m_out],
v);
}
}
}
}
at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor basis, at::Tensor weight_index,
int64_t K) {
auto M_in = x.size(1), M_out = grad_out.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.type());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] {
weighting_bw_w_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_weight),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
grad_out.numel());
});
return grad_weight;
}
template <typename scalar_t>
__global__ void weighting_bw_b_kernel(
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / grad_out.sizes[1], m_out = i % grad_out.sizes[1];
auto S = grad_basis.sizes[1];
auto g =
grad_out.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
for (ptrdiff_t s = 0; s < S; s++) {
scalar_t v = 0;
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < x.sizes[1]; m_in++) {
auto w = weight.data[wi * weight.strides[0] + m_in * weight.strides[1] +
m_out * weight.strides[2]];
v += g * w * x.data[e * x.strides[0] + m_in * x.strides[1]];
}
atomicAdd(&grad_basis.data[e * S + s], v);
}
}
}
at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor weight, at::Tensor weight_index) {
auto E = x.size(0), S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.type());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] {
weighting_bw_b_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_basis),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
grad_out.numel());
});
return grad_basis;
}
...@@ -10,7 +10,10 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} ...@@ -10,7 +10,10 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if torch.cuda.is_available(): if torch.cuda.is_available():
ext_modules += [ ext_modules += [
CUDAExtension('basis_cuda', ['cuda/basis.cpp', 'cuda/basis_kernel.cu']) CUDAExtension('basis_cuda',
['cuda/basis.cpp', 'cuda/basis_kernel.cu']),
CUDAExtension('weighting_cuda',
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu']),
] ]
__version__ = '1.0.4' __version__ = '1.0.4'
......
...@@ -7,6 +7,7 @@ from torch_spline_conv import SplineConv ...@@ -7,6 +7,7 @@ from torch_spline_conv import SplineConv
from torch_spline_conv.basis import implemented_degrees as degrees from torch_spline_conv.basis import implemented_degrees as degrees
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
tests = [{ tests = [{
'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]], 'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
......
...@@ -7,6 +7,7 @@ from torch_spline_conv.weighting import SplineWeighting ...@@ -7,6 +7,7 @@ from torch_spline_conv.weighting import SplineWeighting
from torch_spline_conv.basis import SplineBasis from torch_spline_conv.basis import SplineBasis
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cuda')]
tests = [{ tests = [{
'x': [[1, 2], [3, 4]], 'x': [[1, 2], [3, 4]],
......
import torch import torch
import weighting_cpu import weighting_cpu
if torch.cuda.is_available():
import weighting_cuda
def get_func(name, tensor): def get_func(name, tensor):
# module = weighting_cuda if tensor.is_cuda else weighting_cpu module = weighting_cuda if tensor.is_cuda else weighting_cpu
module = weighting_cpu
return getattr(module, name) return getattr(module, name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment