Commit 7f186c26 authored by rusty1s's avatar rusty1s
Browse files

random walk cpu implementation

parent c0e7bec9
......@@ -21,6 +21,7 @@ The package consists of the following clustering algorithms:
* **[Iterative Farthest Point Sampling](#farthestpointsampling)** from, *e.g.* Qi *et al.*: [PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](https://arxiv.org/abs/1706.02413) (NIPS 2017)
* **[k-NN](#knn-graph)** and **[Radius](#radius-graph)** graph generation
* Clustering based on **[Nearest](#nearest)** points
* **[Random Walk Sampling](#randomwalk-sampling)** from, *e.g.*, Grover and Leskovec: [node2vec: Scalable Feature Learning for Networks](https://arxiv.org/abs/1607.00653) (KDD 2016)
All included operations work on varying data types and are implemented both for CPU and GPU.
......@@ -164,6 +165,29 @@ print(cluster)
tensor([0, 0, 1, 1])
```
## RandomWalk-Sampling
Samples random walks of length `walk_length` from all node indices in `start` in the graph given by `(row, col)`.
```python
import torch
from torch_cluster import random_walk
row = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4])
col = torch.tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3])
start = torch.tensor([0, 1, 2, 3, 4])
walk = random_walk(row, col, start, walk_length=3)
```
```
print(walk)
tensor([[0, 1, 2, 1],
[1, 3, 4, 2],
[3, 4, 3, 1],
[4, 2, 1, 0]])
```
## Running tests
```
......
#include <torch/extension.h>
#include "utils.h"
at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes) {
auto deg = degree(row, num_nodes);
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
auto rand = at::rand({start.size(0), (int64_t)walk_length},
start.options().dtype(at::kFloat));
auto out =
at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options());
auto deg_d = deg.data<int64_t>();
auto cum_deg_d = cum_deg.data<int64_t>();
auto col_d = col.data<int64_t>();
auto start_d = start.data<int64_t>();
auto rand_d = rand.data<float>();
auto out_d = out.data<int64_t>();
for (ptrdiff_t n = 0; n < start.size(0); n++) {
int64_t cur = start_d[n];
auto i = n * (walk_length + 1);
out_d[i] = cur;
for (ptrdiff_t l = 1; l <= (int64_t)walk_length; l++) {
cur = col_d[cum_deg_d[cur] +
int64_t(rand_d[n * walk_length + (l - 1)] * deg_d[cur])];
out_d[i + l] = cur;
}
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rw", &rw, "Random Walk Sampling (CPU)");
}
......@@ -6,6 +6,7 @@ ext_modules = [
CppExtension('torch_cluster.graclus_cpu', ['cpu/graclus.cpp']),
CppExtension('torch_cluster.grid_cpu', ['cpu/grid.cpp']),
CppExtension('torch_cluster.fps_cpu', ['cpu/fps.cpp']),
CppExtension('torch_cluster.rw_cpu', ['cpu/rw.cpp']),
CppExtension('torch_cluster.sampler_cpu', ['cpu/sampler.cpp']),
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
......
......@@ -2,18 +2,17 @@ import pytest
import torch
from torch_cluster import random_walk
from .utils import tensor
from .utils import devices, tensor
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
def test_rw():
device = torch.device('cuda')
start = tensor([0, 1, 2, 3, 4], torch.long, device)
@pytest.mark.parametrize('device', devices)
def test_rw(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10
out = random_walk(row, col, start, walk_length)
out = random_walk(row, col, start, walk_length, coalesced=True)
assert out[:, 0].tolist() == start.tolist()
for n in range(start.size(0)):
......
......@@ -4,8 +4,8 @@ from .fps import fps
from .nearest import nearest
from .knn import knn, knn_graph
from .radius import radius, radius_graph
from .sampler import neighbor_sampler
from .rw import random_walk
from .sampler import neighbor_sampler
__version__ = '1.4.3a1'
......@@ -18,7 +18,7 @@ __all__ = [
'knn_graph',
'radius',
'radius_graph',
'neighbor_sampler',
'random_walk',
'neighbor_sampler',
'__version__',
]
import warnings
import torch
import torch_cluster.rw_cpu
if torch.cuda.is_available():
import torch_cluster.rw_cuda
def random_walk(row, col, start, walk_length, p=1, q=1, num_nodes=None):
if p != 1 or q != 1:
def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False,
num_nodes=None):
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
<https://arxiv.org/abs/1607.00653>`_ paper.
Edge indices :obj:`(row, col)` need to be coalesced/sorted according to
:obj:`row` (use the :obj:`coalesced` attribute to force).
Args:
row (LongTensor): Source nodes.
col (LongTensor): Target nodes.
start (LongTensor): Nodes from where random walks start.
walk_length (int): The walk length.
p (float, optional): Likelihood of immediately revisiting a node in the
walk. (default: :obj:`1`)
q (float, optional): Control parameter to interpolate between
breadth-first strategy and depth-first strategy (default: :obj:`1`)
coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort
the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`False`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
:rtype: :class:`LongTensor`
"""
if num_nodes is None:
num_nodes = max(row.max(), col.max()).item() + 1
if coalesced:
_, perm = torch.sort(row * num_nodes + col)
row, col = row[perm], col[perm]
if p != 1 or q != 1: # pragma: no cover
warnings.warn('Parameters `p` and `q` are not supported yet and will'
'be restored to their default values `p=1` and `q=1`.')
p = q = 1
num_nodes = row.max().item() + 1 if num_nodes is None else num_nodes
start = start.flatten()
if row.is_cuda:
return torch_cluster.rw_cuda.rw(row, col, start, walk_length, p, q,
num_nodes)
else:
raise NotImplementedError
return torch_cluster.rw_cpu.rw(row, col, start, walk_length, p, q,
num_nodes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment