Unverified Commit 42b0c38f authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Kernel] Migrate traversal on adjlist to CSR (#1650)



* traversal to new framework

* add new

* Fix compile

* Pass test

* keep old version

* lint

* lint

* Fix

* Fix

* Fix compatability with new master

* Fix test and tutorials

* Update according to comments

* Fix test
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent fb248b67
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/graph_traversal.h
* \brief common graph traversal operations
*/
#ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_
#include "array.h"
#include "base_heterograph.h"
namespace dgl {
///////////////////////// Graph Traverse routines //////////////////////////
/*!
* \brief Class for representing frontiers.
*
* Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value).
*/
struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */
IdArray ids;
/*!
* \brief a vector store for node/edge tags. Dtype is int64.
* Empty if no tags are requested
*/
IdArray tags;
/*!\brief a section vector to indicate each frontier Dtype is int64. */
IdArray sections;
};
namespace aten {
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
* \return A Frontiers object containing the search result
*/
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order, returning
* the edges of the BFS tree.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
* \return A Frontiers object containing the search result
*/
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
/*!
* \brief Traverse the graph in topological order.
*
* \param csr The input csr matrix.
* \return A Frontiers object containing the search result
*/
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
/*!
* \brief Traverse the graph in a depth-first-search (DFS) order.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
* \return A Frontiers object containing the search result
*/
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
/*!
* \brief Traverse the graph in a depth-first-search (DFS) order and return the
* recorded edge tag if return_labels is specified.
*
* The traversal visit edges in its DFS order. Edges have three tags:
* FORWARD(0), REVERSE(1), NONTREE(2)
*
* A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge
* is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge
* is NOT in the DFS tree.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
* \param has_reverse_edge If true, REVERSE edges are included
* \param has_nontree_edge If true, NONTREE edges are included
* \param return_labels If true, return the recorded edge tags.
* \return A Frontiers object containing the search result
*/
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
} // namespace aten
} // namespace dgl
#endif // DGL_GRAPH_TRAVERSAL_H_
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from . import traversal as trv from . import traversal as trv
from .heterograph import DGLHeteroGraph
__all__ = ['prop_nodes', 'prop_nodes_bfs', 'prop_nodes_topo', __all__ = ['prop_nodes', 'prop_nodes_bfs', 'prop_nodes_topo',
'prop_edges', 'prop_edges_dfs'] 'prop_edges', 'prop_edges_dfs']
...@@ -56,24 +57,24 @@ def prop_edges(graph, ...@@ -56,24 +57,24 @@ def prop_edges(graph,
def prop_nodes_bfs(graph, def prop_nodes_bfs(graph,
source, source,
message_func,
reduce_func,
reverse=False, reverse=False,
message_func='default', apply_node_func=None):
reduce_func='default',
apply_node_func='default'):
"""Message propagation using node frontiers generated by BFS. """Message propagation using node frontiers generated by BFS.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLHeteroGraph
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
reverse : bool, optional message_func : callable
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function. The message function.
reduce_func : callable, optional reduce_func : callable
The reduce function. The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
...@@ -81,26 +82,30 @@ def prop_nodes_bfs(graph, ...@@ -81,26 +82,30 @@ def prop_nodes_bfs(graph,
-------- --------
dgl.traversal.bfs_nodes_generator dgl.traversal.bfs_nodes_generator
""" """
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'prop_nodes_bfs only support homogeneous graph'
nodes_gen = trv.bfs_nodes_generator(graph, source, reverse) nodes_gen = trv.bfs_nodes_generator(graph, source, reverse)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func) prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_nodes_topo(graph, def prop_nodes_topo(graph,
message_func,
reduce_func,
reverse=False, reverse=False,
message_func='default', apply_node_func=None):
reduce_func='default',
apply_node_func='default'):
"""Message propagation using node frontiers generated by topological order. """Message propagation using node frontiers generated by topological order.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLHeteroGraph
The graph object. The graph object.
reverse : bool, optional message_func : callable
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function. The message function.
reduce_func : callable, optional reduce_func : callable
The reduce function. The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
...@@ -108,31 +113,39 @@ def prop_nodes_topo(graph, ...@@ -108,31 +113,39 @@ def prop_nodes_topo(graph,
-------- --------
dgl.traversal.topological_nodes_generator dgl.traversal.topological_nodes_generator
""" """
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'prop_nodes_topo only support homogeneous graph'
nodes_gen = trv.topological_nodes_generator(graph, reverse) nodes_gen = trv.topological_nodes_generator(graph, reverse)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func) prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_edges_dfs(graph, def prop_edges_dfs(graph,
source, source,
message_func,
reduce_func,
reverse=False, reverse=False,
has_reverse_edge=False, has_reverse_edge=False,
has_nontree_edge=False, has_nontree_edge=False,
message_func='default', apply_node_func=None):
reduce_func='default',
apply_node_func='default'):
"""Message propagation using edge frontiers generated by labeled DFS. """Message propagation using edge frontiers generated by labeled DFS.
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLHeteroGraph
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
reverse : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional message_func : callable, optional
The message function. The message function.
reduce_func : callable, optional reduce_func : callable, optional
The reduce function. The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
has_reverse_edge : bool, optional
If true, REVERSE edges are included.
has_nontree_edge : bool, optional
If true, NONTREE edges are included.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
...@@ -140,6 +153,10 @@ def prop_edges_dfs(graph, ...@@ -140,6 +153,10 @@ def prop_edges_dfs(graph,
-------- --------
dgl.traversal.dfs_labeled_edges_generator dgl.traversal.dfs_labeled_edges_generator
""" """
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'prop_edges_dfs only support homogeneous graph'
edges_gen = trv.dfs_labeled_edges_generator( edges_gen = trv.dfs_labeled_edges_generator(
graph, source, reverse, has_reverse_edge, has_nontree_edge, graph, source, reverse, has_reverse_edge, has_nontree_edge,
return_labels=False) return_labels=False)
......
...@@ -4,6 +4,7 @@ from __future__ import absolute_import ...@@ -4,6 +4,7 @@ from __future__ import absolute_import
from ._ffi.function import _init_api from ._ffi.function import _init_api
from . import backend as F from . import backend as F
from . import utils from . import utils
from .heterograph import DGLHeteroGraph
__all__ = ['bfs_nodes_generator', 'bfs_edges_generator', __all__ = ['bfs_nodes_generator', 'bfs_edges_generator',
'topological_nodes_generator', 'topological_nodes_generator',
...@@ -14,7 +15,7 @@ def bfs_nodes_generator(graph, source, reverse=False): ...@@ -14,7 +15,7 @@ def bfs_nodes_generator(graph, source, reverse=False):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLHeteroGraph
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
...@@ -35,14 +36,18 @@ def bfs_nodes_generator(graph, source, reverse=False): ...@@ -35,14 +36,18 @@ def bfs_nodes_generator(graph, source, reverse=False):
/ \\ / \\
0 - 1 - 3 - 5 0 - 1 - 3 - 5
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]) >>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.bfs_nodes_generator(g, 0)) >>> list(dgl.bfs_nodes_generator(g, 0))
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])] [tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'bfs_nodes_generator only support homogeneous graph'
gidx = graph._graph gidx = graph._graph
source = utils.toindex(source) source = utils.toindex(source, dtype=graph._idtype_str)
ret = _CAPI_DGLBFSNodes(gidx, source.todgltensor(), reverse) ret = _CAPI_DGLBFSNodes_v2(gidx, source.todgltensor(), reverse)
all_nodes = utils.toindex(ret(0)).tousertensor() all_nodes = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
node_frontiers = F.split(all_nodes, sections, dim=0) node_frontiers = F.split(all_nodes, sections, dim=0)
...@@ -53,7 +58,7 @@ def bfs_edges_generator(graph, source, reverse=False): ...@@ -53,7 +58,7 @@ def bfs_edges_generator(graph, source, reverse=False):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLHeteroGraph
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
...@@ -75,14 +80,18 @@ def bfs_edges_generator(graph, source, reverse=False): ...@@ -75,14 +80,18 @@ def bfs_edges_generator(graph, source, reverse=False):
/ \\ / \\
0 - 1 - 3 - 5 0 - 1 - 3 - 5
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]) >>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.bfs_edges_generator(g, 0)) >>> list(dgl.bfs_edges_generator(g, 0))
[tensor([0]), tensor([1, 2]), tensor([4, 5])] [tensor([0]), tensor([1, 2]), tensor([4, 5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'bfs_edges_generator only support homogeneous graph'
gidx = graph._graph gidx = graph._graph
source = utils.toindex(source) source = utils.toindex(source, dtype=graph._idtype_str)
ret = _CAPI_DGLBFSEdges(gidx, source.todgltensor(), reverse) ret = _CAPI_DGLBFSEdges_v2(gidx, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
edge_frontiers = F.split(all_edges, sections, dim=0) edge_frontiers = F.split(all_edges, sections, dim=0)
...@@ -93,7 +102,7 @@ def topological_nodes_generator(graph, reverse=False): ...@@ -93,7 +102,7 @@ def topological_nodes_generator(graph, reverse=False):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLHeteroGraph
The graph object. The graph object.
reverse : bool, optional reverse : bool, optional
If True, traverse following the in-edge direction. If True, traverse following the in-edge direction.
...@@ -112,13 +121,17 @@ def topological_nodes_generator(graph, reverse=False): ...@@ -112,13 +121,17 @@ def topological_nodes_generator(graph, reverse=False):
/ \\ / \\
0 - 1 - 3 - 5 0 - 1 - 3 - 5
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]) >>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.topological_nodes_generator(g)) >>> list(dgl.topological_nodes_generator(g))
[tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])] [tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]
""" """
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'topological_nodes_generator only support homogeneous graph'
gidx = graph._graph gidx = graph._graph
ret = _CAPI_DGLTopologicalNodes(gidx, reverse) ret = _CAPI_DGLTopologicalNodes_v2(gidx, reverse)
all_nodes = utils.toindex(ret(0)).tousertensor() all_nodes = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_nodes, sections, dim=0) return F.split(all_nodes, sections, dim=0)
...@@ -133,7 +146,7 @@ def dfs_edges_generator(graph, source, reverse=False): ...@@ -133,7 +146,7 @@ def dfs_edges_generator(graph, source, reverse=False):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLHeteroGraph
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
...@@ -156,14 +169,18 @@ def dfs_edges_generator(graph, source, reverse=False): ...@@ -156,14 +169,18 @@ def dfs_edges_generator(graph, source, reverse=False):
Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)] Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]) >>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.dfs_edges_generator(g, 0)) >>> list(dgl.dfs_edges_generator(g, 0))
[tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])] [tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])]
""" """
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'dfs_edges_generator only support homogeneous graph'
gidx = graph._graph gidx = graph._graph
source = utils.toindex(source) source = utils.toindex(source, dtype=graph._idtype_str)
ret = _CAPI_DGLDFSEdges(gidx, source.todgltensor(), reverse) ret = _CAPI_DGLDFSEdges_v2(gidx, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_edges, sections, dim=0) return F.split(all_edges, sections, dim=0)
...@@ -195,7 +212,7 @@ def dfs_labeled_edges_generator( ...@@ -195,7 +212,7 @@ def dfs_labeled_edges_generator(
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLHeteroGraph
The graph object. The graph object.
source : list, tensor of nodes source : list, tensor of nodes
Source nodes. Source nodes.
...@@ -226,21 +243,25 @@ def dfs_labeled_edges_generator( ...@@ -226,21 +243,25 @@ def dfs_labeled_edges_generator(
Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)] Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]
>>> g = dgl.DGLGraph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]) >>> g = dgl.graph([(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)])
>>> list(dgl.dfs_labeled_edges_generator(g, 0, has_nontree_edge=True)) >>> list(dgl.dfs_labeled_edges_generator(g, 0, has_nontree_edge=True))
(tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])), (tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])),
(tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2])) (tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2]))
""" """
assert isinstance(graph, DGLHeteroGraph), \
'DGLGraph is deprecated, Please use DGLHeteroGraph'
assert len(graph.canonical_etypes) == 1, \
'dfs_labeled_edges_generator only support homogeneous graph'
gidx = graph._graph gidx = graph._graph
source = utils.toindex(source) source = utils.toindex(source, dtype=graph._idtype_str)
ret = _CAPI_DGLDFSLabeledEdges( ret = _CAPI_DGLDFSLabeledEdges_v2(
gidx, gidx,
source.todgltensor(), source.todgltensor(),
reverse, reverse,
has_reverse_edge, has_reverse_edge,
has_nontree_edge, has_nontree_edge,
return_labels) return_labels)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
if return_labels: if return_labels:
all_labels = utils.toindex(ret(1)).tousertensor() all_labels = utils.toindex(ret(1)).tousertensor()
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief DGL array utilities implementation * \brief DGL array utilities implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/shared_mem.h> #include <dgl/runtime/shared_mem.h>
...@@ -653,6 +654,90 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) { ...@@ -653,6 +654,90 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
return ret; return ret;
} }
///////////////////////// Graph Traverse routines //////////////////////////
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSNodesFrontiers", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source);
});
});
return ret;
}
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSEdgesFrontiers", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source);
});
});
return ret;
}
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
Frontiers ret;
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", {
ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {
ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);
});
});
return ret;
}
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSEdges", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::DGLDFSEdges<XPU, IdType>(csr, source);
});
});
return ret;
}
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSLabeledEdges", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::DGLDFSLabeledEdges<XPU, IdType>(csr,
source,
has_reverse_edge,
has_nontree_edge,
return_labels);
});
});
return ret;
}
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_ARRAY_ARRAY_OP_H_ #define DGL_ARRAY_ARRAY_OP_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/graph_traversal.h>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
...@@ -202,6 +203,25 @@ template <DLDeviceType XPU, typename IdType, typename FloatType> ...@@ -202,6 +203,25 @@ template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseTopk( COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending); COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
template <DLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
template <DLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/traversal.cc
* \brief Graph traversal implementation
*/
#include <dgl/graph_traversal.h>
#include <algorithm>
#include <queue>
#include "./traversal.h"
namespace dgl {
namespace aten {
namespace impl {
namespace {
// A utility view class to wrap a vector into a queue.
template<typename DType>
struct VectorQueueWrapper {
std::vector<DType>* vec;
size_t head = 0;
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {}
void push(const DType& elem) {
vec->push_back(elem);
}
DType top() const {
return vec->operator[](head);
}
void pop() {
++head;
}
bool empty() const {
return head == vec->size();
}
size_t size() const {
return vec->size() - head;
}
};
// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
template<typename DType>
IdArray MergeMultipleTraversals(
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0, total_len = 0;
for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen);
total_len += traces[i].size();
}
IdArray ret = IdArray::Empty({total_len},
DLDataType{kDLInt, sizeof(DType) * 8, 1},
DLContext{kDLCPU, 0});
DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) {
const int64_t tracelen = traces[j].size();
if (i >= tracelen) {
continue;
}
*(ret_data++) = traces[j][i];
}
}
return ret;
}
// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
template<typename DType>
IdArray ComputeMergedSections(
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0;
for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen);
}
IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0;
for (size_t j = 0; j < traces.size(); ++j) {
const int64_t tracelen = traces[j].size();
if (i < tracelen) {
++sec_len;
}
}
*(ret_data++) = sec_len;
}
return ret;
}
} // namespace
template <DLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
VectorQueueWrapper<IdType> queue(&ids);
auto visit = [&] (const int64_t v) { };
auto make_frontier = [&] () {
if (!queue.empty()) {
// do not push zero-length frontier
sections.push_back(queue.size());
}
};
BFSTraverseNodes<IdType>(csr, source, &queue, visit, make_frontier);
Frontiers front;
front.ids = VecToIdArray(ids, sizeof(IdType) * 8);
front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);
return front;
}
template Frontiers BFSNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
// NOTE: std::queue has no top() method.
std::vector<IdType> nodes;
VectorQueueWrapper<IdType> queue(&nodes);
auto visit = [&] (const IdType e) { ids.push_back(e); };
bool first_frontier = true;
auto make_frontier = [&] {
if (first_frontier) {
first_frontier = false; // do not push the first section when doing edges
} else if (!queue.empty()) {
// do not push zero-length frontier
sections.push_back(queue.size());
}
};
BFSTraverseEdges<IdType>(csr, source, &queue, visit, make_frontier);
Frontiers front;
front.ids = VecToIdArray(ids, sizeof(IdType) * 8);
front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);
return front;
}
template Frontiers BFSEdgesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
VectorQueueWrapper<IdType> queue(&ids);
auto visit = [&] (const uint64_t v) { };
auto make_frontier = [&] () {
if (!queue.empty()) {
// do not push zero-length frontier
sections.push_back(queue.size());
}
};
TopologicalNodes<IdType>(csr, &queue, visit, make_frontier);
Frontiers front;
front.ids = VecToIdArray(ids, sizeof(IdType) * 8);
front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);
return front;
}
template Frontiers TopologicalNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&);
template <DLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data);
std::vector<std::vector<IdType>> edges(len);
for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (IdType e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges<IdType>(csr, src_data[i], false, false, visit);
}
Frontiers front;
front.ids = MergeMultipleTraversals(edges);
front.sections = ComputeMergedSections(edges);
return front;
}
template Frontiers DGLDFSEdges<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels) {
const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data);
std::vector<std::vector<IdType>> edges(len);
std::vector<std::vector<int64_t>> tags;
if (return_labels) {
tags.resize(len);
}
for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (IdType e, int64_t tag) {
edges[i].push_back(e);
if (return_labels) {
tags[i].push_back(tag);
}
};
DFSLabeledEdges<IdType>(csr, src_data[i],
has_reverse_edge, has_nontree_edge, visit);
}
Frontiers front;
front.ids = MergeMultipleTraversals(edges);
front.sections = ComputeMergedSections(edges);
if (return_labels) {
front.tags = MergeMultipleTraversals(tags);
}
return front;
}
template Frontiers DGLDFSLabeledEdges<kDLCPU, int32_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
const bool);
template Frontiers DGLDFSLabeledEdges<kDLCPU, int64_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
const bool);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/traversal.h
* \brief Graph traversal routines.
*
* Traversal routines generate frontiers. Frontiers can be node frontiers or edge
* frontiers depending on the traversal function. Each frontier is a
* list of nodes/edges (specified by their ids). An optional tag can be specified
* for each node/edge (represented by an int value).
*/
#ifndef DGL_ARRAY_CPU_TRAVERSAL_H_
#define DGL_ARRAY_CPU_TRAVERSAL_H_
#include <dgl/graph_interface.h>
#include <stack>
#include <tuple>
#include <vector>
namespace dgl {
namespace aten {
namespace impl {
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order.
*
* The queue object must suffice following interface:
* Members:
* void push(IdType); // push one node
* IdType top(); // get the first node
* void pop(); // pop one node
* bool empty(); // return true if the queue is empty
* size_t size(); // return the size of the queue
* For example, std::queue<IdType> is a valid queue type.
*
* The visit function must be compatible with following interface:
* void (*visit)(IdType );
*
* The frontier function must be compatible with following interface:
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void BFSTraverseNodes(const CSRMatrix& csr,
IdArray source,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0];
const IdType *src_data = static_cast<IdType*>(source->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const int64_t num_nodes = csr.num_rows;
std::vector<bool> visited(num_nodes);
for (int64_t i = 0; i < len; ++i) {
const IdType u = src_data[i];
visited[u] = true;
visit(u);
queue->push(u);
}
make_frontier();
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top();
queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) {
auto v = indices_data[idx];
if (!visited[v]) {
visited[v] = true;
visit(v);
queue->push(v);
}
}
}
make_frontier();
}
}
/*!
* \brief Traverse the graph in a breadth-first-search (BFS) order, returning
* the edges of the BFS tree.
*
* The queue object must suffice following interface:
* Members:
* void push(IdType); // push one node
* IdType top(); // get the first node
* void pop(); // pop one node
* bool empty(); // return true if the queue is empty
* size_t size(); // return the size of the queue
* For example, std::queue<IdType> is a valid queue type.
*
* The visit function must be compatible with following interface:
* void (*visit)(IdType );
*
* The frontier function must be compatible with following interface:
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* The argument would be edge ID.
* \param make_frontier The function to indicate that a new frontier can be made;
*/
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void BFSTraverseEdges(const CSRMatrix& csr,
IdArray source,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data);
const int64_t num_nodes = csr.num_rows;
std::vector<bool> visited(num_nodes);
for (int64_t i = 0; i < len; ++i) {
const IdType u = src_data[i];
visited[u] = true;
queue->push(u);
}
make_frontier();
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top();
queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) {
auto e = eid_data ? eid_data[idx] : idx;
const IdType v = indices_data[idx];
if (!visited[v]) {
visited[v] = true;
visit(e);
queue->push(v);
}
}
}
make_frontier();
}
}
/*!
* \brief Traverse the graph in topological order.
*
* The queue object must suffice following interface:
* Members:
* void push(IdType); // push one node
* IdType top(); // get the first node
* void pop(); // pop one node
* bool empty(); // return true if the queue is empty
* size_t size(); // return the size of the queue
* For example, std::queue<IdType> is a valid queue type.
*
* The visit function must be compatible with following interface:
* void (*visit)(IdType );
*
* The frontier function must be compatible with following interface:
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param reversed If true, follows the in-edge direction
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const CSRMatrix& csr,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
int64_t num_visited_nodes = 0;
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const int64_t num_nodes = csr.num_rows;
const int64_t num_edges = csr.indices->shape[0];
std::vector<int64_t> degrees(num_nodes, 0);
for (int64_t eid = 0; eid < num_edges; ++eid) {
degrees[indices_data[eid]]++;
}
for (int64_t vid = 0; vid < num_nodes; ++vid) {
if (degrees[vid] == 0) {
visit(vid);
queue->push(static_cast<IdType>(vid));
++num_visited_nodes;
}
}
make_frontier();
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top();
queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) {
const IdType v = indices_data[idx];
if (--(degrees[v]) == 0) {
visit(v);
queue->push(v);
++num_visited_nodes;
}
}
}
make_frontier();
}
if (num_visited_nodes != num_nodes) {
LOG(FATAL) << "Error in topological traversal: loop detected in the given graph.";
}
}
/*!\brief Tags for ``DFSEdges``. */
enum DFSEdgeTag {
kForward = 0,
kReverse,
kNonTree,
};
/*!
* \brief Traverse the graph in a depth-first-search (DFS) order.
*
* The traversal visit edges in its DFS order. Edges have three tags:
* FORWARD(0), REVERSE(1), NONTREE(2)
*
* A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge
* is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge
* is NOT in the DFS tree.
*
* \param source Source node.
* \param reversed If true, DFS follows the in-edge direction
* \param has_reverse_edge If true, REVERSE edges are included
* \param has_nontree_edge If true, NONTREE edges are included
* \param visit The function to call when an edge is visited; the edge id and its
* tag will be given as the arguments.
*/
template<typename IdType, typename VisitFn>
void DFSLabeledEdges(const CSRMatrix& csr,
IdType source,
bool has_reverse_edge,
bool has_nontree_edge,
VisitFn visit) {
const int64_t num_nodes = csr.num_rows;
CHECK_GE(num_nodes, source) << "source " << source <<
" is out of range [0," << num_nodes << "]";
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data);
if (indptr_data[source+1]-indptr_data[source] == 0) {
// no out-going edges from the source node
return;
}
typedef std::tuple<IdType, size_t, bool> StackEntry;
std::stack<StackEntry> stack;
std::vector<bool> visited(num_nodes);
visited[source] = true;
stack.push(std::make_tuple(source, 0, false));
IdType u = 0;
int64_t i = 0;
bool on_tree = false;
while (!stack.empty()) {
std::tie(u, i, on_tree) = stack.top();
const IdType v = indices_data[indptr_data[u] + i];
const IdType uv = eid_data ? eid_data[indptr_data[u] + i] : indptr_data[u] + i;
if (visited[v]) {
if (!on_tree && has_nontree_edge) {
visit(uv, kNonTree);
} else if (on_tree && has_reverse_edge) {
visit(uv, kReverse);
}
stack.pop();
// find next one.
if (indptr_data[u] + i < indptr_data[u + 1] - 1) {
stack.push(std::make_tuple(u, i+1, false));
}
} else {
visited[v] = true;
std::get<2>(stack.top()) = true;
visit(uv, kForward);
// expand
if (indptr_data[v] < indptr_data[v + 1]) {
stack.push(std::make_tuple(v, 0, false));
}
}
}
}
} // namespace impl
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CPU_TRAVERSAL_H_
...@@ -34,13 +34,13 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc( ...@@ -34,13 +34,13 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
* *
* The element type of the vector must be convertible to int64_t. * The element type of the vector must be convertible to int64_t.
*/ */
template<typename DType> template<typename IdType, typename DType>
dgl::runtime::NDArray CopyVectorToNDArray( dgl::runtime::NDArray CopyVectorToNDArray(
const std::vector<DType>& vec) { const std::vector<DType>& vec) {
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
const int64_t len = vec.size(); const int64_t len = vec.size();
NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, sizeof(IdType), 1}, DLContext{kDLCPU, 0});
std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(a->data)); std::copy(vec.begin(), vec.end(), static_cast<IdType*>(a->data));
return a; return a;
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file graph/traversal.cc
* \brief Graph traversal implementation
*/
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
namespace traverse {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::BFSNodesFrontiers(csr, src);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::BFSEdgesFrontiers(csr, src);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
bool reversed = args[1];
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::TopologicalNodesFrontiers(csr);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::DGLDFSEdges(csr, source);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4];
const bool return_labels = args[5];
aten::CSRMatrix csr;
if (reversed) {
csr = g.sptr()->GetCSCMatrix(0);
} else {
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::DGLDFSLabeledEdges(csr,
source,
has_reverse_edge,
has_nontree_edge,
return_labels);
if (return_labels) {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.tags, front.sections});
} else {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
}
});
} // namespace traverse
} // namespace dgl
...@@ -132,8 +132,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes") ...@@ -132,8 +132,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed); const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids); IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
...@@ -162,8 +162,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges") ...@@ -162,8 +162,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed); const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray(front.ids); IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
}); });
...@@ -186,8 +186,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") ...@@ -186,8 +186,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
GraphRef g = args[0]; GraphRef g = args[0];
bool reversed = args[1]; bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed); const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids); IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
......
...@@ -291,3 +291,4 @@ void DFSLabeledEdges(const GraphInterface& graph, ...@@ -291,3 +291,4 @@ void DFSLabeledEdges(const GraphInterface& graph,
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_TRAVERSAL_H_ #endif // DGL_GRAPH_TRAVERSAL_H_
...@@ -12,34 +12,30 @@ def rfunc(nodes): ...@@ -12,34 +12,30 @@ def rfunc(nodes):
def test_prop_nodes_bfs(): def test_prop_nodes_bfs():
g = dgl.DGLGraph(nx.path_graph(5)) g = dgl.DGLGraph(nx.path_graph(5))
g = dgl.graph(g.edges())
g.ndata['x'] = F.ones((5, 2)) g.ndata['x'] = F.ones((5, 2))
g.register_message_func(mfunc) dgl.prop_nodes_bfs(g, 0, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
g.register_reduce_func(rfunc)
dgl.prop_nodes_bfs(g, 0)
# pull nodes using bfs order will result in a cumsum[i] + data[i] + data[i+1] # pull nodes using bfs order will result in a cumsum[i] + data[i] + data[i+1]
assert F.allclose(g.ndata['x'], assert F.allclose(g.ndata['x'],
F.tensor([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]])) F.tensor([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))
def test_prop_edges_dfs(): def test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5)) g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc) g = dgl.graph(g.edges())
g.register_reduce_func(rfunc)
g.ndata['x'] = F.ones((5, 2)) g.ndata['x'] = F.ones((5, 2))
dgl.prop_edges_dfs(g, 0) dgl.prop_edges_dfs(g, 0, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# snr using dfs results in a cumsum # snr using dfs results in a cumsum
assert F.allclose(g.ndata['x'], assert F.allclose(g.ndata['x'],
F.tensor([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]])) F.tensor([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]]))
g.ndata['x'] = F.ones((5, 2)) g.ndata['x'] = F.ones((5, 2))
dgl.prop_edges_dfs(g, 0, has_reverse_edge=True) dgl.prop_edges_dfs(g, 0, has_reverse_edge=True, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# result is cumsum[i] + cumsum[i-1] # result is cumsum[i] + cumsum[i-1]
assert F.allclose(g.ndata['x'], assert F.allclose(g.ndata['x'],
F.tensor([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]])) F.tensor([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]]))
g.ndata['x'] = F.ones((5, 2)) g.ndata['x'] = F.ones((5, 2))
dgl.prop_edges_dfs(g, 0, has_nontree_edge=True) dgl.prop_edges_dfs(g, 0, has_nontree_edge=True, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# result is cumsum[i] + cumsum[i+1] # result is cumsum[i] + cumsum[i+1]
assert F.allclose(g.ndata['x'], assert F.allclose(g.ndata['x'],
F.tensor([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]])) F.tensor([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))
...@@ -47,6 +43,7 @@ def test_prop_edges_dfs(): ...@@ -47,6 +43,7 @@ def test_prop_edges_dfs():
def test_prop_nodes_topo(): def test_prop_nodes_topo():
# bi-directional chain # bi-directional chain
g = dgl.DGLGraph(nx.path_graph(5)) g = dgl.DGLGraph(nx.path_graph(5))
g = dgl.graph(g.edges())
assert U.check_fail(dgl.prop_nodes_topo, g) # has loop assert U.check_fail(dgl.prop_nodes_topo, g) # has loop
# tree # tree
...@@ -56,13 +53,12 @@ def test_prop_nodes_topo(): ...@@ -56,13 +53,12 @@ def test_prop_nodes_topo():
tree.add_edge(2, 0) tree.add_edge(2, 0)
tree.add_edge(3, 2) tree.add_edge(3, 2)
tree.add_edge(4, 2) tree.add_edge(4, 2)
tree.register_message_func(mfunc) tree = dgl.graph(tree.edges())
tree.register_reduce_func(rfunc)
# init node feature data # init node feature data
tree.ndata['x'] = F.zeros((5, 2)) tree.ndata['x'] = F.zeros((5, 2))
# set all leaf nodes to be ones # set all leaf nodes to be ones
tree.nodes[[1, 3, 4]].data['x'] = F.ones((3, 2)) tree.nodes[[1, 3, 4]].data['x'] = F.ones((3, 2))
dgl.prop_nodes_topo(tree) dgl.prop_nodes_topo(tree, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
# root node get the sum # root node get the sum
assert F.allclose(tree.nodes[0].data['x'], F.tensor([[3., 3.]])) assert F.allclose(tree.nodes[0].data['x'], F.tensor([[3., 3.]]))
......
...@@ -9,6 +9,7 @@ import scipy.sparse as sp ...@@ -9,6 +9,7 @@ import scipy.sparse as sp
import backend as F import backend as F
import itertools import itertools
from utils import parametrize_dtype
np.random.seed(42) np.random.seed(42)
...@@ -16,7 +17,8 @@ def toset(x): ...@@ -16,7 +17,8 @@ def toset(x):
# F.zerocopy_to_numpy may return a int # F.zerocopy_to_numpy may return a int
return set(F.zerocopy_to_numpy(x).tolist()) return set(F.zerocopy_to_numpy(x).tolist())
def test_bfs(n=100): @parametrize_dtype
def test_bfs(index_dtype, n=100):
def _bfs_nx(g_nx, src): def _bfs_nx(g_nx, src):
edges = nx.bfs_edges(g_nx, src) edges = nx.bfs_edges(g_nx, src)
layers_nx = [set([src])] layers_nx = [set([src])]
...@@ -41,6 +43,11 @@ def test_bfs(n=100): ...@@ -41,6 +43,11 @@ def test_bfs(n=100):
g = dgl.DGLGraph() g = dgl.DGLGraph()
a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n)) a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n))
g.from_scipy_sparse_matrix(a) g.from_scipy_sparse_matrix(a)
if index_dtype == 'int32':
g = dgl.graph(g.edges()).int()
else:
g = dgl.graph(g.edges()).long()
g_nx = g.to_networkx() g_nx = g.to_networkx()
src = random.choice(range(n)) src = random.choice(range(n))
layers_nx, _ = _bfs_nx(g_nx, src) layers_nx, _ = _bfs_nx(g_nx, src)
...@@ -51,17 +58,27 @@ def test_bfs(n=100): ...@@ -51,17 +58,27 @@ def test_bfs(n=100):
g_nx = nx.random_tree(n, seed=42) g_nx = nx.random_tree(n, seed=42)
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.from_networkx(g_nx) g.from_networkx(g_nx)
if index_dtype == 'int32':
g = dgl.graph(g.edges()).int()
else:
g = dgl.graph(g.edges()).long()
src = 0 src = 0
_, edges_nx = _bfs_nx(g_nx, src) _, edges_nx = _bfs_nx(g_nx, src)
edges_dgl = dgl.bfs_edges_generator(g, src) edges_dgl = dgl.bfs_edges_generator(g, src)
assert len(edges_dgl) == len(edges_nx) assert len(edges_dgl) == len(edges_nx)
assert all(toset(x) == y for x, y in zip(edges_dgl, edges_nx)) assert all(toset(x) == y for x, y in zip(edges_dgl, edges_nx))
def test_topological_nodes(n=100): @parametrize_dtype
def test_topological_nodes(index_dtype, n=100):
g = dgl.DGLGraph() g = dgl.DGLGraph()
a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n)) a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n))
b = sp.tril(a, -1).tocoo() b = sp.tril(a, -1).tocoo()
g.from_scipy_sparse_matrix(b) g.from_scipy_sparse_matrix(b)
if index_dtype == 'int32':
g = dgl.graph(g.edges()).int()
else:
g = dgl.graph(g.edges()).long()
layers_dgl = dgl.topological_nodes_generator(g) layers_dgl = dgl.topological_nodes_generator(g)
...@@ -84,15 +101,19 @@ def test_topological_nodes(n=100): ...@@ -84,15 +101,19 @@ def test_topological_nodes(n=100):
assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv)) assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))
DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree'] DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree']
def test_dfs_labeled_edges(example=False): @parametrize_dtype
def test_dfs_labeled_edges(index_dtype, example=False):
dgl_g = dgl.DGLGraph() dgl_g = dgl.DGLGraph()
dgl_g.add_nodes(6) dgl_g.add_nodes(6)
dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5]) dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])
if index_dtype == 'int32':
dgl_g = dgl.graph(dgl_g.edges()).int()
else:
dgl_g = dgl.graph(dgl_g.edges()).long()
dgl_edges, dgl_labels = dgl.dfs_labeled_edges_generator( dgl_edges, dgl_labels = dgl.dfs_labeled_edges_generator(
dgl_g, [0, 3], has_reverse_edge=True, has_nontree_edge=True) dgl_g, [0, 3], has_reverse_edge=True, has_nontree_edge=True)
dgl_edges = [toset(t) for t in dgl_edges] dgl_edges = [toset(t) for t in dgl_edges]
dgl_labels = [toset(t) for t in dgl_labels] dgl_labels = [toset(t) for t in dgl_labels]
g1_solutions = [ g1_solutions = [
# edges labels # edges labels
[[0, 1, 1, 0, 2], [0, 0, 1, 1, 2]], [[0, 1, 1, 0, 2], [0, 0, 1, 1, 2]],
...@@ -119,8 +140,7 @@ def test_dfs_labeled_edges(example=False): ...@@ -119,8 +140,7 @@ def test_dfs_labeled_edges(example=False):
else: else:
assert False assert False
if __name__ == '__main__': if __name__ == '__main__':
test_bfs() test_bfs(index_dtype='int32')
test_topological_nodes() test_topological_nodes(index_dtype='int32')
test_dfs_labeled_edges() test_dfs_labeled_edges(index_dtype='int32')
...@@ -212,11 +212,15 @@ class TreeLSTMCell(nn.Module): ...@@ -212,11 +212,15 @@ class TreeLSTMCell(nn.Module):
# followings: # followings:
# #
# to heterogenous graph
trv_a_tree = dgl.graph(a_tree.edges())
print('Traversing one tree:') print('Traversing one tree:')
print(dgl.topological_nodes_generator(a_tree)) print(dgl.topological_nodes_generator(trv_a_tree))
# to heterogenous graph
trv_graph = dgl.graph(graph.edges())
print('Traversing many trees at the same time:') print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph)) print(dgl.topological_nodes_generator(trv_graph))
############################################################################## ##############################################################################
# Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing: # Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
...@@ -224,12 +228,11 @@ print(dgl.topological_nodes_generator(graph)) ...@@ -224,12 +228,11 @@ print(dgl.topological_nodes_generator(graph))
import dgl.function as fn import dgl.function as fn
import torch as th import torch as th
graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1) trv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
graph.register_message_func(fn.copy_src('a', 'a')) traversal_order = dgl.topological_nodes_generator(trv_graph)
graph.register_reduce_func(fn.sum('a', 'a')) trv_graph.prop_nodes(traversal_order,
message_func=fn.copy_src('a', 'a'),
traversal_order = dgl.topological_nodes_generator(graph) reduce_func=fn.sum('a', 'a'))
graph.prop_nodes(traversal_order)
# the following is a syntax sugar that does the same # the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph) # dgl.prop_nodes_topo(graph)
...@@ -285,16 +288,18 @@ class TreeLSTM(nn.Module): ...@@ -285,16 +288,18 @@ class TreeLSTM(nn.Module):
The prediction of each node. The prediction of each node.
""" """
g = batch.graph g = batch.graph
g.register_message_func(self.cell.message_func) # to heterogenous graph
g.register_reduce_func(self.cell.reduce_func) g = dgl.graph(g.edges())
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h g.ndata['h'] = h
g.ndata['c'] = c g.ndata['c'] = c
# propagate # propagate
dgl.prop_nodes_topo(g) dgl.prop_nodes_topo(g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func)
# compute logits # compute logits
h = self.dropout(g.ndata.pop('h')) h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h) logits = self.linear(h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment