Commit b992389e authored by rusty1s's avatar rusty1s
Browse files

cleaner

parent b1d9a365
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t> template <typename scalar_t>
__global__ void grid_cuda_kernel( __global__ void grid_cuda_kernel(
int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos, int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos,
...@@ -29,16 +32,13 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, ...@@ -29,16 +32,13 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
const auto num_nodes = pos.size(0); const auto num_nodes = pos.size(0);
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes}); auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
const int threads = 1024;
const dim3 blocks((num_nodes + threads - 1) / threads);
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] { AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] {
auto cluster_data = cluster.data<int64_t>(); auto cluster_data = cluster.data<int64_t>();
auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos); auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos);
auto size_data = size.data<scalar_t>(); auto size_data = size.data<scalar_t>();
auto start_data = start.data<scalar_t>(); auto start_data = start.data<scalar_t>();
auto end_data = end.data<scalar_t>(); auto end_data = end.data<scalar_t>();
grid_cuda_kernel<scalar_t><<<blocks, threads>>>( grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>(
cluster_data, pos_info, size_data, start_data, end_data, num_nodes); cluster_data, pos_info, size_data, start_data, end_data, num_nodes);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment