node2vec_randomwalk.h 5.87 KB
Newer Older
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1
2
/*!
 *  Copyright (c) 2021 by Contributors
3
4
 * @file graph/sampling/node2vec_randomwalk.cc
 * @brief DGL sampler - CPU implementation of node2vec random walk.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
5
6
7
8
9
10
11
12
13
14
15
16
 */

#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_RANDOMWALK_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_RANDOMWALK_H_

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/random.h>

#include <algorithm>
#include <cmath>
#include <functional>
17
#include <tuple>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
18
19
20
#include <utility>
#include <vector>

21
#include "metapath_randomwalk.h"  // for TerminatePredicate
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include "node2vec_impl.h"
#include "randomwalks_cpu.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

namespace {

template <typename IdxType>
37
bool has_edge_between(const CSRMatrix &csr, dgl_id_t u, dgl_id_t v) {
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
38
39
40
41
42
43
44
45
46
47
48
49
  const IdxType *offsets = csr.indptr.Ptr<IdxType>();
  const IdxType *all_succ = csr.indices.Ptr<IdxType>();
  const IdxType *u_succ = all_succ + offsets[u];
  const int64_t size = offsets[u + 1] - offsets[u];

  if (csr.sorted)
    return std::binary_search(u_succ, u_succ + size, v);
  else
    return std::find(u_succ, u_succ + size, v) != u_succ + size;
}

/*!
50
51
52
53
54
 * @brief Node2vec random walk step function
 * @param data The path generated so far, of type \c IdxType.
 * @param curr The last node ID generated.
 * @param pre The last last node ID generated
 * @param p Float, indicating likelihood of immediately revisiting a node in the
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
55
 *        walk.
56
 * @param q Float, control parameter to interpolate between breadth-first
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
57
 *        strategy and depth-first strategy.
58
 * @param len The number of nodes generated so far.  Note that the seed node is
59
60
61
62
63
 *        always included as \c data[0], and the successors start from \c
 * data[1]. \param csr The CSR matrix \param prob Transition probability \param
 * terminate Predicate for terminating the current random walk path. \return A
 * tuple of ID of next successor (-1 if not exist), the edge ID traversed, as
 * well as whether to terminate.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
64
65
 */

66
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
67
68
std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep(
    IdxType *data, dgl_id_t curr, dgl_id_t pre, const double p, const double q,
69
70
    int64_t len, const CSRMatrix &csr, bool csr_has_data,
    const FloatArray &probs, TerminatePredicate<IdxType> terminate) {
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
71
72
  const IdxType *offsets = csr.indptr.Ptr<IdxType>();
  const IdxType *all_succ = csr.indices.Ptr<IdxType>();
73
  const IdxType *all_eids = csr_has_data ? csr.data.Ptr<IdxType>() : nullptr;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  const IdxType *succ = all_succ + offsets[curr];
  const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;

  const int64_t size = offsets[curr + 1] - offsets[curr];

  // Isolated node
  if (size == 0) return std::make_tuple(-1, -1, true);

  IdxType idx = 0;

  // Normalize the weights to compute rejection probabilities
  double max_prob = std::max({1 / p, 1.0, 1 / q});
  // rejection prob for back to the previous node
  double prob0 = 1 / p / max_prob;
  // rejection prob for visiting the node with the distance of 1 between the
  // previous node
  double prob1 = 1 / max_prob;
  // rejection prob for visiting the node with the distance of 2 between the
  // previous node
  double prob2 = 1 / q / max_prob;
  dgl_id_t next_node;
  double r;  // rejection probability.
  if (IsNullArray(probs)) {
    if (len == 0) {
      idx = RandomEngine::ThreadLocal()->RandInt(size);
      next_node = succ[idx];
    } else {
      while (true) {
        idx = RandomEngine::ThreadLocal()->RandInt(size);
        r = RandomEngine::ThreadLocal()->Uniform(0., 1.);
        next_node = succ[idx];
        if (next_node == pre) {
          if (r < prob0) break;
        } else if (has_edge_between<IdxType>(csr, next_node, pre)) {
          if (r < prob1) break;
        } else if (r < prob2) {
          break;
        }
      }
    }
  } else {
    FloatArray prob_selected;
    ATEN_FLOAT_TYPE_SWITCH(probs->dtype, DType, "probability", {
      prob_selected = FloatArray::Empty({size}, probs->dtype, probs->ctx);
      DType *prob_selected_data = prob_selected.Ptr<DType>();
      const DType *prob_etype_data = probs.Ptr<DType>();
      for (int64_t j = 0; j < size; ++j)
121
122
        prob_selected_data[j] =
            prob_etype_data[eids ? eids[j] : j + offsets[curr]];
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    });

    if (len == 0) {
      idx = RandomEngine::ThreadLocal()->Choice<IdxType>(prob_selected);
      next_node = succ[idx];
    } else {
      while (true) {
        idx = RandomEngine::ThreadLocal()->Choice<IdxType>(prob_selected);
        r = RandomEngine::ThreadLocal()->Uniform(0., 1.);
        next_node = succ[idx];
        if (next_node == pre) {
          if (r < prob0) break;
        } else if (has_edge_between<IdxType>(csr, next_node, pre)) {
          if (r < prob1) break;
        } else if (r < prob2) {
          break;
        }
      }
    }
  }
  dgl_id_t eid = eids ? eids[idx] : (idx + offsets[curr]);

  return std::make_tuple(next_node, eid, terminate(data, next_node, len));
}

148
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
149
std::pair<IdArray, IdArray> Node2vecRandomWalk(
150
    const HeteroGraphPtr g, const IdArray seeds, const double p, const double q,
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
151
152
153
    const int64_t max_num_steps, const FloatArray &prob,
    TerminatePredicate<IdxType> terminate) {
  const CSRMatrix &edges = g->GetCSRMatrix(0);  // homogeneous graph.
154
  bool csr_has_data = CSRHasData(edges);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
155

156
157
158
159
160
161
  StepFunc<IdxType> step = [&edges, csr_has_data, &prob, p, q, terminate](
                               IdxType *data, dgl_id_t curr, int64_t len) {
    dgl_id_t pre = (len != 0) ? data[len - 1] : curr;
    return Node2vecRandomWalkStep<XPU, IdxType>(
        data, curr, pre, p, q, len, edges, csr_has_data, prob, terminate);
  };
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
162

163
164
  return GenericRandomWalk<XPU, IdxType>(
      seeds, max_num_steps, step, g->NumVertices(0));
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
165
166
167
168
169
170
171
172
173
174
}

};  // namespace

};  // namespace impl

};  // namespace sampling

};      // namespace dgl
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_RANDOMWALK_H_