randomwalks_cpu.h 2.71 KB
Newer Older
1
/**
2
 *  Copyright (c) 2018 by Contributors
3
4
 * @file graph/sampler/generic_randomwalk_cpu.h
 * @brief DGL sampler - templated implementation definition of random walks on
5
 * CPU
6
7
 */

8
9
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_
10
11

#include <dgl/array.h>
12
#include <dgl/base_heterograph.h>
13
#include <dgl/runtime/parallel_for.h>
14

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
15
16
#include <tuple>
#include <utility>
17

18
19
20
21
22
23
24
25
26
27
28
29
30
#include "randomwalks_impl.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

namespace {

31
/**
32
33
 * @brief Generic Random Walk.
 * @param seeds A 1D array of seed nodes, with the type the source type of the
34
35
36
37
38
39
40
 * first edge type in the metapath.
 * @param max_num_steps The maximum number of steps of a random walk path.
 * @param step The random walk step function with type \c StepFunc.
 * @param max_nodes Throws an error if one of the values in \c seeds exceeds
 * this argument.
 * @return A 2D array of shape (len(seeds), max_num_steps + 1) with node IDs.
 * @note The graph itself should be bounded in the closure of \c step.
41
 */
42
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
43
std::pair<IdArray, IdArray> GenericRandomWalk(
44
    const IdArray seeds, int64_t max_num_steps, StepFunc<IdxType> step,
45
    int64_t max_nodes) {
46
47
  int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
48
49
50
51
  IdArray traces =
      IdArray::Empty({num_seeds, trace_length}, seeds->dtype, seeds->ctx);
  IdArray eids =
      IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, seeds->ctx);
52

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
53
54
55
  const IdxType *seed_data = seeds.Ptr<IdxType>();
  IdxType *traces_data = traces.Ptr<IdxType>();
  IdxType *eids_data = eids.Ptr<IdxType>();
56

57
58
59
60
61
62
  runtime::parallel_for(0, num_seeds, [&](size_t seed_begin, size_t seed_end) {
    for (auto seed_id = seed_begin; seed_id < seed_end; seed_id++) {
      int64_t i;
      dgl_id_t curr = seed_data[seed_id];
      traces_data[seed_id * trace_length] = curr;

63
64
      CHECK_LT(curr, max_nodes)
          << "Seed node ID exceeds the maximum number of nodes.";
65

66
      for (i = 0; i < max_num_steps; ++i) {
67
        const auto &succ = step(traces_data + seed_id * trace_length, curr, i);
68
69
        traces_data[seed_id * trace_length + i + 1] = curr = std::get<0>(succ);
        eids_data[seed_id * max_num_steps + i] = std::get<1>(succ);
70
        if (std::get<2>(succ)) break;
71
72
73
74
75
76
      }

      for (; i < max_num_steps; ++i) {
        traces_data[seed_id * trace_length + i + 1] = -1;
        eids_data[seed_id * max_num_steps + i] = -1;
      }
77
    }
78
  });
79

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
80
  return std::make_pair(traces, eids);
81
82
83
84
85
86
87
88
89
90
}

};  // namespace

};  // namespace impl

};  // namespace sampling

};  // namespace dgl

91
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_