traversal.cc 8.1 KB
Newer Older
1
/**
GaiYu0's avatar
GaiYu0 committed
2
 *  Copyright (c) 2018 by Contributors
3
4
 * @file graph/traversal.cc
 * @brief Graph traversal implementation
GaiYu0's avatar
GaiYu0 committed
5
 */
6
7
#include "./traversal.h"

8
#include <dgl/packed_func_ext.h>
9

GaiYu0's avatar
GaiYu0 committed
10
#include <algorithm>
Gan Quan's avatar
Gan Quan committed
11
#include <queue>
12

GaiYu0's avatar
GaiYu0 committed
13
14
#include "../c_api_common.h"

15
using namespace dgl::runtime;
GaiYu0's avatar
GaiYu0 committed
16
17
18
19

namespace dgl {
namespace traverse {
namespace {
20
// A utility view class to wrap a vector into a queue.
21
template <typename DType>
22
23
24
struct VectorQueueWrapper {
  std::vector<DType>* vec;
  size_t head = 0;
GaiYu0's avatar
GaiYu0 committed
25

26
  explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}
GaiYu0's avatar
GaiYu0 committed
27

28
  void push(const DType& elem) { vec->push_back(elem); }
GaiYu0's avatar
GaiYu0 committed
29

30
  DType top() const { return vec->operator[](head); }
GaiYu0's avatar
GaiYu0 committed
31

32
  void pop() { ++head; }
33

34
  bool empty() const { return head == vec->size(); }
35

36
  size_t size() const { return vec->size() - head; }
GaiYu0's avatar
GaiYu0 committed
37
38
39
40
};

// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
41
42
template <typename DType>
IdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {
GaiYu0's avatar
GaiYu0 committed
43
44
45
46
47
48
  int64_t max_len = 0, total_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
    total_len += traces[i].size();
  }
49
50
  IdArray ret = IdArray::Empty(
      {total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
GaiYu0's avatar
GaiYu0 committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i >= tracelen) {
        continue;
      }
      *(ret_data++) = traces[j][i];
    }
  }
  return ret;
}

// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
66
67
template <typename DType>
IdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {
GaiYu0's avatar
GaiYu0 committed
68
69
70
71
72
  int64_t max_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
  }
73
74
  IdArray ret = IdArray::Empty(
      {max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
GaiYu0's avatar
GaiYu0 committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    int64_t sec_len = 0;
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i < tracelen) {
        ++sec_len;
      }
    }
    *(ret_data++) = sec_len;
  }
  return ret;
}

}  // namespace

91
/**
92
 * @brief Class for representing frontiers.
GaiYu0's avatar
GaiYu0 committed
93
94
 *
 * Each frontier is a list of nodes/edges (specified by their ids).
95
96
 * An optional tag can be specified on each node/edge (represented by an int
 * value).
GaiYu0's avatar
GaiYu0 committed
97
98
 */
struct Frontiers {
99
  /** @brief a vector store for the nodes/edges in all the frontiers */
GaiYu0's avatar
GaiYu0 committed
100
101
  std::vector<dgl_id_t> ids;

102
103
104
  /**
   * @brief a vector store for node/edge tags. Empty if no tags are requested
   */
GaiYu0's avatar
GaiYu0 committed
105
106
  std::vector<int64_t> tags;

107
  /** @brief a section vector to indicate each frontier */
GaiYu0's avatar
GaiYu0 committed
108
109
110
  std::vector<int64_t> sections;
};

111
112
Frontiers BFSNodesFrontiers(
    const GraphInterface& graph, IdArray source, bool reversed) {
GaiYu0's avatar
GaiYu0 committed
113
  Frontiers front;
114
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
115
116
117
118
119
120
121
  auto visit = [&](const dgl_id_t v) {};
  auto make_frontier = [&]() {
    if (!queue.empty()) {
      // do not push zero-length frontier
      front.sections.push_back(queue.size());
    }
  };
122
  BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
123
124
125
  return front;
}

126
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
127
128
129
130
131
132
133
134
135
136
137
138
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      const IdArray src = args[1];
      bool reversed = args[2];
      const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
      IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
      *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
    });

Frontiers BFSEdgesFrontiers(
    const GraphInterface& graph, IdArray source, bool reversed) {
Gan Quan's avatar
Gan Quan committed
139
140
141
142
  Frontiers front;
  // NOTE: std::queue has no top() method.
  std::vector<dgl_id_t> nodes;
  VectorQueueWrapper<dgl_id_t> queue(&nodes);
143
  auto visit = [&](const dgl_id_t e) { front.ids.push_back(e); };
Gan Quan's avatar
Gan Quan committed
144
145
  bool first_frontier = true;
  auto make_frontier = [&] {
146
147
148
149
150
151
152
    if (first_frontier) {
      first_frontier = false;  // do not push the first section when doing edges
    } else if (!queue.empty()) {
      // do not push zero-length frontier
      front.sections.push_back(queue.size());
    }
  };
Gan Quan's avatar
Gan Quan committed
153
154
155
156
  BFSEdges(graph, source, reversed, &queue, visit, make_frontier);
  return front;
}

157
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
158
159
160
161
162
163
164
165
166
167
168
169
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      const IdArray src = args[1];
      bool reversed = args[2];
      const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
      IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
      *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
    });

Frontiers TopologicalNodesFrontiers(
    const GraphInterface& graph, bool reversed) {
GaiYu0's avatar
GaiYu0 committed
170
  Frontiers front;
171
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
172
173
174
175
176
177
178
  auto visit = [&](const dgl_id_t v) {};
  auto make_frontier = [&]() {
    if (!queue.empty()) {
      // do not push zero-length frontier
      front.sections.push_back(queue.size());
    }
  };
179
  TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
180
181
182
  return front;
}

183
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
184
185
186
187
188
189
190
191
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      bool reversed = args[1];
      const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
      IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
      *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
    });
GaiYu0's avatar
GaiYu0 committed
192

193
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      const IdArray source = args[1];
      const bool reversed = args[2];
      CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
      const int64_t len = source->shape[0];
      const int64_t* src_data = static_cast<int64_t*>(source->data);
      std::vector<std::vector<dgl_id_t>> edges(len);
      for (int64_t i = 0; i < len; ++i) {
        auto visit = [&](dgl_id_t e, int tag) { edges[i].push_back(e); };
        DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
      }
      IdArray ids = MergeMultipleTraversals(edges);
      IdArray sections = ComputeMergedSections(edges);
      *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
    });
GaiYu0's avatar
GaiYu0 committed
210

211
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      const IdArray source = args[1];
      const bool reversed = args[2];
      const bool has_reverse_edge = args[3];
      const bool has_nontree_edge = args[4];
      const bool return_labels = args[5];

      CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
      const int64_t len = source->shape[0];
      const int64_t* src_data = static_cast<int64_t*>(source->data);

      std::vector<std::vector<dgl_id_t>> edges(len);
      std::vector<std::vector<int64_t>> tags;
      if (return_labels) {
        tags.resize(len);
      }
      for (int64_t i = 0; i < len; ++i) {
        auto visit = [&](dgl_id_t e, int tag) {
          edges[i].push_back(e);
          if (return_labels) {
            tags[i].push_back(tag);
          }
        };
        DFSLabeledEdges(
            *g.sptr(), src_data[i], reversed, has_reverse_edge,
            has_nontree_edge, visit);
      }
GaiYu0's avatar
GaiYu0 committed
240

241
242
243
244
245
246
247
248
249
      IdArray ids = MergeMultipleTraversals(edges);
      IdArray sections = ComputeMergedSections(edges);
      if (return_labels) {
        IdArray labels = MergeMultipleTraversals(tags);
        *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
      } else {
        *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
      }
    });
GaiYu0's avatar
GaiYu0 committed
250
251
252

}  // namespace traverse
}  // namespace dgl