Commit 0e7f4b8e authored by rusty1s's avatar rusty1s
Browse files

random walk and sampler api

parent 4607290c
......@@ -38,7 +38,7 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
torch::cat({torch::ones(1, num_voxels.options()), num_voxels}, 0);
num_voxels = num_voxels.narrow(0, 0, size.size(0));
auto out = (pos / size.view({1, -1})).toType(at::kLong);
auto out = (pos / size.view({1, -1})).toType(torch::kLong);
out *= num_voxels.view({1, -1});
out = out.sum(1);
......
......@@ -2,34 +2,42 @@
#include "utils.h"
at::Tensor random_walk_cpu(torch::Tensor row, torch::Tensor col,
torch::Tensor start, int64_t walk_length, double p,
double q, int64_t num_nodes) {
auto deg = degree(row, num_nodes);
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
auto rand = at::rand({start.size(0), (int64_t)walk_length},
start.options().dtype(at::kFloat));
auto out =
at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options());
auto deg_d = deg.DATA_PTR<int64_t>();
auto cum_deg_d = cum_deg.DATA_PTR<int64_t>();
auto col_d = col.DATA_PTR<int64_t>();
auto start_d = start.DATA_PTR<int64_t>();
auto rand_d = rand.DATA_PTR<float>();
auto out_d = out.DATA_PTR<int64_t>();
for (ptrdiff_t n = 0; n < start.size(0); n++) {
int64_t cur = start_d[n];
auto i = n * (walk_length + 1);
out_d[i] = cur;
for (ptrdiff_t l = 1; l <= (int64_t)walk_length; l++) {
cur = col_d[cum_deg_d[cur] +
int64_t(rand_d[n * walk_length + (l - 1)] * deg_d[cur])];
out_d[i + l] = cur;
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(start);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto num_nodes = rowptr.size(0) - 1;
auto deg = rowptr.narrow(0, 1, num_nodes) - rowptr.narrow(0, 0, num_nodes);
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto out = torch::full({start.size(0), walk_length + 1}, -1, start.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto deg_data = deg.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto start_data = start.data_ptr<int64_t>();
auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>();
for (auto n = 0; n < start.size(0); n++) {
auto cur = start_data[n];
auto offset = n * (walk_length + 1);
out_data[offset] = cur;
for (auto l = 1; l <= walk_length; l++) {
cur = col_data[rowptr_data[cur] +
int64_t(rand_data[n * walk_length + (l - 1)] *
deg_data[cur])];
out_data[offset + l] = cur;
}
}
......
......@@ -2,6 +2,6 @@
#include <torch/extension.h>
at::Tensor random_walk_cpu(torch::Tensor row, torch::Tensor col,
torch::Tensor start, int64_t walk_length, double p,
double q, int64_t num_nodes);
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q);
#include <torch/extension.h>
#include "sampler_cpu.h"
#include "compat.h"
#include "utils.h"
at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
float factor) {
torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor) {
auto start_ptr = start.DATA_PTR<int64_t>();
auto cumdeg_ptr = cumdeg.DATA_PTR<int64_t>();
auto start_data = start.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
std::vector<int64_t> e_ids;
for (ptrdiff_t i = 0; i < start.size(0); i++) {
int64_t low = cumdeg_ptr[start_ptr[i]];
int64_t high = cumdeg_ptr[start_ptr[i] + 1];
size_t num_neighbors = high - low;
size_t size_i = size_t(ceil(factor * float(num_neighbors)));
size_i = (size_i < size) ? size_i : size;
for (auto i = 0; i < start.size(0); i++) {
auto row_start = rowptr_data[start_data[i]];
auto row_end = rowptr_data[start_data[i] + 1];
auto num_neighbors = row_end - row_start;
int64_t size = count;
if (count < 1) {
size = int64_t(ceil(factor * float(num_neighbors)));
}
// If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without
// replacement, otherwise we sample random numbers into a set as long as
// necessary.
// replacement, otherwise we sample random numbers into a set as long
// as necessary.
std::unordered_set<int64_t> set;
if (size_i < 0.7 * float(num_neighbors)) {
while (set.size() < size_i) {
int64_t z = rand() % num_neighbors;
set.insert(z + low);
if (size < 0.7 * float(num_neighbors)) {
while (int64_t(set.size()) < size) {
int64_t sample = (rand() % num_neighbors) + row_start;
set.insert(sample);
}
std::vector<int64_t> v(set.begin(), set.end());
e_ids.insert(e_ids.end(), v.begin(), v.end());
} else {
auto sample = at::randperm(num_neighbors, start.options());
auto sample_ptr = sample.DATA_PTR<int64_t>();
for (size_t j = 0; j < size_i; j++) {
e_ids.push_back(sample_ptr[j] + low);
auto sample = at::randperm(num_neighbors, start.options()) + row_start;
auto sample_data = sample.data_ptr<int64_t>();
for (auto j = 0; j < size; j++) {
e_ids.push_back(sample_data[j]);
}
}
}
int64_t len = e_ids.size();
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone();
return e_id;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("neighbor_sampler", &neighbor_sampler, "Neighbor Sampler (CPU)");
int64_t length = e_ids.size();
return torch::from_blob(e_ids.data(), {length}, start.options()).clone();
}
#pragma once
#include <torch/extension.h>
torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor);
......@@ -3,27 +3,23 @@
#include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__grid(void) { return NULL; }
PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif
torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
if (pos.device().is_cuda()) {
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length, double p,
double q) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported.")
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return grid_cpu(pos, size, optional_start, optional_end);
return random_walk_cpu(rowptr, col, start, walk_length, p, q);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::grid", &grid);
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk);
#include <Python.h>
#include <torch/script.h>
#include "cpu/sampler_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__sampler(void) { return NULL; }
#endif
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return neighbor_sampler_cpu(start, rowptr, count, factor);
}
}
static auto registry = torch::RegisterOperators().op(
"torch_cluster::neighbor_sampler", &neighbor_sampler);
......@@ -7,7 +7,9 @@ __version__ = '1.5.0'
expected_torch_version = (1, 4)
try:
for library in ['_version', '_grid', '_graclus', '_fps']:
for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler'
]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
except OSError as e:
......@@ -44,8 +46,8 @@ from .fps import fps # noqa
# from .nearest import nearest # noqa
# from .knn import knn, knn_graph # noqa
# from .radius import radius, radius_graph # noqa
# from .rw import random_walk # noqa
# from .sampler import neighbor_sampler # noqa
from .rw import random_walk # noqa
from .sampler import neighbor_sampler # noqa
__all__ = [
'graclus_cluster',
......@@ -56,7 +58,7 @@ __all__ = [
# 'knn_graph',
# 'radius',
# 'radius_graph',
# 'random_walk',
# 'neighbor_sampler',
'random_walk',
'neighbor_sampler',
'__version__',
]
import warnings
from typing import Optional
import torch
import torch_cluster.rw_cpu
if torch.cuda.is_available():
import torch_cluster.rw_cuda
def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False,
num_nodes=None):
@torch.jit.script
def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
walk_length: int, p: float = 1, q: float = 1,
coalesced: bool = False, num_nodes: Optional[int] = None):
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
......@@ -33,22 +32,21 @@ def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False,
:rtype: :class:`LongTensor`
"""
if num_nodes is None:
num_nodes = max(row.max(), col.max()).item() + 1
num_nodes = max(int(row.max()), int(col.max())) + 1
if coalesced:
_, perm = torch.sort(row * num_nodes + col)
row, col = row[perm], col[perm]
if p != 1 or q != 1: # pragma: no cover
deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1)
deg.cumsum(0, out=rowptr[1:])
if p != 1. or q != 1.: # pragma: no cover
warnings.warn('Parameters `p` and `q` are not supported yet and will'
'be restored to their default values `p=1` and `q=1`.')
p = q = 1
start = start.flatten()
p = q = 1.
if row.is_cuda: # pragma: no cover
return torch_cluster.rw_cuda.rw(row, col, start, walk_length, p, q,
num_nodes)
else:
return torch_cluster.rw_cpu.rw(row, col, start, walk_length, p, q,
num_nodes)
return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
p, q)
import torch_cluster.sampler_cpu
import torch
def neighbor_sampler(start, cumdeg, size):
@torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
assert not start.is_cuda
factor = 1
if isinstance(size, float):
factor: float = -1.
count: int = -1
if size <= 1:
factor = size
size = 2147483647
assert factor > 0
else:
count = int(size)
op = torch_cluster.sampler_cpu.neighbor_sampler
return op(start, cumdeg, size, factor)
return torch.ops.torch_cluster.neighbor_sampler(start, rowptr, count,
factor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment