Commit 5765a062 authored by rusty1s's avatar rusty1s
Browse files

update cpu build

parent 5f93cd74
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
#include "compat.h"
#include "utils.h"
at::Tensor get_dist(at::Tensor x, ptrdiff_t index) {
return (x - x[index]).norm(2, 1);
}
at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) {
auto batch_size = batch[-1].DATA_PTR<int64_t>()[0] + 1;
auto deg = degree(batch, batch_size);
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
auto k = (deg.toType(at::kFloat) * ratio).ceil().toType(at::kLong);
auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);
auto out = at::empty(cum_k[-1].DATA_PTR<int64_t>()[0], batch.options());
auto cum_deg_d = cum_deg.DATA_PTR<int64_t>();
auto k_d = k.DATA_PTR<int64_t>();
auto cum_k_d = cum_k.DATA_PTR<int64_t>();
auto out_d = out.DATA_PTR<int64_t>();
for (ptrdiff_t b = 0; b < batch_size; b++) {
auto index = at::range(cum_deg_d[b], cum_deg_d[b + 1] - 1, out.options());
auto y = x.index_select(0, index);
ptrdiff_t start = 0;
if (random) {
start = at::randperm(y.size(0), batch.options()).DATA_PTR<int64_t>()[0];
}
out_d[cum_k_d[b]] = cum_deg_d[b] + start;
auto dist = get_dist(y, start);
for (ptrdiff_t i = 1; i < k_d[b]; i++) {
ptrdiff_t argmax = dist.argmax().DATA_PTR<int64_t>()[0];
out_d[cum_k_d[b] + i] = cum_deg_d[b] + argmax;
dist = at::min(dist, get_dist(y, argmax));
}
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fps", &fps, "Farthest Point Sampling (CPU)");
}
#include "fps_cpu.h"
#include "utils.h"
inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return (x - x[idx]).norm(2, 1);
}
torch::Tensor fps_cpu(torch::Tensor src,
torch::optional<torch::Tensor> optional_ptr, double ratio,
bool random_start) {
CHECK_CPU(src);
if (optional_ptr.has_value()) {
CHECK_CPU(optional_ptr.value());
CHECK_INPUT(optional_ptr.value().dim() == 1);
}
AT_ASSERTM(ratio > 0 and ratio < 1, "Invalid input");
if (!optional_ptr.has_value())
optional_ptr =
torch::tensor({0, src.size(0)}, src.options().dtype(torch::kLong));
src = src.view({src.size(0), -1}).contiguous();
auto ptr = optional_ptr.value().contiguous();
auto batch_size = ptr.size(0) - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
auto out = torch::empty(out_ptr[-1].data_ptr<int64_t>()[0], ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>();
auto out_ptr_data = out_ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
int64_t src_start = 0, out_start = 0, src_end, out_end;
for (auto b = 0; b < batch_size; b++) {
src_end = ptr_data[b + 1], out_end = out_ptr_data[b];
auto y = src.narrow(0, src_start, src_end - src_start);
int64_t start_idx = 0;
if (random_start) {
// TODO: GET RANDOM INTEGER
}
out_data[out_start] = src_start + start_idx;
auto dist = get_dist(y, start_idx);
for (auto i = 1; i < out_end - out_start; i++) {
int64_t argmax = dist.argmax().data_ptr<int64_t>()[0];
out_data[out_start + i] = src_start + argmax;
dist = torch::min(dist, get_dist(y, argmax));
}
src_start = src_end, out_start = out_end;
}
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor fps_cpu(torch::Tensor src,
torch::optional<torch::Tensor> optional_ptr, double ratio,
bool random_start);
#include <torch/extension.h>
#include "compat.h"
#include "utils.h"
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = rand(row, col);
std::tie(row, col) = to_csr(row, col, num_nodes);
auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>();
auto perm = at::randperm(num_nodes, row.options());
auto perm_data = perm.DATA_PTR<int64_t>();
auto cluster = at::full(num_nodes, -1, row.options());
auto cluster_data = cluster.DATA_PTR<int64_t>();
for (int64_t i = 0; i < num_nodes; i++) {
auto u = perm_data[i];
if (cluster_data[u] >= 0)
continue;
cluster_data[u] = u;
for (int64_t j = row_data[u]; j < row_data[u + 1]; j++) {
auto v = col_data[j];
if (cluster_data[v] >= 0)
continue;
cluster_data[u] = std::min(u, v);
cluster_data[v] = std::min(u, v);
break;
}
}
return cluster;
}
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
int64_t num_nodes) {
std::tie(row, col, weight) = remove_self_loops(row, col, weight);
std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>();
auto perm = at::randperm(num_nodes, row.options());
auto perm_data = perm.DATA_PTR<int64_t>();
auto cluster = at::full(num_nodes, -1, row.options());
auto cluster_data = cluster.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] {
auto weight_data = weight.DATA_PTR<scalar_t>();
for (int64_t i = 0; i < num_nodes; i++) {
auto u = perm_data[i];
if (cluster_data[u] >= 0)
continue;
int64_t v_max = u;
scalar_t w_max = 0;
for (int64_t j = row_data[u]; j < row_data[u + 1]; j++) {
auto v = col_data[j];
if (cluster_data[v] >= 0)
continue;
if (weight_data[j] >= w_max) {
v_max = v;
w_max = weight_data[j];
}
}
cluster_data[u] = std::min(u, v_max);
cluster_data[v_max] = std::min(u, v_max);
}
});
return cluster;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CPU)");
m.def("weighted_graclus", &weighted_graclus, "Weighted Graclus (CPU)");
}
#include "graclus_cpu.h"
#include "utils.h"
torch::Tensor graclus_cpu(torch::Tensor row, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight,
int64_t num_nodes) {
CHECK_CPU(row);
CHECK_CPU(col);
CHECK_INPUT(row.dim() == 1 && col.dim() == 1 && row.numel() == col.numel());
if (optional_weight.has_value()) {
CHECK_CPU(optional_weight.value());
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
auto mask = row != col;
row = row.masked_select(mask), col = col.masked_select(mask);
if (optional_weight.has_value())
optional_weight = optional_weight.value().masked_select(mask);
auto perm = torch::randperm(row.size(0), row.options());
row = row.index_select(0, perm);
col = col.index_select(0, perm);
if (optional_weight.has_value())
optional_weight = optional_weight.value().index_select(0, perm);
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
if (optional_weight.has_value())
optional_weight = optional_weight.value().index_select(0, perm);
auto rowptr = torch::zeros(num_nodes, row.options());
rowptr = rowptr.scatter_add_(0, row, torch::ones_like(row)).cumsum(0);
rowptr = torch::cat({torch::zeros(1, row.options()), rowptr}, 0);
perm = torch::randperm(num_nodes, row.options());
auto out = torch::full(num_nodes, -1, row.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto perm_data = perm.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
if (!optional_weight.has_value()) {
for (auto i = 0; i < num_nodes; i++) {
auto u = perm_data[i];
if (out_data[u] >= 0)
continue;
out_data[u] = u;
for (auto j = rowptr_data[u]; j < rowptr_data[u + 1]; j++) {
auto v = col_data[j];
if (out_data[v] >= 0)
continue;
out_data[u] = std::min(u, v);
out_data[v] = std::min(u, v);
break;
}
}
} else {
auto weight = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] {
auto weight_data = weight.data_ptr<scalar_t>();
for (auto i = 0; i < num_nodes; i++) {
auto u = perm_data[i];
if (out_data[u] >= 0)
continue;
auto v_max = u;
scalar_t w_max = (scalar_t)0.;
for (auto j = rowptr_data[u]; j < rowptr_data[u + 1]; j++) {
auto v = col_data[j];
if (out_data[v] >= 0)
continue;
if (weight_data[j] >= w_max) {
v_max = v;
w_max = weight_data[j];
}
}
out_data[u] = std::min(u, v_max);
out_data[v_max] = std::min(u, v_max);
}
});
}
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor graclus_cpu(torch::Tensor row, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight,
int64_t num_nodes);
#include <torch/extension.h>
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
pos = pos - start.view({1, -1});
auto num_voxels = ((end - start) / size).toType(at::kLong) + 1;
num_voxels = num_voxels.cumprod(0);
num_voxels = at::cat({at::ones(1, num_voxels.options()), num_voxels}, 0);
auto index = at::empty(size.size(0), num_voxels.options());
at::arange_out(index, size.size(0));
num_voxels = num_voxels.index_select(0, index);
auto cluster = (pos / size.view({1, -1})).toType(at::kLong);
cluster *= num_voxels.view({1, -1});
cluster = cluster.sum(1);
return cluster;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("grid", &grid, "Grid (CPU)"); }
#include "grid_cpu.h"
#include "utils.h"
torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
CHECK_CPU(pos);
CHECK_CPU(size);
if (optional_start.has_value())
CHECK_CPU(optional_start.value());
if (optional_start.has_value())
CHECK_CPU(optional_start.value());
pos = pos.view({pos.size(0), -1});
CHECK_INPUT(size.numel() == pos.size(1));
if (!optional_start.has_value())
optional_start = std::get<0>(pos.min(0));
else
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
if (!optional_end.has_value())
optional_end = std::get<0>(pos.max(0));
else
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
auto start = optional_start.value();
auto end = optional_end.value();
pos = pos - start.unsqueeze(0);
auto num_voxels = ((end - start) / size).toType(torch::kLong) + 1;
num_voxels = num_voxels.cumprod(0);
num_voxels =
torch::cat({torch::ones(1, num_voxels.options()), num_voxels}, 0);
num_voxels = num_voxels.narrow(0, 0, size.size(0));
auto out = (pos / size.view({1, -1})).toType(at::kLong);
out *= num_voxels.view({1, -1});
out = out.sum(1);
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end);
#include <torch/extension.h> #include "rw_cpu.h"
#include "compat.h"
#include "utils.h" #include "utils.h"
at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start, at::Tensor random_walk_cpu(torch::Tensor row, torch::Tensor col,
size_t walk_length, float p, float q, size_t num_nodes) { torch::Tensor start, int64_t walk_length, double p,
double q, int64_t num_nodes) {
auto deg = degree(row, num_nodes); auto deg = degree(row, num_nodes);
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0); auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
...@@ -34,7 +35,3 @@ at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start, ...@@ -34,7 +35,3 @@ at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
return out; return out;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rw", &rw, "Random Walk Sampling (CPU)");
}
#pragma once
#include <torch/extension.h>
at::Tensor random_walk_cpu(torch::Tensor row, torch::Tensor col,
torch::Tensor start, int64_t walk_length, double p,
double q, int64_t num_nodes);
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#include <torch/extension.h> #include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) { at::Tensor col) {
auto mask = row != col; auto mask = row != col;
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <Python.h>
#include <torch/script.h>
#include "cpu/fps_cpu.h"
#ifdef WITH_CUDA
#include "cuda/fps_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__fps(void) { return NULL; }
#endif
torch::Tensor fps(torch::Tensor src,
torch::optional<torch::Tensor> optional_ptr, double ratio,
bool random_start) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return fps_cuda(src, optional_ptr, ratio, random_start);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return fps_cpu(src, optional_ptr, ratio, random_start);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::fps", &fps);
#include <Python.h>
#include <torch/script.h>
#include "cpu/graclus_cpu.h"
#ifdef WITH_CUDA
#include "cuda/graclus_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__graclus(void) { return NULL; }
#endif
torch::Tensor graclus(torch::Tensor row, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight,
int64_t num_nodes) {
if (row.device().is_cuda()) {
#ifdef WITH_CUDA
return graclus_cuda(row, col, optional_weight, num_nodes);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return graclus_cpu(row, col, optional_weight, num_nodes);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::graclus", &graclus);
#include <Python.h>
#include <torch/script.h>
#include "cpu/grid_cpu.h"
#ifdef WITH_CUDA
#include "cuda/grid_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__grid(void) { return NULL; }
#endif
torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
if (pos.device().is_cuda()) {
#ifdef WITH_CUDA
return grid_cuda(pos, size, optional_start, optional_end);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return grid_cpu(pos, size, optional_start, optional_end);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::grid", &grid);
#include <Python.h>
#include <torch/script.h>
#include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__grid(void) { return NULL; }
#endif
torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
if (pos.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported.")
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return grid_cpu(pos, size, optional_start, optional_end);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::grid", &grid);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment