Commit 4a61d70f authored by rusty1s's avatar rusty1s
Browse files

better csrc api

parent 0e7f4b8e
......@@ -69,7 +69,7 @@ Then run:
pip install torch-cluster
```
When running in a docker container without nvidia driver, PyTorch needs to evaluate the compute capabilities and may fail.
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
```
......
......@@ -6,23 +6,16 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return (x - x[idx]).norm(2, 1);
}
torch::Tensor fps_cpu(torch::Tensor src,
torch::optional<torch::Tensor> optional_ptr, double ratio,
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start) {
CHECK_CPU(src);
if (optional_ptr.has_value()) {
CHECK_CPU(optional_ptr.value());
CHECK_INPUT(optional_ptr.value().dim() == 1);
}
CHECK_CPU(ptr);
CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 and ratio < 1, "Invalid input");
if (!optional_ptr.has_value())
optional_ptr =
torch::tensor({0, src.size(0)}, src.options().dtype(torch::kLong));
src = src.view({src.size(0), -1}).contiguous();
auto ptr = optional_ptr.value().contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
......@@ -42,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src,
int64_t start_idx = 0;
if (random_start) {
// TODO: GET RANDOM INTEGER
start_idx = rand() % src.size(0);
}
out_data[out_start] = src_start + start_idx;
......@@ -56,5 +49,6 @@ torch::Tensor fps_cpu(torch::Tensor src,
src_start = src_end, out_start = out_end;
}
return out;
}
......@@ -2,6 +2,5 @@
#include <torch/extension.h>
torch::Tensor fps_cpu(torch::Tensor src,
torch::optional<torch::Tensor> optional_ptr, double ratio,
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start);
......@@ -2,57 +2,42 @@
#include "utils.h"
torch::Tensor graclus_cpu(torch::Tensor row, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight,
int64_t num_nodes) {
torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
CHECK_CPU(row);
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_INPUT(row.dim() == 1 && col.dim() == 1 && row.numel() == col.numel());
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
if (optional_weight.has_value()) {
CHECK_CPU(optional_weight.value());
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
auto mask = row != col;
row = row.masked_select(mask), col = col.masked_select(mask);
if (optional_weight.has_value())
optional_weight = optional_weight.value().masked_select(mask);
auto perm = torch::randperm(row.size(0), row.options());
row = row.index_select(0, perm);
col = col.index_select(0, perm);
if (optional_weight.has_value())
optional_weight = optional_weight.value().index_select(0, perm);
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
if (optional_weight.has_value())
optional_weight = optional_weight.value().index_select(0, perm);
auto rowptr = torch::zeros(num_nodes, row.options());
rowptr = rowptr.scatter_add_(0, row, torch::ones_like(row)).cumsum(0);
rowptr = torch::cat({torch::zeros(1, row.options()), rowptr}, 0);
perm = torch::randperm(num_nodes, row.options());
auto out = torch::full(num_nodes, -1, row.options());
int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
auto node_perm = torch::randperm(num_nodes, rowptr.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto perm_data = perm.data_ptr<int64_t>();
auto node_perm_data = node_perm.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
if (!optional_weight.has_value()) {
for (auto i = 0; i < num_nodes; i++) {
auto u = perm_data[i];
for (int64_t n = 0; n < num_nodes; n++) {
auto u = node_perm_data[n];
if (out_data[u] >= 0)
continue;
out_data[u] = u;
for (auto j = rowptr_data[u]; j < rowptr_data[u + 1]; j++) {
auto v = col_data[j];
int64_t row_start = rowptr_data[u], row_end = rowptr_data[u + 1];
auto edge_perm = torch::randperm(row_end - row_start, rowptr.options());
auto edge_perm_data = edge_perm.data_ptr<int64_t>();
for (auto e = 0; e < row_end - row_start; e++) {
auto v = col_data[row_start + edge_perm_data[e]];
if (out_data[v] >= 0)
continue;
......@@ -67,8 +52,8 @@ torch::Tensor graclus_cpu(torch::Tensor row, torch::Tensor col,
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] {
auto weight_data = weight.data_ptr<scalar_t>();
for (auto i = 0; i < num_nodes; i++) {
auto u = perm_data[i];
for (auto n = 0; n < num_nodes; n++) {
auto u = node_perm_data[n];
if (out_data[u] >= 0)
continue;
......@@ -76,15 +61,15 @@ torch::Tensor graclus_cpu(torch::Tensor row, torch::Tensor col,
auto v_max = u;
scalar_t w_max = (scalar_t)0.;
for (auto j = rowptr_data[u]; j < rowptr_data[u + 1]; j++) {
auto v = col_data[j];
for (auto e = rowptr_data[u]; e < rowptr_data[u + 1]; e++) {
auto v = col_data[e];
if (out_data[v] >= 0)
continue;
if (weight_data[j] >= w_max) {
if (weight_data[e] >= w_max) {
v_max = v;
w_max = weight_data[j];
w_max = weight_data[e];
}
}
......
......@@ -2,6 +2,5 @@
#include <torch/extension.h>
torch::Tensor graclus_cpu(torch::Tensor row, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight,
int64_t num_nodes);
torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight);
#include "fps_cuda.h"
#include "utils.cuh"
inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return (x - x[idx]).norm(2, 1);
}
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start) {
CHECK_CUDA(src);
CHECK_CUDA(ptr);
CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 and ratio < 1, "Invalid input");
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch.zeros(1, ptr.options()), out_ptr}, 0);
torch::Tensor start;
if (random_start) {
start = at::rand(batch_size, src.options());
start = (start * deg.toType(torch::kFloat)).toType(torch::kLong);
} else {
start = torch::zeros(batch_size, ptr.options());
}
auto out = torch::empty(out_ptr[-1].data_ptr<int64_t>()[0], ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>();
auto out_ptr_data = out_ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start);
......@@ -11,17 +11,16 @@
PyMODINIT_FUNC PyInit__fps(void) { return NULL; }
#endif
torch::Tensor fps(torch::Tensor src,
torch::optional<torch::Tensor> optional_ptr, double ratio,
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return fps_cuda(src, optional_ptr, ratio, random_start);
return fps_cuda(src, ptr, ratio, random_start);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return fps_cpu(src, optional_ptr, ratio, random_start);
return fps_cpu(src, ptr, ratio, random_start);
}
}
......
......@@ -11,17 +11,16 @@
PyMODINIT_FUNC PyInit__graclus(void) { return NULL; }
#endif
torch::Tensor graclus(torch::Tensor row, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight,
int64_t num_nodes) {
if (row.device().is_cuda()) {
torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return graclus_cuda(row, col, optional_weight, num_nodes);
return graclus_cuda(rowptr, col, optional_weight);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return graclus_cpu(row, col, optional_weight, num_nodes);
return graclus_cpu(rowptr, col, optional_weight);
}
}
......
......@@ -26,6 +26,7 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
......
......@@ -33,9 +33,8 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
index = fps(src, batch, ratio=0.5)
"""
ptr: Optional[torch.Tensor] = None
if batch is not None:
assert src.size(0) == batch.size(0)
assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype=torch.long)
......@@ -43,5 +42,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
ptr = src.new_zeros(batch_size + 1, dtype=torch.long)
deg.cumsum(0, out=ptr[1:])
else:
ptr = torch.tensor([0, src.size(0)], device=src.device)
return torch.ops.torch_cluster.fps(src, ptr, ratio, random_start)
......@@ -32,4 +32,12 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
if num_nodes is None:
num_nodes = max(int(row.max()), int(col.max())) + 1
return torch.ops.torch_cluster.graclus(row, col, weight, num_nodes)
perm = torch.argsort(row * num_nodes + col)
row, col = row[perm], col[perm]
deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1)
deg.cumsum(0, out=rowptr[1:])
return torch.ops.torch_cluster.graclus(rowptr, col, weight)
......@@ -35,7 +35,7 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
num_nodes = max(int(row.max()), int(col.max())) + 1
if coalesced:
_, perm = torch.sort(row * num_nodes + col)
perm = torch.argsort(row * num_nodes + col)
row, col = row[perm], col[perm]
deg = row.new_zeros(num_nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment