Commit d7a16f9a authored by rusty1s's avatar rusty1s
Browse files

return both node and edge ids

parent 2dd14df1
...@@ -23,9 +23,9 @@ torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, ...@@ -23,9 +23,9 @@ torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r, int64_t max_num_neighbors); torch::Tensor ptr_y, double r, int64_t max_num_neighbors);
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col, std::tuple<torch::Tensor, torch::Tensor>
torch::Tensor start, int64_t walk_length, double p, random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
double q); int64_t walk_length, double p, double q);
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr, torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor); int64_t count, double factor);
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
#include "utils.h" #include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, std::tuple<torch::Tensor, torch::Tensor>
torch::Tensor start, int64_t walk_length, random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
double p, double q) { int64_t walk_length, double p, double q) {
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
CHECK_CPU(start); CHECK_CPU(start);
...@@ -50,5 +50,5 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -50,5 +50,5 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
} }
}); });
return n_out; return std::make_tuple(n_out, e_out);
} }
...@@ -2,6 +2,6 @@ ...@@ -2,6 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, std::tuple<torch::Tensor, torch::Tensor>
torch::Tensor start, int64_t walk_length, random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
double p, double q); int64_t walk_length, double p, double q);
...@@ -35,9 +35,9 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr, ...@@ -35,9 +35,9 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
} }
} }
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, std::tuple<torch::Tensor, torch::Tensor>
torch::Tensor start, int64_t walk_length, random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
double p, double q) { int64_t walk_length, double p, double q) {
CHECK_CUDA(rowptr); CHECK_CUDA(rowptr);
CHECK_CUDA(col); CHECK_CUDA(col);
CHECK_CUDA(start); CHECK_CUDA(start);
...@@ -60,5 +60,5 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, ...@@ -60,5 +60,5 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(), walk_length, n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(), walk_length,
start.numel()); start.numel());
return n_out.t().contiguous(); return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
} }
...@@ -2,6 +2,6 @@ ...@@ -2,6 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, std::tuple<torch::Tensor, torch::Tensor>
torch::Tensor start, int64_t walk_length, random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
double p, double q); int64_t walk_length, double p, double q);
...@@ -11,9 +11,9 @@ ...@@ -11,9 +11,9 @@
PyMODINIT_FUNC PyInit__rw(void) { return NULL; } PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif #endif
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col, std::tuple<torch::Tensor, torch::Tensor>
torch::Tensor start, int64_t walk_length, double p, random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
double q) { int64_t walk_length, double p, double q) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return random_walk_cuda(rowptr, col, start, walk_length, p, q); return random_walk_cuda(rowptr, col, start, walk_length, p, q);
......
...@@ -2,12 +2,13 @@ import warnings ...@@ -2,12 +2,13 @@ import warnings
from typing import Optional from typing import Optional
import torch import torch
from torch import Tensor
@torch.jit.script @torch.jit.script
def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor, def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
walk_length: int, p: float = 1, q: float = 1, p: float = 1, q: float = 1, coalesced: bool = True,
coalesced: bool = True, num_nodes: Optional[int] = None): num_nodes: Optional[int] = None) -> Tensor:
"""Samples random walks of length :obj:`walk_length` from all node indices """Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks" `"node2vec: Scalable Feature Learning for Networks"
...@@ -49,4 +50,4 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor, ...@@ -49,4 +50,4 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
p = q = 1. p = q = 1.
return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length, return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
p, q) p, q)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment