Unverified Commit 175f53de authored by Rhett-Ying's avatar Rhett-Ying Committed by GitHub
Browse files

[Feature] enable edge reorder in dgl.reorder_graph() (#3113)

* [Feature] enable edge reorder in dgl.reorder_graph()

* refine doc string

* refine doc string for dgl.reorder_graph

* refine doc string further
parent 0ce92a86
......@@ -132,7 +132,7 @@ under the ``dgl`` namespace.
DGLGraph.add_self_loop
DGLGraph.remove_self_loop
DGLGraph.to_simple
DGLGraph.reorder
DGLGraph.reorder_graph
Adjacency and incidence matrix
---------------------------------
......
......@@ -76,7 +76,7 @@ Operators for generating new graphs by manipulating the structure of the existin
metapath_reachable_graph
adj_product_graph
adj_sum_graph
reorder
reorder_graph
sort_csr_by_tag
sort_csc_by_tag
......
......@@ -53,7 +53,7 @@ __all__ = [
'as_heterograph',
'adj_product_graph',
'adj_sum_graph',
'reorder'
'reorder_graph'
]
......@@ -2902,52 +2902,58 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
return new_g
def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
r"""Return a new graph which re-order and re-label the nodes
def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
store_ids=True, permute_config=None):
r"""Return a new graph with nodes and edges re-ordered/re-labeled
according to the specified permute algorithm.
Support homogeneous graph only for the moment.
This API is basically implemented by leveraging :func:`~dgl.node_subgraph`,
so the function signature is similar and raw IDs could be stored
in ``dgl.NID`` and ``dgl.EID``.
The re-ordering has two 2 steps: first re-order nodes and then re-order edges.
Please note that edges are re-ordered/re-labeled according to re-ordered
``'src'`` nodes. This behavior is realized in :func:`dgl.node_subgraph`.
What's more, if user wants to re-order/re-label according to ``'dst'`` nodes
or any other algorithms, please use :func:`dgl.edge_subgraph` with new edge
permutation.
For node permutation, users can re-order by the :attr:`node_permute_algo`
argument. For edge permutation, user can re-arrange edges according to their
source nodes or destination nodes by the :attr:`edge_permute_algo` argument.
Some of the permutation algorithms are only implemented in CPU, so if the
input graph is on GPU, it will be copied to CPU first. The storage order of
the node and edge features in the graph are permuted accordingly.
Parameters
----------
g : DGLGraph
The homogeneous graph.
permute_algo: str, optional
can be ``'rcmk'`` or ``'metis'`` or ``'custom'``. ``'rcmk'`` is the default value.
node_permute_algo: str, optional
The permutation algorithm to re-order nodes. Options are ``rcmk`` or
``metis`` or ``custom``. ``rcmk`` is the default value.
* ``'rcmk'``: Call `Reverse Cuthill–McKee <https://docs.scipy.org/doc/scipy/reference/
* ``rcmk``: Use the `Reverse Cuthill–McKee <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.sparse.csgraph.reverse_cuthill_mckee.html#
scipy-sparse-csgraph-reverse-cuthill-mckee>`__ from ``'scipy'`` to generate nodes
permutation and pass it into :func:`~dgl.node_subgraph` to generate new graph.
* ``'metis'``: Call :func:`~dgl.partition.metis_partition_assignment` from ``'DGL'``
to generate nodes permutation and pass it into :func:`~dgl.node_subgraph` to generate
new graph.
* ``'custom'``: This enables user to pass in self-designed reorder algorithm.
User should pass in ``'nodes_perm'`` via another argument :attr:`permute_config` with
``'custom'`` is specified here. By this way, can the graph be reordered according to
passed in nodes permutation.
scipy-sparse-csgraph-reverse-cuthill-mckee>`__ from ``scipy`` to generate nodes
permutation.
* ``metis``: Use the :func:`~dgl.partition.metis_partition_assignment` function
to partition the input graph, which gives a cluster assignment of each node.
DGL then sorts the assignment array so the new node order will put nodes of
the same cluster together.
* ``custom``: Reorder the graph according to the user-provided node permutation
array (provided in :attr:`permute_config`).
edge_permute_algo: str, optional
The permutation algorithm to reorder edges. Options are ``src`` or ``dst``.
``src`` is the default value.
* ``src``: Edges are arranged according to their source nodes.
* ``dst``: Edges are arranged according to their destination nodes.
store_ids: bool, optional
It is passed into :func:`~dgl.node_subgraph()`. If True, it will store
the raw IDs of the extracted nodes and edges in the ndata and edata of
the resulting graph under name ``'dgl.NID'`` and ``'dgl.EID'``, respectively.
If True, DGL will store the original node and edge IDs in the ndata and edata
of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
permute_config: dict, optional
Additional config data for specified :attr:`permute_algo`.
Additional key-value config data for the specified permutation algorithm.
* For ``'rcmk'``, this argument is not required.
* For ``'metis'``, partition part number ``'k'`` is required and specified in this
argument like this: {'k':10}.
* For ``'custom'``, ``'nodes_perm'`` should be specified in the format of
``'Int Tensor'`` or ``'iterable[int]'`` like :attr:`nodes` in :func:`~dgl.node_subgraph`.
* For ``rcmk``, this argument is not required.
* For ``metis``, users should specify the number of partitions ``k`` (e.g.,
``permute_config={'k':10}`` to partition the graph to 10 clusters).
* For ``custom``, users should provide a node permutation array ``nodes_perm``.
The array must be an integer list or a tensor with the same device of the
input graph.
Returns
-------
......@@ -2976,7 +2982,7 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
Reorder according to ``'rcmk'`` permute algorithm.
>>> rg = dgl.reorder(g)
>>> rg = dgl.reorder_graph(g)
>>> rg.ndata
{'h': tensor([[8, 9],
[6, 7],
......@@ -2990,9 +2996,9 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
[2],
[0]]), '_ID': tensor([4, 3, 1, 2, 0])}
Reorder with according to ``'metis'`` permute algorithm.
Reorder according to ``'metis'`` permute algorithm.
>>> rg = dgl.reorder(g, 'metis', permute_config={'k':2})
>>> rg = dgl.reorder_graph(g, 'metis', permute_config={'k':2})
>>> rg.ndata
{'h': tensor([[4, 5],
[2, 3],
......@@ -3011,7 +3017,7 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
>>> nodes_perm = torch.randperm(g.num_nodes())
>>> nodes_perm
tensor([3, 2, 0, 4, 1])
>>> rg = dgl.reorder(g, 'custom', permute_config={'nodes_perm':nodes_perm})
>>> rg = dgl.reorder_graph(g, 'custom', permute_config={'nodes_perm':nodes_perm})
>>> rg.ndata
{'h': tensor([[6, 7],
[4, 5],
......@@ -3025,16 +3031,38 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
[4],
[1]]), '_ID': tensor([3, 2, 0, 4, 1])}
Reorder according to ``dst`` edge permute algorithm and refine further
according to self-generated edges permutation. Please assure to specify
``relabel_nodes`` as ``False`` to keep the nodes order.
>>> rg = dgl.reorder_graph(g, edge_permute_algo='dst')
>>> rg.edges()
(tensor([0, 3, 1, 2, 4]), tensor([1, 1, 3, 3, 3]))
>>> eg = dgl.edge_subgraph(rg, [0, 2, 4, 1, 3], relabel_nodes=False)
>>> eg.edata
{'w': tensor([[4],
[3],
[0],
[2],
[1]]), '_ID': tensor([0, 2, 4, 1, 3])}
"""
# sanity checks
if not g.is_homogeneous:
raise DGLError("Homograph is supported only.")
expected_algo = ['rcmk', 'metis', 'custom']
if permute_algo not in expected_algo:
raise DGLError("Unexpected permute_algo is specified: {}. Expected algos: {}".format(
permute_algo, expected_algo))
if permute_algo == 'rcmk':
expected_node_algo = ['rcmk', 'metis', 'custom']
if node_permute_algo not in expected_node_algo:
raise DGLError("Unexpected node_permute_algo is specified: {}. Expected algos: {}".format(
node_permute_algo, expected_node_algo))
expected_edge_algo = ['src', 'dst']
if edge_permute_algo not in expected_edge_algo:
raise DGLError("Unexpected edge_permute_algo is specified: {}. Expected algos: {}".format(
edge_permute_algo, expected_edge_algo))
# generate nodes permutation
if node_permute_algo == 'rcmk':
nodes_perm = rcmk_perm(g)
elif permute_algo == 'metis':
elif node_permute_algo == 'metis':
if permute_config is None or 'k' not in permute_config:
raise DGLError(
"Partition parts 'k' is required for metis. Please specify in permute_config.")
......@@ -3048,10 +3076,24 @@ def reorder(g, permute_algo='rcmk', store_ids=True, permute_config=None):
if len(nodes_perm) != g.num_nodes():
raise DGLError("Length of passed in nodes_perm[{}] does not \
match graph num_nodes[{}].".format(len(nodes_perm), g.num_nodes()))
return subgraph.node_subgraph(g, nodes_perm, store_ids=store_ids)
# reorder nodes
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=store_ids)
# reorder edges
if edge_permute_algo == 'src':
# the output graph of dgl.node_subgraph() is ordered/labeled
# according to src already. Nothing needs to do.
pass
elif edge_permute_algo == 'dst':
edges_perm = np.argsort(F.asnumpy(rg.edges()[1]))
rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=store_ids)
return rg
DGLHeteroGraph.reorder = utils.alias_func(reorder)
DGLHeteroGraph.reorder_graph = utils.alias_func(reorder_graph)
def metis_perm(g, k):
......
......@@ -1482,23 +1482,41 @@ def test_remove_selfloop(idtype):
@parametrize_dtype
def test_reorder(idtype):
def test_reorder_graph(idtype):
g = dgl.graph(([0, 1, 2, 3, 4], [2, 2, 3, 2, 3]),
idtype=idtype, device=F.ctx())
g.ndata['h'] = F.copy_to(F.randn((g.num_nodes(), 3)), ctx=F.ctx())
g.edata['w'] = F.copy_to(F.randn((g.num_edges(), 2)), ctx=F.ctx())
# call with default args
rg = dgl.reorder(g)
# call with default args: node_permute_algo='rcmk', edge_permute_algo='src', store_ids=True
rg = dgl.reorder_graph(g)
assert dgl.NID in rg.ndata.keys()
assert dgl.EID in rg.edata.keys()
src = F.asnumpy(rg.edges()[0])
assert np.array_equal(src, np.sort(src))
# call with 'dst' edge_permute_algo
rg = dgl.reorder_graph(g, edge_permute_algo='dst')
dst = F.asnumpy(rg.edges()[1])
assert np.array_equal(dst, np.sort(dst))
# call with unknown edge_permute_algo
raise_error = False
try:
dgl.reorder_graph(g, edge_permute_algo='none')
except:
raise_error = True
assert raise_error
# reorder back to original according to stored ids
rg2 = dgl.reorder(rg, 'custom', permute_config={
rg = dgl.reorder_graph(g)
rg2 = dgl.reorder_graph(rg, 'custom', permute_config={
'nodes_perm': np.argsort(F.asnumpy(rg.ndata[dgl.NID]))})
assert F.array_equal(g.ndata['h'], rg2.ndata['h'])
assert F.array_equal(g.edata['w'], rg2.edata['w'])
# do not store ids
rg = dgl.reorder(g, store_ids=False)
rg = dgl.reorder_graph(g, store_ids=False)
assert not dgl.NID in rg.ndata.keys()
assert not dgl.EID in rg.edata.keys()
......@@ -1512,7 +1530,7 @@ def test_reorder(idtype):
# call with metis strategy, but k is not specified
raise_error = False
try:
dgl.reorder(mg, permute_algo='metis')
dgl.reorder_graph(mg, node_permute_algo='metis')
except:
raise_error = True
assert raise_error
......@@ -1520,8 +1538,8 @@ def test_reorder(idtype):
# call with metis strategy, k is specified
raise_error = False
try:
dgl.reorder(mg,
permute_algo='metis', permute_config={'k': 2})
dgl.reorder_graph(mg,
node_permute_algo='metis', permute_config={'k': 2})
except:
raise_error = True
assert not raise_error
......@@ -1530,7 +1548,7 @@ def test_reorder(idtype):
nodes_perm = np.random.permutation(g.num_nodes())
raise_error = False
try:
dgl.reorder(g, permute_algo='custom', permute_config={
dgl.reorder_graph(g, node_permute_algo='custom', permute_config={
'nodes_perm': nodes_perm})
except:
raise_error = True
......@@ -1539,7 +1557,7 @@ def test_reorder(idtype):
# call with unqualified nodes_perm specified
raise_error = False
try:
dgl.reorder(g, permute_algo='custom', permute_config={
dgl.reorder_graph(g, node_permute_algo='custom', permute_config={
'nodes_perm': nodes_perm[:g.num_nodes() - 1]})
except:
raise_error = True
......@@ -1548,7 +1566,7 @@ def test_reorder(idtype):
# call with unsupported strategy
raise_error = False
try:
dgl.reorder(g, permute_algo='cmk')
dgl.reorder_graph(g, node_permute_algo='cmk')
except:
raise_error = True
assert raise_error
......@@ -1558,7 +1576,7 @@ def test_reorder(idtype):
try:
hg = dgl.heterogrpah({('user', 'follow', 'user'): (
[0, 1], [1, 2])}, idtype=idtype, device=F.ctx())
dgl.reorder(hg)
dgl.reorder_graph(hg)
except:
raise_error = True
assert raise_error
......@@ -1566,7 +1584,7 @@ def test_reorder(idtype):
# add 'csr' format if needed
fg = g.formats('csc')
assert 'csr' not in sum(fg.formats().values(), [])
rfg = dgl.reorder(fg)
rfg = dgl.reorder_graph(fg)
assert 'csr' in sum(rfg.formats().values(), [])
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment