"docs/vscode:/vscode.git/clone" did not exist on "9a0511c8e91a7f633c9c3292fccbcbad5281d1f5"
Commit d36a6784 authored by rusty1s's avatar rusty1s
Browse files

cuda related fixes

parent 6b6c39f4
......@@ -23,7 +23,7 @@ template <typename scalar_t, int64_t degree> struct Basis {
else
return v * v * v / 6.;
} else {
AT_ERROR("Basis degree not implemented");
return (scalar_t)-1.;
}
}
......@@ -47,7 +47,7 @@ template <typename scalar_t, int64_t degree> struct Basis {
else
return v * v / 2.;
} else {
AT_ERROR("Basis degree not implemented");
return (scalar_t)-1.;
}
}
};
......
......@@ -28,7 +28,7 @@ template <typename scalar_t, int64_t degree> struct Basis {
else
return v * v * v / 6.;
} else {
AT_ERROR("Basis degree not implemented");
return (scalar_t)-1.;
}
}
......@@ -52,7 +52,7 @@ template <typename scalar_t, int64_t degree> struct Basis {
else
return v * v / 2.;
} else {
AT_ERROR("Basis degree not implemented");
return (scalar_t)-1.;
}
}
};
......@@ -76,7 +76,7 @@ spline_basis_fw_kernel(const scalar_t *pseudo, const int64_t *kernel_size,
int64_t k_mod = k % (degree + 1);
k /= degree + 1;
scalar_t v = pseudo.data[e * D + d];
scalar_t v = pseudo[e * D + d];
v *= kernel_size[d] - degree * is_open_spline[d];
wi += (((int64_t)v + k_mod) % kernel_size[d]) * wi_offset;
......@@ -87,8 +87,8 @@ spline_basis_fw_kernel(const scalar_t *pseudo, const int64_t *kernel_size,
b *= v;
}
basis[i] = b;
weight_index[i] = wi;
basis[thread_idx] = b;
weight_index[thread_idx] = wi;
}
}
......@@ -123,7 +123,7 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
spline_basis_fw_kernel<scalar_t, DEGREE>
<<<BLOCKS(basis.numel()), THREADS, 0 stream>>>(
<<<BLOCKS(basis.numel()), THREADS, 0, stream>>>(
pseudo_data, kernel_size_data, is_open_spline_data, basis_data,
weight_index_data, E, D, S, basis.numel());
});
......@@ -149,7 +149,7 @@ spline_basis_bw_kernel(const scalar_t *grad_basis, const scalar_t *pseudo,
for (ptrdiff_t s = 0; s < S; s++) {
int64_t k_mod = (s / (int64_t)(powf(degree + 1, d) + 0.5)) % (degree + 1);
scalar_t v = pseudo.data[e * D + d];
scalar_t v = pseudo[e * D + d];
v *= kernel_size[d] - degree * is_open_spline[d];
v -= floor(v);
v = Basis<scalar_t, degree>::backward(v, k_mod);
......@@ -161,13 +161,13 @@ spline_basis_bw_kernel(const scalar_t *grad_basis, const scalar_t *pseudo,
v = pseudo[e * D + d_new];
v *= kernel_size[d_new] - degree * is_open_spline[d_new];
v -= floor(v);
v = BASIS<scalar_t, degree>::forward(v, k_mod);
v = Basis<scalar_t, degree>::forward(v, k_mod);
tmp *= v;
}
g += tmp * grad_basis[e * S + s];
}
g *= kernel_size[d] - degree * is_open_spline[d];
grad_pseudo[i] = g;
grad_pseudo[thread_idx] = g;
}
}
......@@ -205,7 +205,7 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
spline_basis_bw_kernel<scalar_t, DEGREE>
<<<BLOCKS(grad_pseudo.numel()), THREADS, 0 stream>>>(
<<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>(
grad_basis_data, pseudo_data, kernel_size_data,
is_open_spline_data, grad_pseudo_data, E, D, S,
grad_pseudo.numel());
......
#include "weighting_cpu.h"
#include "weighting_cuda.h"
#include "utils.cuh"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment