metapath_randomwalk.h 9.13 KB
Newer Older
1
/**
2
 *  Copyright (c) 2018 by Contributors
3
4
 * @file graph/sampler/generic_randomwalk_cpu.h
 * @brief DGL sampler - templated implementation definition of random walks on
5
 * CPU
6
7
 */

8
9
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_
10
11
12
13

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/random.h>
14

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
15
#include <tuple>
16
#include <utility>
17
#include <vector>
18

19
#include "randomwalks_cpu.h"
20
#include "randomwalks_impl.h"
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

namespace {

// bool WhetherToTerminate(
//     IdxType *node_ids_generated_so_far,
//     dgl_id_t last_node_id_generated,
//     int64_t number_of_nodes_generated_so_far)
37
template <typename IdxType>
38
39
using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;

40
/**
41
 * @brief Select one successor of metapath-based random walk, given the path
42
 * generated so far.
43
 *
44
45
46
 * @param data The path generated so far, of type \c IdxType.
 * @param curr The last node ID generated.
 * @param len The number of nodes generated so far.  Note that the seed node is
47
 * always included as \c data[0], and the successors start from \c data[1].
48
 *
49
50
51
52
 * @param edges_by_type Vector of results from \c GetAdj() by edge type.
 * @param metapath_data Edge types of given metapath.
 * @param prob Transition probability per edge type.
 * @param terminate Predicate for terminating the current random walk path.
53
 *
54
 * @return A tuple of ID of next successor (-1 if not exist), the last traversed
55
 * edge ID, as well as whether to terminate.
56
 */
57
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
58
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
59
    IdxType *data, dgl_id_t curr, int64_t len,
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
60
    const std::vector<CSRMatrix> &edges_by_type,
61
    const std::vector<bool> &csr_has_data, const IdxType *metapath_data,
62
63
64
65
    const std::vector<FloatArray> &prob,
    TerminatePredicate<IdxType> terminate) {
  dgl_type_t etype = metapath_data[len];

66
67
68
69
  // Note that since the selection of successors is very lightweight (especially
  // in the uniform case), we want to reduce the overheads (even from object
  // copies or object construction) as much as possible. Using Successors()
  // slows down by 2x. Using OutEdges() slows down by 10x.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
70
71
72
  const CSRMatrix &csr = edges_by_type[etype];
  const IdxType *offsets = csr.indptr.Ptr<IdxType>();
  const IdxType *all_succ = csr.indices.Ptr<IdxType>();
73
74
  const IdxType *all_eids =
      csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;
75
  const IdxType *succ = all_succ + offsets[curr];
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
76
  const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;
77
78

  const int64_t size = offsets[curr + 1] - offsets[curr];
79
  if (size == 0) return std::make_tuple(-1, -1, true);
80

81
82
83
84
  // Use a reference to the original array instead of copying
  // This avoids updating the ref counts atomically from different threads
  // and avoids cache ping-ponging in the tight loop
  const FloatArray &prob_etype = prob[etype];
Jinjing Zhou's avatar
Jinjing Zhou committed
85
  IdxType idx = 0;
86
  if (IsNullArray(prob_etype)) {
87
88
89
90
    // empty probability array; assume uniform
    idx = RandomEngine::ThreadLocal()->RandInt(size);
  } else {
    ATEN_FLOAT_TYPE_SWITCH(prob_etype->dtype, DType, "probability", {
91
92
      FloatArray prob_selected =
          FloatArray::Empty({size}, prob_etype->dtype, prob_etype->ctx);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
93
94
      DType *prob_selected_data = prob_selected.Ptr<DType>();
      const DType *prob_etype_data = prob_etype.Ptr<DType>();
95
      for (int64_t j = 0; j < size; ++j)
96
97
        prob_selected_data[j] =
            prob_etype_data[eids ? eids[j] : j + offsets[curr]];
98
99
100
      idx = RandomEngine::ThreadLocal()->Choice<IdxType>(prob_selected);
    });
  }
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
101
  dgl_id_t eid = eids ? eids[idx] : (idx + offsets[curr]);
102

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
103
  return std::make_tuple(succ[idx], eid, terminate(data, curr, len));
104
105
}

106
/**
107
 * @brief Select one successor of metapath-based random walk, given the path
108
 * generated so far specifically for the uniform probability distribution.
109
 *
110
111
112
 * @param data The path generated so far, of type \c IdxType.
 * @param curr The last node ID generated.
 * @param len The number of nodes generated so far.  Note that the seed node is
113
 * always included as \c data[0], and the successors start from \c data[1].
114
 *
115
116
117
 * @param edges_by_type Vector of results from \c GetAdj() by edge type.
 * @param metapath_data Edge types of given metapath.
 * @param prob Transition probability per edge type, for this special case this
118
119
 * will be a NullArray
 * @param terminate Predicate for terminating the current
120
 * random walk path.
121
 *
122
 * @return A pair of ID of next successor (-1 if not exist), as well as whether
123
124
 * to terminate. \note This function is called only if all the probability
 * arrays are null.
125
 */
126
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
127
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
128
    IdxType *data, dgl_id_t curr, int64_t len,
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
129
    const std::vector<CSRMatrix> &edges_by_type,
130
    const std::vector<bool> &csr_has_data, const IdxType *metapath_data,
131
132
133
134
    const std::vector<FloatArray> &prob,
    TerminatePredicate<IdxType> terminate) {
  dgl_type_t etype = metapath_data[len];

135
136
137
138
  // Note that since the selection of successors is very lightweight (especially
  // in the uniform case), we want to reduce the overheads (even from object
  // copies or object construction) as much as possible. Using Successors()
  // slows down by 2x. Using OutEdges() slows down by 10x.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
139
140
141
  const CSRMatrix &csr = edges_by_type[etype];
  const IdxType *offsets = csr.indptr.Ptr<IdxType>();
  const IdxType *all_succ = csr.indices.Ptr<IdxType>();
142
143
  const IdxType *all_eids =
      csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;
144
  const IdxType *succ = all_succ + offsets[curr];
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
145
  const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;
146
147

  const int64_t size = offsets[curr + 1] - offsets[curr];
148
  if (size == 0) return std::make_tuple(-1, -1, true);
149
150
151
152

  IdxType idx = 0;
  // Guaranteed uniform distribution
  idx = RandomEngine::ThreadLocal()->RandInt(size);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
153
  dgl_id_t eid = eids ? eids[idx] : (idx + offsets[curr]);
154

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
155
  return std::make_tuple(succ[idx], eid, terminate(data, curr, len));
156
157
}

158
/**
159
160
161
 * @brief Metapath-based random walk.
 * @param hg The heterograph.
 * @param seeds A 1D array of seed nodes, with the type the source type of the
162
163
164
165
166
167
168
169
 * first edge type in the metapath.
 * @param metapath A 1D array of edge types representing the metapath.
 * @param prob A vector of 1D float arrays, indicating the transition
 * probability of each edge by edge type.  An empty float array assumes uniform
 * transition.
 * @param terminate Predicate for terminating a random walk path.
 * @return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs,
 * and A 2D array of shape (len(seeds), len(metapath)) with edge IDs.
170
 */
171
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
172
std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
173
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
174
175
176
177
    const std::vector<FloatArray> &prob,
    TerminatePredicate<IdxType> terminate) {
  int64_t max_num_steps = metapath->shape[0];
  const IdxType *metapath_data = static_cast<IdxType *>(metapath->data);
178
179
  const int64_t begin_ntype =
      hg->meta_graph()->FindEdge(metapath_data[0]).first;
180
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
181
182

  // Prefetch all edges.
183
184
185
186
  // This forces the heterograph to materialize all OutCSR's before the OpenMP
  // loop; otherwise data races will happen.
  // TODO(BarclayII): should we later on materialize COO/CSR/CSC anyway unless
  // told otherwise?
187
188
189
190
191
192
193
194
  int64_t num_etypes = hg->NumEdgeTypes();
  std::vector<CSRMatrix> edges_by_type(num_etypes);
  std::vector<bool> csr_has_data(num_etypes);
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
    edges_by_type[etype] = csr;
    csr_has_data[etype] = CSRHasData(csr);
  }
195

196
197
198
199
200
201
202
203
204
205
  // Hoist the check for Uniform vs Non uniform edge distribution
  // to avoid putting it on the hot path
  bool isUniform = true;
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
      isUniform = false;
      break;
    }
  }
  if (!isUniform) {
206
207
208
209
210
211
212
213
214
    StepFunc<IdxType> step = [&edges_by_type, &csr_has_data, metapath_data,
                              &prob, terminate](
                                 IdxType *data, dgl_id_t curr, int64_t len) {
      return MetapathRandomWalkStep<XPU, IdxType>(
          data, curr, len, edges_by_type, csr_has_data, metapath_data, prob,
          terminate);
    };
    return GenericRandomWalk<XPU, IdxType>(
        seeds, max_num_steps, step, max_nodes);
215
  } else {
216
217
218
219
220
221
222
223
224
    StepFunc<IdxType> step = [&edges_by_type, &csr_has_data, metapath_data,
                              &prob, terminate](
                                 IdxType *data, dgl_id_t curr, int64_t len) {
      return MetapathRandomWalkStepUniform<XPU, IdxType>(
          data, curr, len, edges_by_type, csr_has_data, metapath_data, prob,
          terminate);
    };
    return GenericRandomWalk<XPU, IdxType>(
        seeds, max_num_steps, step, max_nodes);
225
  }
226
227
228
229
230
231
232
233
234
235
}

};  // namespace

};  // namespace impl

};  // namespace sampling

};  // namespace dgl

236
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_