Commit 0e7f4b8e authored by rusty1s's avatar rusty1s
Browse files

random walk and sampler api

parent 4607290c
...@@ -38,7 +38,7 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size, ...@@ -38,7 +38,7 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
torch::cat({torch::ones(1, num_voxels.options()), num_voxels}, 0); torch::cat({torch::ones(1, num_voxels.options()), num_voxels}, 0);
num_voxels = num_voxels.narrow(0, 0, size.size(0)); num_voxels = num_voxels.narrow(0, 0, size.size(0));
auto out = (pos / size.view({1, -1})).toType(at::kLong); auto out = (pos / size.view({1, -1})).toType(torch::kLong);
out *= num_voxels.view({1, -1}); out *= num_voxels.view({1, -1});
out = out.sum(1); out = out.sum(1);
......
...@@ -2,34 +2,42 @@ ...@@ -2,34 +2,42 @@
#include "utils.h" #include "utils.h"
at::Tensor random_walk_cpu(torch::Tensor row, torch::Tensor col, torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length, double p, torch::Tensor start, int64_t walk_length,
double q, int64_t num_nodes) { double p, double q) {
CHECK_CPU(rowptr);
auto deg = degree(row, num_nodes); CHECK_CPU(col);
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0); CHECK_CPU(start);
auto rand = at::rand({start.size(0), (int64_t)walk_length}, CHECK_INPUT(rowptr.dim() == 1);
start.options().dtype(at::kFloat)); CHECK_INPUT(col.dim() == 1);
auto out = CHECK_INPUT(start.dim() == 1);
at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options());
auto num_nodes = rowptr.size(0) - 1;
auto deg_d = deg.DATA_PTR<int64_t>(); auto deg = rowptr.narrow(0, 1, num_nodes) - rowptr.narrow(0, 0, num_nodes);
auto cum_deg_d = cum_deg.DATA_PTR<int64_t>();
auto col_d = col.DATA_PTR<int64_t>(); auto rand = torch::rand({start.size(0), walk_length},
auto start_d = start.DATA_PTR<int64_t>(); start.options().dtype(torch::kFloat));
auto rand_d = rand.DATA_PTR<float>();
auto out_d = out.DATA_PTR<int64_t>(); auto out = torch::full({start.size(0), walk_length + 1}, -1, start.options());
for (ptrdiff_t n = 0; n < start.size(0); n++) { auto rowptr_data = rowptr.data_ptr<int64_t>();
int64_t cur = start_d[n]; auto deg_data = deg.data_ptr<int64_t>();
auto i = n * (walk_length + 1); auto col_data = col.data_ptr<int64_t>();
out_d[i] = cur; auto start_data = start.data_ptr<int64_t>();
auto rand_data = rand.data_ptr<float>();
for (ptrdiff_t l = 1; l <= (int64_t)walk_length; l++) { auto out_data = out.data_ptr<int64_t>();
cur = col_d[cum_deg_d[cur] +
int64_t(rand_d[n * walk_length + (l - 1)] * deg_d[cur])]; for (auto n = 0; n < start.size(0); n++) {
out_d[i + l] = cur; auto cur = start_data[n];
auto offset = n * (walk_length + 1);
out_data[offset] = cur;
for (auto l = 1; l <= walk_length; l++) {
cur = col_data[rowptr_data[cur] +
int64_t(rand_data[n * walk_length + (l - 1)] *
deg_data[cur])];
out_data[offset + l] = cur;
} }
} }
......
...@@ -2,6 +2,6 @@ ...@@ -2,6 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
at::Tensor random_walk_cpu(torch::Tensor row, torch::Tensor col, torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length, double p, torch::Tensor start, int64_t walk_length,
double q, int64_t num_nodes); double p, double q);
#include <torch/extension.h> #include "sampler_cpu.h"
#include "compat.h" #include "utils.h"
at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size, torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
float factor) { int64_t count, double factor) {
auto start_ptr = start.DATA_PTR<int64_t>(); auto start_data = start.data_ptr<int64_t>();
auto cumdeg_ptr = cumdeg.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
std::vector<int64_t> e_ids; std::vector<int64_t> e_ids;
for (ptrdiff_t i = 0; i < start.size(0); i++) { for (auto i = 0; i < start.size(0); i++) {
int64_t low = cumdeg_ptr[start_ptr[i]]; auto row_start = rowptr_data[start_data[i]];
int64_t high = cumdeg_ptr[start_ptr[i] + 1]; auto row_end = rowptr_data[start_data[i] + 1];
size_t num_neighbors = high - low; auto num_neighbors = row_end - row_start;
size_t size_i = size_t(ceil(factor * float(num_neighbors))); int64_t size = count;
size_i = (size_i < size) ? size_i : size; if (count < 1) {
size = int64_t(ceil(factor * float(num_neighbors)));
}
// If the number of neighbors is approximately equal to the number of // If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without // neighbors which are requested, we use `randperm` to sample without
// replacement, otherwise we sample random numbers into a set as long as // replacement, otherwise we sample random numbers into a set as long
// necessary. // as necessary.
std::unordered_set<int64_t> set; std::unordered_set<int64_t> set;
if (size_i < 0.7 * float(num_neighbors)) { if (size < 0.7 * float(num_neighbors)) {
while (set.size() < size_i) { while (int64_t(set.size()) < size) {
int64_t z = rand() % num_neighbors; int64_t sample = (rand() % num_neighbors) + row_start;
set.insert(z + low); set.insert(sample);
} }
std::vector<int64_t> v(set.begin(), set.end()); std::vector<int64_t> v(set.begin(), set.end());
e_ids.insert(e_ids.end(), v.begin(), v.end()); e_ids.insert(e_ids.end(), v.begin(), v.end());
} else { } else {
auto sample = at::randperm(num_neighbors, start.options()); auto sample = at::randperm(num_neighbors, start.options()) + row_start;
auto sample_ptr = sample.DATA_PTR<int64_t>(); auto sample_data = sample.data_ptr<int64_t>();
for (size_t j = 0; j < size_i; j++) { for (auto j = 0; j < size; j++) {
e_ids.push_back(sample_ptr[j] + low); e_ids.push_back(sample_data[j]);
} }
} }
} }
int64_t len = e_ids.size(); int64_t length = e_ids.size();
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone(); return torch::from_blob(e_ids.data(), {length}, start.options()).clone();
return e_id;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("neighbor_sampler", &neighbor_sampler, "Neighbor Sampler (CPU)");
} }
#pragma once
#include <torch/extension.h>
torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor);
...@@ -3,27 +3,23 @@ ...@@ -3,27 +3,23 @@
#include "cpu/rw_cpu.h" #include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32 #ifdef _WIN32
PyMODINIT_FUNC PyInit__grid(void) { return NULL; } PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif #endif
torch::Tensor grid(torch::Tensor pos, torch::Tensor size, torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_start, torch::Tensor start, int64_t walk_length, double p,
torch::optional<torch::Tensor> optional_end) { double q) {
if (pos.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
AT_ERROR("No CUDA version supported.") AT_ERROR("No CUDA version supported");
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
} else { } else {
return grid_cpu(pos, size, optional_start, optional_end); return random_walk_cpu(rowptr, col, start, walk_length, p, q);
} }
} }
static auto registry = static auto registry =
torch::RegisterOperators().op("torch_cluster::grid", &grid); torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk);
#include <Python.h>
#include <torch/script.h>
#include "cpu/sampler_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__sampler(void) { return NULL; }
#endif
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return neighbor_sampler_cpu(start, rowptr, count, factor);
}
}
static auto registry = torch::RegisterOperators().op(
"torch_cluster::neighbor_sampler", &neighbor_sampler);
...@@ -7,7 +7,9 @@ __version__ = '1.5.0' ...@@ -7,7 +7,9 @@ __version__ = '1.5.0'
expected_torch_version = (1, 4) expected_torch_version = (1, 4)
try: try:
for library in ['_version', '_grid', '_graclus', '_fps']: for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler'
]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
except OSError as e: except OSError as e:
...@@ -44,8 +46,8 @@ from .fps import fps # noqa ...@@ -44,8 +46,8 @@ from .fps import fps # noqa
# from .nearest import nearest # noqa # from .nearest import nearest # noqa
# from .knn import knn, knn_graph # noqa # from .knn import knn, knn_graph # noqa
# from .radius import radius, radius_graph # noqa # from .radius import radius, radius_graph # noqa
# from .rw import random_walk # noqa from .rw import random_walk # noqa
# from .sampler import neighbor_sampler # noqa from .sampler import neighbor_sampler # noqa
__all__ = [ __all__ = [
'graclus_cluster', 'graclus_cluster',
...@@ -56,7 +58,7 @@ __all__ = [ ...@@ -56,7 +58,7 @@ __all__ = [
# 'knn_graph', # 'knn_graph',
# 'radius', # 'radius',
# 'radius_graph', # 'radius_graph',
# 'random_walk', 'random_walk',
# 'neighbor_sampler', 'neighbor_sampler',
'__version__', '__version__',
] ]
import warnings import warnings
from typing import Optional
import torch import torch
import torch_cluster.rw_cpu
if torch.cuda.is_available():
import torch_cluster.rw_cuda
@torch.jit.script
def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False, def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
num_nodes=None): walk_length: int, p: float = 1, q: float = 1,
coalesced: bool = False, num_nodes: Optional[int] = None):
"""Samples random walks of length :obj:`walk_length` from all node indices """Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks" `"node2vec: Scalable Feature Learning for Networks"
...@@ -33,22 +32,21 @@ def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False, ...@@ -33,22 +32,21 @@ def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False,
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
""" """
if num_nodes is None: if num_nodes is None:
num_nodes = max(row.max(), col.max()).item() + 1 num_nodes = max(int(row.max()), int(col.max())) + 1
if coalesced: if coalesced:
_, perm = torch.sort(row * num_nodes + col) _, perm = torch.sort(row * num_nodes + col)
row, col = row[perm], col[perm] row, col = row[perm], col[perm]
if p != 1 or q != 1: # pragma: no cover deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1)
deg.cumsum(0, out=rowptr[1:])
if p != 1. or q != 1.: # pragma: no cover
warnings.warn('Parameters `p` and `q` are not supported yet and will' warnings.warn('Parameters `p` and `q` are not supported yet and will'
'be restored to their default values `p=1` and `q=1`.') 'be restored to their default values `p=1` and `q=1`.')
p = q = 1 p = q = 1.
start = start.flatten()
if row.is_cuda: # pragma: no cover return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
return torch_cluster.rw_cuda.rw(row, col, start, walk_length, p, q, p, q)
num_nodes)
else:
return torch_cluster.rw_cpu.rw(row, col, start, walk_length, p, q,
num_nodes)
import torch_cluster.sampler_cpu import torch
def neighbor_sampler(start, cumdeg, size): @torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
assert not start.is_cuda assert not start.is_cuda
factor = 1 factor: float = -1.
if isinstance(size, float): count: int = -1
if size <= 1:
factor = size factor = size
size = 2147483647 assert factor > 0
else:
count = int(size)
op = torch_cluster.sampler_cpu.neighbor_sampler return torch.ops.torch_cluster.neighbor_sampler(start, rowptr, count,
return op(start, cumdeg, size, factor) factor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment