Unverified Commit 48cbea72 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Transform] Module Interface for Transform (#3636)



* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix lint

* lint

* Update

* Update

* lint fix

* Fix CI

* Fix

* Fix CI

* Update

* Fix

* Update

* Update

* resolve conflict

* Fix CI
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-31-136.us-west-2.compute.internal>
parent 2aad1c0b
...@@ -15,3 +15,4 @@ API Reference ...@@ -15,3 +15,4 @@ API Reference
dgl.sampling dgl.sampling
dgl.contrib.UnifiedTensor dgl.contrib.UnifiedTensor
udf udf
transform
.. _apitransform-namespace:
dgl.transform
=============
.. currentmodule:: dgl.transform
.. automodule:: dgl.transform
BaseTransform
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: BaseTransform
:members: __call__, __repr__
:show-inheritance:
AddSelfLoop
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: AddSelfLoop
:show-inheritance:
RemoveSelfLoop
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: RemoveSelfLoop
:show-inheritance:
AddReverse
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: AddReverse
:show-inheritance:
ToSimple
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ToSimple
:show-inheritance:
LineGraph
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LineGraph
:show-inheritance:
KHopGraph
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: KHopGraph
:show-inheritance:
AddMetaPaths
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: AddMetaPaths
:show-inheritance:
KNNGraph
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: KNNGraph
:show-inheritance:
"""Transform for structures and features"""
from .functional import *
from .module import *
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
"""Module for graph transformation utilities.""" """Functional interface for transform"""
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from collections import defaultdict from collections import defaultdict
...@@ -21,22 +21,22 @@ import numpy as np ...@@ -21,22 +21,22 @@ import numpy as np
import scipy.sparse as sparse import scipy.sparse as sparse
import scipy.sparse.linalg import scipy.sparse.linalg
from ._ffi.function import _init_api from .._ffi.function import _init_api
from .base import dgl_warning, DGLError from ..base import dgl_warning, DGLError
from . import convert from .. import convert
from .heterograph import DGLHeteroGraph, DGLBlock from ..heterograph import DGLHeteroGraph, DGLBlock
from .heterograph_index import create_metagraph_index, create_heterograph_from_relations from ..heterograph_index import create_metagraph_index, create_heterograph_from_relations
from .frame import Frame from ..frame import Frame
from . import ndarray as nd from .. import ndarray as nd
from . import backend as F from .. import backend as F
from . import utils, batch from .. import utils, batch
from .partition import metis_partition_assignment from ..partition import metis_partition_assignment
from .partition import partition_graph_with_halo from ..partition import partition_graph_with_halo
from .partition import metis_partition from ..partition import metis_partition
from . import subgraph from .. import subgraph
# TO BE DEPRECATED # TO BE DEPRECATED
from ._deprecate.graph import DGLGraph as DGLGraphStale from .._deprecate.graph import DGLGraph as DGLGraphStale
__all__ = [ __all__ = [
'line_graph', 'line_graph',
...@@ -93,7 +93,7 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'): ...@@ -93,7 +93,7 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
columns correspond to coordinate/feature dimensions. columns correspond to coordinate/feature dimensions.
The nodes of the returned graph correspond to the points, where the predecessors The nodes of the returned graph correspond to the points, where the predecessors
of each point are its k-nearest neighbors measured by the Euclidean distance. of each point are its k-nearest neighbors measured by the chosen distance.
If :attr:`x` is a 3D tensor, then each submatrix will be transformed If :attr:`x` is a 3D tensor, then each submatrix will be transformed
into a separate graph. DGL then composes the graphs into a large into a separate graph. DGL then composes the graphs into a large
...@@ -715,7 +715,7 @@ def to_bidirected(g, copy_ndata=False, readonly=None): ...@@ -715,7 +715,7 @@ def to_bidirected(g, copy_ndata=False, readonly=None):
def add_reverse_edges(g, readonly=None, copy_ndata=True, def add_reverse_edges(g, readonly=None, copy_ndata=True,
copy_edata=False, ignore_bipartite=False): copy_edata=False, ignore_bipartite=False):
r"""Add an reversed edge for each edge in the input graph and return a new graph. r"""Add a reversed edge for each edge in the input graph and return a new graph.
For a graph with edges :math:`(i_1, j_1), \cdots, (i_n, j_n)`, this For a graph with edges :math:`(i_1, j_1), \cdots, (i_n, j_n)`, this
function creates a new graph with edges function creates a new graph with edges
...@@ -740,14 +740,14 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -740,14 +740,14 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
(Default: True) (Default: True)
copy_edata: bool, optional copy_edata: bool, optional
If True, the features of the reversed edges will be identical to If True, the features of the reversed edges will be identical to
the original ones." the original ones.
If False, the new graph will not have any edge features. If False, the new graph will not have any edge features.
(Default: False) (Default: False)
ignore_bipartite: bool, optional ignore_bipartite: bool, optional
If True, unidirectional bipartite graphs are ignored and If True, unidirectional bipartite graphs are ignored and
no error is raised. If False, an error will be raised if no error is raised. If False, an error will be raised if
an edge type of the input heterogeneous graph is for a unidirectional an edge type of the input heterogeneous graph is for a unidirectional
bipartite graph. bipartite graph.
...@@ -865,7 +865,7 @@ def line_graph(g, backtracking=True, shared=False): ...@@ -865,7 +865,7 @@ def line_graph(g, backtracking=True, shared=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
The line graph ``L(G)`` of a given graph ``G`` is defined as another graph where The line graph ``L(G)`` of a given graph ``G`` is defined as another graph where
the nodes in ``L(G)`` maps to the edges in ``G``. For any pair of edges ``(u, v)`` the nodes in ``L(G)`` correspond to the edges in ``G``. For any pair of edges ``(u, v)``
and ``(v, w)`` in ``G``, the corresponding node of edge ``(u, v)`` in ``L(G)`` will and ``(v, w)`` in ``G``, the corresponding node of edge ``(u, v)`` in ``L(G)`` will
have an edge connecting to the corresponding node of edge ``(v, w)``. have an edge connecting to the corresponding node of edge ``(v, w)``.
...@@ -1050,7 +1050,7 @@ def khop_graph(g, k, copy_ndata=True): ...@@ -1050,7 +1050,7 @@ def khop_graph(g, k, copy_ndata=True):
col = np.repeat(adj_k.col, multiplicity) col = np.repeat(adj_k.col, multiplicity)
# TODO(zihao): we should support creating multi-graph from scipy sparse matrix # TODO(zihao): we should support creating multi-graph from scipy sparse matrix
# in the future. # in the future.
new_g = convert.graph((row, col), num_nodes=n) new_g = convert.graph((row, col), num_nodes=n, idtype=g.idtype, device=g.device)
# handle ndata # handle ndata
if copy_ndata: if copy_ndata:
...@@ -2350,7 +2350,7 @@ def to_simple(g, ...@@ -2350,7 +2350,7 @@ def to_simple(g,
(Default: "count") (Default: "count")
writeback_mapping: bool, optional writeback_mapping: bool, optional
If True, return an extra write-back mapping for each edge If True, return an extra write-back mapping for each edge
type. The write-back mapping is a tensor recording type. The write-back mapping is a tensor recording
the mapping from the edge IDs in the input graph to the mapping from the edge IDs in the input graph to
the edge IDs in the result graph. If the graph is the edge IDs in the result graph. If the graph is
heterogeneous, DGL returns a dictionary of edge types and such heterogeneous, DGL returns a dictionary of edge types and such
...@@ -3195,4 +3195,4 @@ def rcmk_perm(g): ...@@ -3195,4 +3195,4 @@ def rcmk_perm(g):
perm = sparse.csgraph.reverse_cuthill_mckee(csr_adj) perm = sparse.csgraph.reverse_cuthill_mckee(csr_adj)
return perm.copy() return perm.copy()
_init_api("dgl.transform") _init_api("dgl.transform", __name__)
This diff is collapsed.
...@@ -1800,5 +1800,361 @@ def test_reorder_graph(idtype): ...@@ -1800,5 +1800,361 @@ def test_reorder_graph(idtype):
rfg = dgl.reorder_graph(fg) rfg = dgl.reorder_graph(fg)
assert 'csr' in sum(rfg.formats().values(), []) assert 'csr' in sum(rfg.formats().values(), [])
@parametrize_dtype
def test_module_add_self_loop(idtype):
g = dgl.graph(([1, 1], [1, 2]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((g.num_nodes(), 2))
g.edata['w'] = F.randn((g.num_edges(), 3))
# Case1: add self-loops with the default setting
transform = dgl.AddSelfLoop()
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.num_nodes() == g.num_nodes()
assert new_g.num_edges() == 4
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (1, 1), (1, 2), (2, 2)}
assert 'h' in new_g.ndata
assert 'w' in new_g.edata
# Case2: Remove self-loops first to avoid duplicate ones
transform = dgl.AddSelfLoop(allow_duplicate=True)
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.num_nodes() == g.num_nodes()
assert new_g.num_edges() == 5
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (1, 1), (1, 2), (2, 2)}
assert 'h' in new_g.ndata
assert 'w' in new_g.edata
# Create a heterogeneous graph
g = dgl.heterograph({
('user', 'plays', 'game'): ([0], [1]),
('user', 'follows', 'user'): ([1], [3])
}, idtype=idtype, device=F.ctx())
g.nodes['user'].data['h1'] = F.randn((4, 2))
g.edges['plays'].data['w1'] = F.randn((1, 3))
g.nodes['game'].data['h2'] = F.randn((2, 4))
g.edges['follows'].data['w2'] = F.randn((1, 5))
# Case3: add self-loops for a heterogeneous graph
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes
for nty in new_g.ntypes:
assert new_g.num_nodes(nty) == g.num_nodes(nty)
assert new_g.num_edges('plays') == 1
assert new_g.num_edges('follows') == 5
assert 'h1' in new_g.nodes['user'].data
assert 'h2' in new_g.nodes['game'].data
assert 'w1' in new_g.edges['plays'].data
assert 'w2' in new_g.edges['follows'].data
# Case4: add self-etypes for a heterogeneous graph
transform = dgl.AddSelfLoop(new_etypes=True)
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.ntypes == g.ntypes
assert set(new_g.canonical_etypes) == {
('user', 'plays', 'game'), ('user', 'follows', 'user'),
('user', 'self', 'user'), ('game', 'self', 'game')
}
for nty in new_g.ntypes:
assert new_g.num_nodes(nty) == g.num_nodes(nty)
assert new_g.num_edges('plays') == 1
assert new_g.num_edges('follows') == 5
assert new_g.num_edges(('user', 'self', 'user')) == 4
assert new_g.num_edges(('game', 'self', 'game')) == 2
assert 'h1' in new_g.nodes['user'].data
assert 'h2' in new_g.nodes['game'].data
assert 'w1' in new_g.edges['plays'].data
assert 'w2' in new_g.edges['follows'].data
@parametrize_dtype
def test_module_remove_self_loop(idtype):
transform = dgl.RemoveSelfLoop()
# Case1: homogeneous graph
g = dgl.graph(([1, 1], [1, 2]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((g.num_nodes(), 2))
g.edata['w'] = F.randn((g.num_edges(), 3))
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.num_nodes() == g.num_nodes()
assert new_g.num_edges() == 1
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(1, 2)}
assert 'h' in new_g.ndata
assert 'w' in new_g.edata
# Case2: heterogeneous graph
g = dgl.heterograph({
('user', 'plays', 'game'): ([0, 1], [1, 1]),
('user', 'follows', 'user'): ([1, 2], [2, 2])
}, idtype=idtype, device=F.ctx())
g.nodes['user'].data['h1'] = F.randn((3, 2))
g.edges['plays'].data['w1'] = F.randn((2, 3))
g.nodes['game'].data['h2'] = F.randn((2, 4))
g.edges['follows'].data['w2'] = F.randn((2, 5))
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes
for nty in new_g.ntypes:
assert new_g.num_nodes(nty) == g.num_nodes(nty)
assert new_g.num_edges('plays') == 2
assert new_g.num_edges('follows') == 1
assert 'h1' in new_g.nodes['user'].data
assert 'h2' in new_g.nodes['game'].data
assert 'w1' in new_g.edges['plays'].data
assert 'w2' in new_g.edges['follows'].data
@parametrize_dtype
def test_module_add_reverse(idtype):
transform = dgl.AddReverse()
# Case1: Add reverse edges for a homogeneous graph
g = dgl.graph(([0], [1]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((g.num_nodes(), 3))
g.edata['w'] = F.randn((g.num_edges(), 2))
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert g.num_nodes() == new_g.num_nodes()
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 0)}
assert F.allclose(g.ndata['h'], new_g.ndata['h'])
assert F.allclose(g.edata['w'], F.narrow_row(new_g.edata['w'], 0, 1))
assert F.allclose(F.narrow_row(new_g.edata['w'], 1, 2), F.zeros((1, 2), F.float32, F.ctx()))
# Case2: Add reverse edges for a homogeneous graph and copy edata
transform = dgl.AddReverse(copy_edata=True)
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert g.num_nodes() == new_g.num_nodes()
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 0)}
assert F.allclose(g.ndata['h'], new_g.ndata['h'])
assert F.allclose(g.edata['w'], F.narrow_row(new_g.edata['w'], 0, 1))
assert F.allclose(g.edata['w'], F.narrow_row(new_g.edata['w'], 1, 2))
# Case3: Add reverse edges for a heterogeneous graph
g = dgl.heterograph({
('user', 'plays', 'game'): ([0, 1], [1, 1]),
('user', 'follows', 'user'): ([1, 2], [2, 2])
}, device=F.ctx())
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert g.ntypes == new_g.ntypes
assert set(new_g.canonical_etypes) == {
('user', 'plays', 'game'), ('user', 'follows', 'user'), ('game', 'rev_plays', 'user')}
for nty in g.ntypes:
assert g.num_nodes(nty) == new_g.num_nodes(nty)
src, dst = new_g.edges(etype='plays')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 1)}
src, dst = new_g.edges(etype='follows')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(1, 2), (2, 2), (2, 1)}
src, dst = new_g.edges(etype='rev_plays')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(1, 1), (1, 0)}
# Case4: Enforce reverse edge types for symmetric canonical edge types
transform = dgl.AddReverse(sym_new_etype=True)
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert g.ntypes == new_g.ntypes
assert set(new_g.canonical_etypes) == {
('user', 'plays', 'game'), ('user', 'follows', 'user'),
('game', 'rev_plays', 'user'), ('user', 'rev_follows', 'user')}
for nty in g.ntypes:
assert g.num_nodes(nty) == new_g.num_nodes(nty)
src, dst = new_g.edges(etype='plays')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 1)}
src, dst = new_g.edges(etype='follows')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(1, 2), (2, 2)}
src, dst = new_g.edges(etype='rev_plays')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(1, 1), (1, 0)}
src, dst = new_g.edges(etype='rev_follows')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(2, 1), (2, 2)}
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not supported for to_simple")
@parametrize_dtype
def test_module_to_simple(idtype):
transform = dgl.ToSimple()
g = dgl.graph(([0, 1, 1], [1, 2, 2]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((g.num_nodes(), 2))
g.edata['w'] = F.tensor([[0.1], [0.2], [0.3]])
sg = transform(g)
assert sg.device == g.device
assert sg.idtype == g.idtype
assert sg.num_nodes() == g.num_nodes()
assert sg.num_edges() == 2
src, dst = sg.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 2)}
assert F.allclose(sg.edata['count'], F.tensor([1, 2]))
assert F.allclose(sg.ndata['h'], g.ndata['h'])
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2]),
('user', 'plays', 'game'): ([0, 1, 0], [1, 1, 1])
})
sg = transform(g)
assert sg.device == g.device
assert sg.idtype == g.idtype
assert sg.ntypes == g.ntypes
assert sg.canonical_etypes == g.canonical_etypes
for nty in sg.ntypes:
assert sg.num_nodes(nty) == g.num_nodes(nty)
for ety in sg.canonical_etypes:
assert sg.num_edges(ety) == 2
src, dst = sg.edges(etype='follows')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 2)}
src, dst = sg.edges(etype='plays')
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 1)}
@parametrize_dtype
def test_module_line_graph(idtype):
transform = dgl.LineGraph()
g = dgl.graph(([0, 1, 1], [1, 0, 2]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.tensor([[0.], [1.], [2.]])
g.edata['w'] = F.tensor([[0.], [0.1], [0.2]])
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.num_nodes() == g.num_edges()
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (0, 2), (1, 0)}
transform = dgl.LineGraph(backtracking=False)
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.num_nodes() == g.num_edges()
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 2)}
@parametrize_dtype
def test_module_khop_graph(idtype):
transform = dgl.KHopGraph(2)
g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((g.num_nodes(), 2))
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.num_nodes() == g.num_nodes()
assert F.allclose(g.ndata['h'], new_g.ndata['h'])
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 2)}
@parametrize_dtype
def test_module_add_metapaths(idtype):
g = dgl.heterograph({
('person', 'author', 'paper'): ([0, 0, 1], [1, 2, 2]),
('paper', 'accepted', 'venue'): ([1], [0]),
('paper', 'rejected', 'venue'): ([2], [1])
}, idtype=idtype, device=F.ctx())
g.nodes['venue'].data['h'] = F.randn((g.num_nodes('venue'), 2))
g.edges['author'].data['h'] = F.randn((g.num_edges('author'), 3))
# Case1: keep_orig_edges is True
metapaths = {
'accepted': [('person', 'author', 'paper'), ('paper', 'accepted', 'venue')],
'rejected': [('person', 'author', 'paper'), ('paper', 'rejected', 'venue')]
}
transform = dgl.AddMetaPaths(metapaths)
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.ntypes == g.ntypes
assert set(new_g.canonical_etypes) == {
('person', 'author', 'paper'), ('paper', 'accepted', 'venue'),
('paper', 'rejected', 'venue'), ('person', 'accepted', 'venue'),
('person', 'rejected', 'venue')
}
for nty in new_g.ntypes:
assert new_g.num_nodes(nty) == g.num_nodes(nty)
for ety in g.canonical_etypes:
assert new_g.num_edges(ety) == g.num_edges(ety)
assert F.allclose(g.nodes['venue'].data['h'], new_g.nodes['venue'].data['h'])
assert F.allclose(g.edges['author'].data['h'], new_g.edges['author'].data['h'])
src, dst = new_g.edges(etype=('person', 'accepted', 'venue'))
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0)}
src, dst = new_g.edges(etype=('person', 'rejected', 'venue'))
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 1)}
# Case2: keep_orig_edges is False
transform = dgl.AddMetaPaths(metapaths, keep_orig_edges=False)
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.ntypes == g.ntypes
assert len(new_g.canonical_etypes) == 2
for nty in new_g.ntypes:
assert new_g.num_nodes(nty) == g.num_nodes(nty)
assert F.allclose(g.nodes['venue'].data['h'], new_g.nodes['venue'].data['h'])
src, dst = new_g.edges(etype=('person', 'accepted', 'venue'))
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0)}
src, dst = new_g.edges(etype=('person', 'rejected', 'venue'))
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 1)}
@parametrize_dtype
def test_module_compose(idtype):
g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())
transform = dgl.Compose([dgl.AddReverse(), dgl.AddSelfLoop()])
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.num_edges() == 7
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 2), (1, 0), (2, 1), (0, 0), (1, 1), (2, 2)}
if __name__ == '__main__': if __name__ == '__main__':
test_partition_with_halo() test_partition_with_halo()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment