Commit 112345ce authored by rusty1s's avatar rusty1s
Browse files

cuda related fixes

parent 8e464c16
...@@ -89,7 +89,7 @@ spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight, ...@@ -89,7 +89,7 @@ spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
const int64_t wi = weight_index[e * S + s]; const int64_t wi = weight_index[e * S + s];
for (int64_t m_out = 0; m_out < M_out; m_out++) { for (int64_t m_out = 0; m_out < M_out; m_out++) {
scalar_t tmp = weight[wi * M_in * M_out + m_out * M_out + m_in]; scalar_t tmp = weight[wi * M_out * M_in + m_out * M_in + m_in];
tmp *= b * grad_out[e * M_out + m_out]; tmp *= b * grad_out[e * M_out + m_out];
v += tmp; v += tmp;
} }
...@@ -116,7 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out, ...@@ -116,7 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
auto S = basis.size(1); auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options()); auto grad_x = at::zeros({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous(); weight = weight.transpose(1, 2).contiguous(); // Contiguous memory-access.
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
...@@ -137,11 +137,10 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out, ...@@ -137,11 +137,10 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
} }
template <typename scalar_t> template <typename scalar_t>
spline_weighting_bw_weight_kernel(const scalar_t *grad_out, const scalar_t *x, __global__ void spline_weighting_bw_weight_kernel(
const scalar_t *basis, const scalar_t *grad_out, const scalar_t *x, const scalar_t *basis,
const int64_t *weight_index, scalar_t *grad_x, const int64_t *weight_index, scalar_t *grad_weight, int64_t E, int64_t M_in,
int64_t E, int64_t M_in, int64_t M_out, int64_t M_out, int64_t S, int64_t numel) {
int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out; const int64_t e = thread_idx / M_out;
...@@ -198,15 +197,14 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out, ...@@ -198,15 +197,14 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
} }
template <typename scalar_t> template <typename scalar_t>
spline_weighting_bw_basis_kernel(const scalar_t *grad_out, const scalar_t *x, __global__ void spline_weighting_bw_basis_kernel(
const scalar_t *weight, const scalar_t *grad_out, const scalar_t *x, const scalar_t *weight,
const int64_t *weight_index, const int64_t *weight_index, scalar_t *grad_basis, int64_t E, int64_t M_in,
scalar_t *grad_basis, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) { int64_t M_out, int64_t S, int64_t numel) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = i / M_out; const int64_t e = thread_idx / M_out;
const int64_t m_out = i % M_out; const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) { if (thread_idx < numel) {
const scalar_t g = grad_out[e * M_out + m_out]; const scalar_t g = grad_out[e * M_out + m_out];
...@@ -228,10 +226,10 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out, ...@@ -228,10 +226,10 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x, torch::Tensor x,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor weight_index) { torch::Tensor weight_index) {
CHECK_CPU(grad_out); CHECK_CUDA(grad_out);
CHECK_CPU(x); CHECK_CUDA(x);
CHECK_CPU(weight); CHECK_CUDA(weight);
CHECK_CPU(weight_index); CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device()); cudaSetDevice(grad_out.get_device());
CHECK_INPUT(x.size(1) == weight.size(1)); CHECK_INPUT(x.size(1) == weight.size(1));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment