Unverified Commit 0c2a2ea1 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bugfix] Handle a Corner Case of Batching after Removing Nodes/Edges (#2465)



* Update

* Update

* Update
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 08671d81
......@@ -227,7 +227,8 @@ class BlockSampler(object):
# to the mapping from the new graph to the old frontier.
# So we need to test if located_eids is empty, and do the remapping ourselves.
if len(located_eids) > 0:
frontier = transform.remove_edges(frontier, located_eids)
frontier = transform.remove_edges(
frontier, located_eids, store_ids=True)
frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID])
else:
# (BarclayII) remove_edges only accepts removing one type of edges,
......@@ -235,7 +236,8 @@ class BlockSampler(object):
new_eids = parent_eids.copy()
for k, v in located_eids.items():
if len(v) > 0:
frontier = transform.remove_edges(frontier, v, etype=k)
frontier = transform.remove_edges(
frontier, v, etype=k, store_ids=True)
new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID])
frontier.edata[EID] = new_eids
......
......@@ -520,7 +520,7 @@ class DGLHeteroGraph(object):
self._edge_frames[etid].append(data)
self._reset_cached_info()
def remove_edges(self, eids, etype=None):
def remove_edges(self, eids, etype=None, store_ids=False):
r"""Remove multiple edges with the specified edge type
Nodes will not be removed. After removing edges, the rest
......@@ -536,6 +536,10 @@ class DGLHeteroGraph(object):
etype : str or tuple of str, optional
The type of the edges to remove. Can be omitted if there is
only one edge type in the graph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Examples
--------
......@@ -602,12 +606,12 @@ class DGLHeteroGraph(object):
else:
edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype)
sub_g = self.edge_subgraph(edges, preserve_nodes=True)
sub_g = self.edge_subgraph(edges, preserve_nodes=True, store_ids=store_ids)
self._graph = sub_g._graph
self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames
def remove_nodes(self, nids, ntype=None):
def remove_nodes(self, nids, ntype=None, store_ids=False):
r"""Remove multiple nodes with the specified node type
Edges that connect to the nodes will be removed as well. After removing
......@@ -623,6 +627,10 @@ class DGLHeteroGraph(object):
ntype : str, optional
The type of the nodes to remove. Can be omitted if there is
only one node type in the graph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Examples
--------
......@@ -694,7 +702,7 @@ class DGLHeteroGraph(object):
nodes[c_ntype] = self.nodes(c_ntype)
# node_subgraph
sub_g = self.subgraph(nodes)
sub_g = self.subgraph(nodes, store_ids=store_ids)
self._graph = sub_g._graph
self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames
......
......@@ -17,7 +17,7 @@ from . import utils
__all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph',
'in_subgraph', 'out_subgraph']
def node_subgraph(graph, nodes):
def node_subgraph(graph, nodes, store_ids=True):
"""Return a subgraph induced on the given nodes.
A node-induced subgraph is a subset of the nodes of a graph together with
......@@ -29,9 +29,6 @@ def node_subgraph(graph, nodes):
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
* Store the IDs of the extracted nodes and edges in the ``ndata`` and ``edata``
of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations
as the input one.
......@@ -52,6 +49,10 @@ def node_subgraph(graph, nodes):
If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being node types
and values being the nodes.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Returns
-------
......@@ -137,11 +138,11 @@ def node_subgraph(graph, nodes):
induced_nodes.append(_process_nodes(ntype, nids))
sgi = graph._graph.node_subgraph(induced_nodes)
induced_edges = sgi.induced_edges
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges)
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids)
DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph)
def edge_subgraph(graph, edges, preserve_nodes=False):
def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
"""Return a subgraph induced on the given edges.
An edge-induced subgraph is equivalent to creating a new graph
......@@ -153,9 +154,6 @@ def edge_subgraph(graph, edges, preserve_nodes=False):
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
* Store the IDs of the extracted nodes and edges in the ``ndata`` and ``edata``
of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations
as the input one.
......@@ -177,8 +175,12 @@ def edge_subgraph(graph, edges, preserve_nodes=False):
Otherwise, the argument must be a dictionary with keys being edge types
and values being the nodes.
preserve_nodes : bool, optional
If true, do not relabel the incident nodes and remove the isolated nodes
If True, do not relabel the incident nodes and remove the isolated nodes
in the extracted subgraph. (Default: False)
store_ids : bool, optional
If True, it will store the IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Returns
-------
......@@ -278,7 +280,7 @@ def edge_subgraph(graph, edges, preserve_nodes=False):
induced_edges.append(_process_edges(cetype, eids))
sgi = graph._graph.edge_subgraph(induced_edges, preserve_nodes)
induced_nodes = sgi.induced_nodes
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges)
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids)
DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph)
......@@ -632,7 +634,7 @@ DGLHeteroGraph.edge_type_subgraph = utils.alias_func(edge_type_subgraph)
#################### Internal functions ####################
def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges):
def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges, store_ids=True):
"""Internal function to create a subgraph.
Parameters
......@@ -647,14 +649,18 @@ def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges):
induced_edges : list[Tensor] or None
Induced edge IDs. Will store it as the dgl.EID ndata unless it
is None, which means the induced edge IDs are the same as the parent edge IDs.
store_ids : bool
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Returns
-------
DGLGraph
Graph
"""
node_frames = utils.extract_node_subframes(parent, induced_nodes)
edge_frames = utils.extract_edge_subframes(parent, induced_edges)
node_frames = utils.extract_node_subframes(parent, induced_nodes, store_ids)
edge_frames = utils.extract_edge_subframes(parent, induced_edges, store_ids)
hsg = DGLHeteroGraph(sgi.graph, parent.ntypes, parent.etypes)
utils.set_new_frames(hsg, node_frames=node_frames, edge_frames=edge_frames)
return hsg
......
......@@ -1116,7 +1116,7 @@ def add_edges(g, u, v, data=None, etype=None):
g.add_edges(u, v, data=data, etype=etype)
return g
def remove_edges(g, eids, etype=None):
def remove_edges(g, eids, etype=None, store_ids=False):
r"""Remove the specified edges and return a new graph.
Also delete the features of the edges. The edges must exist in the graph.
......@@ -1135,6 +1135,10 @@ def remove_edges(g, eids, etype=None):
triplet format in the graph.
Can be omitted if the graph has only one type of edges.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Return
------
......@@ -1179,11 +1183,11 @@ def remove_edges(g, eids, etype=None):
remove_nodes
"""
g = g.clone()
g.remove_edges(eids, etype=etype)
g.remove_edges(eids, etype=etype, store_ids=store_ids)
return g
def remove_nodes(g, nids, ntype=None):
def remove_nodes(g, nids, ntype=None, store_ids=False):
r"""Remove the specified nodes and return a new graph.
Also delete the features. Edges that connect from/to the nodes will be
......@@ -1197,6 +1201,10 @@ def remove_nodes(g, nids, ntype=None):
ntype : str, optional
The type of the nodes to remove. Can be omitted if there is
only one node type in the graph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
Return
------
......@@ -1247,7 +1255,7 @@ def remove_nodes(g, nids, ntype=None):
remove_edges
"""
g = g.clone()
g.remove_nodes(nids, ntype=ntype)
g.remove_nodes(nids, ntype=ntype, store_ids=store_ids)
return g
def add_self_loop(g, etype=None):
......
......@@ -754,7 +754,7 @@ def relabel(x):
F.copy_to(F.arange(0, len(unique_x), dtype), ctx))
return unique_x, old_to_new
def extract_node_subframes(graph, nodes):
def extract_node_subframes(graph, nodes, store_ids=True):
"""Extract node features of the given nodes from :attr:`graph`
and return them in frames.
......@@ -769,8 +769,10 @@ def extract_node_subframes(graph, nodes):
The graph to extract features from.
nodes : list[Tensor] or None
Node IDs. If not None, the list length must be equal to the number of node types
in the graph. The returned frames store the node IDs in the ``dgl.NID`` field
unless it is None, which means the whole frame is shallow-copied.
in the graph. If None, the whole frame is shallow-copied.
store_ids : bool
If True, the returned frames will store :attr:`nodes` in the ``dgl.NID`` field
unless it is None.
Returns
-------
......@@ -783,7 +785,8 @@ def extract_node_subframes(graph, nodes):
node_frames = []
for i, ind_nodes in enumerate(nodes):
subf = graph._node_frames[i].subframe(ind_nodes)
subf[NID] = ind_nodes
if store_ids:
subf[NID] = ind_nodes
node_frames.append(subf)
return node_frames
......@@ -823,7 +826,7 @@ def extract_node_subframes_for_block(graph, srcnodes, dstnodes):
node_frames.append(subf)
return node_frames
def extract_edge_subframes(graph, edges):
def extract_edge_subframes(graph, edges, store_ids=True):
"""Extract edge features of the given edges from :attr:`graph`
and return them in frames.
......@@ -838,8 +841,10 @@ def extract_edge_subframes(graph, edges):
The graph to extract features from.
edges : list[Tensor] or None
Edge IDs. If not None, the list length must be equal to the number of edge types
in the graph. The returned frames store the edge IDs in the ``dgl.NID`` field
unless it is None, which means the whole frame is shallow-copied.
in the graph. If None, the whole frame is shallow-copied.
store_ids : bool
If True, the returned frames will store :attr:`edges` in the ``dgl.EID`` field
unless it is None.
Returns
-------
......@@ -852,7 +857,8 @@ def extract_edge_subframes(graph, edges):
edge_frames = []
for i, ind_edges in enumerate(edges):
subf = graph._edge_frames[i].subframe(ind_edges)
subf[EID] = ind_edges
if store_ids:
subf[EID] = ind_edges
edge_frames.append(subf)
return edge_frames
......
import os
import backend as F
import networkx as nx
import numpy as np
import dgl
from test_utils import parametrize_dtype
......@@ -18,6 +16,8 @@ def test_node_removal(idtype):
g.remove_nodes(range(4, 7))
assert g.number_of_nodes() == 7
assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9]))
assert dgl.NID not in g.ndata
assert dgl.EID not in g.edata
# add nodes
g.add_nodes(3)
......@@ -25,9 +25,11 @@ def test_node_removal(idtype):
assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0]))
# remove nodes
g.remove_nodes(range(1, 4))
g.remove_nodes(range(1, 4), store_ids=True)
assert g.number_of_nodes() == 7
assert F.array_equal(g.ndata['id'], F.tensor([0, 7, 8, 9, 0, 0, 0]))
assert dgl.NID in g.ndata
assert dgl.EID in g.edata
@parametrize_dtype
def test_multigraph_node_removal(idtype):
......@@ -99,6 +101,8 @@ def test_edge_removal(idtype):
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 18
assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25))))
assert dgl.NID not in g.ndata
assert dgl.EID not in g.edata
# add edges
g.add_edge(3, 3)
......@@ -107,10 +111,12 @@ def test_edge_removal(idtype):
assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25)) + [0]))
# remove edges
g.remove_edges(range(2, 10))
g.remove_edges(range(2, 10), store_ids=True)
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 11
assert F.array_equal(g.edata['id'], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0]))
assert dgl.NID in g.ndata
assert dgl.EID in g.edata
@parametrize_dtype
def test_node_and_edge_removal(idtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment