Unverified Commit 3e72c53a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug?] Fix #1563 (#1642)

parent ba2ee7bd
"""Classes for heterogeneous graphs.""" """Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines #pylint: disable= too-many-lines
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
import copy import copy
import networkx as nx import networkx as nx
...@@ -1892,10 +1893,13 @@ class DGLHeteroGraph(object): ...@@ -1892,10 +1893,13 @@ class DGLHeteroGraph(object):
Parameters Parameters
---------- ----------
nodes : dict[str->list or iterable] nodes : list or dict[str->list or iterable]
A dictionary mapping node types to node ID array for constructing A dictionary mapping node types to node ID array for constructing
subgraph. All nodes must exist in the graph. subgraph. All nodes must exist in the graph.
If the graph only has one node type, one can just specify a list,
tensor, or any iterable of node IDs intead.
Returns Returns
------- -------
G : DGLHeteroGraph G : DGLHeteroGraph
...@@ -1952,7 +1956,11 @@ class DGLHeteroGraph(object): ...@@ -1952,7 +1956,11 @@ class DGLHeteroGraph(object):
-------- --------
edge_subgraph edge_subgraph
""" """
check_same_dtype(self._idtype_str, nodes) if not isinstance(nodes, Mapping):
assert len(self.ntypes) == 1, \
'need a dict of node type and IDs for graph with multiple node types'
nodes = {self.ntypes[0]: nodes}
check_idtype_dict(self._idtype_str, nodes)
induced_nodes = [utils.toindex(nodes.get(ntype, []), self._idtype_str) induced_nodes = [utils.toindex(nodes.get(ntype, []), self._idtype_str)
for ntype in self.ntypes] for ntype in self.ntypes]
sgi = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
...@@ -1975,6 +1983,9 @@ class DGLHeteroGraph(object): ...@@ -1975,6 +1983,9 @@ class DGLHeteroGraph(object):
The edge types are characterized by triplets of The edge types are characterized by triplets of
``(src type, etype, dst type)``. ``(src type, etype, dst type)``.
If the graph only has one edge type, one can just specify a list,
tensor, or any iterable of edge IDs intead.
preserve_nodes : bool preserve_nodes : bool
Whether to preserve all nodes or not. If false, all nodes Whether to preserve all nodes or not. If false, all nodes
without edges will be removed. (Default: False) without edges will be removed. (Default: False)
...@@ -2035,6 +2046,10 @@ class DGLHeteroGraph(object): ...@@ -2035,6 +2046,10 @@ class DGLHeteroGraph(object):
-------- --------
subgraph subgraph
""" """
if not isinstance(edges, Mapping):
assert len(self.canonical_etypes) == 1, \
'need a dict of edge type and IDs for graph with multiple edge types'
edges = {self.canonical_etypes[0]: edges}
check_idtype_dict(self._idtype_str, edges) check_idtype_dict(self._idtype_str, edges)
edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()} edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
induced_edges = [ induced_edges = [
......
...@@ -898,6 +898,9 @@ def test_transform(index_dtype): ...@@ -898,6 +898,9 @@ def test_transform(index_dtype):
@parametrize_dtype @parametrize_dtype
def test_subgraph(index_dtype): def test_subgraph(index_dtype):
g = create_test_heterograph(index_dtype) g = create_test_heterograph(index_dtype)
g_graph = g['follows']
g_bipartite = g['plays']
x = F.randn((3, 5)) x = F.randn((3, 5))
y = F.randn((2, 4)) y = F.randn((2, 4))
g.nodes['user'].data['h'] = x g.nodes['user'].data['h'] = x
...@@ -927,6 +930,33 @@ def test_subgraph(index_dtype): ...@@ -927,6 +930,33 @@ def test_subgraph(index_dtype):
sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]}) sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]})
_check_subgraph(g, sg2) _check_subgraph(g, sg2)
def _check_subgraph_single_ntype(g, sg):
assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([1, 2], F.int64))
assert F.array_equal(F.tensor(sg.edges['follows'].data[dgl.EID]),
F.tensor([1], F.int64))
assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h'][1:3])
assert F.array_equal(sg.edges['follows'].data['h'], g.edges['follows'].data['h'][1:2])
def _check_subgraph_single_etype(g, sg):
assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([0, 1], F.int64))
assert F.array_equal(F.tensor(sg.nodes['game'].data[dgl.NID]),
F.tensor([0], F.int64))
assert F.array_equal(F.tensor(sg.edges['plays'].data[dgl.EID]),
F.tensor([0, 1], F.int64))
sg1_graph = g_graph.subgraph([1, 2])
_check_subgraph_single_ntype(g_graph, sg1_graph)
sg2_bipartite = g_bipartite.edge_subgraph([0, 1])
_check_subgraph_single_etype(g_bipartite, sg2_bipartite)
def _check_typed_subgraph1(g, sg): def _check_typed_subgraph1(g, sg):
assert set(sg.ntypes) == {'user', 'game'} assert set(sg.ntypes) == {'user', 'game'}
assert set(sg.etypes) == {'follows', 'plays', 'wishes'} assert set(sg.etypes) == {'follows', 'plays', 'wishes'}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment