Commit 8e464c16 authored by rusty1s's avatar rusty1s
Browse files

complete spline cuda

parent 751593df
...@@ -19,7 +19,7 @@ spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight, ...@@ -19,7 +19,7 @@ spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight,
const int64_t m_out = thread_idx % M_out; const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) { if (thread_idx < numel) {
scalar_t v = 0; scalar_t v = (scalar_t)0.;
for (ptrdiff_t s = 0; s < S; s++) { for (ptrdiff_t s = 0; s < S; s++) {
const scalar_t b = basis[e * S + s]; const scalar_t b = basis[e * S + s];
...@@ -116,6 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out, ...@@ -116,6 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
auto S = basis.size(1); auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options()); auto grad_x = at::zeros({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous();
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
...@@ -135,17 +136,128 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out, ...@@ -135,17 +136,128 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
return grad_x; return grad_x;
} }
template <typename scalar_t>
spline_weighting_bw_weight_kernel(const scalar_t *grad_out, const scalar_t *x,
const scalar_t *basis,
const int64_t *weight_index, scalar_t *grad_x,
int64_t E, int64_t M_in, int64_t M_out,
int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out;
const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) {
auto g = grad_out[e * M_out + m_out];
for (int64_t s = 0; s < S; s++) {
const scalar_t b = basis[e * S + s];
const int64_t wi = weight_index[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x[e * M_in + m_in];
atomicAdd(&grad_weight[wi * M_in * M_out + m_in * M_out + m_out], v);
}
}
}
}
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out, torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x, torch::Tensor x,
torch::Tensor basis, torch::Tensor basis,
torch::Tensor weight_index, torch::Tensor weight_index,
int64_t kernel_size) { int64_t kernel_size) {
return grad_out; CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
spline_weighting_bw_weight_kernel<scalar_t>
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
grad_out_data, x_data, basis_data, weight_index_data,
grad_weight_data, E, M_in, M_out, S, grad_out.numel());
});
return grad_weight;
}
template <typename scalar_t>
spline_weighting_bw_basis_kernel(const scalar_t *grad_out, const scalar_t *x,
const scalar_t *weight,
const int64_t *weight_index,
scalar_t *grad_basis, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = i / M_out;
const int64_t m_out = i % M_out;
if (thread_idx < numel) {
const scalar_t g = grad_out[e * M_out + m_out];
for (int64_t s = 0; s < S; s++) {
scalar_t v = (scalar_t)0.;
const int64_t wi = weight_index[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
const scalar_t w = weight[wi * M_in * M_out + m_in * M_out + m_out];
v += g * w * x[e * M_in + m_in];
}
atomicAdd(&grad_basis[e * S + s], v);
}
}
} }
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out, torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x, torch::Tensor x,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor weight_index) { torch::Tensor weight_index) {
return grad_out; CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(weight_index);
cudaSetDevice(grad_out.get_device());
CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
spline_weighting_bw_basis_kernel<scalar_t>
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
grad_out_data, x_data, weight_data, weight_index_data,
grad_basis_data, E, M_in, M_out, S, grad_out.numel());
});
return grad_basis;
} }
...@@ -5,102 +5,6 @@ ...@@ -5,102 +5,6 @@
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void
weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / out.sizes[1], m_out = i % out.sizes[1];
auto S = basis.sizes[1];
scalar_t v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis.data[e * S + s];
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < x.sizes[1]; m_in++) {
auto tmp =
weight.data[wi * weight.strides[0] + m_in * weight.strides[1] +
m_out * weight.strides[2]];
tmp *= b * x.data[e * x.strides[0] + m_in * x.strides[1]];
v += tmp;
}
}
out.data[i] = v;
}
}
at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
cudaSetDevice(x.get_device());
auto E = x.size(0), M_out = weight.size(2);
auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "weighting_fw", [&] {
weighting_fw_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
out.numel());
});
return out;
}
template <typename scalar_t>
__global__ void weighting_bw_x_kernel(
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_x,
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t e = i / grad_x.sizes[1], m_in = i % grad_x.sizes[1];
auto S = basis.sizes[1];
scalar_t v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis.data[e * S + s];
auto wi = weight_index.data[e * S + s];
for (ptrdiff_t m_out = 0; m_out < grad_out.sizes[1]; m_out++) {
auto tmp =
weight.data[wi * weight.strides[0] + m_out * weight.strides[1] +
m_in * weight.strides[2]];
tmp *= b *
grad_out
.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
v += tmp;
}
}
grad_x.data[i] = v;
}
}
at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0), M_in = weight.size(1);
auto grad_x = at::empty({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous();
AT_DISPATCH_FLOATING_TYPES(grad_x.scalar_type(), "weighting_bw_x", [&] {
weighting_bw_x_kernel<scalar_t><<<BLOCKS(grad_x.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_x),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
grad_x.numel());
});
return grad_x;
}
template <typename scalar_t> template <typename scalar_t>
__global__ void weighting_bw_w_kernel( __global__ void weighting_bw_w_kernel(
at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_weight, at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment