rw.cpp 1.21 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <torch/extension.h>

rusty1s's avatar
rusty1s committed
3
#include "compat.h"
rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
#include "utils.h"

at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
              size_t walk_length, float p, float q, size_t num_nodes) {
  auto deg = degree(row, num_nodes);
  auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);

  auto rand = at::rand({start.size(0), (int64_t)walk_length},
                       start.options().dtype(at::kFloat));
  auto out =
      at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options());

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
  auto deg_d = deg.DATA_PTR<int64_t>();
  auto cum_deg_d = cum_deg.DATA_PTR<int64_t>();
  auto col_d = col.DATA_PTR<int64_t>();
  auto start_d = start.DATA_PTR<int64_t>();
  auto rand_d = rand.DATA_PTR<float>();
  auto out_d = out.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

  for (ptrdiff_t n = 0; n < start.size(0); n++) {
    int64_t cur = start_d[n];
    auto i = n * (walk_length + 1);
    out_d[i] = cur;

    for (ptrdiff_t l = 1; l <= (int64_t)walk_length; l++) {
      cur = col_d[cum_deg_d[cur] +
                  int64_t(rand_d[n * walk_length + (l - 1)] * deg_d[cur])];
      out_d[i + l] = cur;
    }
  }

  return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("rw", &rw, "Random Walk Sampling (CPU)");
}