Commit cb0e5f63 authored by rusty1s's avatar rusty1s
Browse files

cleaner

parent b992389e
......@@ -25,21 +25,16 @@ __global__ void grid_cuda_kernel(
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
size = size.toType(pos.type());
start = start.toType(pos.type());
end = end.toType(pos.type());
const auto num_nodes = pos.size(0);
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] {
auto cluster_data = cluster.data<int64_t>();
auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos);
auto size_data = size.data<scalar_t>();
auto start_data = start.data<scalar_t>();
auto end_data = end.data<scalar_t>();
grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>(
cluster_data, pos_info, size_data, start_data, end_data, num_nodes);
cluster.data<int64_t>(),
at::cuda::detail::getTensorInfo<scalar_t, int>(pos),
size.toType(pos.type()).data<scalar_t>(),
start..toType(pos.type()).data<scalar_t>(),
end.toType(pos.type()).data<scalar_t>(), num_nodes);
});
return cluster;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment