traversal.cc 8.16 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
GaiYu0's avatar
GaiYu0 committed
3
 *  Copyright (c) 2018 by Contributors
4
5
 * @file graph/traversal.cc
 * @brief Graph traversal implementation
GaiYu0's avatar
GaiYu0 committed
6
 */
sangwzh's avatar
sangwzh committed
7
#include "traversal.h"
8

9
#include <dgl/packed_func_ext.h>
10

GaiYu0's avatar
GaiYu0 committed
11
#include <algorithm>
Gan Quan's avatar
Gan Quan committed
12
#include <queue>
13

GaiYu0's avatar
GaiYu0 committed
14
15
#include "../c_api_common.h"

16
using namespace dgl::runtime;
GaiYu0's avatar
GaiYu0 committed
17
18
19
20

namespace dgl {
namespace traverse {
namespace {
21
// A utility view class to wrap a vector into a queue.
22
template <typename DType>
23
24
25
struct VectorQueueWrapper {
  std::vector<DType>* vec;
  size_t head = 0;
GaiYu0's avatar
GaiYu0 committed
26

27
  explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}
GaiYu0's avatar
GaiYu0 committed
28

29
  void push(const DType& elem) { vec->push_back(elem); }
GaiYu0's avatar
GaiYu0 committed
30

31
  DType top() const { return vec->operator[](head); }
GaiYu0's avatar
GaiYu0 committed
32

33
  void pop() { ++head; }
34

35
  bool empty() const { return head == vec->size(); }
36

37
  size_t size() const { return vec->size() - head; }
GaiYu0's avatar
GaiYu0 committed
38
39
40
41
};

// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
42
43
template <typename DType>
IdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {
GaiYu0's avatar
GaiYu0 committed
44
45
46
47
48
49
  int64_t max_len = 0, total_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
    total_len += traces[i].size();
  }
50
51
  IdArray ret = IdArray::Empty(
      {total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
GaiYu0's avatar
GaiYu0 committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i >= tracelen) {
        continue;
      }
      *(ret_data++) = traces[j][i];
    }
  }
  return ret;
}

// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
67
68
template <typename DType>
IdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {
GaiYu0's avatar
GaiYu0 committed
69
70
71
72
73
  int64_t max_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
  }
74
75
  IdArray ret = IdArray::Empty(
      {max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
GaiYu0's avatar
GaiYu0 committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    int64_t sec_len = 0;
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i < tracelen) {
        ++sec_len;
      }
    }
    *(ret_data++) = sec_len;
  }
  return ret;
}

}  // namespace

92
/**
93
 * @brief Class for representing frontiers.
GaiYu0's avatar
GaiYu0 committed
94
95
 *
 * Each frontier is a list of nodes/edges (specified by their ids).
96
97
 * An optional tag can be specified on each node/edge (represented by an int
 * value).
GaiYu0's avatar
GaiYu0 committed
98
99
 */
struct Frontiers {
100
  /** @brief a vector store for the nodes/edges in all the frontiers */
GaiYu0's avatar
GaiYu0 committed
101
102
  std::vector<dgl_id_t> ids;

103
104
105
  /**
   * @brief a vector store for node/edge tags. Empty if no tags are requested
   */
GaiYu0's avatar
GaiYu0 committed
106
107
  std::vector<int64_t> tags;

108
  /** @brief a section vector to indicate each frontier */
GaiYu0's avatar
GaiYu0 committed
109
110
111
  std::vector<int64_t> sections;
};

112
113
Frontiers BFSNodesFrontiers(
    const GraphInterface& graph, IdArray source, bool reversed) {
GaiYu0's avatar
GaiYu0 committed
114
  Frontiers front;
115
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
116
117
118
119
120
121
122
  auto visit = [&](const dgl_id_t v) {};
  auto make_frontier = [&]() {
    if (!queue.empty()) {
      // do not push zero-length frontier
      front.sections.push_back(queue.size());
    }
  };
123
  BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
124
125
126
  return front;
}

127
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
128
129
130
131
132
133
134
135
136
137
138
139
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      const IdArray src = args[1];
      bool reversed = args[2];
      const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
      IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
      *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
    });

Frontiers BFSEdgesFrontiers(
    const GraphInterface& graph, IdArray source, bool reversed) {
Gan Quan's avatar
Gan Quan committed
140
141
142
143
  Frontiers front;
  // NOTE: std::queue has no top() method.
  std::vector<dgl_id_t> nodes;
  VectorQueueWrapper<dgl_id_t> queue(&nodes);
144
  auto visit = [&](const dgl_id_t e) { front.ids.push_back(e); };
Gan Quan's avatar
Gan Quan committed
145
146
  bool first_frontier = true;
  auto make_frontier = [&] {
147
148
149
150
151
152
153
    if (first_frontier) {
      first_frontier = false;  // do not push the first section when doing edges
    } else if (!queue.empty()) {
      // do not push zero-length frontier
      front.sections.push_back(queue.size());
    }
  };
Gan Quan's avatar
Gan Quan committed
154
155
156
157
  BFSEdges(graph, source, reversed, &queue, visit, make_frontier);
  return front;
}

158
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
159
160
161
162
163
164
165
166
167
168
169
170
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      const IdArray src = args[1];
      bool reversed = args[2];
      const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
      IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
      *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
    });

Frontiers TopologicalNodesFrontiers(
    const GraphInterface& graph, bool reversed) {
GaiYu0's avatar
GaiYu0 committed
171
  Frontiers front;
172
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
173
174
175
176
177
178
179
  auto visit = [&](const dgl_id_t v) {};
  auto make_frontier = [&]() {
    if (!queue.empty()) {
      // do not push zero-length frontier
      front.sections.push_back(queue.size());
    }
  };
180
  TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
181
182
183
  return front;
}

184
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
185
186
187
188
189
190
191
192
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      bool reversed = args[1];
      const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
      IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
      *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
    });
GaiYu0's avatar
GaiYu0 committed
193

194
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      const IdArray source = args[1];
      const bool reversed = args[2];
      CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
      const int64_t len = source->shape[0];
      const int64_t* src_data = static_cast<int64_t*>(source->data);
      std::vector<std::vector<dgl_id_t>> edges(len);
      for (int64_t i = 0; i < len; ++i) {
        auto visit = [&](dgl_id_t e, int tag) { edges[i].push_back(e); };
        DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
      }
      IdArray ids = MergeMultipleTraversals(edges);
      IdArray sections = ComputeMergedSections(edges);
      *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
    });
GaiYu0's avatar
GaiYu0 committed
211

212
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      GraphRef g = args[0];
      const IdArray source = args[1];
      const bool reversed = args[2];
      const bool has_reverse_edge = args[3];
      const bool has_nontree_edge = args[4];
      const bool return_labels = args[5];

      CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
      const int64_t len = source->shape[0];
      const int64_t* src_data = static_cast<int64_t*>(source->data);

      std::vector<std::vector<dgl_id_t>> edges(len);
      std::vector<std::vector<int64_t>> tags;
      if (return_labels) {
        tags.resize(len);
      }
      for (int64_t i = 0; i < len; ++i) {
        auto visit = [&](dgl_id_t e, int tag) {
          edges[i].push_back(e);
          if (return_labels) {
            tags[i].push_back(tag);
          }
        };
        DFSLabeledEdges(
            *g.sptr(), src_data[i], reversed, has_reverse_edge,
            has_nontree_edge, visit);
      }
GaiYu0's avatar
GaiYu0 committed
241

242
243
244
245
246
247
248
249
250
      IdArray ids = MergeMultipleTraversals(edges);
      IdArray sections = ComputeMergedSections(edges);
      if (return_labels) {
        IdArray labels = MergeMultipleTraversals(tags);
        *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
      } else {
        *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
      }
    });
GaiYu0's avatar
GaiYu0 committed
251
252
253

}  // namespace traverse
}  // namespace dgl