Commit 4bc9a767 authored by rusty1s's avatar rusty1s
Browse files
parents 92105bf1 05877042
...@@ -6,3 +6,4 @@ dist/ ...@@ -6,3 +6,4 @@ dist/
.eggs/ .eggs/
*.egg-info/ *.egg-info/
.coverage .coverage
*.so
#include <torch/torch.h> #include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, #include "graclus.cpp"
at::Tensor col) { #include "grid.cpp"
auto mask = row != col;
row = row.masked_select(mask);
col = col.masked_select(mask);
return {row, col};
}
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
// Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
row = row.index_select(0, perm);
col = col.index_select(0, perm);
// Randomly swap row values.
auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
row = node_rid.index_select(0, row);
// Sort row and column indices row-wise.
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Revert row value swaps.
row = std::get<1>(node_rid.sort()).index_select(0, row);
return {row, col};
}
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
return zero.scatter_add_(0, index, at::ones_like(index));
}
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes);
auto deg = degree(row, num_nodes);
auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1);
auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>();
auto *deg_data = deg.data<int64_t>();
auto *cluster_data = cluster.data<int64_t>();
int64_t e_idx = 0, d_idx, r, c;
while (e_idx < row.size(0)) {
r = row_data[e_idx];
if (cluster_data[r] < 0) {
cluster_data[r] = r;
for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
c = col_data[e_idx + d_idx];
if (cluster_data[c] < 0) {
cluster_data[r] = std::min(r, c);
cluster_data[c] = std::min(r, c);
break;
}
}
}
e_idx += deg_data[r];
}
return cluster;
}
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
size = size.toType(pos.type());
start = start.toType(pos.type());
end = end.toType(pos.type());
pos = pos - start.view({1, -1});
auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0);
num_voxels -= num_voxels.data<int64_t>()[0];
num_voxels.data<int64_t>()[0] = 1;
auto cluster = pos / size.view({1, -1});
cluster = cluster.toType(at::kLong);
cluster *= num_voxels.view({1, -1});
cluster = cluster.sum(1);
return cluster;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CPU)"); m.def("graclus", &graclus, "Graclus (CPU)");
......
...@@ -13,14 +13,14 @@ def graclus_cluster(row, col, num_nodes): ...@@ -13,14 +13,14 @@ def graclus_cluster(row, col, num_nodes):
return cluster_cpu.graclus(row, col, num_nodes) return cluster_cpu.graclus(row, col, num_nodes)
# pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]]) pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
# size = torch.tensor([2, 2]) size = torch.tensor([2, 2])
# start = torch.tensor([0, 0]) start = torch.tensor([0, 0])
# end = torch.tensor([7, 7]) end = torch.tensor([7, 7])
# print('pos', pos.tolist()) print('pos', pos.tolist())
# print('size', size.tolist()) print('size', size.tolist())
# cluster = grid_cluster(pos, size) cluster = grid_cluster(pos, size)
# print('result', cluster.tolist(), cluster.dtype) print('result', cluster.tolist(), cluster.dtype)
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2]) col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
......
#ifndef DEGREE_CPP
#define DEGREE_CPP
#include <torch/torch.h>
inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type) {
auto zero = at::full(index.type().toScalarType(scalar_type), {num_nodes}, 0);
auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one);
}
#endif // DEGREE_CPP
#include <torch/torch.h>
#include "degree.cpp"
#include "loop.cpp"
#include "perm.cpp"
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes);
auto cluster = at::full(row.type(), {num_nodes}, -1);
auto deg = degree(row, num_nodes, row.type().scalarType());
auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>();
auto *deg_data = deg.data<int64_t>();
auto *cluster_data = cluster.data<int64_t>();
int64_t e_idx = 0, d_idx, r, c;
while (e_idx < row.size(0)) {
r = row_data[e_idx];
if (cluster_data[r] < 0) {
cluster_data[r] = r;
for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
c = col_data[e_idx + d_idx];
if (cluster_data[c] < 0) {
cluster_data[r] = std::min(r, c);
cluster_data[c] = std::min(r, c);
break;
}
}
}
e_idx += deg_data[r];
}
return cluster;
}
#include <torch/torch.h>
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
size = size.toType(pos.type());
start = start.toType(pos.type());
end = end.toType(pos.type());
pos = pos - start.view({1, -1});
auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0);
num_voxels -= num_voxels.data<int64_t>()[0];
num_voxels.data<int64_t>()[0] = 1;
auto cluster = pos / size.view({1, -1});
cluster = cluster.toType(at::kLong);
cluster *= num_voxels.view({1, -1});
cluster = cluster.sum(1);
return cluster;
}
#ifndef LOOP_CPP
#define LOOP_CPP
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
return {row.masked_select(mask), col.masked_select(mask)};
}
#endif // LOOP_CPP
#ifndef PERM_CPP
#define PERM_CPP
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes) {
// Randomly reorder row and column indices.
auto perm = at::randperm(row.type(), row.size(0));
row = row.index_select(0, perm);
col = col.index_select(0, perm);
// Randomly swap row values.
auto node_rid = at::randperm(row.type(), num_nodes);
row = node_rid.index_select(0, row);
// Sort row and column indices row-wise.
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Revert row value swaps.
row = std::get<1>(node_rid.sort()).index_select(0, row);
return {row, col};
}
#endif // PERM_CPP
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t> template <typename scalar_t>
__global__ void grid_cuda_kernel( __global__ void
int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos, grid_cuda_kernel(int64_t *cluster,
const scalar_t *__restrict__ size, const scalar_t *__restrict__ start, at::cuda::detail::TensorInfo<scalar_t, int> pos,
const scalar_t *__restrict__ end, const size_t num_nodes) { scalar_t *__restrict__ size, scalar_t *__restrict__ start,
scalar_t *__restrict__ end, size_t num_nodes) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x; const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < num_nodes; i += stride) { for (ptrdiff_t i = index; i < num_nodes; i += stride) {
...@@ -25,7 +26,7 @@ __global__ void grid_cuda_kernel( ...@@ -25,7 +26,7 @@ __global__ void grid_cuda_kernel(
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) { at::Tensor end) {
const auto num_nodes = pos.size(0); auto num_nodes = pos.size(0);
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes}); auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] { AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment