Commit d85bc4fb authored by rusty1s's avatar rusty1s
Browse files

voxel grid CUDA bugfix

parent 920cc934
#include <torch/torch.h>
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end);
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes);
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
int num_nodes);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid", &grid, "Grid (CUDA)");
m.def("graclus", &graclus, "Graclus (CUDA)");
m.def("weighted_graclus", &weighted_graclus, "Weightes Graclus (CUDA)");
}
......@@ -18,7 +18,7 @@ __global__ void grid_kernel(int64_t *cluster,
for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
scalar_t p = pos.data[i * pos.strides[0] + d * pos.strides[1]] - start[d];
c += (int64_t)(p / size[d]) * k;
k += (int64_t)((end[d] - start[d]) / size[d]);
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
}
cluster[i] = c;
}
......
......@@ -15,7 +15,7 @@ if torch.cuda.is_available():
CUDAExtension('grid_cuda', ['cuda/grid.cpp', 'cuda/grid_kernel.cu']),
]
__version__ = '1.1.4'
__version__ = '1.1.5'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = []
......
from .graclus import graclus_cluster
from .grid import grid_cluster
__version__ = '1.1.4'
__version__ = '1.1.5'
__all__ = ['graclus_cluster', 'grid_cluster', '__version__']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment