Commit 24bbdb74 authored by Mufei Li's avatar Mufei Li Committed by Minjie Wang
Browse files

[Feature] Reversed Graph and Transform Module (#331)

* reverse a graph

* Reverse a graph

* Fix

* Revert "Fix"

This reverts commit 1728826b4f3a083b76dfc0fdecd2bbf943a971b2.

* Fix

* Fix

* Delete vcs.xml

* Delete Project_Default.xml

* Fix

* Fix

* Fix

* Remove outdated test

* Reorg transform and update reverse (#2)

* Reorg transform and update reverse

* Fix doc and test

* Update test

* Resolve conflict

* CI oriented fix

* Remove outdated import

* Fix import

* Fix import

* define __all__ for wildcard imports

* Fix import

* Address circular imports

* Fix

* Fix test case

* Fix

* Fix

* Remove unused import

* Fix

* Fix

* Fix
parent 4bd4d6e3
......@@ -55,6 +55,7 @@ Transforming graph
DGLGraph.subgraphs
DGLGraph.edge_subgraph
DGLGraph.line_graph
DGLGraph.reverse
Converting from/to other format
-------------------------------
......
......@@ -13,3 +13,4 @@ API Reference
udf
sampler
data
transform
.. _apigraph:
Transform -- Graph Transformation
=================================
.. automodule:: dgl.transform
.. autosummary::
:toctree: ../../generated/
line_graph
reverse
......@@ -59,7 +59,7 @@ The backend is controlled by ``DGLBACKEND`` environment variable, which defaults
| | | `official website <https://pytorch.org>`_ |
+---------+---------+--------------------------------------------------+
| mxnet | MXNet | Requires nightly build; run the following |
| | | command to install (TODO): |
| | | command to install: |
| | | |
| | | .. code:: bash |
| | | |
......
......@@ -14,3 +14,4 @@ from .graph import DGLGraph
from .traversal import *
from .propagate import *
from .udf import NodeBatch, EdgeBatch
from .transform import *
......@@ -3,6 +3,7 @@ from __future__ import absolute_import
from collections import defaultdict
import dgl
from .base import ALL, is_all, DGLError
from . import backend as F
from . import init
......@@ -2760,22 +2761,16 @@ class DGLGraph(object):
def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph.
Parameters
----------
backtracking : bool, optional
Whether the returned line graph is backtracking.
See :func:`~dgl.transform.line_graph`.
"""
return dgl.line_graph(self, backtracking, shared)
shared : bool, optional
Whether the returned line graph shares representations with `self`.
def reverse(self, share_ndata=False, share_edata=False):
"""Return the reverse of this graph.
Returns
-------
DGLGraph
The line graph of this graph.
See :func:`~dgl.transform.reverse`.
"""
graph_data = self._graph.line_graph(backtracking)
node_frame = self._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
return dgl.reverse(self, share_ndata, share_edata)
def filter_nodes(self, predicate, nodes=ALL):
"""Return a tensor of node IDs that satisfy the given predicate.
......
"""Module for graph transformation methods."""
from .graph import DGLGraph
from .batched_graph import BatchedDGLGraph
__all__ = ['line_graph', 'reverse']
def line_graph(g, backtracking=True, shared=False):
"""Return the line graph of this graph.
Parameters
----------
g : dgl.DGLGraph
backtracking : bool, optional
Whether the returned line graph is backtracking.
shared : bool, optional
Whether the returned line graph shares representations with `self`.
Returns
-------
DGLGraph
The line graph of this graph.
"""
graph_data = g._graph.line_graph(backtracking)
node_frame = g._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
def reverse(g, share_ndata=False, share_edata=False):
"""Return the reverse of a graph
The reverse (also called converse, transpose) of a directed graph is another directed
graph on the same nodes with edges reversed in terms of direction.
Given a :class:`DGLGraph` object, we return another :class:`DGLGraph` object
representing its reverse.
Notes
-----
* This function does not support :class:`~dgl.BatchedDGLGraph` objects.
* We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently,
you can get a mismatched node/edge feature.
Parameters
----------
g : dgl.DGLGraph
share_ndata: bool, optional
If True, the original graph and the reversed graph share memory for node attributes.
Otherwise the reversed graph will not be initialized with node attributes.
share_edata: bool, optional
If True, the original graph and the reversed graph share memory for edge attributes.
Otherwise the reversed graph will not have edge attributes.
Examples
--------
Create a graph to reverse.
>>> import dgl
>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1, 2], [1, 2, 0])
>>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])
>>> g.edata['h'] = th.tensor([[3.], [4.], [5.]])
Reverse the graph and examine its structure.
>>> rg = g.reverse(share_ndata=True, share_edata=True)
>>> print(rg)
DGLGraph with 3 nodes and 3 edges.
Node data: {'h': Scheme(shape=(1,), dtype=torch.float32)}
Edge data: {'h': Scheme(shape=(1,), dtype=torch.float32)}
The edges are reversed now.
>>> rg.has_edges_between([1, 2, 0], [0, 1, 2])
tensor([1, 1, 1])
Reversed edges have the same feature as the original ones.
>>> g.edges[[0, 2], [1, 0]].data['h'] == rg.edges[[1, 0], [0, 2]].data['h']
tensor([[1],
[1]], dtype=torch.uint8)
The node/edge features of the reversed graph share memory with the original
graph, which is helpful for both forward computation and back propagation.
>>> g.ndata['h'] = g.ndata['h'] + 1
>>> rg.ndata['h']
tensor([[1.],
[2.],
[3.]])
"""
assert not isinstance(g, BatchedDGLGraph), \
'reverse is not supported for a BatchedDGLGraph object'
g_reversed = DGLGraph(multigraph=g.is_multigraph)
g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.edges()
g_reversed.add_edges(g_edges[1], g_edges[0])
if share_ndata:
g_reversed._node_frame = g._node_frame
if share_edata:
g_reversed._edge_frame = g._edge_frame
return g_reversed
......@@ -2,10 +2,12 @@ import torch as th
import networkx as nx
import numpy as np
import dgl
import dgl.function as fn
import utils as U
D = 5
# line graph related
def test_line_graph():
N = 5
G = dgl.DGLGraph(nx.star_graph(N))
......@@ -39,6 +41,61 @@ def test_no_backtracking():
assert not L.has_edge_between(e1, e2)
assert not L.has_edge_between(e2, e1)
# reverse graph related
def test_reverse():
g = dgl.DGLGraph()
g.add_nodes(5)
# The graph need not to be completely connected.
g.add_edges([0, 1, 2], [1, 2, 1])
g.ndata['h'] = th.tensor([[0.], [1.], [2.], [3.], [4.]])
g.edata['h'] = th.tensor([[5.], [6.], [7.]])
rg = g.reverse()
assert g.is_multigraph == rg.is_multigraph
assert g.number_of_nodes() == rg.number_of_nodes()
assert g.number_of_edges() == rg.number_of_edges()
assert U.allclose(rg.has_edges_between([1, 2, 1], [0, 1, 2]).float(), th.ones(3))
assert g.edge_id(0, 1) == rg.edge_id(1, 0)
assert g.edge_id(1, 2) == rg.edge_id(2, 1)
assert g.edge_id(2, 1) == rg.edge_id(1, 2)
def test_reverse_shared_frames():
g = dgl.DGLGraph()
g.add_nodes(3)
g.add_edges([0, 1, 2], [1, 2, 1])
g.ndata['h'] = th.tensor([[0.], [1.], [2.]], requires_grad=True)
g.edata['h'] = th.tensor([[3.], [4.], [5.]], requires_grad=True)
rg = g.reverse(share_ndata=True, share_edata=True)
assert U.allclose(g.ndata['h'], rg.ndata['h'])
assert U.allclose(g.edata['h'], rg.edata['h'])
assert U.allclose(g.edges[[0, 2], [1, 1]].data['h'],
rg.edges[[1, 1], [0, 2]].data['h'])
rg.ndata['h'] = rg.ndata['h'] + 1
assert U.allclose(rg.ndata['h'], g.ndata['h'])
g.edata['h'] = g.edata['h'] - 1
assert U.allclose(rg.edata['h'], g.edata['h'])
src_msg = fn.copy_src(src='h', out='m')
sum_reduce = fn.sum(msg='m', out='h')
rg.update_all(src_msg, sum_reduce)
assert U.allclose(g.ndata['h'], rg.ndata['h'])
# Grad check
g.ndata['h'].retain_grad()
rg.ndata['h'].retain_grad()
loss_func = th.nn.MSELoss()
target = th.zeros(3, 1)
loss = loss_func(rg.ndata['h'], target)
loss.backward()
assert U.allclose(g.ndata['h'].grad, rg.ndata['h'].grad)
if __name__ == '__main__':
test_line_graph()
test_no_backtracking()
test_reverse()
test_reverse_shared_frames()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment