Unverified Commit 0c2a2ea1 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bugfix] Handle a Corner Case of Batching after Removing Nodes/Edges (#2465)



* Update

* Update

* Update
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 08671d81
...@@ -227,7 +227,8 @@ class BlockSampler(object): ...@@ -227,7 +227,8 @@ class BlockSampler(object):
# to the mapping from the new graph to the old frontier. # to the mapping from the new graph to the old frontier.
# So we need to test if located_eids is empty, and do the remapping ourselves. # So we need to test if located_eids is empty, and do the remapping ourselves.
if len(located_eids) > 0: if len(located_eids) > 0:
frontier = transform.remove_edges(frontier, located_eids) frontier = transform.remove_edges(
frontier, located_eids, store_ids=True)
frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID]) frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID])
else: else:
# (BarclayII) remove_edges only accepts removing one type of edges, # (BarclayII) remove_edges only accepts removing one type of edges,
...@@ -235,7 +236,8 @@ class BlockSampler(object): ...@@ -235,7 +236,8 @@ class BlockSampler(object):
new_eids = parent_eids.copy() new_eids = parent_eids.copy()
for k, v in located_eids.items(): for k, v in located_eids.items():
if len(v) > 0: if len(v) > 0:
frontier = transform.remove_edges(frontier, v, etype=k) frontier = transform.remove_edges(
frontier, v, etype=k, store_ids=True)
new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID]) new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID])
frontier.edata[EID] = new_eids frontier.edata[EID] = new_eids
......
...@@ -520,7 +520,7 @@ class DGLHeteroGraph(object): ...@@ -520,7 +520,7 @@ class DGLHeteroGraph(object):
self._edge_frames[etid].append(data) self._edge_frames[etid].append(data)
self._reset_cached_info() self._reset_cached_info()
def remove_edges(self, eids, etype=None): def remove_edges(self, eids, etype=None, store_ids=False):
r"""Remove multiple edges with the specified edge type r"""Remove multiple edges with the specified edge type
Nodes will not be removed. After removing edges, the rest Nodes will not be removed. After removing edges, the rest
...@@ -536,6 +536,10 @@ class DGLHeteroGraph(object): ...@@ -536,6 +536,10 @@ class DGLHeteroGraph(object):
etype : str or tuple of str, optional etype : str or tuple of str, optional
The type of the edges to remove. Can be omitted if there is The type of the edges to remove. Can be omitted if there is
only one edge type in the graph. only one edge type in the graph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Examples Examples
-------- --------
...@@ -602,12 +606,12 @@ class DGLHeteroGraph(object): ...@@ -602,12 +606,12 @@ class DGLHeteroGraph(object):
else: else:
edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype) edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype)
sub_g = self.edge_subgraph(edges, preserve_nodes=True) sub_g = self.edge_subgraph(edges, preserve_nodes=True, store_ids=store_ids)
self._graph = sub_g._graph self._graph = sub_g._graph
self._node_frames = sub_g._node_frames self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames self._edge_frames = sub_g._edge_frames
def remove_nodes(self, nids, ntype=None): def remove_nodes(self, nids, ntype=None, store_ids=False):
r"""Remove multiple nodes with the specified node type r"""Remove multiple nodes with the specified node type
Edges that connect to the nodes will be removed as well. After removing Edges that connect to the nodes will be removed as well. After removing
...@@ -623,6 +627,10 @@ class DGLHeteroGraph(object): ...@@ -623,6 +627,10 @@ class DGLHeteroGraph(object):
ntype : str, optional ntype : str, optional
The type of the nodes to remove. Can be omitted if there is The type of the nodes to remove. Can be omitted if there is
only one node type in the graph. only one node type in the graph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Examples Examples
-------- --------
...@@ -694,7 +702,7 @@ class DGLHeteroGraph(object): ...@@ -694,7 +702,7 @@ class DGLHeteroGraph(object):
nodes[c_ntype] = self.nodes(c_ntype) nodes[c_ntype] = self.nodes(c_ntype)
# node_subgraph # node_subgraph
sub_g = self.subgraph(nodes) sub_g = self.subgraph(nodes, store_ids=store_ids)
self._graph = sub_g._graph self._graph = sub_g._graph
self._node_frames = sub_g._node_frames self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames self._edge_frames = sub_g._edge_frames
......
...@@ -17,7 +17,7 @@ from . import utils ...@@ -17,7 +17,7 @@ from . import utils
__all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph', __all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph',
'in_subgraph', 'out_subgraph'] 'in_subgraph', 'out_subgraph']
def node_subgraph(graph, nodes): def node_subgraph(graph, nodes, store_ids=True):
"""Return a subgraph induced on the given nodes. """Return a subgraph induced on the given nodes.
A node-induced subgraph is a subset of the nodes of a graph together with A node-induced subgraph is a subset of the nodes of a graph together with
...@@ -29,9 +29,6 @@ def node_subgraph(graph, nodes): ...@@ -29,9 +29,6 @@ def node_subgraph(graph, nodes):
* Copy the features of the extracted nodes and edges to the resulting graph. * Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed. The copy is *lazy* and incurs data movement only when needed.
* Store the IDs of the extracted nodes and edges in the ``ndata`` and ``edata``
of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations them as the resulting graph. Thus, the resulting graph has the same set of relations
as the input one. as the input one.
...@@ -52,6 +49,10 @@ def node_subgraph(graph, nodes): ...@@ -52,6 +49,10 @@ def node_subgraph(graph, nodes):
If the graph is homogeneous, one can directly pass the above formats. If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being node types Otherwise, the argument must be a dictionary with keys being node types
and values being the nodes. and values being the nodes.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Returns Returns
------- -------
...@@ -137,11 +138,11 @@ def node_subgraph(graph, nodes): ...@@ -137,11 +138,11 @@ def node_subgraph(graph, nodes):
induced_nodes.append(_process_nodes(ntype, nids)) induced_nodes.append(_process_nodes(ntype, nids))
sgi = graph._graph.node_subgraph(induced_nodes) sgi = graph._graph.node_subgraph(induced_nodes)
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges) return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids)
DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph) DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph)
def edge_subgraph(graph, edges, preserve_nodes=False): def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
"""Return a subgraph induced on the given edges. """Return a subgraph induced on the given edges.
An edge-induced subgraph is equivalent to creating a new graph An edge-induced subgraph is equivalent to creating a new graph
...@@ -153,9 +154,6 @@ def edge_subgraph(graph, edges, preserve_nodes=False): ...@@ -153,9 +154,6 @@ def edge_subgraph(graph, edges, preserve_nodes=False):
* Copy the features of the extracted nodes and edges to the resulting graph. * Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed. The copy is *lazy* and incurs data movement only when needed.
* Store the IDs of the extracted nodes and edges in the ``ndata`` and ``edata``
of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations them as the resulting graph. Thus, the resulting graph has the same set of relations
as the input one. as the input one.
...@@ -177,8 +175,12 @@ def edge_subgraph(graph, edges, preserve_nodes=False): ...@@ -177,8 +175,12 @@ def edge_subgraph(graph, edges, preserve_nodes=False):
Otherwise, the argument must be a dictionary with keys being edge types Otherwise, the argument must be a dictionary with keys being edge types
and values being the nodes. and values being the nodes.
preserve_nodes : bool, optional preserve_nodes : bool, optional
If true, do not relabel the incident nodes and remove the isolated nodes If True, do not relabel the incident nodes and remove the isolated nodes
in the extracted subgraph. (Default: False) in the extracted subgraph. (Default: False)
store_ids : bool, optional
If True, it will store the IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Returns Returns
------- -------
...@@ -278,7 +280,7 @@ def edge_subgraph(graph, edges, preserve_nodes=False): ...@@ -278,7 +280,7 @@ def edge_subgraph(graph, edges, preserve_nodes=False):
induced_edges.append(_process_edges(cetype, eids)) induced_edges.append(_process_edges(cetype, eids))
sgi = graph._graph.edge_subgraph(induced_edges, preserve_nodes) sgi = graph._graph.edge_subgraph(induced_edges, preserve_nodes)
induced_nodes = sgi.induced_nodes induced_nodes = sgi.induced_nodes
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges) return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids)
DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph) DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph)
...@@ -632,7 +634,7 @@ DGLHeteroGraph.edge_type_subgraph = utils.alias_func(edge_type_subgraph) ...@@ -632,7 +634,7 @@ DGLHeteroGraph.edge_type_subgraph = utils.alias_func(edge_type_subgraph)
#################### Internal functions #################### #################### Internal functions ####################
def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges): def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges, store_ids=True):
"""Internal function to create a subgraph. """Internal function to create a subgraph.
Parameters Parameters
...@@ -647,14 +649,18 @@ def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges): ...@@ -647,14 +649,18 @@ def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges):
induced_edges : list[Tensor] or None induced_edges : list[Tensor] or None
Induced edge IDs. Will store it as the dgl.EID ndata unless it Induced edge IDs. Will store it as the dgl.EID ndata unless it
is None, which means the induced edge IDs are the same as the parent edge IDs. is None, which means the induced edge IDs are the same as the parent edge IDs.
store_ids : bool
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Returns Returns
------- -------
DGLGraph DGLGraph
Graph Graph
""" """
node_frames = utils.extract_node_subframes(parent, induced_nodes) node_frames = utils.extract_node_subframes(parent, induced_nodes, store_ids)
edge_frames = utils.extract_edge_subframes(parent, induced_edges) edge_frames = utils.extract_edge_subframes(parent, induced_edges, store_ids)
hsg = DGLHeteroGraph(sgi.graph, parent.ntypes, parent.etypes) hsg = DGLHeteroGraph(sgi.graph, parent.ntypes, parent.etypes)
utils.set_new_frames(hsg, node_frames=node_frames, edge_frames=edge_frames) utils.set_new_frames(hsg, node_frames=node_frames, edge_frames=edge_frames)
return hsg return hsg
......
...@@ -1116,7 +1116,7 @@ def add_edges(g, u, v, data=None, etype=None): ...@@ -1116,7 +1116,7 @@ def add_edges(g, u, v, data=None, etype=None):
g.add_edges(u, v, data=data, etype=etype) g.add_edges(u, v, data=data, etype=etype)
return g return g
def remove_edges(g, eids, etype=None): def remove_edges(g, eids, etype=None, store_ids=False):
r"""Remove the specified edges and return a new graph. r"""Remove the specified edges and return a new graph.
Also delete the features of the edges. The edges must exist in the graph. Also delete the features of the edges. The edges must exist in the graph.
...@@ -1135,6 +1135,10 @@ def remove_edges(g, eids, etype=None): ...@@ -1135,6 +1135,10 @@ def remove_edges(g, eids, etype=None):
triplet format in the graph. triplet format in the graph.
Can be omitted if the graph has only one type of edges. Can be omitted if the graph has only one type of edges.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Return Return
------ ------
...@@ -1179,11 +1183,11 @@ def remove_edges(g, eids, etype=None): ...@@ -1179,11 +1183,11 @@ def remove_edges(g, eids, etype=None):
remove_nodes remove_nodes
""" """
g = g.clone() g = g.clone()
g.remove_edges(eids, etype=etype) g.remove_edges(eids, etype=etype, store_ids=store_ids)
return g return g
def remove_nodes(g, nids, ntype=None): def remove_nodes(g, nids, ntype=None, store_ids=False):
r"""Remove the specified nodes and return a new graph. r"""Remove the specified nodes and return a new graph.
Also delete the features. Edges that connect from/to the nodes will be Also delete the features. Edges that connect from/to the nodes will be
...@@ -1197,6 +1201,10 @@ def remove_nodes(g, nids, ntype=None): ...@@ -1197,6 +1201,10 @@ def remove_nodes(g, nids, ntype=None):
ntype : str, optional ntype : str, optional
The type of the nodes to remove. Can be omitted if there is The type of the nodes to remove. Can be omitted if there is
only one node type in the graph. only one node type in the graph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Return Return
------ ------
...@@ -1247,7 +1255,7 @@ def remove_nodes(g, nids, ntype=None): ...@@ -1247,7 +1255,7 @@ def remove_nodes(g, nids, ntype=None):
remove_edges remove_edges
""" """
g = g.clone() g = g.clone()
g.remove_nodes(nids, ntype=ntype) g.remove_nodes(nids, ntype=ntype, store_ids=store_ids)
return g return g
def add_self_loop(g, etype=None): def add_self_loop(g, etype=None):
......
...@@ -754,7 +754,7 @@ def relabel(x): ...@@ -754,7 +754,7 @@ def relabel(x):
F.copy_to(F.arange(0, len(unique_x), dtype), ctx)) F.copy_to(F.arange(0, len(unique_x), dtype), ctx))
return unique_x, old_to_new return unique_x, old_to_new
def extract_node_subframes(graph, nodes): def extract_node_subframes(graph, nodes, store_ids=True):
"""Extract node features of the given nodes from :attr:`graph` """Extract node features of the given nodes from :attr:`graph`
and return them in frames. and return them in frames.
...@@ -769,8 +769,10 @@ def extract_node_subframes(graph, nodes): ...@@ -769,8 +769,10 @@ def extract_node_subframes(graph, nodes):
The graph to extract features from. The graph to extract features from.
nodes : list[Tensor] or None nodes : list[Tensor] or None
Node IDs. If not None, the list length must be equal to the number of node types Node IDs. If not None, the list length must be equal to the number of node types
in the graph. The returned frames store the node IDs in the ``dgl.NID`` field in the graph. If None, the whole frame is shallow-copied.
unless it is None, which means the whole frame is shallow-copied. store_ids : bool
If True, the returned frames will store :attr:`nodes` in the ``dgl.NID`` field
unless it is None.
Returns Returns
------- -------
...@@ -783,7 +785,8 @@ def extract_node_subframes(graph, nodes): ...@@ -783,7 +785,8 @@ def extract_node_subframes(graph, nodes):
node_frames = [] node_frames = []
for i, ind_nodes in enumerate(nodes): for i, ind_nodes in enumerate(nodes):
subf = graph._node_frames[i].subframe(ind_nodes) subf = graph._node_frames[i].subframe(ind_nodes)
subf[NID] = ind_nodes if store_ids:
subf[NID] = ind_nodes
node_frames.append(subf) node_frames.append(subf)
return node_frames return node_frames
...@@ -823,7 +826,7 @@ def extract_node_subframes_for_block(graph, srcnodes, dstnodes): ...@@ -823,7 +826,7 @@ def extract_node_subframes_for_block(graph, srcnodes, dstnodes):
node_frames.append(subf) node_frames.append(subf)
return node_frames return node_frames
def extract_edge_subframes(graph, edges): def extract_edge_subframes(graph, edges, store_ids=True):
"""Extract edge features of the given edges from :attr:`graph` """Extract edge features of the given edges from :attr:`graph`
and return them in frames. and return them in frames.
...@@ -838,8 +841,10 @@ def extract_edge_subframes(graph, edges): ...@@ -838,8 +841,10 @@ def extract_edge_subframes(graph, edges):
The graph to extract features from. The graph to extract features from.
edges : list[Tensor] or None edges : list[Tensor] or None
Edge IDs. If not None, the list length must be equal to the number of edge types Edge IDs. If not None, the list length must be equal to the number of edge types
in the graph. The returned frames store the edge IDs in the ``dgl.NID`` field in the graph. If None, the whole frame is shallow-copied.
unless it is None, which means the whole frame is shallow-copied. store_ids : bool
If True, the returned frames will store :attr:`edges` in the ``dgl.EID`` field
unless it is None.
Returns Returns
------- -------
...@@ -852,7 +857,8 @@ def extract_edge_subframes(graph, edges): ...@@ -852,7 +857,8 @@ def extract_edge_subframes(graph, edges):
edge_frames = [] edge_frames = []
for i, ind_edges in enumerate(edges): for i, ind_edges in enumerate(edges):
subf = graph._edge_frames[i].subframe(ind_edges) subf = graph._edge_frames[i].subframe(ind_edges)
subf[EID] = ind_edges if store_ids:
subf[EID] = ind_edges
edge_frames.append(subf) edge_frames.append(subf)
return edge_frames return edge_frames
......
import os
import backend as F import backend as F
import networkx as nx
import numpy as np import numpy as np
import dgl import dgl
from test_utils import parametrize_dtype from test_utils import parametrize_dtype
...@@ -18,6 +16,8 @@ def test_node_removal(idtype): ...@@ -18,6 +16,8 @@ def test_node_removal(idtype):
g.remove_nodes(range(4, 7)) g.remove_nodes(range(4, 7))
assert g.number_of_nodes() == 7 assert g.number_of_nodes() == 7
assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9])) assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9]))
assert dgl.NID not in g.ndata
assert dgl.EID not in g.edata
# add nodes # add nodes
g.add_nodes(3) g.add_nodes(3)
...@@ -25,9 +25,11 @@ def test_node_removal(idtype): ...@@ -25,9 +25,11 @@ def test_node_removal(idtype):
assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0])) assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0]))
# remove nodes # remove nodes
g.remove_nodes(range(1, 4)) g.remove_nodes(range(1, 4), store_ids=True)
assert g.number_of_nodes() == 7 assert g.number_of_nodes() == 7
assert F.array_equal(g.ndata['id'], F.tensor([0, 7, 8, 9, 0, 0, 0])) assert F.array_equal(g.ndata['id'], F.tensor([0, 7, 8, 9, 0, 0, 0]))
assert dgl.NID in g.ndata
assert dgl.EID in g.edata
@parametrize_dtype @parametrize_dtype
def test_multigraph_node_removal(idtype): def test_multigraph_node_removal(idtype):
...@@ -99,6 +101,8 @@ def test_edge_removal(idtype): ...@@ -99,6 +101,8 @@ def test_edge_removal(idtype):
assert g.number_of_nodes() == 5 assert g.number_of_nodes() == 5
assert g.number_of_edges() == 18 assert g.number_of_edges() == 18
assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25)))) assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25))))
assert dgl.NID not in g.ndata
assert dgl.EID not in g.edata
# add edges # add edges
g.add_edge(3, 3) g.add_edge(3, 3)
...@@ -107,10 +111,12 @@ def test_edge_removal(idtype): ...@@ -107,10 +111,12 @@ def test_edge_removal(idtype):
assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25)) + [0])) assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25)) + [0]))
# remove edges # remove edges
g.remove_edges(range(2, 10)) g.remove_edges(range(2, 10), store_ids=True)
assert g.number_of_nodes() == 5 assert g.number_of_nodes() == 5
assert g.number_of_edges() == 11 assert g.number_of_edges() == 11
assert F.array_equal(g.edata['id'], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0])) assert F.array_equal(g.edata['id'], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0]))
assert dgl.NID in g.ndata
assert dgl.EID in g.edata
@parametrize_dtype @parametrize_dtype
def test_node_and_edge_removal(idtype): def test_node_and_edge_removal(idtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment