Commit f5cb51ae authored by rusty1s's avatar rusty1s
Browse files

stream to scatter kernels

parent 3cf59da2
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
...@@ -8,20 +9,23 @@ ...@@ -8,20 +9,23 @@
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
auto stream = at::cuda::getCurrentCUDAStream();
#define KERNEL_RUN(NAME, DIMS, N, ...) \ #define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \ [&] { \
auto stream = at::cuda::getCurrentCUDAStream(); \
switch (DIMS) { \ switch (DIMS) { \
case 1: \ case 1: \
NAME<scalar_t, 1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \ NAME<scalar_t, 1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \ break; \
case 2: \ case 2: \
NAME<scalar_t, 2><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \ NAME<scalar_t, 2><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \ break; \
case 3: \ case 3: \
NAME<scalar_t, 3><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \ NAME<scalar_t, 3><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
break; \ break; \
default: \ default: \
NAME<scalar_t, -1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \ NAME<scalar_t, -1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N); \
} \ } \
}() }()
...@@ -43,7 +47,6 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -43,7 +47,6 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) { int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(), KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
...@@ -70,7 +73,6 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -70,7 +73,6 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) { int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(), KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
...@@ -116,7 +118,6 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -116,7 +118,6 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) { at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src); auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
...@@ -147,7 +148,6 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -147,7 +148,6 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) { at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src); auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment