Unverified Commit 3e72c53a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug?] Fix #1563 (#1642)

parent ba2ee7bd
"""Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines
from collections import defaultdict
from collections.abc import Mapping
from contextlib import contextmanager
import copy
import networkx as nx
......@@ -1892,10 +1893,13 @@ class DGLHeteroGraph(object):
Parameters
----------
nodes : dict[str->list or iterable]
nodes : list or dict[str->list or iterable]
A dictionary mapping node types to node ID array for constructing
subgraph. All nodes must exist in the graph.
If the graph only has one node type, one can just specify a list,
tensor, or any iterable of node IDs intead.
Returns
-------
G : DGLHeteroGraph
......@@ -1952,7 +1956,11 @@ class DGLHeteroGraph(object):
--------
edge_subgraph
"""
check_same_dtype(self._idtype_str, nodes)
if not isinstance(nodes, Mapping):
assert len(self.ntypes) == 1, \
'need a dict of node type and IDs for graph with multiple node types'
nodes = {self.ntypes[0]: nodes}
check_idtype_dict(self._idtype_str, nodes)
induced_nodes = [utils.toindex(nodes.get(ntype, []), self._idtype_str)
for ntype in self.ntypes]
sgi = self._graph.node_subgraph(induced_nodes)
......@@ -1975,6 +1983,9 @@ class DGLHeteroGraph(object):
The edge types are characterized by triplets of
``(src type, etype, dst type)``.
If the graph only has one edge type, one can just specify a list,
tensor, or any iterable of edge IDs intead.
preserve_nodes : bool
Whether to preserve all nodes or not. If false, all nodes
without edges will be removed. (Default: False)
......@@ -2035,6 +2046,10 @@ class DGLHeteroGraph(object):
--------
subgraph
"""
if not isinstance(edges, Mapping):
assert len(self.canonical_etypes) == 1, \
'need a dict of edge type and IDs for graph with multiple edge types'
edges = {self.canonical_etypes[0]: edges}
check_idtype_dict(self._idtype_str, edges)
edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
induced_edges = [
......
......@@ -898,6 +898,9 @@ def test_transform(index_dtype):
@parametrize_dtype
def test_subgraph(index_dtype):
g = create_test_heterograph(index_dtype)
g_graph = g['follows']
g_bipartite = g['plays']
x = F.randn((3, 5))
y = F.randn((2, 4))
g.nodes['user'].data['h'] = x
......@@ -927,6 +930,33 @@ def test_subgraph(index_dtype):
sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]})
_check_subgraph(g, sg2)
def _check_subgraph_single_ntype(g, sg):
assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([1, 2], F.int64))
assert F.array_equal(F.tensor(sg.edges['follows'].data[dgl.EID]),
F.tensor([1], F.int64))
assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h'][1:3])
assert F.array_equal(sg.edges['follows'].data['h'], g.edges['follows'].data['h'][1:2])
def _check_subgraph_single_etype(g, sg):
assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([0, 1], F.int64))
assert F.array_equal(F.tensor(sg.nodes['game'].data[dgl.NID]),
F.tensor([0], F.int64))
assert F.array_equal(F.tensor(sg.edges['plays'].data[dgl.EID]),
F.tensor([0, 1], F.int64))
sg1_graph = g_graph.subgraph([1, 2])
_check_subgraph_single_ntype(g_graph, sg1_graph)
sg2_bipartite = g_bipartite.edge_subgraph([0, 1])
_check_subgraph_single_etype(g_bipartite, sg2_bipartite)
def _check_typed_subgraph1(g, sg):
assert set(sg.ntypes) == {'user', 'game'}
assert set(sg.etypes) == {'follows', 'plays', 'wishes'}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment