Commit 7f186c26 authored by rusty1s's avatar rusty1s
Browse files

random walk cpu implementation

parent c0e7bec9
...@@ -21,6 +21,7 @@ The package consists of the following clustering algorithms: ...@@ -21,6 +21,7 @@ The package consists of the following clustering algorithms:
* **[Iterative Farthest Point Sampling](#farthestpointsampling)** from, *e.g.* Qi *et al.*: [PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](https://arxiv.org/abs/1706.02413) (NIPS 2017) * **[Iterative Farthest Point Sampling](#farthestpointsampling)** from, *e.g.* Qi *et al.*: [PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](https://arxiv.org/abs/1706.02413) (NIPS 2017)
* **[k-NN](#knn-graph)** and **[Radius](#radius-graph)** graph generation * **[k-NN](#knn-graph)** and **[Radius](#radius-graph)** graph generation
* Clustering based on **[Nearest](#nearest)** points * Clustering based on **[Nearest](#nearest)** points
* **[Random Walk Sampling](#randomwalk-sampling)** from, *e.g.*, Grover and Leskovec: [node2vec: Scalable Feature Learning for Networks](https://arxiv.org/abs/1607.00653) (KDD 2016)
All included operations work on varying data types and are implemented both for CPU and GPU. All included operations work on varying data types and are implemented both for CPU and GPU.
...@@ -164,6 +165,29 @@ print(cluster) ...@@ -164,6 +165,29 @@ print(cluster)
tensor([0, 0, 1, 1]) tensor([0, 0, 1, 1])
``` ```
## RandomWalk-Sampling
Samples random walks of length `walk_length` from all node indices in `start` in the graph given by `(row, col)`.
```python
import torch
from torch_cluster import random_walk
row = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4])
col = torch.tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3])
start = torch.tensor([0, 1, 2, 3, 4])
walk = random_walk(row, col, start, walk_length=3)
```
```
print(walk)
tensor([[0, 1, 2, 1],
[1, 3, 4, 2],
[3, 4, 3, 1],
[4, 2, 1, 0]])
```
## Running tests ## Running tests
``` ```
......
#include <torch/extension.h>
#include "utils.h"
at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes) {
auto deg = degree(row, num_nodes);
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
auto rand = at::rand({start.size(0), (int64_t)walk_length},
start.options().dtype(at::kFloat));
auto out =
at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options());
auto deg_d = deg.data<int64_t>();
auto cum_deg_d = cum_deg.data<int64_t>();
auto col_d = col.data<int64_t>();
auto start_d = start.data<int64_t>();
auto rand_d = rand.data<float>();
auto out_d = out.data<int64_t>();
for (ptrdiff_t n = 0; n < start.size(0); n++) {
int64_t cur = start_d[n];
auto i = n * (walk_length + 1);
out_d[i] = cur;
for (ptrdiff_t l = 1; l <= (int64_t)walk_length; l++) {
cur = col_d[cum_deg_d[cur] +
int64_t(rand_d[n * walk_length + (l - 1)] * deg_d[cur])];
out_d[i + l] = cur;
}
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rw", &rw, "Random Walk Sampling (CPU)");
}
...@@ -6,6 +6,7 @@ ext_modules = [ ...@@ -6,6 +6,7 @@ ext_modules = [
CppExtension('torch_cluster.graclus_cpu', ['cpu/graclus.cpp']), CppExtension('torch_cluster.graclus_cpu', ['cpu/graclus.cpp']),
CppExtension('torch_cluster.grid_cpu', ['cpu/grid.cpp']), CppExtension('torch_cluster.grid_cpu', ['cpu/grid.cpp']),
CppExtension('torch_cluster.fps_cpu', ['cpu/fps.cpp']), CppExtension('torch_cluster.fps_cpu', ['cpu/fps.cpp']),
CppExtension('torch_cluster.rw_cpu', ['cpu/rw.cpp']),
CppExtension('torch_cluster.sampler_cpu', ['cpu/sampler.cpp']), CppExtension('torch_cluster.sampler_cpu', ['cpu/sampler.cpp']),
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
......
...@@ -2,18 +2,17 @@ import pytest ...@@ -2,18 +2,17 @@ import pytest
import torch import torch
from torch_cluster import random_walk from torch_cluster import random_walk
from .utils import tensor from .utils import devices, tensor
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.parametrize('device', devices)
def test_rw(): def test_rw(device):
device = torch.device('cuda')
start = tensor([0, 1, 2, 3, 4], torch.long, device)
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10 walk_length = 10
out = random_walk(row, col, start, walk_length) out = random_walk(row, col, start, walk_length, coalesced=True)
assert out[:, 0].tolist() == start.tolist() assert out[:, 0].tolist() == start.tolist()
for n in range(start.size(0)): for n in range(start.size(0)):
......
...@@ -4,8 +4,8 @@ from .fps import fps ...@@ -4,8 +4,8 @@ from .fps import fps
from .nearest import nearest from .nearest import nearest
from .knn import knn, knn_graph from .knn import knn, knn_graph
from .radius import radius, radius_graph from .radius import radius, radius_graph
from .sampler import neighbor_sampler
from .rw import random_walk from .rw import random_walk
from .sampler import neighbor_sampler
__version__ = '1.4.3a1' __version__ = '1.4.3a1'
...@@ -18,7 +18,7 @@ __all__ = [ ...@@ -18,7 +18,7 @@ __all__ = [
'knn_graph', 'knn_graph',
'radius', 'radius',
'radius_graph', 'radius_graph',
'neighbor_sampler',
'random_walk', 'random_walk',
'neighbor_sampler',
'__version__', '__version__',
] ]
import warnings import warnings
import torch import torch
import torch_cluster.rw_cpu
if torch.cuda.is_available(): if torch.cuda.is_available():
import torch_cluster.rw_cuda import torch_cluster.rw_cuda
def random_walk(row, col, start, walk_length, p=1, q=1, num_nodes=None): def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False,
if p != 1 or q != 1: num_nodes=None):
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
<https://arxiv.org/abs/1607.00653>`_ paper.
Edge indices :obj:`(row, col)` need to be coalesced/sorted according to
:obj:`row` (use the :obj:`coalesced` attribute to force).
Args:
row (LongTensor): Source nodes.
col (LongTensor): Target nodes.
start (LongTensor): Nodes from where random walks start.
walk_length (int): The walk length.
p (float, optional): Likelihood of immediately revisiting a node in the
walk. (default: :obj:`1`)
q (float, optional): Control parameter to interpolate between
breadth-first strategy and depth-first strategy (default: :obj:`1`)
coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort
the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`False`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
:rtype: :class:`LongTensor`
"""
if num_nodes is None:
num_nodes = max(row.max(), col.max()).item() + 1
if coalesced:
_, perm = torch.sort(row * num_nodes + col)
row, col = row[perm], col[perm]
if p != 1 or q != 1: # pragma: no cover
warnings.warn('Parameters `p` and `q` are not supported yet and will' warnings.warn('Parameters `p` and `q` are not supported yet and will'
'be restored to their default values `p=1` and `q=1`.') 'be restored to their default values `p=1` and `q=1`.')
p = q = 1 p = q = 1
num_nodes = row.max().item() + 1 if num_nodes is None else num_nodes start = start.flatten()
if row.is_cuda: if row.is_cuda:
return torch_cluster.rw_cuda.rw(row, col, start, walk_length, p, q, return torch_cluster.rw_cuda.rw(row, col, start, walk_length, p, q,
num_nodes) num_nodes)
else: else:
raise NotImplementedError return torch_cluster.rw_cpu.rw(row, col, start, walk_length, p, q,
num_nodes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment