traversal.cc 7.11 KB
Newer Older
GaiYu0's avatar
GaiYu0 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/traversal.cc
 * \brief Graph traversal implementation
 */
#include <algorithm>
#include "./traversal.h"
#include "../c_api_common.h"

using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using tvm::runtime::NDArray;

namespace dgl {
namespace traverse {
namespace {
19
// A utility view class to wrap a vector into a queue.
GaiYu0's avatar
GaiYu0 committed
20
template<typename DType>
21
22
23
struct VectorQueueWrapper {
  std::vector<DType>* vec;
  size_t head = 0;
GaiYu0's avatar
GaiYu0 committed
24

25
  explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {}
GaiYu0's avatar
GaiYu0 committed
26

27
28
  void push(const DType& elem) {
    vec->push_back(elem);
GaiYu0's avatar
GaiYu0 committed
29
30
  }

31
32
  DType top() const {
    return vec->operator[](head);
GaiYu0's avatar
GaiYu0 committed
33
34
  }

35
36
37
38
39
40
41
42
43
44
45
  void pop() {
    ++head;
  }

  bool empty() const {
    return head == vec->size();
  }

  size_t size() const {
    return vec->size() - head;
  }
GaiYu0's avatar
GaiYu0 committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
};

// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
template<typename DType>
IdArray MergeMultipleTraversals(
    const std::vector<std::vector<DType>>& traces) {
  int64_t max_len = 0, total_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
    total_len += traces[i].size();
  }
  IdArray ret = IdArray::Empty({total_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i >= tracelen) {
        continue;
      }
      *(ret_data++) = traces[j][i];
    }
  }
  return ret;
}

// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
template<typename DType>
IdArray ComputeMergedSections(
    const std::vector<std::vector<DType>>& traces) {
  int64_t max_len = 0;
  for (size_t i = 0; i < traces.size(); ++i) {
    const int64_t tracelen = traces[i].size();
    max_len = std::max(max_len, tracelen);
  }
  IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* ret_data = static_cast<int64_t*>(ret->data);
  for (int64_t i = 0; i < max_len; ++i) {
    int64_t sec_len = 0;
    for (size_t j = 0; j < traces.size(); ++j) {
      const int64_t tracelen = traces[j].size();
      if (i < tracelen) {
        ++sec_len;
      }
    }
    *(ret_data++) = sec_len;
  }
  return ret;
}

}  // namespace

/*!
 * \brief Class for representing frontiers.
 *
 * Each frontier is a list of nodes/edges (specified by their ids).
 * An optional tag can be specified on each node/edge (represented by an int value).
 */
struct Frontiers {
  /*!\brief a vector store for the edges in all the fronties */
  std::vector<dgl_id_t> ids;

  /*!\brief a vector store for edge tags. The vector is empty is no tags. */
  std::vector<int64_t> tags;

  /*!\brief a section vector to indicate each frontier */
  std::vector<int64_t> sections;
};

Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
  Frontiers front;
119
120
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
  auto visit = [&] (const dgl_id_t v) { };
GaiYu0's avatar
GaiYu0 committed
121
  auto make_frontier = [&] () {
122
      if (!queue.empty()) {
GaiYu0's avatar
GaiYu0 committed
123
        // do not push zero-length frontier
124
        front.sections.push_back(queue.size());
GaiYu0's avatar
GaiYu0 committed
125
126
      }
    };
127
  BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  return front;
}

TVM_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
    const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    bool reversed = args[2];
    const auto& front = BFSNodesFrontiers(*gptr, src, reversed);
    IdArray node_ids = CopyVectorToNDArray(front.ids);
    IdArray sections = CopyVectorToNDArray(front.sections);
    *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
  });

Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
  Frontiers front;
145
146
  VectorQueueWrapper<dgl_id_t> queue(&front.ids);
  auto visit = [&] (const dgl_id_t v) { };
GaiYu0's avatar
GaiYu0 committed
147
  auto make_frontier = [&] () {
148
      if (!queue.empty()) {
GaiYu0's avatar
GaiYu0 committed
149
        // do not push zero-length frontier
150
        front.sections.push_back(queue.size());
GaiYu0's avatar
GaiYu0 committed
151
152
      }
    };
153
  TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
GaiYu0's avatar
GaiYu0 committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
  return front;
}

TVM_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
    bool reversed = args[1];
    const auto& front = TopologicalNodesFrontiers(*gptr, reversed);
    IdArray node_ids = CopyVectorToNDArray(front.ids);
    IdArray sections = CopyVectorToNDArray(front.sections);
    *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
  });


TVM_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
    const IdArray source = args[1];
    const bool reversed = args[2];
    CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
    const int64_t len = source->shape[0];
    const int64_t* src_data = static_cast<int64_t*>(source->data);
    std::vector<std::vector<dgl_id_t>> edges(len);
    for (int64_t i = 0; i < len; ++i) {
      auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); };
      DFSLabeledEdges(*gptr, src_data[i], reversed, false, false, visit);
    }
    IdArray ids = MergeMultipleTraversals(edges);
    IdArray sections = ComputeMergedSections(edges);
    *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
  });

TVM_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
    const IdArray source = args[1];
    const bool reversed = args[2];
    const bool has_reverse_edge = args[3];
    const bool has_nontree_edge = args[4];
    const bool return_labels = args[5];

    CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
    const int64_t len = source->shape[0];
    const int64_t* src_data = static_cast<int64_t*>(source->data);

    std::vector<std::vector<dgl_id_t>> edges(len);
    std::vector<std::vector<int64_t>> tags;
    if (return_labels) {
      tags.resize(len);
    }
    for (int64_t i = 0; i < len; ++i) {
      auto visit = [&] (dgl_id_t e, int tag) {
        edges[i].push_back(e);
        if (return_labels) {
          tags[i].push_back(tag);
        }
      };
      DFSLabeledEdges(*gptr, src_data[i], reversed,
          has_reverse_edge, has_nontree_edge, visit);
    }

    IdArray ids = MergeMultipleTraversals(edges);
    IdArray sections = ComputeMergedSections(edges);
    if (return_labels) {
      IdArray labels = MergeMultipleTraversals(tags);
      *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
    } else {
      *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
    }
  });

}  // namespace traverse
}  // namespace dgl