Commit 24bbdb74 authored by Mufei Li's avatar Mufei Li Committed by Minjie Wang
Browse files

[Feature] Reversed Graph and Transform Module (#331)

* reverse a graph

* Reverse a graph

* Fix

* Revert "Fix"

This reverts commit 1728826b4f3a083b76dfc0fdecd2bbf943a971b2.

* Fix

* Fix

* Delete vcs.xml

* Delete Project_Default.xml

* Fix

* Fix

* Fix

* Remove outdated test

* Reorg transform and update reverse (#2)

* Reorg transform and update reverse

* Fix doc and test

* Update test

* Resolve conflict

* CI oriented fix

* Remove outdated import

* Fix import

* Fix import

* define __all__ for wildcard imports

* Fix import

* Address circular imports

* Fix

* Fix test case

* Fix

* Fix

* Remove unused import

* Fix

* Fix

* Fix
parent 4bd4d6e3
...@@ -55,6 +55,7 @@ Transforming graph ...@@ -55,6 +55,7 @@ Transforming graph
DGLGraph.subgraphs DGLGraph.subgraphs
DGLGraph.edge_subgraph DGLGraph.edge_subgraph
DGLGraph.line_graph DGLGraph.line_graph
DGLGraph.reverse
Converting from/to other format Converting from/to other format
------------------------------- -------------------------------
......
...@@ -13,3 +13,4 @@ API Reference ...@@ -13,3 +13,4 @@ API Reference
udf udf
sampler sampler
data data
transform
.. _apigraph:
Transform -- Graph Transformation
=================================
.. automodule:: dgl.transform
.. autosummary::
:toctree: ../../generated/
line_graph
reverse
...@@ -59,7 +59,7 @@ The backend is controlled by ``DGLBACKEND`` environment variable, which defaults ...@@ -59,7 +59,7 @@ The backend is controlled by ``DGLBACKEND`` environment variable, which defaults
| | | `official website <https://pytorch.org>`_ | | | | `official website <https://pytorch.org>`_ |
+---------+---------+--------------------------------------------------+ +---------+---------+--------------------------------------------------+
| mxnet | MXNet | Requires nightly build; run the following | | mxnet | MXNet | Requires nightly build; run the following |
| | | command to install (TODO): | | | | command to install: |
| | | | | | | |
| | | .. code:: bash | | | | .. code:: bash |
| | | | | | | |
......
...@@ -14,3 +14,4 @@ from .graph import DGLGraph ...@@ -14,3 +14,4 @@ from .graph import DGLGraph
from .traversal import * from .traversal import *
from .propagate import * from .propagate import *
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
from .transform import *
...@@ -3,6 +3,7 @@ from __future__ import absolute_import ...@@ -3,6 +3,7 @@ from __future__ import absolute_import
from collections import defaultdict from collections import defaultdict
import dgl
from .base import ALL, is_all, DGLError from .base import ALL, is_all, DGLError
from . import backend as F from . import backend as F
from . import init from . import init
...@@ -2760,22 +2761,16 @@ class DGLGraph(object): ...@@ -2760,22 +2761,16 @@ class DGLGraph(object):
def line_graph(self, backtracking=True, shared=False): def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
Parameters See :func:`~dgl.transform.line_graph`.
---------- """
backtracking : bool, optional return dgl.line_graph(self, backtracking, shared)
Whether the returned line graph is backtracking.
shared : bool, optional def reverse(self, share_ndata=False, share_edata=False):
Whether the returned line graph shares representations with `self`. """Return the reverse of this graph.
Returns See :func:`~dgl.transform.reverse`.
-------
DGLGraph
The line graph of this graph.
""" """
graph_data = self._graph.line_graph(backtracking) return dgl.reverse(self, share_ndata, share_edata)
node_frame = self._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
def filter_nodes(self, predicate, nodes=ALL): def filter_nodes(self, predicate, nodes=ALL):
"""Return a tensor of node IDs that satisfy the given predicate. """Return a tensor of node IDs that satisfy the given predicate.
......
"""Module for graph transformation methods."""
from .graph import DGLGraph
from .batched_graph import BatchedDGLGraph
__all__ = ['line_graph', 'reverse']
def line_graph(g, backtracking=True, shared=False):
"""Return the line graph of this graph.
Parameters
----------
g : dgl.DGLGraph
backtracking : bool, optional
Whether the returned line graph is backtracking.
shared : bool, optional
Whether the returned line graph shares representations with `self`.
Returns
-------
DGLGraph
The line graph of this graph.
"""
graph_data = g._graph.line_graph(backtracking)
node_frame = g._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
def reverse(g, share_ndata=False, share_edata=False):
"""Return the reverse of a graph
The reverse (also called converse, transpose) of a directed graph is another directed
graph on the same nodes with edges reversed in terms of direction.
Given a :class:`DGLGraph` object, we return another :class:`DGLGraph` object
representing its reverse.
Notes
-----
* This function does not support :class:`~dgl.BatchedDGLGraph` objects.
* We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently,
you can get a mismatched node/edge feature.
Parameters
----------
g : dgl.DGLGraph
share_ndata: bool, optional
If True, the original graph and the reversed graph share memory for node attributes.
Otherwise the reversed graph will not be initialized with node attributes.
share_edata: bool, optional
If True, the original graph and the reversed graph share memory for edge attributes.
Otherwise the reversed graph will not have edge attributes.
Examples
--------
Create a graph to reverse.
>>> import dgl
>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1, 2], [1, 2, 0])
>>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])
>>> g.edata['h'] = th.tensor([[3.], [4.], [5.]])
Reverse the graph and examine its structure.
>>> rg = g.reverse(share_ndata=True, share_edata=True)
>>> print(rg)
DGLGraph with 3 nodes and 3 edges.
Node data: {'h': Scheme(shape=(1,), dtype=torch.float32)}
Edge data: {'h': Scheme(shape=(1,), dtype=torch.float32)}
The edges are reversed now.
>>> rg.has_edges_between([1, 2, 0], [0, 1, 2])
tensor([1, 1, 1])
Reversed edges have the same feature as the original ones.
>>> g.edges[[0, 2], [1, 0]].data['h'] == rg.edges[[1, 0], [0, 2]].data['h']
tensor([[1],
[1]], dtype=torch.uint8)
The node/edge features of the reversed graph share memory with the original
graph, which is helpful for both forward computation and back propagation.
>>> g.ndata['h'] = g.ndata['h'] + 1
>>> rg.ndata['h']
tensor([[1.],
[2.],
[3.]])
"""
assert not isinstance(g, BatchedDGLGraph), \
'reverse is not supported for a BatchedDGLGraph object'
g_reversed = DGLGraph(multigraph=g.is_multigraph)
g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.edges()
g_reversed.add_edges(g_edges[1], g_edges[0])
if share_ndata:
g_reversed._node_frame = g._node_frame
if share_edata:
g_reversed._edge_frame = g._edge_frame
return g_reversed
...@@ -2,10 +2,12 @@ import torch as th ...@@ -2,10 +2,12 @@ import torch as th
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import dgl import dgl
import dgl.function as fn
import utils as U import utils as U
D = 5 D = 5
# line graph related
def test_line_graph(): def test_line_graph():
N = 5 N = 5
G = dgl.DGLGraph(nx.star_graph(N)) G = dgl.DGLGraph(nx.star_graph(N))
...@@ -39,6 +41,61 @@ def test_no_backtracking(): ...@@ -39,6 +41,61 @@ def test_no_backtracking():
assert not L.has_edge_between(e1, e2) assert not L.has_edge_between(e1, e2)
assert not L.has_edge_between(e2, e1) assert not L.has_edge_between(e2, e1)
# reverse graph related
def test_reverse():
g = dgl.DGLGraph()
g.add_nodes(5)
# The graph need not to be completely connected.
g.add_edges([0, 1, 2], [1, 2, 1])
g.ndata['h'] = th.tensor([[0.], [1.], [2.], [3.], [4.]])
g.edata['h'] = th.tensor([[5.], [6.], [7.]])
rg = g.reverse()
assert g.is_multigraph == rg.is_multigraph
assert g.number_of_nodes() == rg.number_of_nodes()
assert g.number_of_edges() == rg.number_of_edges()
assert U.allclose(rg.has_edges_between([1, 2, 1], [0, 1, 2]).float(), th.ones(3))
assert g.edge_id(0, 1) == rg.edge_id(1, 0)
assert g.edge_id(1, 2) == rg.edge_id(2, 1)
assert g.edge_id(2, 1) == rg.edge_id(1, 2)
def test_reverse_shared_frames():
g = dgl.DGLGraph()
g.add_nodes(3)
g.add_edges([0, 1, 2], [1, 2, 1])
g.ndata['h'] = th.tensor([[0.], [1.], [2.]], requires_grad=True)
g.edata['h'] = th.tensor([[3.], [4.], [5.]], requires_grad=True)
rg = g.reverse(share_ndata=True, share_edata=True)
assert U.allclose(g.ndata['h'], rg.ndata['h'])
assert U.allclose(g.edata['h'], rg.edata['h'])
assert U.allclose(g.edges[[0, 2], [1, 1]].data['h'],
rg.edges[[1, 1], [0, 2]].data['h'])
rg.ndata['h'] = rg.ndata['h'] + 1
assert U.allclose(rg.ndata['h'], g.ndata['h'])
g.edata['h'] = g.edata['h'] - 1
assert U.allclose(rg.edata['h'], g.edata['h'])
src_msg = fn.copy_src(src='h', out='m')
sum_reduce = fn.sum(msg='m', out='h')
rg.update_all(src_msg, sum_reduce)
assert U.allclose(g.ndata['h'], rg.ndata['h'])
# Grad check
g.ndata['h'].retain_grad()
rg.ndata['h'].retain_grad()
loss_func = th.nn.MSELoss()
target = th.zeros(3, 1)
loss = loss_func(rg.ndata['h'], target)
loss.backward()
assert U.allclose(g.ndata['h'].grad, rg.ndata['h'].grad)
if __name__ == '__main__': if __name__ == '__main__':
test_line_graph() test_line_graph()
test_no_backtracking() test_no_backtracking()
test_reverse()
test_reverse_shared_frames()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment