Unverified Commit ff519f98 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[API] Standardize Subgraph APIs (#2929)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Fix

* Update

* Fix subgraph tests

* Capture stdout for distributed test

* Capture stdout for distributed test

* Update

* Update

* Update

* Update subgraph.cc
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-17.us-west-2.compute.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent fc69f6ce
......@@ -439,7 +439,7 @@ all edge types, so that it can work on heterogeneous graphs as well.
# Return a new graph with the same nodes as the original graph as a
# frontier
frontier = dgl.edge_subgraph(new_edge_masks, preserve_nodes=True)
frontier = dgl.edge_subgraph(new_edge_masks, relabel_nodes=False)
return frontier
def __len__(self):
......
......@@ -381,7 +381,7 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所
new_edges_masks[etype] = edge_mask.bool()
# 返回一个与初始图有相同节点的图作为边界
frontier = dgl.edge_subgraph(new_edge_masks, preserve_nodes=True)
frontier = dgl.edge_subgraph(new_edge_masks, relabel_nodes=False)
return frontier
def __len__(self):
......
......@@ -18,7 +18,7 @@ def drop_edge(graph, drop_prob):
masks = th.bernoulli(1 - mask_rates)
edge_idx = masks.nonzero().squeeze(1)
sg = dgl.edge_subgraph(graph, edge_idx, preserve_nodes=True)
sg = dgl.edge_subgraph(graph, edge_idx, relabel_nodes=False)
return sg
......
......@@ -58,7 +58,7 @@ class RGAT(nn.Module):
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, preserve_nodes=True)
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, relabel_nodes=False)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels)
x = self.norms[i](x_skip)
x = F.elu(x)
......
......@@ -61,7 +61,7 @@ class RGAT(nn.Module):
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, preserve_nodes=True)
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, relabel_nodes=False)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels)
x = self.norms[i](x_skip)
x = F.elu(x)
......
......@@ -31,10 +31,7 @@ def train_test_split_by_time(df, timestamp, user):
def build_train_graph(g, train_indices, utype, itype, etype, etype_rev):
train_g = g.edge_subgraph(
{etype: train_indices, etype_rev: train_indices},
preserve_nodes=True)
# remove the induced node IDs - should be assigned by model instead
del train_g.nodes[utype].data[dgl.NID]
del train_g.nodes[itype].data[dgl.NID]
relabel_nodes=False)
# copy features
for ntype in g.ntypes:
......
......@@ -174,7 +174,7 @@ class TemporalEdgeCollator(EdgeCollator):
def _collate_with_negative_sampling(self, items):
items = _prepare_tensor(self.g_sampling, items, 'items', False)
# Here node id will not change
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True)
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[dgl.EID]
neg_srcdst_raw = self.negative_sampler(self.g, items)
......@@ -546,7 +546,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
def _collate_with_negative_sampling(self, items):
items = _prepare_tensor(self.g_sampling, items, 'items', False)
# Here node id will not change
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True)
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[dgl.EID]
neg_srcdst_raw = self.negative_sampler(self.g, items)
......
......@@ -99,7 +99,7 @@ def dgl_main():
# create train graph
train_edge_idx = torch.tensor(train_edge_idx).to(device)
train_graph = dgl.edge_subgraph(graph, train_edge_idx, preserve_nodes=True)
train_graph = dgl.edge_subgraph(graph, train_edge_idx, relabel_nodes=False)
train_graph = train_graph.to(device)
adj = train_graph.adjacency_matrix().to_dense().to(device)
......
......@@ -685,19 +685,23 @@ HeteroGraphPtr CreateFromCSC(
* \brief Extract the subgraph of the in edges of the given nodes.
* \param graph Graph
* \param nodes Node IDs of each type
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest ones
* \return Subgraph containing only the in edges. The returned graph has the same
* schema as the original one.
*/
HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& nodes);
HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, bool relabel_nodes = false);
/*!
* \brief Extract the subgraph of the out edges of the given nodes.
* \param graph Graph
* \param nodes Node IDs of each type
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest ones
* \return Subgraph containing only the out edges. The returned graph has the same
* schema as the original one.
*/
HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& nodes);
HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, bool relabel_nodes = false);
/*!
* \brief Joint union multiple graphs into one graph.
......
......@@ -3563,7 +3563,7 @@ class DGLGraph(DGLBaseGraph):
tensor([0, 1, 4])
>>> SG.parent_eid
tensor([0, 4])
>>> SG = G.edge_subgraph([0, 4], preserve_nodes=True)
>>> SG = G.edge_subgraph([0, 4], relabel_nodes=False)
>>> SG.nodes()
tensor([0, 1, 2, 3, 4])
>>> SG.edges()
......
......@@ -422,13 +422,13 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>> # build train_g
>>> train_edges = train_set
>>> train_g = g.edge_subgraph(train_edges,
preserve_nodes=True)
relabel_nodes=False)
>>> train_g.edata['e_type'] = e_type[train_edges];
>>>
>>> # build val_g
>>> val_edges = th.cat([train_edges, val_edges])
>>> val_g = g.edge_subgraph(val_edges,
preserve_nodes=True)
relabel_nodes=False)
>>> val_g.edata['e_type'] = e_type[val_edges];
>>>
>>> # Train, Validation and Test
......@@ -558,13 +558,13 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> # build train_g
>>> train_edges = train_set
>>> train_g = g.edge_subgraph(train_edges,
preserve_nodes=True)
relabel_nodes=False)
>>> train_g.edata['e_type'] = e_type[train_edges];
>>>
>>> # build val_g
>>> val_edges = th.cat([train_edges, val_edges])
>>> val_g = g.edge_subgraph(val_edges,
preserve_nodes=True)
relabel_nodes=False)
>>> val_g.edata['e_type'] = e_type[val_edges];
>>>
>>> # Train, Validation and Test
......@@ -694,13 +694,13 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> # build train_g
>>> train_edges = train_set
>>> train_g = g.edge_subgraph(train_edges,
preserve_nodes=True)
relabel_nodes=False)
>>> train_g.edata['e_type'] = e_type[train_edges];
>>>
>>> # build val_g
>>> val_edges = th.cat([train_edges, val_edges])
>>> val_g = g.edge_subgraph(val_edges,
preserve_nodes=True)
relabel_nodes=False)
>>> val_g.edata['e_type'] = e_type[val_edges];
>>>
>>> # Train, Validation and Test
......
......@@ -696,7 +696,7 @@ class EdgeCollator(Collator):
else:
items = _prepare_tensor(self.g_sampling, items, 'items', self._is_distributed)
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True)
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[EID]
neg_srcdst = self.negative_sampler(self.g, items)
......
......@@ -622,7 +622,7 @@ class DGLHeteroGraph(object):
else:
edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype)
sub_g = self.edge_subgraph(edges, preserve_nodes=True, store_ids=store_ids)
sub_g = self.edge_subgraph(edges, relabel_nodes=False, store_ids=store_ids)
self._graph = sub_g._graph
self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames
......@@ -4294,7 +4294,7 @@ class DGLHeteroGraph(object):
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
if core.is_builtin(func):
if not is_all(eid):
g = g.edge_subgraph(eid, preserve_nodes=True)
g = g.edge_subgraph(eid, relabel_nodes=False)
edata = core.invoke_gsddmm(g, func)
else:
edata = core.invoke_edge_udf(g, eid, etype, func)
......
......@@ -843,7 +843,7 @@ class HeteroGraphIndex(ObjectBase):
raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
return inc, shuffle_idx
def node_subgraph(self, induced_nodes):
def node_subgraph(self, induced_nodes, relabel_nodes):
"""Return the induced node subgraph.
Parameters
......@@ -851,6 +851,9 @@ class HeteroGraphIndex(ObjectBase):
induced_nodes : list of utils.Index
Induced nodes. The length should be equal to the number of
node types in this heterograph.
relabel_nodes : bool
If True, the extracted subgraph will only have the nodes in the specified node set
and it will relabel the nodes in order.
Returns
-------
......@@ -858,7 +861,7 @@ class HeteroGraphIndex(ObjectBase):
The subgraph index.
"""
vids = [F.to_dgl_nd(nodes) for nodes in induced_nodes]
return _CAPI_DGLHeteroVertexSubgraph(self, vids)
return _CAPI_DGLHeteroVertexSubgraph(self, vids, relabel_nodes)
def edge_subgraph(self, induced_edges, preserve_nodes):
"""Return the induced edge subgraph.
......
......@@ -342,7 +342,7 @@ class RelGraphConv(nn.Module):
# list, where each element is the number of edges of the type.
# Sort the graph based on the etypes
sorted_etypes, index = th.sort(etypes)
g = edge_subgraph(g, index, preserve_nodes=True)
g = edge_subgraph(g, index, relabel_nodes=False)
# Create a new etypes to be an integer list of number of edges.
pos = _searchsorted(sorted_etypes, th.arange(self.num_rels, device=g.device))
num = th.tensor([len(etypes)], device=g.device)
......
......@@ -199,7 +199,7 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
eid = F.astype(induced_edges[0], F.int64) + max_eid * F.astype(inner_edge == 0, F.int64)
_, index = F.sort_1d(eid)
subg1 = edge_subgraph(subg1, index, preserve_nodes=True)
subg1 = edge_subgraph(subg1, index, relabel_nodes=False)
subg1.ndata[NID] = induced_nodes[0]
subg1.edata[EID] = F.gather_row(induced_edges[0], index)
else:
......
......@@ -6,7 +6,7 @@ For stochastic subgraph extraction, please see functions under :mod:`dgl.samplin
from collections.abc import Mapping
from ._ffi.function import _init_api
from .base import DGLError
from .base import DGLError, dgl_warning
from . import backend as F
from . import graph_index
from . import heterograph_index
......@@ -17,17 +17,13 @@ from . import utils
__all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph',
'in_subgraph', 'out_subgraph']
def node_subgraph(graph, nodes, store_ids=True):
def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True):
"""Return a subgraph induced on the given nodes.
A node-induced subgraph is a subset of the nodes of a graph together with
any edges whose endpoints are both in this subset. In addition to extracting
the subgraph, DGL conducts the following:
* Relabel the extracted nodes to IDs starting from zero.
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
A node-induced subgraph is a graph with edges whose endpoints are both in the
specified node set. In addition to extracting the subgraph, DGL also copies
the features of the extracted nodes and edges to the resulting graph. The copy
is *lazy* and incurs data movement only when needed.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations
......@@ -48,11 +44,15 @@ def node_subgraph(graph, nodes, store_ids=True):
If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being node types
and values being the nodes.
and values being the node IDs in the above formats.
relabel_nodes : bool, optional
If True, the extracted subgraph will only have the nodes in the specified node set
and it will relabel the nodes in order.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the specified nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns
-------
......@@ -144,23 +144,20 @@ def node_subgraph(graph, nodes, store_ids=True):
for ntype in graph.ntypes:
nids = nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), graph.device))
induced_nodes.append(_process_nodes(ntype, nids))
sgi = graph._graph.node_subgraph(induced_nodes)
sgi = graph._graph.node_subgraph(induced_nodes, relabel_nodes)
induced_edges = sgi.induced_edges
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph)
def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, **deprecated_kwargs):
"""Return a subgraph induced on the given edges.
An edge-induced subgraph is equivalent to creating a new graph
with the same number of nodes using the given edges. In addition to extracting
the subgraph, DGL conducts the following:
* Relabel the incident nodes to IDs starting from zero. Isolated nodes are removed.
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
An edge-induced subgraph is equivalent to creating a new graph using the given
edges. In addition to extracting the subgraph, DGL also copies the features
of the extracted nodes and edges to the resulting graph. The copy is *lazy*
and incurs data movement only when needed.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations
......@@ -170,7 +167,7 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
----------
graph : DGLGraph
The graph to extract the subgraph from.
edges : dict[(str, str, str), edges]
edges : edges or dict[(str, str, str), edges]
The edges to form the subgraph. The allowed edges formats are:
* Int Tensor: Each element is an edge ID. The tensor must have the same device type
......@@ -181,14 +178,15 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being edge types
and values being the edge IDs.
preserve_nodes : bool, optional
If True, do not relabel the incident nodes and remove the isolated nodes
in the extracted subgraph. (Default: False)
and values being the edge IDs in the above formats.
relabel_nodes : bool, optional
If True, it will remove the isolated nodes and relabel the incident nodes in the
extracted subgraph.
store_ids : bool, optional
If True, it will store the IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the incident nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns
-------
......@@ -227,10 +225,10 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
Extract a subgraph without node relabeling.
>>> sg = dgl.edge_subgraph(g, [0, 4], preserve_nodes=True)
>>> sg = dgl.edge_subgraph(g, [0, 4], relabel_nodes=False)
>>> sg
Graph(num_nodes=5, num_edges=2,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
ndata_schemes={}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([0, 4]), tensor([1, 0]))
......@@ -276,7 +274,11 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
--------
node_subgraph
"""
if graph.is_block and not preserve_nodes:
if len(deprecated_kwargs) != 0:
dgl_warning(
"Key word argument preserve_nodes is deprecated. Use relabel_nodes instead.")
relabel_nodes = not deprecated_kwargs.get('preserve_nodes')
if graph.is_block and relabel_nodes:
raise DGLError('Extracting subgraph from a block graph is not allowed.')
if not isinstance(edges, Mapping):
assert len(graph.canonical_etypes) == 1, \
......@@ -294,25 +296,20 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
for cetype in graph.canonical_etypes:
eids = edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), graph.device))
induced_edges.append(_process_edges(cetype, eids))
sgi = graph._graph.edge_subgraph(induced_edges, preserve_nodes)
induced_nodes = sgi.induced_nodes
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids)
sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph)
def in_subgraph(g, nodes):
def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
"""Return the subgraph induced on the inbound edges of all the edge types of the
given nodes.
An edge-induced subgraph is equivalent to creating a new graph
with the same number of nodes using the given edges. In addition to extracting
the subgraph, DGL conducts the following:
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
* Store the IDs of the extracted edges in the ``edata``
of the resulting graph under name ``dgl.EID``.
An in subgraph is equivalent to creating a new graph using the incoming edges of the
given nodes. In addition to extracting the subgraph, DGL also copies the features of
the extracted nodes and edges to the resulting graph. The copy is *lazy* and incurs
data movement only when needed.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations
......@@ -320,18 +317,26 @@ def in_subgraph(g, nodes):
Parameters
----------
g : DGLGraph
graph : DGLGraph
The input graph.
nodes : nodes or dict[str, nodes]
The nodes to form the subgraph. The allowed nodes formats are:
* Int Tensor: Each element is an ID. The tensor must have the same device type
* Int Tensor: Each element is a node ID. The tensor must have the same device type
and ID data type as the graph's.
* iterable[int]: Each element is an ID.
* iterable[int]: Each element is a node ID.
If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being node types
and values being the nodes.
and values being the node IDs in the above formats.
relabel_nodes : bool, optional
If True, it will remove the isolated nodes and relabel the rest nodes in the
extracted subgraph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns
-------
......@@ -371,6 +376,21 @@ def in_subgraph(g, nodes):
tensor([[2, 3],
[8, 9]])
Extract a subgraph with node labeling.
>>> sg = dgl.in_subgraph(g, [2, 0], relabel_nodes=True)
>>> sg
Graph(num_nodes=4, num_edges=2,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64}
edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([1, 3]), tensor([2, 0]))
>>> sg.edata[dgl.EID] # original edge IDs
tensor([1, 4])
>>> sg.ndata[dgl.NID] # original node IDs
tensor([0, 1, 2, 4])
Extract a subgraph from a heterogeneous graph.
>>> g = dgl.heterograph({
......@@ -386,39 +406,35 @@ def in_subgraph(g, nodes):
--------
out_subgraph
"""
if g.is_block:
if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.')
if not isinstance(nodes, dict):
if len(g.ntypes) > 1:
if len(graph.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes}
nodes = utils.prepare_tensor_dict(g, nodes, 'nodes')
nodes = {graph.ntypes[0] : nodes}
nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes')
nodes_all_types = []
for ntype in g.ntypes:
for ntype in graph.ntypes:
if ntype in nodes:
nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
else:
nodes_all_types.append(nd.NULL[g._idtype_str])
nodes_all_types.append(nd.NULL[graph._idtype_str])
sgi = _CAPI_DGLInSubgraph(g._graph, nodes_all_types)
sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
induced_edges = sgi.induced_edges
return _create_hetero_subgraph(g, sgi, None, induced_edges)
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
DGLHeteroGraph.in_subgraph = utils.alias_func(in_subgraph)
def out_subgraph(g, nodes):
"""Return the subgraph induced on the out-bound edges of all the edge types of the
def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
"""Return the subgraph induced on the outbound edges of all the edge types of the
given nodes.
An edge-induced subgraph is equivalent to creating a new graph
with the same number of nodes using the given edges. In addition to extracting
the subgraph, DGL conducts the following:
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
* Store the IDs of the extracted edges in the ``edata``
of the resulting graph under name ``dgl.EID``.
An out subgraph is equivalent to creating a new graph using the outcoming edges of
the given nodes. In addition to extracting the subgraph, DGL also copies the features
of the extracted nodes and edges to the resulting graph. The copy is *lazy* and incurs
data movement only when needed.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations
......@@ -426,7 +442,7 @@ def out_subgraph(g, nodes):
Parameters
----------
g : DGLGraph
graph : DGLGraph
The input graph.
nodes : nodes or dict[str, nodes]
The nodes to form the subgraph. The allowed nodes formats are:
......@@ -437,7 +453,15 @@ def out_subgraph(g, nodes):
If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being node types
and values being the nodes.
and values being the node IDs in the above formats.
relabel_nodes : bool, optional
If True, it will remove the isolated nodes and relabel the rest nodes in the
extracted subgraph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns
-------
......@@ -477,6 +501,21 @@ def out_subgraph(g, nodes):
tensor([[4, 5],
[0, 1]])
Extract a subgraph with node labeling.
>>> sg = dgl.out_subgraph(g, [2, 0], relabel_nodes=True)
>>> sg
Graph(num_nodes=4, num_edges=2,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([2, 0]), tensor([3, 1]))
>>> sg.edata[dgl.EID] # original edge IDs
tensor([2, 0])
>>> sg.ndata[dgl.NID] # original node IDs
tensor([0, 1, 2, 3])
Extract a subgraph from a heterogeneous graph.
>>> g = dgl.heterograph({
......@@ -492,23 +531,24 @@ def out_subgraph(g, nodes):
--------
in_subgraph
"""
if g.is_block:
if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.')
if not isinstance(nodes, dict):
if len(g.ntypes) > 1:
if len(graph.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes}
nodes = utils.prepare_tensor_dict(g, nodes, 'nodes')
nodes = {graph.ntypes[0] : nodes}
nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes')
nodes_all_types = []
for ntype in g.ntypes:
for ntype in graph.ntypes:
if ntype in nodes:
nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
else:
nodes_all_types.append(nd.NULL[g._idtype_str])
nodes_all_types.append(nd.NULL[graph._idtype_str])
sgi = _CAPI_DGLOutSubgraph(g._graph, nodes_all_types)
sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
induced_edges = sgi.induced_edges
return _create_hetero_subgraph(g, sgi, None, induced_edges)
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
DGLHeteroGraph.out_subgraph = utils.alias_func(out_subgraph)
......@@ -701,9 +741,10 @@ def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges, store_ids
Induced edge IDs. Will store it as the dgl.EID ndata unless it
is None, which means the induced edge IDs are the same as the parent edge IDs.
store_ids : bool
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,
respectively.
If True and induced_nodes is not None, it will store the raw IDs of the extracted
nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``.
If True and induced_edges is not None, it will store the raw IDs of the extracted
edges in the ``edata`` of the resulting graph under name ``dgl.EID``.
Returns
-------
......
......@@ -379,6 +379,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> vids = args[1];
bool relabel_nodes = args[2];
CHECK(relabel_nodes) << "Node subgraph only supports relabel_nodes=True.";
std::vector<IdArray> vid_vec;
vid_vec.reserve(vids.size());
for (Value val : vids) {
......@@ -649,8 +651,9 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
bool relabel_nodes = args[2];
std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);
*ret = InEdgeGraph(hg.sptr(), nodes);
*ret = InEdgeGraph(hg.sptr(), nodes, relabel_nodes);
*rv = HeteroGraphRef(ret);
});
......@@ -658,8 +661,9 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
bool relabel_nodes = args[2];
std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);
*ret = OutEdgeGraph(hg.sptr(), nodes);
*ret = OutEdgeGraph(hg.sptr(), nodes, relabel_nodes);
*rv = HeteroGraphRef(ret);
});
......
......@@ -8,7 +8,27 @@ using namespace dgl::runtime;
namespace dgl {
HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
HeteroSubgraph InEdgeGraphRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<IdArray> eids(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_type_t dst_vtype = pair.second;
if (aten::IsNullArray(vids[dst_vtype])) {
eids[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else {
const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
eids[etype] = earr.id;
}
}
return graph->EdgeSubgraph(eids, false);
}
HeteroSubgraph InEdgeGraphNoRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
// TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
......@@ -43,7 +63,36 @@ HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray
return ret;
}
HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) {
if (relabel_nodes) {
return InEdgeGraphRelabelNodes(graph, vids);
} else {
return InEdgeGraphNoRelabelNodes(graph, vids);
}
}
HeteroSubgraph OutEdgeGraphRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<IdArray> eids(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
if (aten::IsNullArray(vids[src_vtype])) {
eids[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else {
const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
eids[etype] = earr.id;
}
}
return graph->EdgeSubgraph(eids, false);
}
HeteroSubgraph OutEdgeGraphNoRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
// TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
......@@ -78,4 +127,13 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra
return ret;
}
HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) {
if (relabel_nodes) {
return OutEdgeGraphRelabelNodes(graph, vids);
} else {
return OutEdgeGraphNoRelabelNodes(graph, vids);
}
}
} // namespace dgl
......@@ -273,16 +273,16 @@ def test_set_batch_info(idtype):
assert subg_n2.num_edges() == subg2.num_edges()
# test homogeneous edge subgraph
sg_e = dgl.edge_subgraph(bg, list(range(40, 70)) + list(range(150, 200)), preserve_nodes=True)
induced_nodes = sg_e.ndata['_ID']
sg_e = dgl.edge_subgraph(bg, list(range(40, 70)) + list(range(150, 200)), relabel_nodes=False)
induced_nodes = F.arange(0, bg.num_nodes(), idtype)
induced_edges = sg_e.edata['_ID']
new_batch_num_nodes = _get_subgraph_batch_info(bg.ntypes, [induced_nodes], batch_num_nodes)
new_batch_num_edges = _get_subgraph_batch_info(bg.canonical_etypes, [induced_edges], batch_num_edges)
sg_e.set_batch_num_nodes(new_batch_num_nodes)
sg_e.set_batch_num_edges(new_batch_num_edges)
subg_e1, subg_e2 = dgl.unbatch(sg_e)
subg1 = dgl.edge_subgraph(g1, list(range(40, 70)), preserve_nodes=True)
subg2 = dgl.edge_subgraph(g2, list(range(50, 100)), preserve_nodes=True)
subg1 = dgl.edge_subgraph(g1, list(range(40, 70)), relabel_nodes=False)
subg2 = dgl.edge_subgraph(g2, list(range(50, 100)), relabel_nodes=False)
assert subg_e1.num_nodes() == subg1.num_nodes()
assert subg_e2.num_nodes() == subg2.num_nodes()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment