Commit d1b3f976 authored by rusty1s's avatar rusty1s
Browse files

kernel boilerplate

parent 10e5ee86
...@@ -179,6 +179,15 @@ template <typename scalar_t> struct BasisBackward { ...@@ -179,6 +179,15 @@ template <typename scalar_t> struct BasisBackward {
return grad_pseudo; \ return grad_pseudo; \
} }
#define BASIS_BACKWARD_KERNEL(M, GRAD_PSEUDO, GRAD_BASIS, PSEUDO, KERNEL_SIZE, \
IS_OPEN_SPLINE, NUMEL, CODE, GRAD_CODE) \
[&] { \
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; \
const size_t stride = blockDim.x * gridDim.x; \
for (ptrdiff_t i = index; i < NUMEL; i += stride) { \
} \
}()
at::Tensor linear_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo, at::Tensor linear_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) { at::Tensor kernel_size, at::Tensor is_open_spline) {
return grad_basis; return grad_basis;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment