Unverified Commit 175f53de authored by Rhett-Ying's avatar Rhett-Ying Committed by GitHub
Browse files

[Feature] enable edge reorder in dgl.reorder_graph() (#3113)

* [Feature] enable edge reorder in dgl.reorder_graph()

* refine doc string

* refine doc string for dgl.reorder_graph

* refine doc string further
parent 0ce92a86
...@@ -132,7 +132,7 @@ under the ``dgl`` namespace. ...@@ -132,7 +132,7 @@ under the ``dgl`` namespace.
DGLGraph.add_self_loop DGLGraph.add_self_loop
DGLGraph.remove_self_loop DGLGraph.remove_self_loop
DGLGraph.to_simple DGLGraph.to_simple
DGLGraph.reorder DGLGraph.reorder_graph
Adjacency and incidence matrix Adjacency and incidence matrix
--------------------------------- ---------------------------------
......
...@@ -76,7 +76,7 @@ Operators for generating new graphs by manipulating the structure of the existin ...@@ -76,7 +76,7 @@ Operators for generating new graphs by manipulating the structure of the existin
metapath_reachable_graph metapath_reachable_graph
adj_product_graph adj_product_graph
adj_sum_graph adj_sum_graph
reorder reorder_graph
sort_csr_by_tag sort_csr_by_tag
sort_csc_by_tag sort_csc_by_tag
......
...@@ -53,7 +53,7 @@ __all__ = [ ...@@ -53,7 +53,7 @@ __all__ = [
'as_heterograph', 'as_heterograph',
'adj_product_graph', 'adj_product_graph',
'adj_sum_graph', 'adj_sum_graph',
'reorder' 'reorder_graph'
] ]
...@@ -2902,52 +2902,58 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2902,52 +2902,58 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
return new_g return new_g
def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None): def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
r"""Return a new graph which re-order and re-label the nodes store_ids=True, permute_config=None):
r"""Return a new graph with nodes and edges re-ordered/re-labeled
according to the specified permute algorithm. according to the specified permute algorithm.
Support homogeneous graph only for the moment. Support homogeneous graph only for the moment.
This API is basically implemented by leveraging :func:`~dgl.node_subgraph`, The re-ordering has two 2 steps: first re-order nodes and then re-order edges.
so the function signature is similar and raw IDs could be stored
in ``dgl.NID`` and ``dgl.EID``.
Please note that edges are re-ordered/re-labeled according to re-ordered For node permutation, users can re-order by the :attr:`node_permute_algo`
``'src'`` nodes. This behavior is realized in :func:`dgl.node_subgraph`. argument. For edge permutation, user can re-arrange edges according to their
What's more, if user wants to re-order/re-label according to ``'dst'`` nodes source nodes or destination nodes by the :attr:`edge_permute_algo` argument.
or any other algorithms, please use :func:`dgl.edge_subgraph` with new edge Some of the permutation algorithms are only implemented in CPU, so if the
permutation. input graph is on GPU, it will be copied to CPU first. The storage order of
the node and edge features in the graph are permuted accordingly.
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
The homogeneous graph. The homogeneous graph.
permute_algo: str, optional node_permute_algo: str, optional
can be ``'rcmk'`` or ``'metis'`` or ``'custom'``. ``'rcmk'`` is the default value. The permutation algorithm to re-order nodes. Options are ``rcmk`` or
``metis`` or ``custom``. ``rcmk`` is the default value.
* ``'rcmk'``: Call `Reverse Cuthill–McKee <https://docs.scipy.org/doc/scipy/reference/ * ``rcmk``: Use the `Reverse Cuthill–McKee <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.sparse.csgraph.reverse_cuthill_mckee.html# generated/scipy.sparse.csgraph.reverse_cuthill_mckee.html#
scipy-sparse-csgraph-reverse-cuthill-mckee>`__ from ``'scipy'`` to generate nodes scipy-sparse-csgraph-reverse-cuthill-mckee>`__ from ``scipy`` to generate nodes
permutation and pass it into :func:`~dgl.node_subgraph` to generate new graph. permutation.
* ``'metis'``: Call :func:`~dgl.partition.metis_partition_assignment` from ``'DGL'`` * ``metis``: Use the :func:`~dgl.partition.metis_partition_assignment` function
to generate nodes permutation and pass it into :func:`~dgl.node_subgraph` to generate to partition the input graph, which gives a cluster assignment of each node.
new graph. DGL then sorts the assignment array so the new node order will put nodes of
* ``'custom'``: This enables user to pass in self-designed reorder algorithm. the same cluster together.
User should pass in ``'nodes_perm'`` via another argument :attr:`permute_config` with * ``custom``: Reorder the graph according to the user-provided node permutation
``'custom'`` is specified here. By this way, can the graph be reordered according to array (provided in :attr:`permute_config`).
passed in nodes permutation. edge_permute_algo: str, optional
The permutation algorithm to reorder edges. Options are ``src`` or ``dst``.
``src`` is the default value.
* ``src``: Edges are arranged according to their source nodes.
* ``dst``: Edges are arranged according to their destination nodes.
store_ids: bool, optional store_ids: bool, optional
It is passed into :func:`~dgl.node_subgraph()`. If True, it will store If True, DGL will store the original node and edge IDs in the ndata and edata
the raw IDs of the extracted nodes and edges in the ndata and edata of of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
the resulting graph under name ``'dgl.NID'`` and ``'dgl.EID'``, respectively.
permute_config: dict, optional permute_config: dict, optional
Additional config data for specified :attr:`permute_algo`. Additional key-value config data for the specified permutation algorithm.
* For ``'rcmk'``, this argument is not required. * For ``rcmk``, this argument is not required.
* For ``'metis'``, partition part number ``'k'`` is required and specified in this * For ``metis``, users should specify the number of partitions ``k`` (e.g.,
argument like this: {'k':10}. ``permute_config={'k':10}`` to partition the graph to 10 clusters).
* For ``'custom'``, ``'nodes_perm'`` should be specified in the format of * For ``custom``, users should provide a node permutation array ``nodes_perm``.
``'Int Tensor'`` or ``'iterable[int]'`` like :attr:`nodes` in :func:`~dgl.node_subgraph`. The array must be an integer list or a tensor with the same device of the
input graph.
Returns Returns
------- -------
...@@ -2976,7 +2982,7 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None): ...@@ -2976,7 +2982,7 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
Reorder according to ``'rcmk'`` permute algorithm. Reorder according to ``'rcmk'`` permute algorithm.
>>> rg = dgl.reorder(g) >>> rg = dgl.reorder_graph(g)
>>> rg.ndata >>> rg.ndata
{'h': tensor([[8, 9], {'h': tensor([[8, 9],
[6, 7], [6, 7],
...@@ -2990,9 +2996,9 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None): ...@@ -2990,9 +2996,9 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
[2], [2],
[0]]), '_ID': tensor([4, 3, 1, 2, 0])} [0]]), '_ID': tensor([4, 3, 1, 2, 0])}
Reorder with according to ``'metis'`` permute algorithm. Reorder according to ``'metis'`` permute algorithm.
>>> rg = dgl.reorder(g, 'metis', permute_config={'k':2}) >>> rg = dgl.reorder_graph(g, 'metis', permute_config={'k':2})
>>> rg.ndata >>> rg.ndata
{'h': tensor([[4, 5], {'h': tensor([[4, 5],
[2, 3], [2, 3],
...@@ -3011,7 +3017,7 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None): ...@@ -3011,7 +3017,7 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
>>> nodes_perm = torch.randperm(g.num_nodes()) >>> nodes_perm = torch.randperm(g.num_nodes())
>>> nodes_perm >>> nodes_perm
tensor([3, 2, 0, 4, 1]) tensor([3, 2, 0, 4, 1])
>>> rg = dgl.reorder(g, 'custom', permute_config={'nodes_perm':nodes_perm}) >>> rg = dgl.reorder_graph(g, 'custom', permute_config={'nodes_perm':nodes_perm})
>>> rg.ndata >>> rg.ndata
{'h': tensor([[6, 7], {'h': tensor([[6, 7],
[4, 5], [4, 5],
...@@ -3025,16 +3031,38 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None): ...@@ -3025,16 +3031,38 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
[4], [4],
[1]]), '_ID': tensor([3, 2, 0, 4, 1])} [1]]), '_ID': tensor([3, 2, 0, 4, 1])}
Reorder according to ``dst`` edge permute algorithm and refine further
according to self-generated edges permutation. Please assure to specify
``relabel_nodes`` as ``False`` to keep the nodes order.
>>> rg = dgl.reorder_graph(g, edge_permute_algo='dst')
>>> rg.edges()
(tensor([0, 3, 1, 2, 4]), tensor([1, 1, 3, 3, 3]))
>>> eg = dgl.edge_subgraph(rg, [0, 2, 4, 1, 3], relabel_nodes=False)
>>> eg.edata
{'w': tensor([[4],
[3],
[0],
[2],
[1]]), '_ID': tensor([0, 2, 4, 1, 3])}
""" """
# sanity checks
if not g.is_homogeneous: if not g.is_homogeneous:
raise DGLError("Homograph is supported only.") raise DGLError("Homograph is supported only.")
expected_algo = ['rcmk', 'metis', 'custom'] expected_node_algo = ['rcmk', 'metis', 'custom']
if permute_algo not in expected_algo: if node_permute_algo not in expected_node_algo:
raise DGLError("Unexpected permute_algo is specified: {}. Expected algos: {}".format( raise DGLError("Unexpected node_permute_algo is specified: {}. Expected algos: {}".format(
permute_algo, expected_algo)) node_permute_algo, expected_node_algo))
if permute_algo == 'rcmk': expected_edge_algo = ['src', 'dst']
if edge_permute_algo not in expected_edge_algo:
raise DGLError("Unexpected edge_permute_algo is specified: {}. Expected algos: {}".format(
edge_permute_algo, expected_edge_algo))
# generate nodes permutation
if node_permute_algo == 'rcmk':
nodes_perm = rcmk_perm(g) nodes_perm = rcmk_perm(g)
elif permute_algo == 'metis': elif node_permute_algo == 'metis':
if permute_config is None or 'k' not in permute_config: if permute_config is None or 'k' not in permute_config:
raise DGLError( raise DGLError(
"Partition parts 'k' is required for metis. Please specify in permute_config.") "Partition parts 'k' is required for metis. Please specify in permute_config.")
...@@ -3048,10 +3076,24 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None): ...@@ -3048,10 +3076,24 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
if len(nodes_perm) != g.num_nodes(): if len(nodes_perm) != g.num_nodes():
raise DGLError("Length of passed in nodes_perm[{}] does not \ raise DGLError("Length of passed in nodes_perm[{}] does not \
match graph num_nodes[{}].".format(len(nodes_perm), g.num_nodes())) match graph num_nodes[{}].".format(len(nodes_perm), g.num_nodes()))
return subgraph.node_subgraph(g, nodes_perm, store_ids=store_ids)
# reorder nodes
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=store_ids)
# reorder edges
if edge_permute_algo == 'src':
# the output graph of dgl.node_subgraph() is ordered/labeled
# according to src already. Nothing needs to do.
pass
elif edge_permute_algo == 'dst':
edges_perm = np.argsort(F.asnumpy(rg.edges()[1]))
rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=store_ids)
return rg
DGLHeteroGraph.reorder = utils.alias_func(reorder) DGLHeteroGraph.reorder_graph = utils.alias_func(reorder_graph)
def metis_perm(g, k): def metis_perm(g, k):
......
...@@ -1482,23 +1482,41 @@ def test_remove_selfloop(idtype): ...@@ -1482,23 +1482,41 @@ def test_remove_selfloop(idtype):
@parametrize_dtype @parametrize_dtype
def test_reorder(idtype): def test_reorder_graph(idtype):
g = dgl.graph(([0, 1, 2, 3, 4], [2, 2, 3, 2, 3]), g = dgl.graph(([0, 1, 2, 3, 4], [2, 2, 3, 2, 3]),
idtype=idtype, device=F.ctx()) idtype=idtype, device=F.ctx())
g.ndata['h'] = F.copy_to(F.randn((g.num_nodes(), 3)), ctx=F.ctx()) g.ndata['h'] = F.copy_to(F.randn((g.num_nodes(), 3)), ctx=F.ctx())
g.edata['w'] = F.copy_to(F.randn((g.num_edges(), 2)), ctx=F.ctx()) g.edata['w'] = F.copy_to(F.randn((g.num_edges(), 2)), ctx=F.ctx())
# call with default args # call with default args: node_permute_algo='rcmk', edge_permute_algo='src', store_ids=True
rg = dgl.reorder(g) rg = dgl.reorder_graph(g)
assert dgl.NID in rg.ndata.keys()
assert dgl.EID in rg.edata.keys()
src = F.asnumpy(rg.edges()[0])
assert np.array_equal(src, np.sort(src))
# call with 'dst' edge_permute_algo
rg = dgl.reorder_graph(g, edge_permute_algo='dst')
dst = F.asnumpy(rg.edges()[1])
assert np.array_equal(dst, np.sort(dst))
# call with unknown edge_permute_algo
raise_error = False
try:
dgl.reorder_graph(g, edge_permute_algo='none')
except:
raise_error = True
assert raise_error
# reorder back to original according to stored ids # reorder back to original according to stored ids
rg2 = dgl.reorder(rg, 'custom', permute_config={ rg = dgl.reorder_graph(g)
'nodes_perm': np.argsort(F.asnumpy(rg.ndata[dgl.NID]))}) rg2 = dgl.reorder_graph(rg, 'custom', permute_config={
'nodes_perm': np.argsort(F.asnumpy(rg.ndata[dgl.NID]))})
assert F.array_equal(g.ndata['h'], rg2.ndata['h']) assert F.array_equal(g.ndata['h'], rg2.ndata['h'])
assert F.array_equal(g.edata['w'], rg2.edata['w']) assert F.array_equal(g.edata['w'], rg2.edata['w'])
# do not store ids # do not store ids
rg = dgl.reorder(g, store_ids=False) rg = dgl.reorder_graph(g, store_ids=False)
assert not dgl.NID in rg.ndata.keys() assert not dgl.NID in rg.ndata.keys()
assert not dgl.EID in rg.edata.keys() assert not dgl.EID in rg.edata.keys()
...@@ -1512,7 +1530,7 @@ def test_reorder(idtype): ...@@ -1512,7 +1530,7 @@ def test_reorder(idtype):
# call with metis strategy, but k is not specified # call with metis strategy, but k is not specified
raise_error = False raise_error = False
try: try:
dgl.reorder(mg, permute_algo='metis') dgl.reorder_graph(mg, node_permute_algo='metis')
except: except:
raise_error = True raise_error = True
assert raise_error assert raise_error
...@@ -1520,8 +1538,8 @@ def test_reorder(idtype): ...@@ -1520,8 +1538,8 @@ def test_reorder(idtype):
# call with metis strategy, k is specified # call with metis strategy, k is specified
raise_error = False raise_error = False
try: try:
dgl.reorder(mg, dgl.reorder_graph(mg,
permute_algo='metis', permute_config={'k': 2}) node_permute_algo='metis', permute_config={'k': 2})
except: except:
raise_error = True raise_error = True
assert not raise_error assert not raise_error
...@@ -1530,8 +1548,8 @@ def test_reorder(idtype): ...@@ -1530,8 +1548,8 @@ def test_reorder(idtype):
nodes_perm = np.random.permutation(g.num_nodes()) nodes_perm = np.random.permutation(g.num_nodes())
raise_error = False raise_error = False
try: try:
dgl.reorder(g, permute_algo='custom', permute_config={ dgl.reorder_graph(g, node_permute_algo='custom', permute_config={
'nodes_perm': nodes_perm}) 'nodes_perm': nodes_perm})
except: except:
raise_error = True raise_error = True
assert not raise_error assert not raise_error
...@@ -1539,8 +1557,8 @@ def test_reorder(idtype): ...@@ -1539,8 +1557,8 @@ def test_reorder(idtype):
# call with unqualified nodes_perm specified # call with unqualified nodes_perm specified
raise_error = False raise_error = False
try: try:
dgl.reorder(g, permute_algo='custom', permute_config={ dgl.reorder_graph(g, node_permute_algo='custom', permute_config={
'nodes_perm': nodes_perm[:g.num_nodes() - 1]}) 'nodes_perm': nodes_perm[:g.num_nodes() - 1]})
except: except:
raise_error = True raise_error = True
assert raise_error assert raise_error
...@@ -1548,7 +1566,7 @@ def test_reorder(idtype): ...@@ -1548,7 +1566,7 @@ def test_reorder(idtype):
# call with unsupported strategy # call with unsupported strategy
raise_error = False raise_error = False
try: try:
dgl.reorder(g, permute_algo='cmk') dgl.reorder_graph(g, node_permute_algo='cmk')
except: except:
raise_error = True raise_error = True
assert raise_error assert raise_error
...@@ -1558,7 +1576,7 @@ def test_reorder(idtype): ...@@ -1558,7 +1576,7 @@ def test_reorder(idtype):
try: try:
hg = dgl.heterogrpah({('user', 'follow', 'user'): ( hg = dgl.heterogrpah({('user', 'follow', 'user'): (
[0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) [0, 1], [1, 2])}, idtype=idtype, device=F.ctx())
dgl.reorder(hg) dgl.reorder_graph(hg)
except: except:
raise_error = True raise_error = True
assert raise_error assert raise_error
...@@ -1566,7 +1584,7 @@ def test_reorder(idtype): ...@@ -1566,7 +1584,7 @@ def test_reorder(idtype):
# add 'csr' format if needed # add 'csr' format if needed
fg = g.formats('csc') fg = g.formats('csc')
assert 'csr' not in sum(fg.formats().values(), []) assert 'csr' not in sum(fg.formats().values(), [])
rfg = dgl.reorder(fg) rfg = dgl.reorder_graph(fg)
assert 'csr' in sum(rfg.formats().values(), []) assert 'csr' in sum(rfg.formats().values(), [])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment