rw_kernel.cu 1.57 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
#include "compat.cuh"
rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include "utils.cuh"

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

__global__ void uniform_rw_kernel(
    const int64_t *__restrict__ row, const int64_t *__restrict__ col,
    const int64_t *__restrict__ deg, const int64_t *__restrict__ start,
    const float *__restrict__ rand, int64_t *__restrict__ out,
    const size_t walk_length, const size_t numel) {

  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;

  for (ptrdiff_t n = index; n < numel; n += stride) {
    out[n] = start[n];

    for (ptrdiff_t l = 1; l <= walk_length; l++) {
rusty1s's avatar
rusty1s committed
22
23
24
      auto i = (l - 1) * numel + n;
      auto cur = out[i];
      out[l * numel + n] = col[row[cur] + int64_t(rand[i] * deg[cur])];
rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
    }
  }
}

at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start,
                   size_t walk_length, float p, float q, size_t num_nodes) {
rusty1s's avatar
rusty1s committed
31
  cudaSetDevice(row.get_device());
rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
39
40
  auto deg = degree(row, num_nodes);
  row = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);

  auto rand = at::rand({(int64_t)walk_length, start.size(0)},
                       start.options().dtype(at::kFloat));
  auto out =
      at::full({(int64_t)walk_length + 1, start.size(0)}, -1, start.options());

  uniform_rw_kernel<<<BLOCKS(start.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
41
42
43
      row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), deg.DATA_PTR<int64_t>(),
      start.DATA_PTR<int64_t>(), rand.DATA_PTR<float>(),
      out.DATA_PTR<int64_t>(), walk_length, start.numel());
rusty1s's avatar
rusty1s committed
44
45
46

  return out.t().contiguous();
}