Commit 6d8410a1 authored by rusty1s's avatar rusty1s
Browse files

outsoruce

parent cb0e5f63
......@@ -3,9 +3,7 @@
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
row = row.masked_select(mask);
col = col.masked_select(mask);
return {row, col};
return {row.masked_select(mask), col.masked_select(mask)};
}
inline std::tuple<at::Tensor, at::Tensor>
......@@ -29,17 +27,19 @@ randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
return {row, col};
}
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
return zero.scatter_add_(0, index, at::ones_like(index));
inline at::Tensor degree(at::Tensor index, int64_t num_nodes,
at::ScalarType scalar_type) {
auto zero = at::full(torch::CPU(scalar_type), {num_nodes}, 0);
auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one);
}
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes);
auto deg = degree(row, num_nodes);
auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1);
auto cluster = at::full(row.type(), {num_nodes}, -1);
auto deg = degree(row, num_nodes, row.type().scalarType());
auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>();
......@@ -66,27 +66,6 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
return cluster;
}
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
size = size.toType(pos.type());
start = start.toType(pos.type());
end = end.toType(pos.type());
pos = pos - start.view({1, -1});
auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0);
num_voxels -= num_voxels.data<int64_t>()[0];
num_voxels.data<int64_t>()[0] = 1;
auto cluster = pos / size.view({1, -1});
cluster = cluster.toType(at::kLong);
cluster *= num_voxels.view({1, -1});
cluster = cluster.sum(1);
return cluster;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CPU)");
m.def("grid", &grid, "Grid (CPU)");
}
#include <torch/torch.h>
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
size = size.toType(pos.type());
start = start.toType(pos.type());
end = end.toType(pos.type());
pos = pos - start.view({1, -1});
auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0);
num_voxels -= num_voxels.data<int64_t>()[0];
num_voxels.data<int64_t>()[0] = 1;
auto cluster = pos / size.view({1, -1});
cluster = cluster.toType(at::kLong);
cluster *= num_voxels.view({1, -1});
cluster = cluster.sum(1);
return cluster;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("grid", &grid, "Grid (CPU)"); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment