Unverified Commit 66676a54 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] Fix bug in bfs and topo (#135)

* Fix bug in bfs and topo

* remove legacy codes
parent 6d96a97f
......@@ -16,23 +16,33 @@ using tvm::runtime::NDArray;
namespace dgl {
namespace traverse {
namespace {
// A utility view class for a range of data in a vector.
// A utility view class to wrap a vector into a queue.
template<typename DType>
struct VectorView {
const std::vector<DType>* vec;
size_t range_start, range_end;
struct VectorQueueWrapper {
std::vector<DType>* vec;
size_t head = 0;
explicit VectorView(const std::vector<DType>* vec): vec(vec) {}
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {}
auto begin() const -> decltype(vec->begin()) {
return vec->begin() + range_start;
void push(const DType& elem) {
vec->push_back(elem);
}
auto end() const -> decltype(vec->end()) {
return vec->begin() + range_end;
DType top() const {
return vec->operator[](head);
}
size_t size() const { return range_end - range_start; }
void pop() {
++head;
}
bool empty() const {
return head == vec->size();
}
size_t size() const {
return vec->size() - head;
}
};
// Internal function to merge multiple traversal traces into one ndarray.
......@@ -106,20 +116,15 @@ struct Frontiers {
Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
Frontiers front;
size_t i = 0;
VectorView<dgl_id_t> front_view(&front.ids);
auto visit = [&] (const dgl_id_t v) { front.ids.push_back(v); };
VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { };
auto make_frontier = [&] () {
front_view.range_start = i;
front_view.range_end = front.ids.size();
if (front.ids.size() != i) {
if (!queue.empty()) {
// do not push zero-length frontier
front.sections.push_back(front.ids.size() - i);
front.sections.push_back(queue.size());
}
i = front.ids.size();
return front_view;
};
BFSNodes(graph, source, reversed, visit, make_frontier);
BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
return front;
}
......@@ -137,20 +142,15 @@ TVM_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
Frontiers front;
size_t i = 0;
VectorView<dgl_id_t> front_view(&front.ids);
auto visit = [&] (const dgl_id_t v) { front.ids.push_back(v); };
VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { };
auto make_frontier = [&] () {
front_view.range_start = i;
front_view.range_end = front.ids.size();
if (front.ids.size() != i) {
if (!queue.empty()) {
// do not push zero-length frontier
front.sections.push_back(front.ids.size() - i);
front.sections.push_back(queue.size());
}
i = front.ids.size();
return front_view;
};
TopologicalNodes(graph, reversed, visit, make_frontier);
TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
return front;
}
......
......@@ -22,17 +22,33 @@ namespace traverse {
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order.
*
* The queue object must suffice following interface:
* Members:
* void push(dgl_id_t); // push one node
* dgl_id_t top(); // get the first node
* void pop(); // pop one node
* bool empty(); // return true if the queue is empty
* size_t size(); // return the size of the queue
* For example, std::queue<dgl_id_t> is a valid queue type.
*
* The visit function must be compatible with following interface:
* void (*visit)(dgl_id_t );
*
* The frontier function must be compatible with following interface:
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction
* \param visit The function to call when a node is visited; the node id will be
* given as its only argument.
* \param make_frontier The function to make a new froniter; the function should return a
* node iterator to the just created frontier.
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename VisitFn, typename FrontierFn>
template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSNodes(const Graph& graph,
IdArray source,
bool reversed,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0];
......@@ -40,37 +56,59 @@ void BFSNodes(const Graph& graph,
std::vector<bool> visited(graph.NumVertices());
for (int64_t i = 0; i < len; ++i) {
visited[src_data[i]] = true;
visit(src_data[i]);
const dgl_id_t u = src_data[i];
visited[u] = true;
visit(u);
queue->push(u);
}
auto frontier = make_frontier();
make_frontier();
const auto neighbor_iter = reversed? &Graph::PredVec : &Graph::SuccVec;
while (frontier.size() != 0) {
for (const dgl_id_t u : frontier) {
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
const dgl_id_t u = queue->top();
queue->pop();
for (auto v : (graph.*neighbor_iter)(u)) {
if (!visited[v]) {
visit(v);
visited[v] = true;
visit(v);
queue->push(v);
}
}
}
frontier = make_frontier();
make_frontier();
}
}
/*!
* \brief Traverse the graph in topological order.
*
* The queue object must suffice following interface:
* Members:
* void push(dgl_id_t); // push one node
* dgl_id_t top(); // get the first node
* void pop(); // pop one node
* bool empty(); // return true if the queue is empty
* size_t size(); // return the size of the queue
* For example, std::queue<dgl_id_t> is a valid queue type.
*
* The visit function must be compatible with following interface:
* void (*visit)(dgl_id_t );
*
* The frontier function must be compatible with following interface:
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param reversed If true, follows the in-edge direction
* \param visit The function to call when a node is visited; the node id will be
* given as its only argument.
* \param make_frontier The function to make a new froniter; the function should return a
* node iterator to the just created frontier.
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename VisitFn, typename FrontierFn>
template<typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const Graph& graph,
bool reversed,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const auto get_degree = reversed? &Graph::OutDegree : &Graph::InDegree;
......@@ -81,23 +119,28 @@ void TopologicalNodes(const Graph& graph,
degrees[vid] = (graph.*get_degree)(vid);
if (degrees[vid] == 0) {
visit(vid);
queue->push(vid);
++num_visited_nodes;
}
}
auto frontier = make_frontier();
make_frontier();
while (frontier.size() != 0) {
for (const dgl_id_t u : frontier) {
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
const dgl_id_t u = queue->top();
queue->pop();
for (auto v : (graph.*neighbor_iter)(u)) {
if (--(degrees[v]) == 0) {
visit(v);
queue->push(v);
++num_visited_nodes;
}
}
}
// new node frointer
frontier = make_frontier();
make_frontier();
}
if (num_visited_nodes != graph.NumVertices()) {
LOG(FATAL) << "Error in topological traversal: loop detected in the given graph.";
}
......
......@@ -11,7 +11,7 @@ import utils as U
np.random.seed(42)
def test_bfs_nodes(n=100):
def test_bfs_nodes(n=1000):
g = dgl.DGLGraph()
a = sp.random(n, n, 10 / n, data_rvs=lambda n: np.ones(n))
g.from_scipy_sparse_matrix(a)
......@@ -19,8 +19,6 @@ def test_bfs_nodes(n=100):
src = random.choice(range(n))
layers_dgl = dgl.bfs_nodes_generator(g, src)
edges = nx.bfs_edges(g_nx, src)
layers_nx = [set([src])]
frontier = set()
......@@ -32,11 +30,13 @@ def test_bfs_nodes(n=100):
frontier = set([v])
layers_nx.append(frontier)
layers_dgl = dgl.bfs_nodes_generator(g, src)
toset = lambda x: set(x.tolist())
assert len(layers_dgl) == len(layers_nx)
assert all(toset(x) == y for x, y in zip(layers_dgl, layers_nx))
def test_topological_nodes(n=100):
def test_topological_nodes(n=1000):
g = dgl.DGLGraph()
a = sp.random(n, n, 10 / n, data_rvs=lambda n: np.ones(n))
b = sp.tril(a, -1).tocoo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment