Commit f5cb51ae authored by rusty1s's avatar rusty1s
Browse files

stream to scatter kernels

parent 3cf59da2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
......@@ -8,20 +9,23 @@
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
auto stream = at::cuda::getCurrentCUDAStream();
#define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \
auto stream = at::cuda::getCurrentCUDAStream(); \
switch (DIMS) { \
case 1: \
NAME<scalar_t, 1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
NAME<scalar_t, 1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
case 2: \
NAME<scalar_t, 2><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
NAME<scalar_t, 2><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
case 3: \
NAME<scalar_t, 3><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
NAME<scalar_t, 3><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \
default: \
NAME<scalar_t, -1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
NAME<scalar_t, -1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
} \
}()
......@@ -43,7 +47,6 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
......@@ -70,7 +73,6 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
......@@ -116,7 +118,6 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
......@@ -147,7 +148,6 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment