randomwalks_cpu.h 2.38 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/sampler/generic_randomwalk_cpu.h
 * \brief DGL sampler - templated implementation definition of random walks on CPU
 */

7
8
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_
9
10
11

#include <dgl/base_heterograph.h>
#include <dgl/array.h>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
12
13
#include <tuple>
#include <utility>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include "randomwalks_impl.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

namespace {

/*!
 * \brief Generic Random Walk.
 * \param seeds A 1D array of seed nodes, with the type the source type of the first
 *        edge type in the metapath.
 * \param max_num_steps The maximum number of steps of a random walk path.
 * \param step The random walk step function with type \c StepFunc.
 * \return A 2D array of shape (len(seeds), max_num_steps + 1) with node IDs.
 * \note The graph itself should be bounded in the closure of \c step.
 */
template<DLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
37
std::pair<IdArray, IdArray> GenericRandomWalk(
38
39
40
41
42
43
    const IdArray seeds,
    int64_t max_num_steps,
    StepFunc<IdxType> step) {
  int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, seeds->ctx);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
44
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, seeds->ctx);
45

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
46
47
48
  const IdxType *seed_data = seeds.Ptr<IdxType>();
  IdxType *traces_data = traces.Ptr<IdxType>();
  IdxType *eids_data = eids.Ptr<IdxType>();
49
50
51
52
53
54
55
56
57

#pragma omp parallel for
  for (int64_t seed_id = 0; seed_id < num_seeds; ++seed_id) {
    int64_t i;
    dgl_id_t curr = seed_data[seed_id];
    traces_data[seed_id * trace_length] = curr;

    for (i = 0; i < max_num_steps; ++i) {
      const auto &succ = step(traces_data + seed_id * max_num_steps, curr, i);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
58
59
60
      traces_data[seed_id * trace_length + i + 1] = curr = std::get<0>(succ);
      eids_data[seed_id * max_num_steps + i] = std::get<1>(succ);
      if (std::get<2>(succ))
61
62
63
        break;
    }

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
64
    for (; i < max_num_steps; ++i) {
65
      traces_data[seed_id * trace_length + i + 1] = -1;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
66
67
      eids_data[seed_id * max_num_steps + i] = -1;
    }
68
69
  }

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
70
  return std::make_pair(traces, eids);
71
72
73
74
75
76
77
78
79
80
}

};  // namespace

};  // namespace impl

};  // namespace sampling

};  // namespace dgl

81
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_