Commit 112345ce authored by rusty1s's avatar rusty1s
Browse files

cuda related fixes

parent 8e464c16
......@@ -89,7 +89,7 @@ spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
const int64_t wi = weight_index[e * S + s];
for (int64_t m_out = 0; m_out < M_out; m_out++) {
scalar_t tmp = weight[wi * M_in * M_out + m_out * M_out + m_in];
scalar_t tmp = weight[wi * M_out * M_in + m_out * M_in + m_in];
tmp *= b * grad_out[e * M_out + m_out];
v += tmp;
}
......@@ -116,7 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous();
weight = weight.transpose(1, 2).contiguous(); // Contiguous memory-access.
auto weight_index_data = weight_index.data_ptr<int64_t>();
......@@ -137,11 +137,10 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
}
template <typename scalar_t>
spline_weighting_bw_weight_kernel(const scalar_t *grad_out, const scalar_t *x,
const scalar_t *basis,
const int64_t *weight_index, scalar_t *grad_x,
int64_t E, int64_t M_in, int64_t M_out,
int64_t S, int64_t numel) {
__global__ void spline_weighting_bw_weight_kernel(
const scalar_t *grad_out, const scalar_t *x, const scalar_t *basis,
const int64_t *weight_index, scalar_t *grad_weight, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out;
......@@ -198,15 +197,14 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
}
template <typename scalar_t>
spline_weighting_bw_basis_kernel(const scalar_t *grad_out, const scalar_t *x,
const scalar_t *weight,
const int64_t *weight_index,
scalar_t *grad_basis, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
__global__ void spline_weighting_bw_basis_kernel(
const scalar_t *grad_out, const scalar_t *x, const scalar_t *weight,
const int64_t *weight_index, scalar_t *grad_basis, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = i / M_out;
const int64_t m_out = i % M_out;
const int64_t e = thread_idx / M_out;
const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) {
const scalar_t g = grad_out[e * M_out + m_out];
......@@ -228,10 +226,10 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(weight_index);
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device());
CHECK_INPUT(x.size(1) == weight.size(1));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment