Commit 2862a818 authored by rusty1s's avatar rusty1s
Browse files

multi gpu update

parent 46fe125c
......@@ -33,6 +33,7 @@ template <typename scalar_t> struct BasisForward {
#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, KERNEL_NAME) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
cudaSetDevice(PSEUDO.get_device()); \
auto E = PSEUDO.size(0); \
auto S = (int64_t)(powf(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.options()); \
......@@ -163,6 +164,7 @@ template <typename scalar_t> struct BasisBackward {
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
KERNEL_NAME) \
[&]() -> at::Tensor { \
cudaSetDevice(GRAD_BASIS.get_device()); \
auto E = PSEUDO.size(0); \
auto D = PSEUDO.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.options()); \
......
......@@ -39,6 +39,7 @@ weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
cudaSetDevice(x.get_device());
auto E = x.size(0), M_out = weight.size(2);
auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] {
......@@ -86,6 +87,7 @@ __global__ void weighting_bw_x_kernel(
at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0), M_in = weight.size(1);
auto grad_x = at::empty({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous();
......@@ -131,6 +133,7 @@ __global__ void weighting_bw_w_kernel(
at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor basis, at::Tensor weight_index,
int64_t K) {
cudaSetDevice(grad_out.get_device());
auto M_in = x.size(1), M_out = grad_out.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] {
......@@ -175,6 +178,7 @@ __global__ void weighting_bw_b_kernel(
at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor weight, at::Tensor weight_index) {
cudaSetDevice(grad_out.get_device());
auto E = x.size(0), S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] {
......
......@@ -16,7 +16,7 @@ if CUDA_HOME is not None:
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu']),
]
__version__ = '1.0.5'
__version__ = '1.0.6'
url = 'https://github.com/rusty1s/pytorch_spline_conv'
install_requires = []
......
......@@ -2,6 +2,6 @@ from .basis import SplineBasis
from .weighting import SplineWeighting
from .conv import SplineConv
__version__ = '1.0.5'
__version__ = '1.0.6'
__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment