"src/vscode:/vscode.git/clone" did not exist on "1e7f96544203724652707dc93a7671c90efd3eeb"
traversal.h 9.82 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/traversal.h
 * @brief Graph traversal routines.
5
 *
6
7
8
9
 * Traversal routines generate frontiers. Frontiers can be node frontiers or
 * edge frontiers depending on the traversal function. Each frontier is a list
 * of nodes/edges (specified by their ids). An optional tag can be specified for
 * each node/edge (represented by an int value).
10
11
12
13
14
 */
#ifndef DGL_ARRAY_CPU_TRAVERSAL_H_
#define DGL_ARRAY_CPU_TRAVERSAL_H_

#include <dgl/graph_interface.h>
15

16
17
18
19
20
21
22
23
#include <stack>
#include <tuple>
#include <vector>

namespace dgl {
namespace aten {
namespace impl {

24
/**
25
 * @brief Traverse the graph in a breadth-first-search (BFS) order.
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
 *
 * The queue object must suffice following interface:
 *   Members:
 *   void push(IdType);  // push one node
 *   IdType top();       // get the first node
 *   void pop();           // pop one node
 *   bool empty();         // return true if the queue is empty
 *   size_t size();        // return the size of the queue
 * For example, std::queue<IdType> is a valid queue type.
 *
 * The visit function must be compatible with following interface:
 *   void (*visit)(IdType );
 *
 * The frontier function must be compatible with following interface:
 *   void (*make_frontier)(void);
 *
42
43
44
45
46
47
 * @param graph The graph.
 * @param sources Source nodes.
 * @param reversed If true, BFS follows the in-edge direction
 * @param queue The queue used to do bfs.
 * @param visit The function to call when a node is visited.
 * @param make_frontier The function to indicate that a new froniter can be
48
 * made;
49
 */
50
51
52
53
54
template <
    typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void BFSTraverseNodes(
    const CSRMatrix &csr, IdArray source, Queue *queue, VisitFn visit,
    FrontierFn make_frontier) {
55
  const int64_t len = source->shape[0];
56
  const IdType *src_data = static_cast<IdType *>(source->data);
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
  const int64_t num_nodes = csr.num_rows;
  std::vector<bool> visited(num_nodes);
  for (int64_t i = 0; i < len; ++i) {
    const IdType u = src_data[i];
    visited[u] = true;
    visit(u);
    queue->push(u);
  }
  make_frontier();

  while (!queue->empty()) {
    const size_t size = queue->size();
    for (size_t i = 0; i < size; ++i) {
      const IdType u = queue->top();
      queue->pop();
75
      for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
76
77
78
79
80
81
82
83
84
85
86
87
        auto v = indices_data[idx];
        if (!visited[v]) {
          visited[v] = true;
          visit(v);
          queue->push(v);
        }
      }
    }
    make_frontier();
  }
}

88
/**
89
 * @brief Traverse the graph in a breadth-first-search (BFS) order, returning
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
 *        the edges of the BFS tree.
 *
 * The queue object must suffice following interface:
 *   Members:
 *   void push(IdType);  // push one node
 *   IdType top();       // get the first node
 *   void pop();           // pop one node
 *   bool empty();         // return true if the queue is empty
 *   size_t size();        // return the size of the queue
 * For example, std::queue<IdType> is a valid queue type.
 *
 * The visit function must be compatible with following interface:
 *   void (*visit)(IdType );
 *
 * The frontier function must be compatible with following interface:
 *   void (*make_frontier)(void);
 *
107
108
109
110
111
 * @param graph The graph.
 * @param sources Source nodes.
 * @param reversed If true, BFS follows the in-edge direction
 * @param queue The queue used to do bfs.
 * @param visit The function to call when a node is visited.
112
 *        The argument would be edge ID.
113
 * @param make_frontier The function to indicate that a new frontier can be
114
 * made;
115
 */
116
117
118
119
120
template <
    typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void BFSTraverseEdges(
    const CSRMatrix &csr, IdArray source, Queue *queue, VisitFn visit,
    FrontierFn make_frontier) {
121
  const int64_t len = source->shape[0];
122
  const IdType *src_data = static_cast<IdType *>(source->data);
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
  const IdType *eid_data = static_cast<IdType *>(csr.data->data);

  const int64_t num_nodes = csr.num_rows;
  std::vector<bool> visited(num_nodes);
  for (int64_t i = 0; i < len; ++i) {
    const IdType u = src_data[i];
    visited[u] = true;
    queue->push(u);
  }
  make_frontier();

  while (!queue->empty()) {
    const size_t size = queue->size();
    for (size_t i = 0; i < size; ++i) {
      const IdType u = queue->top();
      queue->pop();
142
      for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
143
144
145
146
147
148
149
150
151
152
153
154
155
        auto e = eid_data ? eid_data[idx] : idx;
        const IdType v = indices_data[idx];
        if (!visited[v]) {
          visited[v] = true;
          visit(e);
          queue->push(v);
        }
      }
    }
    make_frontier();
  }
}

156
/**
157
 * @brief Traverse the graph in topological order.
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
 *
 * The queue object must suffice following interface:
 *   Members:
 *   void push(IdType);  // push one node
 *   IdType top();       // get the first node
 *   void pop();           // pop one node
 *   bool empty();         // return true if the queue is empty
 *   size_t size();        // return the size of the queue
 * For example, std::queue<IdType> is a valid queue type.
 *
 * The visit function must be compatible with following interface:
 *   void (*visit)(IdType );
 *
 * The frontier function must be compatible with following interface:
 *   void (*make_frontier)(void);
 *
174
175
176
177
178
 * @param graph The graph.
 * @param reversed If true, follows the in-edge direction
 * @param queue The queue used to do bfs.
 * @param visit The function to call when a node is visited.
 * @param make_frontier The function to indicate that a new froniter can be
179
 * made;
180
 */
181
182
183
184
185
template <
    typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(
    const CSRMatrix &csr, Queue *queue, VisitFn visit,
    FrontierFn make_frontier) {
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
  int64_t num_visited_nodes = 0;
  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);

  const int64_t num_nodes = csr.num_rows;
  const int64_t num_edges = csr.indices->shape[0];
  std::vector<int64_t> degrees(num_nodes, 0);
  for (int64_t eid = 0; eid < num_edges; ++eid) {
    degrees[indices_data[eid]]++;
  }

  for (int64_t vid = 0; vid < num_nodes; ++vid) {
    if (degrees[vid] == 0) {
      visit(vid);
      queue->push(static_cast<IdType>(vid));
      ++num_visited_nodes;
    }
  }
  make_frontier();

  while (!queue->empty()) {
    const size_t size = queue->size();
    for (size_t i = 0; i < size; ++i) {
      const IdType u = queue->top();
      queue->pop();
211
      for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
212
213
214
215
216
217
218
219
220
221
222
223
        const IdType v = indices_data[idx];
        if (--(degrees[v]) == 0) {
          visit(v);
          queue->push(v);
          ++num_visited_nodes;
        }
      }
    }
    make_frontier();
  }

  if (num_visited_nodes != num_nodes) {
224
225
    LOG(FATAL)
        << "Error in topological traversal: loop detected in the given graph.";
226
227
228
  }
}

229
/** @brief Tags for ``DFSEdges``. */
230
231
232
233
234
enum DFSEdgeTag {
  kForward = 0,
  kReverse,
  kNonTree,
};
235
/**
236
 * @brief Traverse the graph in a depth-first-search (DFS) order.
237
238
239
240
241
 *
 * The traversal visit edges in its DFS order. Edges have three tags:
 * FORWARD(0), REVERSE(1), NONTREE(2)
 *
 * A FORWARD edge is one in which `u` has been visisted but `v` has not.
242
243
244
 * A REVERSE edge is one in which both `u` and `v` have been visisted and the
 * edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` have
 * been visisted but the edge is NOT in the DFS tree.
245
 *
246
247
248
249
250
 * @param source Source node.
 * @param reversed If true, DFS follows the in-edge direction
 * @param has_reverse_edge If true, REVERSE edges are included
 * @param has_nontree_edge If true, NONTREE edges are included
 * @param visit The function to call when an edge is visited; the edge id and
251
 * its tag will be given as the arguments.
252
 */
253
254
255
256
template <typename IdType, typename VisitFn>
void DFSLabeledEdges(
    const CSRMatrix &csr, IdType source, bool has_reverse_edge,
    bool has_nontree_edge, VisitFn visit) {
257
  const int64_t num_nodes = csr.num_rows;
258
259
  CHECK_GE(num_nodes, source)
      << "source " << source << " is out of range [0," << num_nodes << "]";
260
261
262
263
  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
  const IdType *eid_data = static_cast<IdType *>(csr.data->data);

264
  if (indptr_data[source + 1] - indptr_data[source] == 0) {
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    // no out-going edges from the source node
    return;
  }

  typedef std::tuple<IdType, size_t, bool> StackEntry;
  std::stack<StackEntry> stack;
  std::vector<bool> visited(num_nodes);
  visited[source] = true;
  stack.push(std::make_tuple(source, 0, false));
  IdType u = 0;
  int64_t i = 0;
  bool on_tree = false;

  while (!stack.empty()) {
    std::tie(u, i, on_tree) = stack.top();
    const IdType v = indices_data[indptr_data[u] + i];
281
282
    const IdType uv =
        eid_data ? eid_data[indptr_data[u] + i] : indptr_data[u] + i;
283
284
285
286
287
288
289
290
291
    if (visited[v]) {
      if (!on_tree && has_nontree_edge) {
        visit(uv, kNonTree);
      } else if (on_tree && has_reverse_edge) {
        visit(uv, kReverse);
      }
      stack.pop();
      // find next one.
      if (indptr_data[u] + i < indptr_data[u + 1] - 1) {
292
        stack.push(std::make_tuple(u, i + 1, false));
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
      }
    } else {
      visited[v] = true;
      std::get<2>(stack.top()) = true;
      visit(uv, kForward);
      // expand
      if (indptr_data[v] < indptr_data[v + 1]) {
        stack.push(std::make_tuple(v, 0, false));
      }
    }
  }
}

}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_TRAVERSAL_H_