Commit 1dcb3933 authored by rusty1s's avatar rusty1s
Browse files

random walk sampling

parent cf4aebef
......@@ -4,3 +4,4 @@ source=torch_cluster
exclude_lines =
pragma: no cover
cuda
raise
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes);
at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes) {
CHECK_CUDA(row);
CHECK_CUDA(col);
CHECK_CUDA(start);
return rw_cuda(row, col, start, walk_length, p, q, num_nodes);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rw", &rw, "Random Walk Sampling (CUDA)");
}
#include <ATen/ATen.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_rw_kernel(
const int64_t *__restrict__ row, const int64_t *__restrict__ col,
const int64_t *__restrict__ deg, const int64_t *__restrict__ start,
const float *__restrict__ rand, int64_t *__restrict__ out,
const size_t walk_length, const size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t n = index; n < numel; n += stride) {
out[n] = start[n];
for (ptrdiff_t l = 1; l <= walk_length; l++) {
auto i = l * numel + n;
auto cur = out[(l - 1) * numel + n];
out[i] = col[row[cur] + int64_t(rand[i] * deg[cur])];
}
}
}
at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes) {
auto deg = degree(row, num_nodes);
row = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
auto rand = at::rand({(int64_t)walk_length, start.size(0)},
start.options().dtype(at::kFloat));
auto out =
at::full({(int64_t)walk_length + 1, start.size(0)}, -1, start.options());
uniform_rw_kernel<<<BLOCKS(start.numel()), THREADS>>>(
row.data<int64_t>(), col.data<int64_t>(), deg.data<int64_t>(),
start.data<int64_t>(), rand.data<float>(), out.data<int64_t>(),
walk_length, start.numel());
return out.t().contiguous();
}
......@@ -23,6 +23,8 @@ if CUDA_HOME is not None:
['cuda/knn.cpp', 'cuda/knn_kernel.cu']),
CUDAExtension('torch_cluster.radius_cuda',
['cuda/radius.cpp', 'cuda/radius_kernel.cu']),
CUDAExtension('torch_cluster.rw_cuda',
['cuda/rw.cpp', 'cuda/rw_kernel.cu']),
]
__version__ = '1.2.3'
......
import pytest
import torch
from torch_cluster import random_walk
from .utils import tensor
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
def test_rw():
device = torch.device('cuda')
start = tensor([0, 1, 2, 3, 4], torch.long, device)
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
walk_length = 2
out = random_walk(row, col, start, walk_length)
assert out[:, 0].tolist() == start.tolist()
for n in range(start.size(0)):
cur = start[n].item()
for l in range(1, walk_length):
assert out[n, l].item() in col[row == cur].tolist()
cur = out[n, l].item()
......@@ -4,6 +4,7 @@ from .fps import fps
from .nearest import nearest
from .knn import knn, knn_graph
from .radius import radius, radius_graph
from .rw import random_walk
__version__ = '1.2.3'
......@@ -16,5 +17,6 @@ __all__ = [
'knn_graph',
'radius',
'radius_graph',
'random_walk',
'__version__',
]
import torch
if torch.cuda.is_available():
import torch_cluster.rw_cuda
def random_walk(row, col, start, walk_length, num_nodes=None):
num_nodes = row.max().item() + 1 if num_nodes is None else num_nodes
if row.is_cuda:
return torch_cluster.rw_cuda.rw(row, col, start, walk_length, 1, 1,
num_nodes)
else:
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment