traversal.cc 7.97 KB
Newer Older
GaiYu0's avatar
GaiYu0 committed
1
2
3
4
5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/traversal.cc
 * \brief Graph traversal implementation
 */
6
#include <dgl/packed_func_ext.h>
GaiYu0's avatar
GaiYu0 committed
7
#include <algorithm>
Gan Quan's avatar
Gan Quan committed
8
#include <queue>
GaiYu0's avatar
GaiYu0 committed
9
10
11
#include "./traversal.h"
#include "../c_api_common.h"

12
using namespace dgl::runtime;
GaiYu0's avatar
GaiYu0 committed
13
14
15
16

namespace dgl {
namespace traverse {
namespace {
17
// A utility view class to wrap a vector into a queue.
GaiYu0's avatar
GaiYu0 committed
18
template<typename DType>
19
20
21
struct VectorQueueWrapper {
  std::vector<DType>* vec;
  size_t head = 0;
GaiYu0's avatar
GaiYu0 committed
22

23
  explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {}
GaiYu0's avatar
GaiYu0 committed
24

25
26
  void push(const DType& elem) {
    vec->push_back(elem);
GaiYu0's avatar
GaiYu0 committed
27
28
  }

29
30
  DType top() const {
    return vec->operator[](head);
GaiYu0's avatar
GaiYu0 committed
31
32
  }

33
34
35
36
37
38
39
40
41
42
43
  void pop() {
    ++head;
  }

  bool empty() const {
    return head == vec->size();
  }

  size_t size() const {
    return vec->size() - head;
  }
GaiYu0's avatar
GaiYu0 committed
44
45
46
47
48
49
50
51
52
53
54
55
56
};

// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
template<typename DType>
IdArray MergeMultipleTraversals(
    const std::vector<std::vector<DType>>& traces) {
  int64_t max_len = 0, total_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
    total_len += traces[i].size();
  }
57
  IdArray ret = IdArray::Empty({total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
GaiYu0's avatar
GaiYu0 committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i >= tracelen) {
        continue;
      }
      *(ret_data++) = traces[j][i];
    }
  }
  return ret;
}

// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
template<typename DType>
IdArray ComputeMergedSections(
    const std::vector<std::vector<DType>>& traces) {
  int64_t max_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
  }
81
  IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
GaiYu0's avatar
GaiYu0 committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    int64_t sec_len = 0;
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i < tracelen) {
        ++sec_len;
      }
    }
    *(ret_data++) = sec_len;
  }
  return ret;
}

}  // namespace

/*!
 * \brief Class for representing frontiers.
 *
 * Each frontier is a list of nodes/edges (specified by their ids).
 * An optional tag can be specified on each node/edge (represented by an int value).
 */
struct Frontiers {
Gan Quan's avatar
Gan Quan committed
105
  /*!\brief a vector store for the nodes/edges in all the frontiers */
GaiYu0's avatar
GaiYu0 committed
106
107
  std::vector<dgl_id_t> ids;

Gan Quan's avatar
Gan Quan committed
108
  /*!\brief a vector store for node/edge tags. Empty if no tags are requested */
GaiYu0's avatar
GaiYu0 committed
109
110
111
112
113
114
  std::vector<int64_t> tags;

  /*!\brief a section vector to indicate each frontier */
  std::vector<int64_t> sections;
};

115
Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
GaiYu0's avatar
GaiYu0 committed
116
  Frontiers front;
117
118
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
  auto visit = [&] (const dgl_id_t v) { };
GaiYu0's avatar
GaiYu0 committed
119
  auto make_frontier = [&] () {
120
      if (!queue.empty()) {
GaiYu0's avatar
GaiYu0 committed
121
        // do not push zero-length frontier
122
        front.sections.push_back(queue.size());
GaiYu0's avatar
GaiYu0 committed
123
124
      }
    };
125
  BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
126
127
128
  return front;
}

129
130
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
131
    GraphRef g = args[0];
132
    const IdArray src = args[1];
GaiYu0's avatar
GaiYu0 committed
133
    bool reversed = args[2];
134
    const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
135
136
    IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
    IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
GaiYu0's avatar
GaiYu0 committed
137
138
139
    *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
  });

140
Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Gan Quan's avatar
Gan Quan committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  Frontiers front;
  // NOTE: std::queue has no top() method.
  std::vector<dgl_id_t> nodes;
  VectorQueueWrapper<dgl_id_t> queue(&nodes);
  auto visit = [&] (const dgl_id_t e) { front.ids.push_back(e); };
  bool first_frontier = true;
  auto make_frontier = [&] {
      if (first_frontier) {
        first_frontier = false;   // do not push the first section when doing edges
      } else if (!queue.empty()) {
        // do not push zero-length frontier
        front.sections.push_back(queue.size());
      }
    };
  BFSEdges(graph, source, reversed, &queue, visit, make_frontier);
  return front;
}

159
160
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
161
    GraphRef g = args[0];
162
    const IdArray src = args[1];
Gan Quan's avatar
Gan Quan committed
163
    bool reversed = args[2];
164
    const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
165
166
    IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
    IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
Gan Quan's avatar
Gan Quan committed
167
168
169
    *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
  });

170
Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) {
GaiYu0's avatar
GaiYu0 committed
171
  Frontiers front;
172
173
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
  auto visit = [&] (const dgl_id_t v) { };
GaiYu0's avatar
GaiYu0 committed
174
  auto make_frontier = [&] () {
175
      if (!queue.empty()) {
GaiYu0's avatar
GaiYu0 committed
176
        // do not push zero-length frontier
177
        front.sections.push_back(queue.size());
GaiYu0's avatar
GaiYu0 committed
178
179
      }
    };
180
  TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
181
182
183
  return front;
}

184
185
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
186
    GraphRef g = args[0];
GaiYu0's avatar
GaiYu0 committed
187
    bool reversed = args[1];
188
    const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
189
190
    IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
    IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
GaiYu0's avatar
GaiYu0 committed
191
192
193
194
    *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
  });


195
196
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
197
    GraphRef g = args[0];
198
    const IdArray source = args[1];
GaiYu0's avatar
GaiYu0 committed
199
    const bool reversed = args[2];
200
    CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
GaiYu0's avatar
GaiYu0 committed
201
202
203
204
205
    const int64_t len = source->shape[0];
    const int64_t* src_data = static_cast<int64_t*>(source->data);
    std::vector<std::vector<dgl_id_t>> edges(len);
    for (int64_t i = 0; i < len; ++i) {
      auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); };
206
      DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
GaiYu0's avatar
GaiYu0 committed
207
208
209
210
211
212
    }
    IdArray ids = MergeMultipleTraversals(edges);
    IdArray sections = ComputeMergedSections(edges);
    *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
  });

213
214
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
215
    GraphRef g = args[0];
216
    const IdArray source = args[1];
GaiYu0's avatar
GaiYu0 committed
217
218
219
220
221
    const bool reversed = args[2];
    const bool has_reverse_edge = args[3];
    const bool has_nontree_edge = args[4];
    const bool return_labels = args[5];

222
    CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
GaiYu0's avatar
GaiYu0 committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    const int64_t len = source->shape[0];
    const int64_t* src_data = static_cast<int64_t*>(source->data);

    std::vector<std::vector<dgl_id_t>> edges(len);
    std::vector<std::vector<int64_t>> tags;
    if (return_labels) {
      tags.resize(len);
    }
    for (int64_t i = 0; i < len; ++i) {
      auto visit = [&] (dgl_id_t e, int tag) {
        edges[i].push_back(e);
        if (return_labels) {
          tags[i].push_back(tag);
        }
      };
238
      DFSLabeledEdges(*g.sptr(), src_data[i], reversed,
GaiYu0's avatar
GaiYu0 committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
          has_reverse_edge, has_nontree_edge, visit);
    }

    IdArray ids = MergeMultipleTraversals(edges);
    IdArray sections = ComputeMergedSections(edges);
    if (return_labels) {
      IdArray labels = MergeMultipleTraversals(tags);
      *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
    } else {
      *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
    }
  });

}  // namespace traverse
}  // namespace dgl