Commit d7a16f9a authored by rusty1s's avatar rusty1s
Browse files

return both node and edge ids

parent 2dd14df1
......@@ -23,9 +23,9 @@ torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r, int64_t max_num_neighbors);
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length, double p,
double q);
std::tuple<torch::Tensor, torch::Tensor>
random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor);
......@@ -4,9 +4,9 @@
#include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q) {
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(start);
......@@ -50,5 +50,5 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
}
});
return n_out;
return std::make_tuple(n_out, e_out);
}
......@@ -2,6 +2,6 @@
#include <torch/extension.h>
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q);
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);
......@@ -35,9 +35,9 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q) {
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
......@@ -60,5 +60,5 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(), walk_length,
start.numel());
return n_out.t().contiguous();
return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
}
......@@ -2,6 +2,6 @@
#include <torch/extension.h>
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q);
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);
......@@ -11,9 +11,9 @@
PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length, double p,
double q) {
std::tuple<torch::Tensor, torch::Tensor>
random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return random_walk_cuda(rowptr, col, start, walk_length, p, q);
......
......@@ -2,12 +2,13 @@ import warnings
from typing import Optional
import torch
from torch import Tensor
@torch.jit.script
def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
walk_length: int, p: float = 1, q: float = 1,
coalesced: bool = True, num_nodes: Optional[int] = None):
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
p: float = 1, q: float = 1, coalesced: bool = True,
num_nodes: Optional[int] = None) -> Tensor:
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
......@@ -49,4 +50,4 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
p = q = 1.
return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
p, q)
p, q)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment