"vscode:/vscode.git/clone" did not exist on "b77a02cdfdb4cd58be3ebc6a66d076832c309cfc"
traversal.cc 7.9 KB
Newer Older
GaiYu0's avatar
GaiYu0 committed
1
2
3
4
5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/traversal.cc
 * \brief Graph traversal implementation
 */
6
#include <dgl/packed_func_ext.h>
GaiYu0's avatar
GaiYu0 committed
7
#include <algorithm>
Gan Quan's avatar
Gan Quan committed
8
#include <queue>
GaiYu0's avatar
GaiYu0 committed
9
10
11
#include "./traversal.h"
#include "../c_api_common.h"

12
using namespace dgl::runtime;
GaiYu0's avatar
GaiYu0 committed
13
14
15
16

namespace dgl {
namespace traverse {
namespace {
17
// A utility view class to wrap a vector into a queue.
GaiYu0's avatar
GaiYu0 committed
18
template<typename DType>
19
20
21
struct VectorQueueWrapper {
  std::vector<DType>* vec;
  size_t head = 0;
GaiYu0's avatar
GaiYu0 committed
22

23
  explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {}
GaiYu0's avatar
GaiYu0 committed
24

25
26
  void push(const DType& elem) {
    vec->push_back(elem);
GaiYu0's avatar
GaiYu0 committed
27
28
  }

29
30
  DType top() const {
    return vec->operator[](head);
GaiYu0's avatar
GaiYu0 committed
31
32
  }

33
34
35
36
37
38
39
40
41
42
43
  void pop() {
    ++head;
  }

  bool empty() const {
    return head == vec->size();
  }

  size_t size() const {
    return vec->size() - head;
  }
GaiYu0's avatar
GaiYu0 committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
};

// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
template<typename DType>
IdArray MergeMultipleTraversals(
    const std::vector<std::vector<DType>>& traces) {
  int64_t max_len = 0, total_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
    total_len += traces[i].size();
  }
  IdArray ret = IdArray::Empty({total_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i >= tracelen) {
        continue;
      }
      *(ret_data++) = traces[j][i];
    }
  }
  return ret;
}

// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
template<typename DType>
IdArray ComputeMergedSections(
    const std::vector<std::vector<DType>>& traces) {
  int64_t max_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
  }
  IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    int64_t sec_len = 0;
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i < tracelen) {
        ++sec_len;
      }
    }
    *(ret_data++) = sec_len;
  }
  return ret;
}

}  // namespace

/*!
 * \brief Class for representing frontiers.
 *
 * Each frontier is a list of nodes/edges (specified by their ids).
 * An optional tag can be specified on each node/edge (represented by an int value).
 */
struct Frontiers {
Gan Quan's avatar
Gan Quan committed
105
  /*!\brief a vector store for the nodes/edges in all the frontiers */
GaiYu0's avatar
GaiYu0 committed
106
107
  std::vector<dgl_id_t> ids;

Gan Quan's avatar
Gan Quan committed
108
  /*!\brief a vector store for node/edge tags. Empty if no tags are requested */
GaiYu0's avatar
GaiYu0 committed
109
110
111
112
113
114
  std::vector<int64_t> tags;

  /*!\brief a section vector to indicate each frontier */
  std::vector<int64_t> sections;
};

115
Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
GaiYu0's avatar
GaiYu0 committed
116
  Frontiers front;
117
118
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
  auto visit = [&] (const dgl_id_t v) { };
GaiYu0's avatar
GaiYu0 committed
119
  auto make_frontier = [&] () {
120
      if (!queue.empty()) {
GaiYu0's avatar
GaiYu0 committed
121
        // do not push zero-length frontier
122
        front.sections.push_back(queue.size());
GaiYu0's avatar
GaiYu0 committed
123
124
      }
    };
125
  BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
126
127
128
  return front;
}

129
130
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
131
    GraphRef g = args[0];
132
    const IdArray src = args[1];
GaiYu0's avatar
GaiYu0 committed
133
    bool reversed = args[2];
134
    const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
GaiYu0's avatar
GaiYu0 committed
135
136
137
138
139
    IdArray node_ids = CopyVectorToNDArray(front.ids);
    IdArray sections = CopyVectorToNDArray(front.sections);
    *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
  });

140
Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Gan Quan's avatar
Gan Quan committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  Frontiers front;
  // NOTE: std::queue has no top() method.
  std::vector<dgl_id_t> nodes;
  VectorQueueWrapper<dgl_id_t> queue(&nodes);
  auto visit = [&] (const dgl_id_t e) { front.ids.push_back(e); };
  bool first_frontier = true;
  auto make_frontier = [&] {
      if (first_frontier) {
        first_frontier = false;   // do not push the first section when doing edges
      } else if (!queue.empty()) {
        // do not push zero-length frontier
        front.sections.push_back(queue.size());
      }
    };
  BFSEdges(graph, source, reversed, &queue, visit, make_frontier);
  return front;
}

159
160
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
161
    GraphRef g = args[0];
162
    const IdArray src = args[1];
Gan Quan's avatar
Gan Quan committed
163
    bool reversed = args[2];
164
    const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
Gan Quan's avatar
Gan Quan committed
165
166
167
168
169
    IdArray edge_ids = CopyVectorToNDArray(front.ids);
    IdArray sections = CopyVectorToNDArray(front.sections);
    *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
  });

170
Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) {
GaiYu0's avatar
GaiYu0 committed
171
  Frontiers front;
172
173
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
  auto visit = [&] (const dgl_id_t v) { };
GaiYu0's avatar
GaiYu0 committed
174
  auto make_frontier = [&] () {
175
      if (!queue.empty()) {
GaiYu0's avatar
GaiYu0 committed
176
        // do not push zero-length frontier
177
        front.sections.push_back(queue.size());
GaiYu0's avatar
GaiYu0 committed
178
179
      }
    };
180
  TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
181
182
183
  return front;
}

184
185
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
186
    GraphRef g = args[0];
GaiYu0's avatar
GaiYu0 committed
187
    bool reversed = args[1];
188
    const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
GaiYu0's avatar
GaiYu0 committed
189
190
191
192
193
194
    IdArray node_ids = CopyVectorToNDArray(front.ids);
    IdArray sections = CopyVectorToNDArray(front.sections);
    *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
  });


195
196
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
197
    GraphRef g = args[0];
198
    const IdArray source = args[1];
GaiYu0's avatar
GaiYu0 committed
199
200
201
202
203
204
205
    const bool reversed = args[2];
    CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
    const int64_t len = source->shape[0];
    const int64_t* src_data = static_cast<int64_t*>(source->data);
    std::vector<std::vector<dgl_id_t>> edges(len);
    for (int64_t i = 0; i < len; ++i) {
      auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); };
206
      DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
GaiYu0's avatar
GaiYu0 committed
207
208
209
210
211
212
    }
    IdArray ids = MergeMultipleTraversals(edges);
    IdArray sections = ComputeMergedSections(edges);
    *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
  });

213
214
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
215
    GraphRef g = args[0];
216
    const IdArray source = args[1];
GaiYu0's avatar
GaiYu0 committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    const bool reversed = args[2];
    const bool has_reverse_edge = args[3];
    const bool has_nontree_edge = args[4];
    const bool return_labels = args[5];

    CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
    const int64_t len = source->shape[0];
    const int64_t* src_data = static_cast<int64_t*>(source->data);

    std::vector<std::vector<dgl_id_t>> edges(len);
    std::vector<std::vector<int64_t>> tags;
    if (return_labels) {
      tags.resize(len);
    }
    for (int64_t i = 0; i < len; ++i) {
      auto visit = [&] (dgl_id_t e, int tag) {
        edges[i].push_back(e);
        if (return_labels) {
          tags[i].push_back(tag);
        }
      };
238
      DFSLabeledEdges(*g.sptr(), src_data[i], reversed,
GaiYu0's avatar
GaiYu0 committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
          has_reverse_edge, has_nontree_edge, visit);
    }

    IdArray ids = MergeMultipleTraversals(edges);
    IdArray sections = ComputeMergedSections(edges);
    if (return_labels) {
      IdArray labels = MergeMultipleTraversals(tags);
      *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
    } else {
      *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
    }
  });

}  // namespace traverse
}  // namespace dgl