Commit e1e47f9b authored by rusty1s's avatar rusty1s
Browse files

fix

parent eb66a19d
......@@ -4,57 +4,3 @@
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
return std::make_tuple(row.masked_select(mask), col.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) {
auto mask = row != col;
return std::make_tuple(row.masked_select(mask), col.masked_select(mask),
weight.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor> rand(at::Tensor row, at::Tensor col) {
auto perm = at::randperm(row.size(0), row.options());
return std::make_tuple(row.index_select(0, perm), col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor> sort_by_row(at::Tensor row, at::Tensor col) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
sort_by_row(at::Tensor row, at::Tensor col, at::Tensor weight) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm),
weight.index_select(0, perm));
}
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
std::tie(row, col) = sort_by_row(row, col);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
to_csr(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) {
std::tie(row, col, weight) = sort_by_row(row, col, weight);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col, weight);
}
......@@ -3,6 +3,10 @@
#include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif
......@@ -12,7 +16,7 @@ torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
double q) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
return random_walk_cuda(rowptr, col, start, walk_length, p, q);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
......
......@@ -41,7 +41,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr[1:])
torch.cumsum(deg, 0, out=ptr[1:])
else:
ptr = torch.tensor([0, src.size(0)], device=src.device)
......
......@@ -54,6 +54,6 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1)
deg.cumsum(0, out=rowptr[1:])
torch.cumsum(deg, 0, out=rowptr[1:])
return torch.ops.torch_cluster.graclus(rowptr, col, weight)
......@@ -53,19 +53,19 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_x[1:])
torch.cumsum(deg, 0, out=ptr_x[1:])
else:
ptr_x = torch.tensor([0, x.size(0)], device=x.device)
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.may()) + 1
batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_y[1:])
torch.cumsum(deg, 0, out=ptr_y[1:])
else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device)
......
......@@ -48,19 +48,19 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_x[1:])
torch.cumsum(deg, 0, out=ptr_x[1:])
else:
ptr_x = torch.tensor([0, x.size(0)], device=x.device)
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.may()) + 1
batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_y[1:])
torch.cumsum(deg, 0, out=ptr_y[1:])
else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device)
......
......@@ -57,19 +57,18 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_x[1:])
torch.cumsum(deg, 0, out=ptr_x[1:])
else:
ptr_x = torch.tensor([0, x.size(0)], device=x.device)
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.may()) + 1
batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_y[1:])
torch.cumsum(deg, 0, out=ptr_y[1:])
else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device)
......
......@@ -41,7 +41,7 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1)
deg.cumsum(0, out=rowptr[1:])
torch.cumsum(deg, 0, out=rowptr[1:])
if p != 1. or q != 1.: # pragma: no cover
warnings.warn('Parameters `p` and `q` are not supported yet and will'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment