Commit ec8477ea authored by rusty1s's avatar rusty1s
Browse files

cleaner scatter min/max backward implementation

parent 3922eca6
...@@ -62,26 +62,9 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -62,26 +62,9 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
}); });
} }
void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) {
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(grad.scalar_type(), "index_backward", [&] {
DIM_APPLY4(scalar_t, grad, int64_t, index, int64_t, arg, scalar_t, out, dim,
{
for (i = 0; i < elems_per_row; i++) {
idx = index_data[i * index_stride];
if (arg_data[idx * arg_stride] == i) {
out_data[i * out_stride] = grad_data[idx * grad_stride];
}
}
});
});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scatter_mul", &scatter_mul, "Scatter Mul (CPU)"); m.def("scatter_mul", &scatter_mul, "Scatter Mul (CPU)");
m.def("scatter_div", &scatter_div, "Scatter Div (CPU)"); m.def("scatter_div", &scatter_div, "Scatter Div (CPU)");
m.def("scatter_max", &scatter_max, "Scatter Max (CPU)"); m.def("scatter_max", &scatter_max, "Scatter Max (CPU)");
m.def("scatter_min", &scatter_min, "Scatter Min (CPU)"); m.def("scatter_min", &scatter_min, "Scatter Min (CPU)");
m.def("index_backward", &index_backward, "Index Backward (CPU)");
} }
...@@ -47,19 +47,9 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -47,19 +47,9 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_min_cuda(src, index, out, arg, dim); scatter_min_cuda(src, index, out, arg, dim);
} }
void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) {
CHECK_CUDA(grad);
CHECK_CUDA(index);
CHECK_CUDA(arg);
CHECK_CUDA(out);
index_backward_cuda(grad, index, arg, out, dim);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scatter_mul", &scatter_mul, "Scatter Mul (CUDA)"); m.def("scatter_mul", &scatter_mul, "Scatter Mul (CUDA)");
m.def("scatter_div", &scatter_div, "Scatter Div (CUDA)"); m.def("scatter_div", &scatter_div, "Scatter Div (CUDA)");
m.def("scatter_max", &scatter_max, "Scatter Max (CUDA)"); m.def("scatter_max", &scatter_max, "Scatter Max (CUDA)");
m.def("scatter_min", &scatter_min, "Scatter Min (CUDA)"); m.def("scatter_min", &scatter_min, "Scatter Min (CUDA)");
m.def("index_backward", &index_backward, "Index Backward (CUDA)");
} }
...@@ -159,36 +159,3 @@ void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -159,36 +159,3 @@ void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
dim); dim);
}); });
} }
template <typename scalar_t, int64_t Dims>
__global__ void
index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
at::cuda::detail::TensorInfo<int64_t, int64_t> index,
at::cuda::detail::TensorInfo<int64_t, int64_t> arg,
at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
int64_t dim, size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = idx; i < numel; i += stride) {
int64_t gradOffset = 0, indexOffset = 0, argOffset = 0, outOffset = 0;
IndexToScatterOffsets4<scalar_t, int64_t, scalar_t, Dims>::compute(
i, dim, index, &indexOffset, out, &outOffset, arg, &argOffset, grad,
&gradOffset);
if (arg.data[argOffset] ==
(outOffset / out.strides[dim]) % out.sizes[dim]) {
out.data[outOffset] = grad.data[gradOffset];
}
}
}
void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) {
cudaSetDevice(grad.get_device());
AT_DISPATCH_ALL_TYPES(grad.scalar_type(), "index_backward_kernel", [&] {
KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
});
}
...@@ -25,8 +25,7 @@ class ScatterMax(Function): ...@@ -25,8 +25,7 @@ class ScatterMax(Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size()) grad_src = grad_out.new_zeros(index.size())
func = get_func('index_backward', grad_out) grad_src.scatter_(ctx.dim, arg.detach(), grad_out)
func(grad_out, index, arg, grad_src, ctx.dim)
return None, grad_src, None, None return None, grad_src, None, None
......
...@@ -25,8 +25,7 @@ class ScatterMin(Function): ...@@ -25,8 +25,7 @@ class ScatterMin(Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size()) grad_src = grad_out.new_zeros(index.size())
func = get_func('index_backward', grad_out) grad_src.scatter_(ctx.dim, arg.detach(), grad_out)
func(grad_out, index, arg, grad_src, ctx.dim)
return None, grad_src, None, None return None, grad_src, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment