Unverified Commit 42b0c38f authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Kernel] Migrate traversal on adjlist to CSR (#1650)



* traversal to new framework

* add new

* Fix compile

* Pass test

* keep old version

* lint

* lint

* Fix

* Fix

* Fix compatability with new master

* Fix test and tutorials

* Update according to comments

* Fix test
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent fb248b67
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/graph_traversal.h
* \brief common graph traversal operations
*/
#ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_
#include "array.h"
#include "base_heterograph.h"
namespace dgl {
///////////////////////// Graph Traverse routines //////////////////////////
/*!
* \brief Class for representing frontiers.
*
* Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value).
*/
struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */
IdArray ids;
/*!
* \brief a vector store for node/edge tags. Dtype is int64.
* Empty if no tags are requested
*/
IdArray tags;
/*!\brief a section vector to indicate each frontier Dtype is int64. */
IdArray sections;
};
namespace aten {
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
* \return A Frontiers object containing the search result
*/
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order, returning
* the edges of the BFS tree.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
* \return A Frontiers object containing the search result
*/
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
/*!
* \brief Traverse the graph in topological order.
*
* \param csr The input csr matrix.
* \return A Frontiers object containing the search result
*/
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
/*!
* \brief Traverse the graph in a depth-first-search (DFS) order.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
* \return A Frontiers object containing the search result
*/
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
/*!
* \brief Traverse the graph in a depth-first-search (DFS) order and return the
* recorded edge tag if return_labels is specified.
*
* The traversal visit edges in its DFS order. Edges have three tags:
* FORWARD(0), REVERSE(1), NONTREE(2)
*
* A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge
* is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge
* is NOT in the DFS tree.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
* \param has_reverse_edge If true, REVERSE edges are included
* \param has_nontree_edge If true, NONTREE edges are included
* \param return_labels If true, return the recorded edge tags.
* \return A Frontiers object containing the search result
*/
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
} // namespace aten
} // namespace dgl
#endif // DGL_GRAPH_TRAVERSAL_H_
......@@ -2,6 +2,7 @@
from __future__ import absolute_import
from . import traversal as trv
from .heterograph import DGLHeteroGraph
__all__ = ['prop_nodes', 'prop_nodes_bfs', 'prop_nodes_topo',
'prop_edges', 'prop_edges_dfs']
......@@ -56,24 +57,24 @@ def prop_edges(graph,
def prop_nodes_bfs(graph,
source,
message_func,
reduce_func,
reverse=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
apply_node_func=None):
"""Message propagation using node frontiers generated by BFS.
Parameters
----------
graph : DGLGraph
graph : DGLHeteroGraph
The graph object.
source : list, tensor of nodes
Source nodes.
reverse : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
message_func : callable
The message function.
reduce_func : callable, optional
reduce_func : callable
The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
apply_node_func : callable, optional
The update function.
......@@ -81,26 +82,30 @@ def prop_nodes_bfs(graph,
--------
dgl.traversal.bfs_nodes_generator
"""
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'prop_nodes_bfs only support homogeneous graph'
nodes_gen = trv.bfs_nodes_generator(graph, source, reverse)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_nodes_topo(graph,
message_func,
reduce_func,
reverse=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
apply_node_func=None):
"""Message propagation using node frontiers generated by topological order.
Parameters
----------
graph : DGLGraph
graph : DGLHeteroGraph
The graph object.
reverse : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
message_func : callable
The message function.
reduce_func : callable, optional
reduce_func : callable
The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
apply_node_func : callable, optional
The update function.
......@@ -108,31 +113,39 @@ def prop_nodes_topo(graph,
--------
dgl.traversal.topological_nodes_generator
"""
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'prop_nodes_topo only support homogeneous graph'
nodes_gen = trv.topological_nodes_generator(graph, reverse)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_edges_dfs(graph,
source,
message_func,
reduce_func,
reverse=False,
has_reverse_edge=False,
has_nontree_edge=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
apply_node_func=None):
"""Message propagation using edge frontiers generated by labeled DFS.
Parameters
----------
graph : DGLGraph
graph : DGLHeteroGraph
The graph object.
source : list, tensor of nodes
Source nodes.
reverse : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
has_reverse_edge : bool, optional
If true, REVERSE edges are included.
has_nontree_edge : bool, optional
If true, NONTREE edges are included.
apply_node_func : callable, optional
The update function.
......@@ -140,6 +153,10 @@ def prop_edges_dfs(graph,
--------
dgl.traversal.dfs_labeled_edges_generator
"""
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'prop_edges_dfs only support homogeneous graph'
edges_gen = trv.dfs_labeled_edges_generator(
graph, source, reverse, has_reverse_edge, has_nontree_edge,
return_labels=False)
......
......@@ -4,6 +4,7 @@ from __future__ import absolute_import
from ._ffi.function import _init_api
from . import backend as F
from . import utils
from .heterograph import DGLHeteroGraph
__all__ = ['bfs_nodes_generator', 'bfs_edges_generator',
'topological_nodes_generator',
......@@ -14,7 +15,7 @@ def bfs_nodes_generator(graph, source, reverse=False):
Parameters
----------
graph : DGLGraph
graph : DGLHeteroGraph
The graph object.
source : list, tensor of nodes
Source nodes.
......@@ -35,14 +36,18 @@ def bfs_nodes_generator(graph, source, reverse=False):
/ \\
0 - 1 - 3 - 5
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.bfs_nodes_generator(g, 0))
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
"""
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'bfs_nodes_generator only support homogeneous graph'
gidx = graph._graph
source = utils.toindex(source)
ret = _CAPI_DGLBFSNodes(gidx, source.todgltensor(), reverse)
all_nodes = utils.toindex(ret(0)).tousertensor()
source = utils.toindex(source, dtype=graph._idtype_str)
ret = _CAPI_DGLBFSNodes_v2(gidx, source.todgltensor(), reverse)
all_nodes = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
node_frontiers = F.split(all_nodes, sections, dim=0)
......@@ -53,7 +58,7 @@ def bfs_edges_generator(graph, source, reverse=False):
Parameters
----------
graph : DGLGraph
graph : DGLHeteroGraph
The graph object.
source : list, tensor of nodes
Source nodes.
......@@ -75,14 +80,18 @@ def bfs_edges_generator(graph, source, reverse=False):
/ \\
0 - 1 - 3 - 5
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.bfs_edges_generator(g, 0))
[tensor([0]), tensor([1, 2]), tensor([4, 5])]
"""
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'bfs_edges_generator only support homogeneous graph'
gidx = graph._graph
source = utils.toindex(source)
ret = _CAPI_DGLBFSEdges(gidx, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0)).tousertensor()
source = utils.toindex(source, dtype=graph._idtype_str)
ret = _CAPI_DGLBFSEdges_v2(gidx, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
edge_frontiers = F.split(all_edges, sections, dim=0)
......@@ -93,7 +102,7 @@ def topological_nodes_generator(graph, reverse=False):
Parameters
----------
graph : DGLGraph
graph : DGLHeteroGraph
The graph object.
reverse : bool, optional
If True, traverse following the in-edge direction.
......@@ -112,13 +121,17 @@ def topological_nodes_generator(graph, reverse=False):
/ \\
0 - 1 - 3 - 5
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.topological_nodes_generator(g))
[tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]
"""
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'topological_nodes_generator only support homogeneous graph'
gidx = graph._graph
ret = _CAPI_DGLTopologicalNodes(gidx, reverse)
all_nodes = utils.toindex(ret(0)).tousertensor()
ret = _CAPI_DGLTopologicalNodes_v2(gidx, reverse)
all_nodes = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_nodes, sections, dim=0)
......@@ -133,7 +146,7 @@ def dfs_edges_generator(graph, source, reverse=False):
Parameters
----------
graph : DGLGraph
graph : DGLHeteroGraph
The graph object.
source : list, tensor of nodes
Source nodes.
......@@ -156,14 +169,18 @@ def dfs_edges_generator(graph, source, reverse=False):
Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.dfs_edges_generator(g, 0))
[tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])]
"""
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'dfs_edges_generator only support homogeneous graph'
gidx = graph._graph
source = utils.toindex(source)
ret = _CAPI_DGLDFSEdges(gidx, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0)).tousertensor()
source = utils.toindex(source, dtype=graph._idtype_str)
ret = _CAPI_DGLDFSEdges_v2(gidx, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_edges, sections, dim=0)
......@@ -195,7 +212,7 @@ def dfs_labeled_edges_generator(
Parameters
----------
graph : DGLGraph
graph : DGLHeteroGraph
The graph object.
source : list, tensor of nodes
Source nodes.
......@@ -226,21 +243,25 @@ def dfs_labeled_edges_generator(
Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.dfs_labeled_edges_generator(g, 0, has_nontree_edge=True))
(tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])),
(tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2]))
"""
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'dfs_labeled_edges_generator only support homogeneous graph'
gidx = graph._graph
source = utils.toindex(source)
ret = _CAPI_DGLDFSLabeledEdges(
source = utils.toindex(source, dtype=graph._idtype_str)
ret = _CAPI_DGLDFSLabeledEdges_v2(
gidx,
source.todgltensor(),
reverse,
has_reverse_edge,
has_nontree_edge,
return_labels)
all_edges = utils.toindex(ret(0)).tousertensor()
all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list
if return_labels:
all_labels = utils.toindex(ret(1)).tousertensor()
......
......@@ -4,6 +4,7 @@
* \brief DGL array utilities implementation
*/
#include <dgl/array.h>
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/shared_mem.h>
......@@ -653,6 +654,90 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
return ret;
}
///////////////////////// Graph Traverse routines //////////////////////////
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSNodesFrontiers", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source);
});
});
return ret;
}
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSEdgesFrontiers", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source);
});
});
return ret;
}
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
Frontiers ret;
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", {
ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {
ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);
});
});
return ret;
}
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSEdges", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::DGLDFSEdges<XPU, IdType>(csr, source);
});
});
return ret;
}
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSLabeledEdges", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::DGLDFSLabeledEdges<XPU, IdType>(csr,
source,
has_reverse_edge,
has_nontree_edge,
return_labels);
});
});
return ret;
}
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
......
......@@ -7,6 +7,7 @@
#define DGL_ARRAY_ARRAY_OP_H_
#include <dgl/array.h>
#include <dgl/graph_traversal.h>
#include <vector>
#include <tuple>
#include <utility>
......@@ -202,6 +203,25 @@ template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
template <DLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
template <DLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
} // namespace impl
} // namespace aten
} // namespace dgl
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/traversal.cc
* \brief Graph traversal implementation
*/
#include <dgl/graph_traversal.h>
#include <algorithm>
#include <queue>
#include "./traversal.h"
namespace dgl {
namespace aten {
namespace impl {
namespace {
// A utility view class to wrap a vector into a queue.
template<typename DType>
struct VectorQueueWrapper {
std::vector<DType>* vec;
size_t head = 0;
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {}
void push(const DType& elem) {
vec->push_back(elem);
}
DType top() const {
return vec->operator[](head);
}
void pop() {
++head;
}
bool empty() const {
return head == vec->size();
}
size_t size() const {
return vec->size() - head;
}
};
// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
template<typename DType>
IdArray MergeMultipleTraversals(
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0, total_len = 0;
for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen);
total_len += traces[i].size();
}
IdArray ret = IdArray::Empty({total_len},
DLDataType{kDLInt, sizeof(DType) * 8, 1},
DLContext{kDLCPU, 0});
DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) {
const int64_t tracelen = traces[j].size();
if (i >= tracelen) {
continue;
}
*(ret_data++) = traces[j][i];
}
}
return ret;
}
// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
template<typename DType>
IdArray ComputeMergedSections(
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0;
for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen);
}
IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0;
for (size_t j = 0; j < traces.size(); ++j) {
const int64_t tracelen = traces[j].size();
if (i < tracelen) {
++sec_len;
}
}
*(ret_data++) = sec_len;
}
return ret;
}
} // namespace
template <DLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
VectorQueueWrapper<IdType> queue(&ids);
auto visit = [&] (const int64_t v) { };
auto make_frontier = [&] () {
if (!queue.empty()) {
// do not push zero-length frontier
sections.push_back(queue.size());
}
};
BFSTraverseNodes<IdType>(csr, source, &queue, visit, make_frontier);
Frontiers front;
front.ids = VecToIdArray(ids, sizeof(IdType) * 8);
front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);
return front;
}
template Frontiers BFSNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
// NOTE: std::queue has no top() method.
std::vector<IdType> nodes;
VectorQueueWrapper<IdType> queue(&nodes);
auto visit = [&] (const IdType e) { ids.push_back(e); };
bool first_frontier = true;
auto make_frontier = [&] {
if (first_frontier) {
first_frontier = false; // do not push the first section when doing edges
} else if (!queue.empty()) {
// do not push zero-length frontier
sections.push_back(queue.size());
}
};
BFSTraverseEdges<IdType>(csr, source, &queue, visit, make_frontier);
Frontiers front;
front.ids = VecToIdArray(ids, sizeof(IdType) * 8);
front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);
return front;
}
template Frontiers BFSEdgesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
VectorQueueWrapper<IdType> queue(&ids);
auto visit = [&] (const uint64_t v) { };
auto make_frontier = [&] () {
if (!queue.empty()) {
// do not push zero-length frontier
sections.push_back(queue.size());
}
};
TopologicalNodes<IdType>(csr, &queue, visit, make_frontier);
Frontiers front;
front.ids = VecToIdArray(ids, sizeof(IdType) * 8);
front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);
return front;
}
template Frontiers TopologicalNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&);
template <DLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data);
std::vector<std::vector<IdType>> edges(len);
for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (IdType e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges<IdType>(csr, src_data[i], false, false, visit);
}
Frontiers front;
front.ids = MergeMultipleTraversals(edges);
front.sections = ComputeMergedSections(edges);
return front;
}
template Frontiers DGLDFSEdges<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels) {
const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data);
std::vector<std::vector<IdType>> edges(len);
std::vector<std::vector<int64_t>> tags;
if (return_labels) {
tags.resize(len);
}
for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (IdType e, int64_t tag) {
edges[i].push_back(e);
if (return_labels) {
tags[i].push_back(tag);
}
};
DFSLabeledEdges<IdType>(csr, src_data[i],
has_reverse_edge, has_nontree_edge, visit);
}
Frontiers front;
front.ids = MergeMultipleTraversals(edges);
front.sections = ComputeMergedSections(edges);
if (return_labels) {
front.tags = MergeMultipleTraversals(tags);
}
return front;
}
template Frontiers DGLDFSLabeledEdges<kDLCPU, int32_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
const bool);
template Frontiers DGLDFSLabeledEdges<kDLCPU, int64_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
const bool);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/traversal.h
* \brief Graph traversal routines.
*
* Traversal routines generate frontiers. Frontiers can be node frontiers or edge
* frontiers depending on the traversal function. Each frontier is a
* list of nodes/edges (specified by their ids). An optional tag can be specified
* for each node/edge (represented by an int value).
*/
#ifndef DGL_ARRAY_CPU_TRAVERSAL_H_
#define DGL_ARRAY_CPU_TRAVERSAL_H_
#include <dgl/graph_interface.h>
#include <stack>
#include <tuple>
#include <vector>
namespace dgl {
namespace aten {
namespace impl {
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order.
*
* The queue object must suffice following interface:
* Members:
* void push(IdType); // push one node
* IdType top(); // get the first node
* void pop(); // pop one node
* bool empty(); // return true if the queue is empty
* size_t size(); // return the size of the queue
* For example, std::queue<IdType> is a valid queue type.
*
* The visit function must be compatible with following interface:
* void (*visit)(IdType );
*
* The frontier function must be compatible with following interface:
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void BFSTraverseNodes(const CSRMatrix& csr,
IdArray source,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0];
const IdType *src_data = static_cast<IdType*>(source->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const int64_t num_nodes = csr.num_rows;
std::vector<bool> visited(num_nodes);
for (int64_t i = 0; i < len; ++i) {
const IdType u = src_data[i];
visited[u] = true;
visit(u);
queue->push(u);
}
make_frontier();
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top();
queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) {
auto v = indices_data[idx];
if (!visited[v]) {
visited[v] = true;
visit(v);
queue->push(v);
}
}
}
make_frontier();
}
}
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order, returning
* the edges of the BFS tree.
*
* The queue object must suffice following interface:
* Members:
* void push(IdType); // push one node
* IdType top(); // get the first node
* void pop(); // pop one node
* bool empty(); // return true if the queue is empty
* size_t size(); // return the size of the queue
* For example, std::queue<IdType> is a valid queue type.
*
* The visit function must be compatible with following interface:
* void (*visit)(IdType );
*
* The frontier function must be compatible with following interface:
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* The argument would be edge ID.
* \param make_frontier The function to indicate that a new frontier can be made;
*/
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void BFSTraverseEdges(const CSRMatrix& csr,
IdArray source,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data);
const int64_t num_nodes = csr.num_rows;
std::vector<bool> visited(num_nodes);
for (int64_t i = 0; i < len; ++i) {
const IdType u = src_data[i];
visited[u] = true;
queue->push(u);
}
make_frontier();
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top();
queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) {
auto e = eid_data ? eid_data[idx] : idx;
const IdType v = indices_data[idx];
if (!visited[v]) {
visited[v] = true;
visit(e);
queue->push(v);
}
}
}
make_frontier();
}
}
/*!
* \brief Traverse the graph in topological order.
*
* The queue object must suffice following interface:
* Members:
* void push(IdType); // push one node
* IdType top(); // get the first node
* void pop(); // pop one node
* bool empty(); // return true if the queue is empty
* size_t size(); // return the size of the queue
* For example, std::queue<IdType> is a valid queue type.
*
* The visit function must be compatible with following interface:
* void (*visit)(IdType );
*
* The frontier function must be compatible with following interface:
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param reversed If true, follows the in-edge direction
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const CSRMatrix& csr,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
int64_t num_visited_nodes = 0;
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const int64_t num_nodes = csr.num_rows;
const int64_t num_edges = csr.indices->shape[0];
std::vector<int64_t> degrees(num_nodes, 0);
for (int64_t eid = 0; eid < num_edges; ++eid) {
degrees[indices_data[eid]]++;
}
for (int64_t vid = 0; vid < num_nodes; ++vid) {
if (degrees[vid] == 0) {
visit(vid);
queue->push(static_cast<IdType>(vid));
++num_visited_nodes;
}
}
make_frontier();
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top();
queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) {
const IdType v = indices_data[idx];
if (--(degrees[v]) == 0) {
visit(v);
queue->push(v);
++num_visited_nodes;
}
}
}
make_frontier();
}
if (num_visited_nodes != num_nodes) {
LOG(FATAL) << "Error in topological traversal: loop detected in the given graph.";
}
}
/*!\brief Tags for ``DFSEdges``. */
enum DFSEdgeTag {
kForward = 0,
kReverse,
kNonTree,
};
/*!
* \brief Traverse the graph in a depth-first-search (DFS) order.
*
* The traversal visit edges in its DFS order. Edges have three tags:
* FORWARD(0), REVERSE(1), NONTREE(2)
*
* A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge
* is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge
* is NOT in the DFS tree.
*
* \param source Source node.
* \param reversed If true, DFS follows the in-edge direction
* \param has_reverse_edge If true, REVERSE edges are included
* \param has_nontree_edge If true, NONTREE edges are included
* \param visit The function to call when an edge is visited; the edge id and its
* tag will be given as the arguments.
*/
template<typename IdType, typename VisitFn>
void DFSLabeledEdges(const CSRMatrix& csr,
IdType source,
bool has_reverse_edge,
bool has_nontree_edge,
VisitFn visit) {
const int64_t num_nodes = csr.num_rows;
CHECK_GE(num_nodes, source) << "source " << source <<
" is out of range [0," << num_nodes << "]";
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data);
if (indptr_data[source+1]-indptr_data[source] == 0) {
// no out-going edges from the source node
return;
}
typedef std::tuple<IdType, size_t, bool> StackEntry;
std::stack<StackEntry> stack;
std::vector<bool> visited(num_nodes);
visited[source] = true;
stack.push(std::make_tuple(source, 0, false));
IdType u = 0;
int64_t i = 0;
bool on_tree = false;
while (!stack.empty()) {
std::tie(u, i, on_tree) = stack.top();
const IdType v = indices_data[indptr_data[u] + i];
const IdType uv = eid_data ? eid_data[indptr_data[u] + i] : indptr_data[u] + i;
if (visited[v]) {
if (!on_tree && has_nontree_edge) {
visit(uv, kNonTree);
} else if (on_tree && has_reverse_edge) {
visit(uv, kReverse);
}
stack.pop();
// find next one.
if (indptr_data[u] + i < indptr_data[u + 1] - 1) {
stack.push(std::make_tuple(u, i+1, false));
}
} else {
visited[v] = true;
std::get<2>(stack.top()) = true;
visit(uv, kForward);
// expand
if (indptr_data[v] < indptr_data[v + 1]) {
stack.push(std::make_tuple(v, 0, false));
}
}
}
}
} // namespace impl
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CPU_TRAVERSAL_H_
......@@ -34,13 +34,13 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
*
* The element type of the vector must be convertible to int64_t.
*/
template<typename DType>
template<typename IdType, typename DType>
dgl::runtime::NDArray CopyVectorToNDArray(
const std::vector<DType>& vec) {
using dgl::runtime::NDArray;
const int64_t len = vec.size();
NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(a->data));
NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, sizeof(IdType), 1}, DLContext{kDLCPU, 0});
std::copy(vec.begin(), vec.end(), static_cast<IdType*>(a->data));
return a;
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file graph/traversal.cc
* \brief Graph traversal implementation
*/
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
namespace traverse {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::BFSNodesFrontiers(csr, src);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::BFSEdgesFrontiers(csr, src);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
bool reversed = args[1];
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::TopologicalNodesFrontiers(csr);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::DGLDFSEdges(csr, source);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4];
const bool return_labels = args[5];
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::DGLDFSLabeledEdges(csr,
source,
has_reverse_edge,
has_nontree_edge,
return_labels);
if (return_labels) {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.tags, front.sections});
} else {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
}
});
} // namespace traverse
} // namespace dgl
......@@ -132,8 +132,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
});
......@@ -162,8 +162,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
});
......@@ -186,8 +186,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
GraphRef g = args[0];
bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
});
......
......@@ -291,3 +291,4 @@ void DFSLabeledEdges(const GraphInterface& graph,
} // namespace dgl
#endif // DGL_GRAPH_TRAVERSAL_H_
......@@ -12,34 +12,30 @@ def rfunc(nodes):
def test_prop_nodes_bfs():
g = dgl.DGLGraph(nx.path_graph(5))
g = dgl.graph(g.edges())
g.ndata['x'] = F.ones((5, 2))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
dgl.prop_nodes_bfs(g, 0)
dgl.prop_nodes_bfs(g, 0, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# pull nodes using bfs order will result in a cumsum[i] + data[i] + data[i+1]
assert F.allclose(g.ndata['x'],
F.tensor([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))
def test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
g = dgl.graph(g.edges())
g.ndata['x'] = F.ones((5, 2))
dgl.prop_edges_dfs(g, 0)
dgl.prop_edges_dfs(g, 0, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# snr using dfs results in a cumsum
assert F.allclose(g.ndata['x'],
F.tensor([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]]))
g.ndata['x'] = F.ones((5, 2))
dgl.prop_edges_dfs(g, 0, has_reverse_edge=True)
dgl.prop_edges_dfs(g, 0, has_reverse_edge=True, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# result is cumsum[i] + cumsum[i-1]
assert F.allclose(g.ndata['x'],
F.tensor([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]]))
g.ndata['x'] = F.ones((5, 2))
dgl.prop_edges_dfs(g, 0, has_nontree_edge=True)
dgl.prop_edges_dfs(g, 0, has_nontree_edge=True, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# result is cumsum[i] + cumsum[i+1]
assert F.allclose(g.ndata['x'],
F.tensor([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))
......@@ -47,6 +43,7 @@ def test_prop_edges_dfs():
def test_prop_nodes_topo():
# bi-directional chain
g = dgl.DGLGraph(nx.path_graph(5))
g = dgl.graph(g.edges())
assert U.check_fail(dgl.prop_nodes_topo, g) # has loop
# tree
......@@ -56,13 +53,12 @@ def test_prop_nodes_topo():
tree.add_edge(2, 0)
tree.add_edge(3, 2)
tree.add_edge(4, 2)
tree.register_message_func(mfunc)
tree.register_reduce_func(rfunc)
tree = dgl.graph(tree.edges())
# init node feature data
tree.ndata['x'] = F.zeros((5, 2))
# set all leaf nodes to be ones
tree.nodes[[1, 3, 4]].data['x'] = F.ones((3, 2))
dgl.prop_nodes_topo(tree)
dgl.prop_nodes_topo(tree, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# root node get the sum
assert F.allclose(tree.nodes[0].data['x'], F.tensor([[3., 3.]]))
......
......@@ -9,6 +9,7 @@ import scipy.sparse as sp
import backend as F
import itertools
from utils import parametrize_dtype
np.random.seed(42)
......@@ -16,7 +17,8 @@ def toset(x):
# F.zerocopy_to_numpy may return a int
return set(F.zerocopy_to_numpy(x).tolist())
def test_bfs(n=100):
@parametrize_dtype
def test_bfs(index_dtype, n=100):
def _bfs_nx(g_nx, src):
edges = nx.bfs_edges(g_nx, src)
layers_nx = [set([src])]
......@@ -41,6 +43,11 @@ def test_bfs(n=100):
g = dgl.DGLGraph()
a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n))
g.from_scipy_sparse_matrix(a)
if index_dtype == 'int32':
g = dgl.graph(g.edges()).int()
else:
g = dgl.graph(g.edges()).long()
g_nx = g.to_networkx()
src = random.choice(range(n))
layers_nx, _ = _bfs_nx(g_nx, src)
......@@ -51,17 +58,27 @@ def test_bfs(n=100):
g_nx = nx.random_tree(n, seed=42)
g = dgl.DGLGraph()
g.from_networkx(g_nx)
if index_dtype == 'int32':
g = dgl.graph(g.edges()).int()
else:
g = dgl.graph(g.edges()).long()
src = 0
_, edges_nx = _bfs_nx(g_nx, src)
edges_dgl = dgl.bfs_edges_generator(g, src)
assert len(edges_dgl) == len(edges_nx)
assert all(toset(x) == y for x, y in zip(edges_dgl, edges_nx))
def test_topological_nodes(n=100):
@parametrize_dtype
def test_topological_nodes(index_dtype, n=100):
g = dgl.DGLGraph()
a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n))
b = sp.tril(a, -1).tocoo()
g.from_scipy_sparse_matrix(b)
if index_dtype == 'int32':
g = dgl.graph(g.edges()).int()
else:
g = dgl.graph(g.edges()).long()
layers_dgl = dgl.topological_nodes_generator(g)
......@@ -84,15 +101,19 @@ def test_topological_nodes(n=100):
assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))
DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree']
def test_dfs_labeled_edges(example=False):
@parametrize_dtype
def test_dfs_labeled_edges(index_dtype, example=False):
dgl_g = dgl.DGLGraph()
dgl_g.add_nodes(6)
dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])
if index_dtype == 'int32':
dgl_g = dgl.graph(dgl_g.edges()).int()
else:
dgl_g = dgl.graph(dgl_g.edges()).long()
dgl_edges, dgl_labels = dgl.dfs_labeled_edges_generator(
dgl_g, [0, 3], has_reverse_edge=True, has_nontree_edge=True)
dgl_edges = [toset(t) for t in dgl_edges]
dgl_labels = [toset(t) for t in dgl_labels]
g1_solutions = [
# edges labels
[[0, 1, 1, 0, 2], [0, 0, 1, 1, 2]],
......@@ -119,8 +140,7 @@ def test_dfs_labeled_edges(example=False):
else:
assert False
if __name__ == '__main__':
test_bfs()
test_topological_nodes()
test_dfs_labeled_edges()
test_bfs(index_dtype='int32')
test_topological_nodes(index_dtype='int32')
test_dfs_labeled_edges(index_dtype='int32')
......@@ -212,11 +212,15 @@ class TreeLSTMCell(nn.Module):
# followings:
#
# to heterogenous graph
trv_a_tree = dgl.graph(a_tree.edges())
print('Traversing one tree:')
print(dgl.topological_nodes_generator(a_tree))
print(dgl.topological_nodes_generator(trv_a_tree))
# to heterogenous graph
trv_graph = dgl.graph(graph.edges())
print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph))
print(dgl.topological_nodes_generator(trv_graph))
##############################################################################
# Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
......@@ -224,12 +228,11 @@ print(dgl.topological_nodes_generator(graph))
import dgl.function as fn
import torch as th
graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
graph.register_message_func(fn.copy_src('a', 'a'))
graph.register_reduce_func(fn.sum('a', 'a'))
traversal_order = dgl.topological_nodes_generator(graph)
graph.prop_nodes(traversal_order)
trv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
traversal_order = dgl.topological_nodes_generator(trv_graph)
trv_graph.prop_nodes(traversal_order,
message_func=fn.copy_src('a', 'a'),
reduce_func=fn.sum('a', 'a'))
# the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph)
......@@ -285,16 +288,18 @@ class TreeLSTM(nn.Module):
The prediction of each node.
"""
g = batch.graph
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_node_func)
# to heterogenous graph
g = dgl.graph(g.edges())
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
dgl.prop_nodes_topo(g)
dgl.prop_nodes_topo(g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func)
# compute logits
h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment