Commit e1e47f9b authored by rusty1s's avatar rusty1s
Browse files

fix

parent eb66a19d
...@@ -4,57 +4,3 @@ ...@@ -4,57 +4,3 @@
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
return std::make_tuple(row.masked_select(mask), col.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) {
auto mask = row != col;
return std::make_tuple(row.masked_select(mask), col.masked_select(mask),
weight.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor> rand(at::Tensor row, at::Tensor col) {
auto perm = at::randperm(row.size(0), row.options());
return std::make_tuple(row.index_select(0, perm), col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor> sort_by_row(at::Tensor row, at::Tensor col) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
sort_by_row(at::Tensor row, at::Tensor col, at::Tensor weight) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm),
weight.index_select(0, perm));
}
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
std::tie(row, col) = sort_by_row(row, col);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
to_csr(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) {
std::tie(row, col, weight) = sort_by_row(row, col, weight);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col, weight);
}
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
#include "cpu/rw_cpu.h" #include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32 #ifdef _WIN32
PyMODINIT_FUNC PyInit__rw(void) { return NULL; } PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif #endif
...@@ -12,7 +16,7 @@ torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col, ...@@ -12,7 +16,7 @@ torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
double q) { double q) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
AT_ERROR("No CUDA version supported"); return random_walk_cuda(rowptr, col, start, walk_length, p, q);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
......
...@@ -41,7 +41,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ...@@ -41,7 +41,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
deg.scatter_add_(0, batch, torch.ones_like(batch)) deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr = deg.new_zeros(batch_size + 1) ptr = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr[1:]) torch.cumsum(deg, 0, out=ptr[1:])
else: else:
ptr = torch.tensor([0, src.size(0)], device=src.device) ptr = torch.tensor([0, src.size(0)], device=src.device)
......
...@@ -54,6 +54,6 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor, ...@@ -54,6 +54,6 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
deg = row.new_zeros(num_nodes) deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row)) deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1) rowptr = row.new_zeros(num_nodes + 1)
deg.cumsum(0, out=rowptr[1:]) torch.cumsum(deg, 0, out=rowptr[1:])
return torch.ops.torch_cluster.graclus(rowptr, col, weight) return torch.ops.torch_cluster.graclus(rowptr, col, weight)
...@@ -53,19 +53,19 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -53,19 +53,19 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x)) deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1) ptr_x = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_x[1:]) torch.cumsum(deg, 0, out=ptr_x[1:])
else: else:
ptr_x = torch.tensor([0, x.size(0)], device=x.device) ptr_x = torch.tensor([0, x.size(0)], device=x.device)
if batch_y is not None: if batch_y is not None:
assert y.size(0) == batch_y.numel() assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.may()) + 1 batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long) deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y)) deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1) ptr_y = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_y[1:]) torch.cumsum(deg, 0, out=ptr_y[1:])
else: else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device) ptr_y = torch.tensor([0, y.size(0)], device=y.device)
......
...@@ -48,19 +48,19 @@ def nearest(x: torch.Tensor, y: torch.Tensor, ...@@ -48,19 +48,19 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x)) deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1) ptr_x = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_x[1:]) torch.cumsum(deg, 0, out=ptr_x[1:])
else: else:
ptr_x = torch.tensor([0, x.size(0)], device=x.device) ptr_x = torch.tensor([0, x.size(0)], device=x.device)
if batch_y is not None: if batch_y is not None:
assert y.size(0) == batch_y.numel() assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.may()) + 1 batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long) deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y)) deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1) ptr_y = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_y[1:]) torch.cumsum(deg, 0, out=ptr_y[1:])
else: else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device) ptr_y = torch.tensor([0, y.size(0)], device=y.device)
......
...@@ -57,19 +57,18 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -57,19 +57,18 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x)) deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1) ptr_x = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_x[1:]) torch.cumsum(deg, 0, out=ptr_x[1:])
else: else:
ptr_x = torch.tensor([0, x.size(0)], device=x.device) ptr_x = torch.tensor([0, x.size(0)], device=x.device)
if batch_y is not None: if batch_y is not None:
assert y.size(0) == batch_y.numel() assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.may()) + 1 batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long) deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y)) deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1) ptr_y = deg.new_zeros(batch_size + 1)
deg.cumsum(0, out=ptr_y[1:]) torch.cumsum(deg, 0, out=ptr_y[1:])
else: else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device) ptr_y = torch.tensor([0, y.size(0)], device=y.device)
......
...@@ -41,7 +41,7 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor, ...@@ -41,7 +41,7 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
deg = row.new_zeros(num_nodes) deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row)) deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1) rowptr = row.new_zeros(num_nodes + 1)
deg.cumsum(0, out=rowptr[1:]) torch.cumsum(deg, 0, out=rowptr[1:])
if p != 1. or q != 1.: # pragma: no cover if p != 1. or q != 1.: # pragma: no cover
warnings.warn('Parameters `p` and `q` are not supported yet and will' warnings.warn('Parameters `p` and `q` are not supported yet and will'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment