Unverified Commit 200340ab authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Transform] Reverse for heterogenous graph (#1784)



* reverse

* Add more test

* Fix lint

* Fix

* move to transform

* upd

* upd

* upd

* Add more test

* lint

* Fix

* Fix doc

* Fix

* test

* doc

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent fb9d2138
...@@ -4701,6 +4701,7 @@ class DGLHeteroGraph(object): ...@@ -4701,6 +4701,7 @@ class DGLHeteroGraph(object):
self._node_frames, self._node_frames,
self._edge_frames) self._edge_frames)
############################################################ ############################################################
# Internal APIs # Internal APIs
############################################################ ############################################################
......
...@@ -991,16 +991,21 @@ class HeteroGraphIndex(ObjectBase): ...@@ -991,16 +991,21 @@ class HeteroGraphIndex(ObjectBase):
""" """
return _CAPI_DGLHeteroGetFormatGraph(self, restrict_format) return _CAPI_DGLHeteroGetFormatGraph(self, restrict_format)
def reverse(self): def reverse(self, metagraph):
"""Reverse the heterogeneous graph adjacency """Reverse the heterogeneous graph adjacency
The node types and edge types are not changed The node types and edge types are not changed
Parameters
----------
metagraph : GraphIndex
Meta-graph.
Returns Returns
------- -------
A new graph index. A new graph index.
""" """
return _CAPI_DGLHeteroReverse(self) return _CAPI_DGLHeteroReverse(metagraph, self)
@register_object('graph.HeteroSubgraph') @register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase): class HeteroSubgraphIndex(ObjectBase):
......
...@@ -9,7 +9,7 @@ from .graph import DGLGraph ...@@ -9,7 +9,7 @@ from .graph import DGLGraph
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from . import ndarray as nd from . import ndarray as nd
from . import backend as F from . import backend as F
from .graph_index import from_coo from .graph_index import from_coo, from_edge_list
from .graph_index import _get_halo_subgraph_inner_node from .graph_index import _get_halo_subgraph_inner_node
from .graph import unbatch from .graph import unbatch
from .convert import graph, bipartite from .convert import graph, bipartite
...@@ -24,6 +24,7 @@ __all__ = [ ...@@ -24,6 +24,7 @@ __all__ = [
'khop_adj', 'khop_adj',
'khop_graph', 'khop_graph',
'reverse', 'reverse',
'reverse_heterograph',
'to_simple_graph', 'to_simple_graph',
'to_bidirected', 'to_bidirected',
'laplacian_lambda_max', 'laplacian_lambda_max',
...@@ -304,37 +305,39 @@ def khop_graph(g, k): ...@@ -304,37 +305,39 @@ def khop_graph(g, k):
# in the future. # in the future.
return DGLGraph(from_coo(n, row, col, True)) return DGLGraph(from_coo(n, row, col, True))
def reverse(g, share_ndata=False, share_edata=False): def reverse(g, copy_ndata=False, copy_edata=False):
"""Return the reverse of a graph """Return the reverse of a graph
The reverse (also called converse, transpose) of a directed graph is another directed The reverse (also called converse, transpose) of a directed graph is another directed
graph on the same nodes with edges reversed in terms of direction. graph on the same nodes with edges reversed in terms of direction.
Given a :class:`dgl.DGLGraph` object, we return another :class:`dgl.DGLGraph` object
Given a :class:`DGLGraph` object, we return another :class:`DGLGraph` object
representing its reverse. representing its reverse.
Notes
-----
* We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently,
you can get a mismatched node/edge feature.
Parameters Parameters
---------- ----------
g : dgl.DGLGraph g : dgl.DGLGraph
The input graph. The input graph.
share_ndata: bool, optional copy_ndata: bool, optional
If True, the original graph and the reversed graph share memory for node attributes. If True, node attributes are copied from the original graph to the reversed graph.
Otherwise the reversed graph will not be initialized with node attributes. Otherwise the reversed graph will not be initialized with node attributes.
share_edata: bool, optional copy_edata: bool, optional
If True, the original graph and the reversed graph share memory for edge attributes. If True, edge attributes are copied from the original graph to the reversed graph.
Otherwise the reversed graph will not have edge attributes. Otherwise the reversed graph will not have edge attributes.
Return
------
dgl.DGLGraph
The reversed graph.
Notes
-----
* We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently,
you can get a mismatched node/edge feature.
Examples Examples
-------- --------
Create a graph to reverse. Create a graph to reverse.
>>> import dgl >>> import dgl
>>> import torch as th >>> import torch as th
>>> g = dgl.DGLGraph() >>> g = dgl.DGLGraph()
...@@ -342,29 +345,21 @@ def reverse(g, share_ndata=False, share_edata=False): ...@@ -342,29 +345,21 @@ def reverse(g, share_ndata=False, share_edata=False):
>>> g.add_edges([0, 1, 2], [1, 2, 0]) >>> g.add_edges([0, 1, 2], [1, 2, 0])
>>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]]) >>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])
>>> g.edata['h'] = th.tensor([[3.], [4.], [5.]]) >>> g.edata['h'] = th.tensor([[3.], [4.], [5.]])
Reverse the graph and examine its structure. Reverse the graph and examine its structure.
>>> rg = g.reverse(copy_ndata=True, copy_edata=True)
>>> rg = g.reverse(share_ndata=True, share_edata=True)
>>> print(rg) >>> print(rg)
DGLGraph with 3 nodes and 3 edges. DGLGraph with 3 nodes and 3 edges.
Node data: {'h': Scheme(shape=(1,), dtype=torch.float32)} Node data: {'h': Scheme(shape=(1,), dtype=torch.float32)}
Edge data: {'h': Scheme(shape=(1,), dtype=torch.float32)} Edge data: {'h': Scheme(shape=(1,), dtype=torch.float32)}
The edges are reversed now. The edges are reversed now.
>>> rg.has_edges_between([1, 2, 0], [0, 1, 2]) >>> rg.has_edges_between([1, 2, 0], [0, 1, 2])
tensor([1, 1, 1]) tensor([1, 1, 1])
Reversed edges have the same feature as the original ones. Reversed edges have the same feature as the original ones.
>>> g.edges[[0, 2], [1, 0]].data['h'] == rg.edges[[1, 0], [0, 2]].data['h'] >>> g.edges[[0, 2], [1, 0]].data['h'] == rg.edges[[1, 0], [0, 2]].data['h']
tensor([[1], tensor([[1],
[1]], dtype=torch.uint8) [1]], dtype=torch.uint8)
The node/edge features of the reversed graph share memory with the original The node/edge features of the reversed graph share memory with the original
graph, which is helpful for both forward computation and back propagation. graph, which is helpful for both forward computation and back propagation.
>>> g.ndata['h'] = g.ndata['h'] + 1 >>> g.ndata['h'] = g.ndata['h'] + 1
>>> rg.ndata['h'] >>> rg.ndata['h']
tensor([[1.], tensor([[1.],
...@@ -377,12 +372,164 @@ def reverse(g, share_ndata=False, share_edata=False): ...@@ -377,12 +372,164 @@ def reverse(g, share_ndata=False, share_edata=False):
g_reversed.add_edges(g_edges[1], g_edges[0]) g_reversed.add_edges(g_edges[1], g_edges[0])
g_reversed._batch_num_nodes = g._batch_num_nodes g_reversed._batch_num_nodes = g._batch_num_nodes
g_reversed._batch_num_edges = g._batch_num_edges g_reversed._batch_num_edges = g._batch_num_edges
if share_ndata: if copy_ndata:
g_reversed._node_frame = g._node_frame g_reversed._node_frame = g._node_frame
if share_edata: if copy_edata:
g_reversed._edge_frame = g._edge_frame g_reversed._edge_frame = g._edge_frame
return g_reversed return g_reversed
def reverse_heterograph(g, copy_ndata=True, copy_edata=False):
r"""Return the reverse of a graph.
The reverse (also called converse, transpose) of a graph with edges
:math:`(i_1, j_1), (i_2, j_2), \cdots` is a new graph with edges
:math:`(j_1, i_1), (j_2, i_2), \cdots`.
For a heterograph with multiple edge types, we can treat edges corresponding
to each type as a separate graph and compute the reverse for each of them.
If the original edge type is (A, B, C), its reverse will have edge type (C, B, A).
Given a :class:`dgl.DGLGraph` object, we return another :class:`dgl.DGLGraph`
object representing its reverse.
Parameters
----------
g : dgl.DGLGraph
The input graph.
copy_ndata: bool, optional
If True, the node features of the reversed graph are copied from the
original graph. If False, the reversed graph will not have any node features.
(Default: True)
copy_edata: bool, optional
If True, the edge features of the reversed graph are copied from the
original graph. If False, the reversed graph will not have any edge features.
(Default: False)
Return
------
dgl.DGLGraph
The reversed graph.
Notes
-----
If ``copy_ndata`` or ``copy_edata`` is ``True``, same tensors will be used for
the features of the original graph and the reversed graph to save memory cost.
As a result, users
should avoid performing in-place operations on the features of the reversed
graph, which will corrupt the features of the original graph as well. For
concrete examples, refer to the ``Examples`` section below.
Examples
--------
**Homographs or Heterographs with A Single Edge Type**
Create a graph to reverse.
>>> import dgl
>>> import torch as th
>>> g = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 0])))
>>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])
>>> g.edata['h'] = th.tensor([[3.], [4.], [5.]])
Reverse the graph.
>>> rg = dgl.reverse(g, copy_edata=True)
>>> rg.ndata['h']
tensor([[0.],
[1.],
[2.]])
The i-th edge in the reversed graph corresponds to the i-th edge in the
original graph. When ``copy_edata`` is ``True``, they have the same features.
>>> rg.edges()
(tensor([1, 2, 0]), tensor([0, 1, 2]))
>>> rg.edata['h']
tensor([[3.],
[4.],
[5.]])
**In-place operations on features of one graph will be reflected on features of
its reverse, which is dangerous. Out-place operations will not be reflected.**
>>> rg.ndata['h'] += 1
>>> g.ndata['h']
tensor([[1.],
[2.],
[3.]])
>>> g.ndata['h'] += 1
>>> rg.ndata['h']
tensor([[2.],
[3.],
[4.]])
>>> rg.ndata['h2'] = th.ones(3, 1)
>>> 'h2' in g.ndata
False
**Heterographs with Multiple Edge Types**
>>> g = dgl.heterograph({
>>> ('user', 'follows', 'user'): (th.tensor([0, 2]), th.tensor([1, 2])),
>>> ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1]))
>>> })
>>> g.nodes['game'].data['hv'] = th.ones(3, 1)
>>> g.edges['plays'].data['he'] = th.zeros(3, 1)
The reverse of the graph above can be obtained by combining the reverse of the
subgraph corresponding to ('user', 'follows', 'user') and the subgraph corresponding
to ('user', 'plays', 'game'). The reverse for a graph with relation (h, r, t) will
have relation (t, r, h).
>>> rg = dgl.reverse(g, copy_ndata=True)
>>> rg
Graph(num_nodes={'game': 3, 'user': 3},
num_edges={('user', 'follows', 'user'): 2, ('game', 'plays', 'user'): 3},
metagraph=[('user', 'user'), ('game', 'user')])
>>> rg.edges(etype='follows')
(tensor([1, 2]), tensor([0, 2]))
>>> rg.edges(etype='plays')
(tensor([2, 1, 1]), tensor([1, 2, 1]))
>>> rg.nodes['game'].data['hv]
tensor([[1.],
[1.],
[1.]])
>>> rg.edges['plays'].data
{}
"""
# TODO(0.5 release, xiangsx) need to handle BLOCK
# currently reversing a block results in undefined behavior
canonical_etypes = g.canonical_etypes
meta_edges_src = []
meta_edges_dst = []
etypes = []
for c_etype in canonical_etypes:
meta_edges_src.append(g.get_ntype_id(c_etype[2]))
meta_edges_dst.append(g.get_ntype_id(c_etype[0]))
etypes.append(c_etype[1])
metagraph = from_edge_list((meta_edges_src, meta_edges_dst), True)
gidx = g._graph.reverse(metagraph)
new_g = DGLHeteroGraph(gidx, g.ntypes, etypes)
# handle ndata
if copy_ndata:
# for each ntype
for ntype in g.ntypes:
# for each data field
for k in g.nodes[ntype].data:
new_g.nodes[ntype].data[k] = g.nodes[ntype].data[k]
# handle edata
if copy_edata:
# for each etype
for etype in canonical_etypes:
# for each data field
for k in g.edges[etype].data:
new_g.edges[etype].data[k] = g.edges[etype].data[k]
return new_g
DGLHeteroGraph.reverse = reverse_heterograph
def to_simple_graph(g): def to_simple_graph(g):
"""Convert the graph to a simple graph with no multi-edge. """Convert the graph to a simple graph with no multi-edge.
......
...@@ -600,7 +600,8 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes") ...@@ -600,7 +600,8 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; GraphRef meta_graph = args[0];
HeteroGraphRef hg = args[1];
CHECK_GT(hg->NumEdgeTypes(), 0); CHECK_GT(hg->NumEdgeTypes(), 0);
auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr()); auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
std::vector<HeteroGraphPtr> rev_ugs; std::vector<HeteroGraphPtr> rev_ugs;
...@@ -613,6 +614,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse") ...@@ -613,6 +614,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse")
} }
// node types are not changed // node types are not changed
const auto& num_nodes = g->NumVerticesPerType(); const auto& num_nodes = g->NumVerticesPerType();
*rv = CreateHeteroGraph(hg->meta_graph(), rev_ugs, num_nodes); auto hgptr = CreateHeteroGraph(meta_graph.sptr(), rev_ugs, num_nodes);
*rv = HeteroGraphRef(hgptr);
}); });
} // namespace dgl } // namespace dgl
...@@ -1944,7 +1944,7 @@ def test_reverse(index_dtype): ...@@ -1944,7 +1944,7 @@ def test_reverse(index_dtype):
('user', 'follows', 'user'): ([0, 1, 2, 4, 3 ,1, 3], [1, 2, 3, 2, 0, 0, 1]), ('user', 'follows', 'user'): ([0, 1, 2, 4, 3 ,1, 3], [1, 2, 3, 2, 0, 0, 1]),
}, index_dtype=index_dtype) }, index_dtype=index_dtype)
gidx = g._graph gidx = g._graph
r_gidx = gidx.reverse() r_gidx = gidx.reverse(gidx.metagraph)
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0) assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0) assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0)
...@@ -1956,7 +1956,7 @@ def test_reverse(index_dtype): ...@@ -1956,7 +1956,7 @@ def test_reverse(index_dtype):
# force to start with 'csr' # force to start with 'csr'
gidx = gidx.to_format('csr') gidx = gidx.to_format('csr')
gidx = gidx.to_format('any') gidx = gidx.to_format('any')
r_gidx = gidx.reverse() r_gidx = gidx.reverse(gidx.metagraph)
assert gidx.format_in_use(0)[0] == 'csr' assert gidx.format_in_use(0)[0] == 'csr'
assert r_gidx.format_in_use(0)[0] == 'csc' assert r_gidx.format_in_use(0)[0] == 'csc'
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0) assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
...@@ -1969,7 +1969,7 @@ def test_reverse(index_dtype): ...@@ -1969,7 +1969,7 @@ def test_reverse(index_dtype):
# force to start with 'csc' # force to start with 'csc'
gidx = gidx.to_format('csc') gidx = gidx.to_format('csc')
gidx = gidx.to_format('any') gidx = gidx.to_format('any')
r_gidx = gidx.reverse() r_gidx = gidx.reverse(gidx.metagraph)
assert gidx.format_in_use(0)[0] == 'csc' assert gidx.format_in_use(0)[0] == 'csc'
assert r_gidx.format_in_use(0)[0] == 'csr' assert r_gidx.format_in_use(0)[0] == 'csr'
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0) assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
...@@ -1985,7 +1985,7 @@ def test_reverse(index_dtype): ...@@ -1985,7 +1985,7 @@ def test_reverse(index_dtype):
('developer', 'develops', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]), ('developer', 'develops', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
}, index_dtype=index_dtype) }, index_dtype=index_dtype)
gidx = g._graph gidx = g._graph
r_gidx = gidx.reverse() r_gidx = gidx.reverse(gidx.metagraph)
# three node types and three edge types # three node types and three edge types
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0) assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_nodes(1) == r_gidx.number_of_nodes(1) assert gidx.number_of_nodes(1) == r_gidx.number_of_nodes(1)
...@@ -2009,7 +2009,7 @@ def test_reverse(index_dtype): ...@@ -2009,7 +2009,7 @@ def test_reverse(index_dtype):
# force to start with 'csr' # force to start with 'csr'
gidx = gidx.to_format('csr') gidx = gidx.to_format('csr')
gidx = gidx.to_format('any') gidx = gidx.to_format('any')
r_gidx = gidx.reverse() r_gidx = gidx.reverse(gidx.metagraph)
# three node types and three edge types # three node types and three edge types
assert gidx.format_in_use(0)[0] == 'csr' assert gidx.format_in_use(0)[0] == 'csr'
assert r_gidx.format_in_use(0)[0] == 'csc' assert r_gidx.format_in_use(0)[0] == 'csc'
...@@ -2039,7 +2039,7 @@ def test_reverse(index_dtype): ...@@ -2039,7 +2039,7 @@ def test_reverse(index_dtype):
# force to start with 'csc' # force to start with 'csc'
gidx = gidx.to_format('csc') gidx = gidx.to_format('csc')
gidx = gidx.to_format('any') gidx = gidx.to_format('any')
r_gidx = gidx.reverse() r_gidx = gidx.reverse(gidx.metagraph)
# three node types and three edge types # three node types and three edge types
assert gidx.format_in_use(0)[0] == 'csc' assert gidx.format_in_use(0)[0] == 'csc'
assert r_gidx.format_in_use(0)[0] == 'csr' assert r_gidx.format_in_use(0)[0] == 'csr'
...@@ -2066,6 +2066,7 @@ def test_reverse(index_dtype): ...@@ -2066,6 +2066,7 @@ def test_reverse(index_dtype):
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor()) assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor()) assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
if __name__ == '__main__': if __name__ == '__main__':
# test_create() # test_create()
# test_query() # test_query()
...@@ -2092,6 +2093,6 @@ if __name__ == '__main__': ...@@ -2092,6 +2093,6 @@ if __name__ == '__main__':
# test_isolated_ntype() # test_isolated_ntype()
# test_bipartite() # test_bipartite()
# test_dtype_cast() # test_dtype_cast()
# test_reverse("int32") test_reverse("int32")
test_format() test_format()
pass pass
...@@ -56,7 +56,7 @@ def test_hetero_linegraph(index_dtype): ...@@ -56,7 +56,7 @@ def test_hetero_linegraph(index_dtype):
np.array([0, 1, 2, 4])) np.array([0, 1, 2, 4]))
assert np.array_equal(F.asnumpy(col), assert np.array_equal(F.asnumpy(col),
np.array([4, 0, 3, 1])) np.array([4, 0, 3, 1]))
g = dgl.graph(([0, 1, 1, 2, 2],[2, 0, 2, 0, 1]), g = dgl.graph(([0, 1, 1, 2, 2],[2, 0, 2, 0, 1]),
'user', 'follows', restrict_format='csr', index_dtype=index_dtype) 'user', 'follows', restrict_format='csr', index_dtype=index_dtype)
lg = dgl.line_heterograph(g) lg = dgl.line_heterograph(g)
assert lg.number_of_nodes() == 5 assert lg.number_of_nodes() == 5
...@@ -67,7 +67,7 @@ def test_hetero_linegraph(index_dtype): ...@@ -67,7 +67,7 @@ def test_hetero_linegraph(index_dtype):
assert np.array_equal(F.asnumpy(col), assert np.array_equal(F.asnumpy(col),
np.array([3, 4, 0, 3, 4, 0, 1, 2])) np.array([3, 4, 0, 3, 4, 0, 1, 2]))
g = dgl.graph(([0, 1, 1, 2, 2],[2, 0, 2, 0, 1]), g = dgl.graph(([0, 1, 1, 2, 2],[2, 0, 2, 0, 1]),
'user', 'follows', restrict_format='csc', index_dtype=index_dtype) 'user', 'follows', restrict_format='csc', index_dtype=index_dtype)
lg = dgl.line_heterograph(g) lg = dgl.line_heterograph(g)
assert lg.number_of_nodes() == 5 assert lg.number_of_nodes() == 5
...@@ -94,8 +94,6 @@ def test_no_backtracking(): ...@@ -94,8 +94,6 @@ def test_no_backtracking():
assert not L.has_edge_between(e2, e1) assert not L.has_edge_between(e2, e1)
# reverse graph related # reverse graph related
def test_reverse(): def test_reverse():
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(5) g.add_nodes(5)
...@@ -115,6 +113,117 @@ def test_reverse(): ...@@ -115,6 +113,117 @@ def test_reverse():
assert g.edge_id(1, 2) == rg.edge_id(2, 1) assert g.edge_id(1, 2) == rg.edge_id(2, 1)
assert g.edge_id(2, 1) == rg.edge_id(1, 2) assert g.edge_id(2, 1) == rg.edge_id(1, 2)
# test dgl.reverse_heterograph
# test homogeneous graph
g = dgl.graph((F.tensor([0, 1, 2]), F.tensor([1, 2, 0])))
g.ndata['h'] = F.tensor([[0.], [1.], [2.]])
g.edata['h'] = F.tensor([[3.], [4.], [5.]])
g_r = dgl.reverse_heterograph(g)
assert g.number_of_nodes() == g_r.number_of_nodes()
assert g.number_of_edges() == g_r.number_of_edges()
u_g, v_g, eids_g = g.all_edges(form='all')
u_rg, v_rg, eids_rg = g_r.all_edges(form='all')
assert F.array_equal(u_g, v_rg)
assert F.array_equal(v_g, u_rg)
assert F.array_equal(eids_g, eids_rg)
assert F.array_equal(g.ndata['h'], g_r.ndata['h'])
assert len(g_r.edata) == 0
# without share ndata
g_r = dgl.reverse_heterograph(g, copy_ndata=False)
assert g.number_of_nodes() == g_r.number_of_nodes()
assert g.number_of_edges() == g_r.number_of_edges()
assert len(g_r.ndata) == 0
assert len(g_r.edata) == 0
# with share ndata and edata
g_r = dgl.reverse_heterograph(g, copy_ndata=True, copy_edata=True)
assert g.number_of_nodes() == g_r.number_of_nodes()
assert g.number_of_edges() == g_r.number_of_edges()
assert F.array_equal(g.ndata['h'], g_r.ndata['h'])
assert F.array_equal(g.edata['h'], g_r.edata['h'])
# add new node feature to g_r
g_r.ndata['hh'] = F.tensor([0, 1, 2])
assert ('hh' in g.ndata) is False
assert ('hh' in g_r.ndata) is True
# add new edge feature to g_r
g_r.edata['hh'] = F.tensor([0, 1, 2])
assert ('hh' in g.edata) is False
assert ('hh' in g_r.edata) is True
# test heterogeneous graph
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 4, 3 ,1, 3], [1, 2, 3, 2, 0, 0, 1]),
('user', 'plays', 'game'): ([0, 0, 2, 3, 3, 4, 1], [1, 0, 1, 0, 1, 0, 0]),
('developer', 'develops', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])})
g.nodes['user'].data['h'] = F.tensor([0, 1, 2, 3, 4])
g.nodes['user'].data['hh'] = F.tensor([1, 1, 1, 1, 1])
g.nodes['game'].data['h'] = F.tensor([0, 1])
g.edges['follows'].data['h'] = F.tensor([0, 1, 2, 4, 3 ,1, 3])
g.edges['follows'].data['hh'] = F.tensor([1, 2, 3, 2, 0, 0, 1])
g_r = dgl.reverse_heterograph(g)
for etype_g, etype_gr in zip(g.canonical_etypes, g_r.canonical_etypes):
assert etype_g[0] == etype_gr[2]
assert etype_g[1] == etype_gr[1]
assert etype_g[2] == etype_gr[0]
assert g.number_of_edges(etype_g) == g_r.number_of_edges(etype_gr)
for ntype in g.ntypes:
assert g.number_of_nodes(ntype) == g_r.number_of_nodes(ntype)
assert F.array_equal(g.nodes['user'].data['h'], g_r.nodes['user'].data['h'])
assert F.array_equal(g.nodes['user'].data['hh'], g_r.nodes['user'].data['hh'])
assert F.array_equal(g.nodes['game'].data['h'], g_r.nodes['game'].data['h'])
assert len(g_r.edges['follows'].data) == 0
u_g, v_g, eids_g = g.all_edges(form='all', etype=('user', 'follows', 'user'))
u_rg, v_rg, eids_rg = g_r.all_edges(form='all', etype=('user', 'follows', 'user'))
assert F.array_equal(u_g, v_rg)
assert F.array_equal(v_g, u_rg)
assert F.array_equal(eids_g, eids_rg)
u_g, v_g, eids_g = g.all_edges(form='all', etype=('user', 'plays', 'game'))
u_rg, v_rg, eids_rg = g_r.all_edges(form='all', etype=('game', 'plays', 'user'))
assert F.array_equal(u_g, v_rg)
assert F.array_equal(v_g, u_rg)
assert F.array_equal(eids_g, eids_rg)
u_g, v_g, eids_g = g.all_edges(form='all', etype=('developer', 'develops', 'game'))
u_rg, v_rg, eids_rg = g_r.all_edges(form='all', etype=('game', 'develops', 'developer'))
assert F.array_equal(u_g, v_rg)
assert F.array_equal(v_g, u_rg)
assert F.array_equal(eids_g, eids_rg)
# withour share ndata
g_r = dgl.reverse_heterograph(g, copy_ndata=False)
for etype_g, etype_gr in zip(g.canonical_etypes, g_r.canonical_etypes):
assert etype_g[0] == etype_gr[2]
assert etype_g[1] == etype_gr[1]
assert etype_g[2] == etype_gr[0]
assert g.number_of_edges(etype_g) == g_r.number_of_edges(etype_gr)
for ntype in g.ntypes:
assert g.number_of_nodes(ntype) == g_r.number_of_nodes(ntype)
assert len(g_r.nodes['user'].data) == 0
assert len(g_r.nodes['game'].data) == 0
g_r = dgl.reverse_heterograph(g, copy_ndata=True, copy_edata=True)
print(g_r)
for etype_g, etype_gr in zip(g.canonical_etypes, g_r.canonical_etypes):
assert etype_g[0] == etype_gr[2]
assert etype_g[1] == etype_gr[1]
assert etype_g[2] == etype_gr[0]
assert g.number_of_edges(etype_g) == g_r.number_of_edges(etype_gr)
assert F.array_equal(g.edges['follows'].data['h'], g_r.edges['follows'].data['h'])
assert F.array_equal(g.edges['follows'].data['hh'], g_r.edges['follows'].data['hh'])
# add new node feature to g_r
g_r.nodes['user'].data['hhh'] = F.tensor([0, 1, 2, 3, 4])
assert ('hhh' in g.nodes['user'].data) is False
assert ('hhh' in g_r.nodes['user'].data) is True
# add new edge feature to g_r
g_r.edges['follows'].data['hhh'] = F.tensor([1, 2, 3, 2, 0, 0, 1])
assert ('hhh' in g.edges['follows'].data) is False
assert ('hhh' in g_r.edges['follows'].data) is True
def test_reverse_shared_frames(): def test_reverse_shared_frames():
g = dgl.DGLGraph() g = dgl.DGLGraph()
...@@ -531,7 +640,7 @@ def test_compact(index_dtype): ...@@ -531,7 +640,7 @@ def test_compact(index_dtype):
g3, always_preserve=F.tensor([1, 7], dtype=getattr(F, index_dtype))) g3, always_preserve=F.tensor([1, 7], dtype=getattr(F, index_dtype)))
induced_nodes = {ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in new_g3.ntypes} induced_nodes = {ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in new_g3.ntypes}
induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}
assert new_g3._idtype_str == index_dtype assert new_g3._idtype_str == index_dtype
assert set(induced_nodes['user']) == set([0, 1, 2, 7]) assert set(induced_nodes['user']) == set([0, 1, 2, 7])
_check(g3, new_g3, induced_nodes) _check(g3, new_g3, induced_nodes)
...@@ -551,7 +660,7 @@ def test_compact(index_dtype): ...@@ -551,7 +660,7 @@ def test_compact(index_dtype):
new_g1, new_g2 = dgl.compact_graphs( new_g1, new_g2 = dgl.compact_graphs(
[g1, g2], always_preserve={'game': F.tensor([4, 7], dtype=getattr(F, index_dtype))}) [g1, g2], always_preserve={'game': F.tensor([4, 7], dtype=getattr(F, index_dtype))})
induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes} induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes}
induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}
assert new_g1._idtype_str == index_dtype assert new_g1._idtype_str == index_dtype
assert new_g2._idtype_str == index_dtype assert new_g2._idtype_str == index_dtype
assert set(induced_nodes['user']) == set([1, 3, 5, 2, 7, 8, 9]) assert set(induced_nodes['user']) == set([1, 3, 5, 2, 7, 8, 9])
...@@ -564,7 +673,7 @@ def test_compact(index_dtype): ...@@ -564,7 +673,7 @@ def test_compact(index_dtype):
[g3, g4], always_preserve=F.tensor([1, 7], dtype=getattr(F, index_dtype))) [g3, g4], always_preserve=F.tensor([1, 7], dtype=getattr(F, index_dtype)))
induced_nodes = {ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in new_g3.ntypes} induced_nodes = {ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in new_g3.ntypes}
induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}
assert new_g3._idtype_str == index_dtype assert new_g3._idtype_str == index_dtype
assert new_g4._idtype_str == index_dtype assert new_g4._idtype_str == index_dtype
assert set(induced_nodes['user']) == set([0, 1, 2, 3, 5, 7]) assert set(induced_nodes['user']) == set([0, 1, 2, 3, 5, 7])
...@@ -757,7 +866,7 @@ if __name__ == '__main__': ...@@ -757,7 +866,7 @@ if __name__ == '__main__':
test_reorder_nodes() test_reorder_nodes()
# test_line_graph() # test_line_graph()
# test_no_backtracking() # test_no_backtracking()
# test_reverse() test_reverse()
# test_reverse_shared_frames() # test_reverse_shared_frames()
# test_simple_graph() # test_simple_graph()
# test_bidirected_graph() # test_bidirected_graph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment