".github/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "dda5bfac2f7900684cc491b1d3a4aafca2eeccbf"
Unverified Commit 55072b4e authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature] dgl.subgraph support Bool mask tensor as index. (#1709)



* dgl.subgraph support using Bool mask tensor

* support numpy boo tensor and add some test

* Update

* Fix docstring
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 870da747
...@@ -2061,6 +2061,13 @@ class DGLHeteroGraph(object): ...@@ -2061,6 +2061,13 @@ class DGLHeteroGraph(object):
If the graph only has one node type, one can just specify a list, If the graph only has one node type, one can just specify a list,
tensor, or any iterable of node IDs intead. tensor, or any iterable of node IDs intead.
The node ID array can be either an interger tensor or a bool tensor.
When a bool tensor is used, it is automatically converted to
an interger tensor using the semantic of np.where(nodes_idx == True).
Note: When using bool tensor, only backend (torch, tensorflow, mxnet)
tensors are supported.
Returns Returns
------- -------
G : DGLHeteroGraph G : DGLHeteroGraph
...@@ -2095,6 +2102,14 @@ class DGLHeteroGraph(object): ...@@ -2095,6 +2102,14 @@ class DGLHeteroGraph(object):
num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2}, num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')]) metagraph=[('user', 'game'), ('user', 'user')])
Get subgraphs using boolean mask tensor.
>>> sub_g = g.subgraph({'user': th.tensor([False, True, True])})
>>> print(sub_g)
Graph(num_nodes={'user': 2, 'game': 0},
num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')])
Get the original node/edge indices. Get the original node/edge indices.
>>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph >>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
...@@ -2121,7 +2136,21 @@ class DGLHeteroGraph(object): ...@@ -2121,7 +2136,21 @@ class DGLHeteroGraph(object):
assert len(self.ntypes) == 1, \ assert len(self.ntypes) == 1, \
'need a dict of node type and IDs for graph with multiple node types' 'need a dict of node type and IDs for graph with multiple node types'
nodes = {self.ntypes[0]: nodes} nodes = {self.ntypes[0]: nodes}
check_idtype_dict(self._idtype_str, nodes)
for ntype, v in nodes.items():
if F.is_tensor(v):
# Check if the v is a bool tensor
if F.dtype(v) is F.data_type_dict['bool']:
assert len(F.shape(v)) == 1, \
"dgl.subgraph only support 1D tensor as ID array"
nodes_idx = F.nonzero_1d(v)
nodes[ntype] = F.astype(nodes_idx,
ty=F.data_type_dict[self._idtype_str])
else:
check_same_dtype(self._idtype_str, v)
else:
v = F.tensor(v, dtype=F.data_type_dict[self._idtype_str])
induced_nodes = [utils.toindex(nodes.get(ntype, []), self._idtype_str) induced_nodes = [utils.toindex(nodes.get(ntype, []), self._idtype_str)
for ntype in self.ntypes] for ntype in self.ntypes]
sgi = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
...@@ -2147,6 +2176,14 @@ class DGLHeteroGraph(object): ...@@ -2147,6 +2176,14 @@ class DGLHeteroGraph(object):
If the graph only has one edge type, one can just specify a list, If the graph only has one edge type, one can just specify a list,
tensor, or any iterable of edge IDs intead. tensor, or any iterable of edge IDs intead.
The edge ID array can be either an interger tensor or a bool tensor.
When a bool tensor is used, it is automatically converted to
an interger tensor using the semantic of np.where(edges_idx == True).
Note: When using bool tensor, only backend (torch, tensorflow, mxnet)
tensors are supported.
preserve_nodes : bool preserve_nodes : bool
Whether to preserve all nodes or not. If false, all nodes Whether to preserve all nodes or not. If false, all nodes
without edges will be removed. (Default: False) without edges will be removed. (Default: False)
...@@ -2185,6 +2222,14 @@ class DGLHeteroGraph(object): ...@@ -2185,6 +2222,14 @@ class DGLHeteroGraph(object):
num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2}, num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')]) metagraph=[('user', 'game'), ('user', 'user')])
Get subgraphs using boolean mask tensor.
>>> sub_g = g.edge_subgraph({('user', 'follows', 'user'): th.tensor([False, True, True]),
>>> ('user', 'plays', 'game'): th.tensor([False, False, True, False])})
>>> sub_g
Graph(num_nodes={'user': 2, 'game': 1},
num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')])
Get the original node/edge indices. Get the original node/edge indices.
>>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph >>> sub_g['follows'].ndata[dgl.NID] # Get the node indices in the raw graph
...@@ -2211,7 +2256,21 @@ class DGLHeteroGraph(object): ...@@ -2211,7 +2256,21 @@ class DGLHeteroGraph(object):
assert len(self.canonical_etypes) == 1, \ assert len(self.canonical_etypes) == 1, \
'need a dict of edge type and IDs for graph with multiple edge types' 'need a dict of edge type and IDs for graph with multiple edge types'
edges = {self.canonical_etypes[0]: edges} edges = {self.canonical_etypes[0]: edges}
check_idtype_dict(self._idtype_str, edges)
for etype, v in edges.items():
if F.is_tensor(v):
# Check if the v is a bool tensor
if F.dtype(v) is F.data_type_dict['bool']:
assert len(F.shape(v)) == 1, \
"dgl.edge_subgraph only support 1D tensor as ID array"
edges_idx = F.nonzero_1d(v)
edges[etype] = F.astype(edges_idx,
ty=F.data_type_dict[self._idtype_str])
else:
check_same_dtype(self._idtype_str, v)
else:
v = F.tensor(v, dtype=F.data_type_dict[self._idtype_str])
edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()} edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
induced_edges = [ induced_edges = [
utils.toindex(edges.get(canonical_etype, []), self._idtype_str) utils.toindex(edges.get(canonical_etype, []), self._idtype_str)
......
...@@ -1022,6 +1022,46 @@ def test_transform(index_dtype): ...@@ -1022,6 +1022,46 @@ def test_transform(index_dtype):
assert new_g.number_of_edges() == 2 assert new_g.number_of_edges() == 2
assert F.asnumpy(new_g.has_edges_between([0, 1], [1, 2])).all() assert F.asnumpy(new_g.has_edges_between([0, 1], [1, 2])).all()
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="MXNet doesn't support bool tensor")
@parametrize_dtype
def test_subgraph_mask(index_dtype):
g = create_test_heterograph(index_dtype)
g_graph = g['follows']
g_bipartite = g['plays']
x = F.randn((3, 5))
y = F.randn((2, 4))
g.nodes['user'].data['h'] = x
g.edges['follows'].data['h'] = y
def _check_subgraph(g, sg):
assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes
assert sg.canonical_etypes == g.canonical_etypes
assert F.array_equal(F.tensor(sg.nodes['user'].data[dgl.NID]),
F.tensor([1, 2], F.int64))
assert F.array_equal(F.tensor(sg.nodes['game'].data[dgl.NID]),
F.tensor([0], F.int64))
assert F.array_equal(F.tensor(sg.edges['follows'].data[dgl.EID]),
F.tensor([1], F.int64))
assert F.array_equal(F.tensor(sg.edges['plays'].data[dgl.EID]),
F.tensor([1], F.int64))
assert F.array_equal(F.tensor(sg.edges['wishes'].data[dgl.EID]),
F.tensor([1], F.int64))
assert sg.number_of_nodes('developer') == 0
assert sg.number_of_edges('develops') == 0
assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h'][1:3])
assert F.array_equal(sg.edges['follows'].data['h'], g.edges['follows'].data['h'][1:2])
# backend boo input tensor
sg1 = g.subgraph({'user': F.tensor([False, True, True], dtype=F.data_type_dict['bool']),
'game': F.tensor([True, False, False, False], dtype=F.data_type_dict['bool'])})
_check_subgraph(g, sg1)
sg2 = g.edge_subgraph({'follows': F.tensor([False, True], dtype=F.data_type_dict['bool']),
'plays': F.tensor([False, True, False, False], dtype=F.data_type_dict['bool']),
'wishes': F.tensor([False, True], dtype=F.data_type_dict['bool'])})
_check_subgraph(g, sg2)
@parametrize_dtype @parametrize_dtype
def test_subgraph(index_dtype): def test_subgraph(index_dtype):
g = create_test_heterograph(index_dtype) g = create_test_heterograph(index_dtype)
...@@ -1057,6 +1097,24 @@ def test_subgraph(index_dtype): ...@@ -1057,6 +1097,24 @@ def test_subgraph(index_dtype):
sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]}) sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]})
_check_subgraph(g, sg2) _check_subgraph(g, sg2)
# backend tensor input
sg1 = g.subgraph({'user': F.tensor([1, 2], dtype=F.data_type_dict[index_dtype]),
'game': F.tensor([0], dtype=F.data_type_dict[index_dtype])})
_check_subgraph(g, sg1)
sg2 = g.edge_subgraph({'follows': F.tensor([1], dtype=F.data_type_dict[index_dtype]),
'plays': F.tensor([1], dtype=F.data_type_dict[index_dtype]),
'wishes': F.tensor([1], dtype=F.data_type_dict[index_dtype])})
_check_subgraph(g, sg2)
# numpy input
sg1 = g.subgraph({'user': np.array([1, 2]),
'game': np.array([0])})
_check_subgraph(g, sg1)
sg2 = g.edge_subgraph({'follows': np.array([1]),
'plays': np.array([1]),
'wishes': np.array([1])})
_check_subgraph(g, sg2)
def _check_subgraph_single_ntype(g, sg, preserve_nodes=False): def _check_subgraph_single_ntype(g, sg, preserve_nodes=False):
assert sg.ntypes == g.ntypes assert sg.ntypes == g.ntypes
assert sg.etypes == g.etypes assert sg.etypes == g.etypes
...@@ -1871,7 +1929,8 @@ if __name__ == '__main__': ...@@ -1871,7 +1929,8 @@ if __name__ == '__main__':
# test_convert() # test_convert()
# test_to_device() # test_to_device()
# test_transform("int32") # test_transform("int32")
# test_subgraph() test_subgraph("int32")
test_subgraph_mask("int32")
# test_apply() # test_apply()
# test_level1() # test_level1()
# test_level2() # test_level2()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment