traversal.cc 8.25 KB
Newer Older
GaiYu0's avatar
GaiYu0 committed
1
2
3
4
5
6
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/traversal.cc
 * \brief Graph traversal implementation
 */
#include <algorithm>
Gan Quan's avatar
Gan Quan committed
7
#include <queue>
GaiYu0's avatar
GaiYu0 committed
8
9
10
#include "./traversal.h"
#include "../c_api_common.h"

11
12
13
14
15
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
GaiYu0's avatar
GaiYu0 committed
16
17
18
19

namespace dgl {
namespace traverse {
namespace {
20
// A utility view class to wrap a vector into a queue.
GaiYu0's avatar
GaiYu0 committed
21
template<typename DType>
22
23
24
struct VectorQueueWrapper {
  std::vector<DType>* vec;
  size_t head = 0;
GaiYu0's avatar
GaiYu0 committed
25

26
  explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {}
GaiYu0's avatar
GaiYu0 committed
27

28
29
  void push(const DType& elem) {
    vec->push_back(elem);
GaiYu0's avatar
GaiYu0 committed
30
31
  }

32
33
  DType top() const {
    return vec->operator[](head);
GaiYu0's avatar
GaiYu0 committed
34
35
  }

36
37
38
39
40
41
42
43
44
45
46
  void pop() {
    ++head;
  }

  bool empty() const {
    return head == vec->size();
  }

  size_t size() const {
    return vec->size() - head;
  }
GaiYu0's avatar
GaiYu0 committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
};

// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
template<typename DType>
IdArray MergeMultipleTraversals(
    const std::vector<std::vector<DType>>& traces) {
  int64_t max_len = 0, total_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
    total_len += traces[i].size();
  }
  IdArray ret = IdArray::Empty({total_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i >= tracelen) {
        continue;
      }
      *(ret_data++) = traces[j][i];
    }
  }
  return ret;
}

// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
template<typename DType>
IdArray ComputeMergedSections(
    const std::vector<std::vector<DType>>& traces) {
  int64_t max_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
  }
  IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    int64_t sec_len = 0;
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i < tracelen) {
        ++sec_len;
      }
    }
    *(ret_data++) = sec_len;
  }
  return ret;
}

}  // namespace

/*!
 * \brief Class for representing frontiers.
 *
 * Each frontier is a list of nodes/edges (specified by their ids).
 * An optional tag can be specified on each node/edge (represented by an int value).
 */
struct Frontiers {
Gan Quan's avatar
Gan Quan committed
108
  /*!\brief a vector store for the nodes/edges in all the frontiers */
GaiYu0's avatar
GaiYu0 committed
109
110
  std::vector<dgl_id_t> ids;

Gan Quan's avatar
Gan Quan committed
111
  /*!\brief a vector store for node/edge tags. Empty if no tags are requested */
GaiYu0's avatar
GaiYu0 committed
112
113
114
115
116
117
118
119
  std::vector<int64_t> tags;

  /*!\brief a section vector to indicate each frontier */
  std::vector<int64_t> sections;
};

Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
  Frontiers front;
120
121
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
  auto visit = [&] (const dgl_id_t v) { };
GaiYu0's avatar
GaiYu0 committed
122
  auto make_frontier = [&] () {
123
      if (!queue.empty()) {
GaiYu0's avatar
GaiYu0 committed
124
        // do not push zero-length frontier
125
        front.sections.push_back(queue.size());
GaiYu0's avatar
GaiYu0 committed
126
127
      }
    };
128
  BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
129
130
131
  return front;
}

132
133
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GaiYu0's avatar
GaiYu0 committed
134
135
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
136
    const IdArray src = args[1];
GaiYu0's avatar
GaiYu0 committed
137
138
139
140
141
142
143
    bool reversed = args[2];
    const auto& front = BFSNodesFrontiers(*gptr, src, reversed);
    IdArray node_ids = CopyVectorToNDArray(front.ids);
    IdArray sections = CopyVectorToNDArray(front.sections);
    *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
  });

Gan Quan's avatar
Gan Quan committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) {
  Frontiers front;
  // NOTE: std::queue has no top() method.
  std::vector<dgl_id_t> nodes;
  VectorQueueWrapper<dgl_id_t> queue(&nodes);
  auto visit = [&] (const dgl_id_t e) { front.ids.push_back(e); };
  bool first_frontier = true;
  auto make_frontier = [&] {
      if (first_frontier) {
        first_frontier = false;   // do not push the first section when doing edges
      } else if (!queue.empty()) {
        // do not push zero-length frontier
        front.sections.push_back(queue.size());
      }
    };
  BFSEdges(graph, source, reversed, &queue, visit, make_frontier);
  return front;
}

163
164
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Gan Quan's avatar
Gan Quan committed
165
166
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
167
    const IdArray src = args[1];
Gan Quan's avatar
Gan Quan committed
168
169
170
171
172
173
174
    bool reversed = args[2];
    const auto& front = BFSEdgesFrontiers(*gptr, src, reversed);
    IdArray edge_ids = CopyVectorToNDArray(front.ids);
    IdArray sections = CopyVectorToNDArray(front.sections);
    *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
  });

GaiYu0's avatar
GaiYu0 committed
175
176
Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
  Frontiers front;
177
178
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
  auto visit = [&] (const dgl_id_t v) { };
GaiYu0's avatar
GaiYu0 committed
179
  auto make_frontier = [&] () {
180
      if (!queue.empty()) {
GaiYu0's avatar
GaiYu0 committed
181
        // do not push zero-length frontier
182
        front.sections.push_back(queue.size());
GaiYu0's avatar
GaiYu0 committed
183
184
      }
    };
185
  TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
186
187
188
  return front;
}

189
190
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GaiYu0's avatar
GaiYu0 committed
191
192
193
194
195
196
197
198
199
200
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
    bool reversed = args[1];
    const auto& front = TopologicalNodesFrontiers(*gptr, reversed);
    IdArray node_ids = CopyVectorToNDArray(front.ids);
    IdArray sections = CopyVectorToNDArray(front.sections);
    *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
  });


201
202
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GaiYu0's avatar
GaiYu0 committed
203
204
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
205
    const IdArray source = args[1];
GaiYu0's avatar
GaiYu0 committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    const bool reversed = args[2];
    CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
    const int64_t len = source->shape[0];
    const int64_t* src_data = static_cast<int64_t*>(source->data);
    std::vector<std::vector<dgl_id_t>> edges(len);
    for (int64_t i = 0; i < len; ++i) {
      auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); };
      DFSLabeledEdges(*gptr, src_data[i], reversed, false, false, visit);
    }
    IdArray ids = MergeMultipleTraversals(edges);
    IdArray sections = ComputeMergedSections(edges);
    *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
  });

220
221
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GaiYu0's avatar
GaiYu0 committed
222
223
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
224
    const IdArray source = args[1];
GaiYu0's avatar
GaiYu0 committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    const bool reversed = args[2];
    const bool has_reverse_edge = args[3];
    const bool has_nontree_edge = args[4];
    const bool return_labels = args[5];

    CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
    const int64_t len = source->shape[0];
    const int64_t* src_data = static_cast<int64_t*>(source->data);

    std::vector<std::vector<dgl_id_t>> edges(len);
    std::vector<std::vector<int64_t>> tags;
    if (return_labels) {
      tags.resize(len);
    }
    for (int64_t i = 0; i < len; ++i) {
      auto visit = [&] (dgl_id_t e, int tag) {
        edges[i].push_back(e);
        if (return_labels) {
          tags[i].push_back(tag);
        }
      };
      DFSLabeledEdges(*gptr, src_data[i], reversed,
          has_reverse_edge, has_nontree_edge, visit);
    }

    IdArray ids = MergeMultipleTraversals(edges);
    IdArray sections = ComputeMergedSections(edges);
    if (return_labels) {
      IdArray labels = MergeMultipleTraversals(tags);
      *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
    } else {
      *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
    }
  });

}  // namespace traverse
}  // namespace dgl