rw_cpu.cpp 1.75 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include "rw_cpu.h"
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
#include <ATen/Parallel.h>
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
6
#include "utils.h"

rusty1s's avatar
rusty1s committed
7
8
9
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
                int64_t walk_length, double p, double q) {
rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
17
18
19
20
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  CHECK_CPU(start);

  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  CHECK_INPUT(start.dim() == 1);

  auto rand = torch::rand({start.size(0), walk_length},
                          start.options().dtype(torch::kFloat));

rusty1s's avatar
rusty1s committed
21
22
  auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options());
  auto e_out = torch::empty({start.size(0), walk_length}, start.options());
rusty1s's avatar
rusty1s committed
23
24
25
26
27

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto start_data = start.data_ptr<int64_t>();
  auto rand_data = rand.data_ptr<float>();
rusty1s's avatar
rusty1s committed
28
29
  auto n_out_data = n_out.data_ptr<int64_t>();
  auto e_out_data = e_out.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
32
33
  int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
  at::parallel_for(0, start.numel(), grain_size, [&](int64_t b, int64_t e) {
    for (auto n = b; n < e; n++) {
rusty1s's avatar
rusty1s committed
34
35
36
      int64_t n_cur = start_data[n], e_cur, row_start, row_end, rnd;

      n_out_data[n * (walk_length + 1)] = n_cur;
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
39
      for (auto l = 0; l < walk_length; l++) {
        row_start = rowptr_data[n_cur], row_end = rowptr_data[n_cur + 1];
rusty1s's avatar
rusty1s committed
40
        if (row_end - row_start == 0) {
rusty1s's avatar
rusty1s committed
41
          e_cur = -1;
rusty1s's avatar
rusty1s committed
42
        } else {
rusty1s's avatar
rusty1s committed
43
44
45
          rnd = int64_t(rand_data[n * walk_length + l] * (row_end - row_start));
          e_cur = row_start + rnd;
          n_cur = col_data[e_cur];
rusty1s's avatar
rusty1s committed
46
        }
rusty1s's avatar
rusty1s committed
47
        n_out_data[n * (walk_length + 1) + (l + 1)] = n_cur;
rusty1s's avatar
rusty1s committed
48
        e_out_data[n * walk_length + l] = e_cur;
rusty1s's avatar
rusty1s committed
49
      }
rusty1s's avatar
rusty1s committed
50
    }
rusty1s's avatar
rusty1s committed
51
  });
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
  return std::make_tuple(n_out, e_out);
rusty1s's avatar
rusty1s committed
54
}