Commit a315a06d authored by rusty1s's avatar rusty1s
Browse files

first aten try

parent de54ccec
#include <torch/torch.h>
at::Tensor graclus(at::Tensor row, at::Tensor col, at::Tensor weight) {
return row;
}
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor end) {
if (!start.defined()) start = std::get<0>(pos.min(1));
if (!end.defined()) end = std::get<0>(pos.max(1));
size = size.toType(pos.type());
start = start.toType(pos.type());
end = end.toType(pos.type());
pos = pos - start.view({ 1, -1 });
auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0);
num_voxels = num_voxels - num_voxels[0];
num_voxels[0] = 1;
auto cluster = pos / size.view({ 1, -1 });
cluster = cluster.toType(at::kLong);
cluster *= num_voxels.view({ 1, -1 });
cluster = cluster.sum(1);
return cluster;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CPU)", py::arg("row"), py::arg("col"), py::arg("weight"));
m.def("grid", &grid, "Grid (CPU)", py::arg("pos"), py::arg("size"), py::arg("start"),
py::arg("end"));
}
import torch
import cluster_cpu
def grid_cluster(pos, size, start, end):
return cluster_cpu.grid(pos, size, start, end)
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]], dtype=torch.uint8)
size = torch.tensor([2, 2])
start = torch.tensor([0, 0])
end = torch.tensor([7, 7])
print('pos', pos.tolist())
print('size', size.tolist())
cluster = grid_cluster(pos, size, start, end)
print('result', cluster.tolist(), cluster.dtype)
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='cluster',
ext_modules=[CppExtension('cluster_cpu', ['cluster.cpp'])],
cmdclass={'build_ext': BuildExtension},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment