grid.cpp 725 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
2
3
4
5
6

at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
                at::Tensor end) {
  pos = pos - start.view({1, -1});

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
  auto num_voxels = ((end - start) / size).toType(at::kLong) + 1;
  num_voxels = num_voxels.cumprod(0);

  num_voxels = at::cat({at::ones(1, num_voxels.options()), num_voxels}, 0);
  auto index = empty(size.size(0), num_voxels.options());
  arange_out(index, size.size(0));
  num_voxels = num_voxels.index_select(0, index);

  auto cluster = (pos / size.view({1, -1})).toType(at::kLong);
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
  cluster *= num_voxels.view({1, -1});
  cluster = cluster.sum(1);

  return cluster;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("grid", &grid, "Grid (CPU)"); }