Unverified Commit ff519f98 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[API] Standardize Subgraph APIs (#2929)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Fix

* Update

* Fix subgraph tests

* Capture stdout for distributed test

* Capture stdout for distributed test

* Update

* Update

* Update

* Update subgraph.cc
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-17.us-west-2.compute.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent fc69f6ce
...@@ -439,7 +439,7 @@ all edge types, so that it can work on heterogeneous graphs as well. ...@@ -439,7 +439,7 @@ all edge types, so that it can work on heterogeneous graphs as well.
# Return a new graph with the same nodes as the original graph as a # Return a new graph with the same nodes as the original graph as a
# frontier # frontier
frontier = dgl.edge_subgraph(new_edge_masks, preserve_nodes=True) frontier = dgl.edge_subgraph(new_edge_masks, relabel_nodes=False)
return frontier return frontier
def __len__(self): def __len__(self):
......
...@@ -381,7 +381,7 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所 ...@@ -381,7 +381,7 @@ DGL确保块的输出节点将始终出现在输入节点中。如下代码所
new_edges_masks[etype] = edge_mask.bool() new_edges_masks[etype] = edge_mask.bool()
# 返回一个与初始图有相同节点的图作为边界 # 返回一个与初始图有相同节点的图作为边界
frontier = dgl.edge_subgraph(new_edge_masks, preserve_nodes=True) frontier = dgl.edge_subgraph(new_edge_masks, relabel_nodes=False)
return frontier return frontier
def __len__(self): def __len__(self):
......
...@@ -18,7 +18,7 @@ def drop_edge(graph, drop_prob): ...@@ -18,7 +18,7 @@ def drop_edge(graph, drop_prob):
masks = th.bernoulli(1 - mask_rates) masks = th.bernoulli(1 - mask_rates)
edge_idx = masks.nonzero().squeeze(1) edge_idx = masks.nonzero().squeeze(1)
sg = dgl.edge_subgraph(graph, edge_idx, preserve_nodes=True) sg = dgl.edge_subgraph(graph, edge_idx, relabel_nodes=False)
return sg return sg
......
...@@ -58,7 +58,7 @@ class RGAT(nn.Module): ...@@ -58,7 +58,7 @@ class RGAT(nn.Module):
mfg = dgl.block_to_graph(mfg) mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst) x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes): for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, preserve_nodes=True) subg = mfg.edge_subgraph(mfg.edata['etype'] == j, relabel_nodes=False)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels) x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels)
x = self.norms[i](x_skip) x = self.norms[i](x_skip)
x = F.elu(x) x = F.elu(x)
......
...@@ -61,7 +61,7 @@ class RGAT(nn.Module): ...@@ -61,7 +61,7 @@ class RGAT(nn.Module):
mfg = dgl.block_to_graph(mfg) mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst) x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes): for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, preserve_nodes=True) subg = mfg.edge_subgraph(mfg.edata['etype'] == j, relabel_nodes=False)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels) x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels)
x = self.norms[i](x_skip) x = self.norms[i](x_skip)
x = F.elu(x) x = F.elu(x)
......
...@@ -31,10 +31,7 @@ def train_test_split_by_time(df, timestamp, user): ...@@ -31,10 +31,7 @@ def train_test_split_by_time(df, timestamp, user):
def build_train_graph(g, train_indices, utype, itype, etype, etype_rev): def build_train_graph(g, train_indices, utype, itype, etype, etype_rev):
train_g = g.edge_subgraph( train_g = g.edge_subgraph(
{etype: train_indices, etype_rev: train_indices}, {etype: train_indices, etype_rev: train_indices},
preserve_nodes=True) relabel_nodes=False)
# remove the induced node IDs - should be assigned by model instead
del train_g.nodes[utype].data[dgl.NID]
del train_g.nodes[itype].data[dgl.NID]
# copy features # copy features
for ntype in g.ntypes: for ntype in g.ntypes:
......
...@@ -174,7 +174,7 @@ class TemporalEdgeCollator(EdgeCollator): ...@@ -174,7 +174,7 @@ class TemporalEdgeCollator(EdgeCollator):
def _collate_with_negative_sampling(self, items): def _collate_with_negative_sampling(self, items):
items = _prepare_tensor(self.g_sampling, items, 'items', False) items = _prepare_tensor(self.g_sampling, items, 'items', False)
# Here node id will not change # Here node id will not change
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True) pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[dgl.EID] induced_edges = pair_graph.edata[dgl.EID]
neg_srcdst_raw = self.negative_sampler(self.g, items) neg_srcdst_raw = self.negative_sampler(self.g, items)
...@@ -546,7 +546,7 @@ class FastTemporalEdgeCollator(EdgeCollator): ...@@ -546,7 +546,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
def _collate_with_negative_sampling(self, items): def _collate_with_negative_sampling(self, items):
items = _prepare_tensor(self.g_sampling, items, 'items', False) items = _prepare_tensor(self.g_sampling, items, 'items', False)
# Here node id will not change # Here node id will not change
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True) pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[dgl.EID] induced_edges = pair_graph.edata[dgl.EID]
neg_srcdst_raw = self.negative_sampler(self.g, items) neg_srcdst_raw = self.negative_sampler(self.g, items)
......
...@@ -99,7 +99,7 @@ def dgl_main(): ...@@ -99,7 +99,7 @@ def dgl_main():
# create train graph # create train graph
train_edge_idx = torch.tensor(train_edge_idx).to(device) train_edge_idx = torch.tensor(train_edge_idx).to(device)
train_graph = dgl.edge_subgraph(graph, train_edge_idx, preserve_nodes=True) train_graph = dgl.edge_subgraph(graph, train_edge_idx, relabel_nodes=False)
train_graph = train_graph.to(device) train_graph = train_graph.to(device)
adj = train_graph.adjacency_matrix().to_dense().to(device) adj = train_graph.adjacency_matrix().to_dense().to(device)
......
...@@ -685,19 +685,23 @@ HeteroGraphPtr CreateFromCSC( ...@@ -685,19 +685,23 @@ HeteroGraphPtr CreateFromCSC(
* \brief Extract the subgraph of the in edges of the given nodes. * \brief Extract the subgraph of the in edges of the given nodes.
* \param graph Graph * \param graph Graph
* \param nodes Node IDs of each type * \param nodes Node IDs of each type
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest ones
* \return Subgraph containing only the in edges. The returned graph has the same * \return Subgraph containing only the in edges. The returned graph has the same
* schema as the original one. * schema as the original one.
*/ */
HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& nodes); HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, bool relabel_nodes = false);
/*! /*!
* \brief Extract the subgraph of the out edges of the given nodes. * \brief Extract the subgraph of the out edges of the given nodes.
* \param graph Graph * \param graph Graph
* \param nodes Node IDs of each type * \param nodes Node IDs of each type
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest ones
* \return Subgraph containing only the out edges. The returned graph has the same * \return Subgraph containing only the out edges. The returned graph has the same
* schema as the original one. * schema as the original one.
*/ */
HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& nodes); HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, bool relabel_nodes = false);
/*! /*!
* \brief Joint union multiple graphs into one graph. * \brief Joint union multiple graphs into one graph.
......
...@@ -3563,7 +3563,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3563,7 +3563,7 @@ class DGLGraph(DGLBaseGraph):
tensor([0, 1, 4]) tensor([0, 1, 4])
>>> SG.parent_eid >>> SG.parent_eid
tensor([0, 4]) tensor([0, 4])
>>> SG = G.edge_subgraph([0, 4], preserve_nodes=True) >>> SG = G.edge_subgraph([0, 4], relabel_nodes=False)
>>> SG.nodes() >>> SG.nodes()
tensor([0, 1, 2, 3, 4]) tensor([0, 1, 2, 3, 4])
>>> SG.edges() >>> SG.edges()
......
...@@ -422,13 +422,13 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -422,13 +422,13 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>> # build train_g >>> # build train_g
>>> train_edges = train_set >>> train_edges = train_set
>>> train_g = g.edge_subgraph(train_edges, >>> train_g = g.edge_subgraph(train_edges,
preserve_nodes=True) relabel_nodes=False)
>>> train_g.edata['e_type'] = e_type[train_edges]; >>> train_g.edata['e_type'] = e_type[train_edges];
>>> >>>
>>> # build val_g >>> # build val_g
>>> val_edges = th.cat([train_edges, val_edges]) >>> val_edges = th.cat([train_edges, val_edges])
>>> val_g = g.edge_subgraph(val_edges, >>> val_g = g.edge_subgraph(val_edges,
preserve_nodes=True) relabel_nodes=False)
>>> val_g.edata['e_type'] = e_type[val_edges]; >>> val_g.edata['e_type'] = e_type[val_edges];
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
...@@ -558,13 +558,13 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -558,13 +558,13 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> # build train_g >>> # build train_g
>>> train_edges = train_set >>> train_edges = train_set
>>> train_g = g.edge_subgraph(train_edges, >>> train_g = g.edge_subgraph(train_edges,
preserve_nodes=True) relabel_nodes=False)
>>> train_g.edata['e_type'] = e_type[train_edges]; >>> train_g.edata['e_type'] = e_type[train_edges];
>>> >>>
>>> # build val_g >>> # build val_g
>>> val_edges = th.cat([train_edges, val_edges]) >>> val_edges = th.cat([train_edges, val_edges])
>>> val_g = g.edge_subgraph(val_edges, >>> val_g = g.edge_subgraph(val_edges,
preserve_nodes=True) relabel_nodes=False)
>>> val_g.edata['e_type'] = e_type[val_edges]; >>> val_g.edata['e_type'] = e_type[val_edges];
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
...@@ -694,13 +694,13 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -694,13 +694,13 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> # build train_g >>> # build train_g
>>> train_edges = train_set >>> train_edges = train_set
>>> train_g = g.edge_subgraph(train_edges, >>> train_g = g.edge_subgraph(train_edges,
preserve_nodes=True) relabel_nodes=False)
>>> train_g.edata['e_type'] = e_type[train_edges]; >>> train_g.edata['e_type'] = e_type[train_edges];
>>> >>>
>>> # build val_g >>> # build val_g
>>> val_edges = th.cat([train_edges, val_edges]) >>> val_edges = th.cat([train_edges, val_edges])
>>> val_g = g.edge_subgraph(val_edges, >>> val_g = g.edge_subgraph(val_edges,
preserve_nodes=True) relabel_nodes=False)
>>> val_g.edata['e_type'] = e_type[val_edges]; >>> val_g.edata['e_type'] = e_type[val_edges];
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
......
...@@ -696,7 +696,7 @@ class EdgeCollator(Collator): ...@@ -696,7 +696,7 @@ class EdgeCollator(Collator):
else: else:
items = _prepare_tensor(self.g_sampling, items, 'items', self._is_distributed) items = _prepare_tensor(self.g_sampling, items, 'items', self._is_distributed)
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True) pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[EID] induced_edges = pair_graph.edata[EID]
neg_srcdst = self.negative_sampler(self.g, items) neg_srcdst = self.negative_sampler(self.g, items)
......
...@@ -622,7 +622,7 @@ class DGLHeteroGraph(object): ...@@ -622,7 +622,7 @@ class DGLHeteroGraph(object):
else: else:
edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype) edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype)
sub_g = self.edge_subgraph(edges, preserve_nodes=True, store_ids=store_ids) sub_g = self.edge_subgraph(edges, relabel_nodes=False, store_ids=store_ids)
self._graph = sub_g._graph self._graph = sub_g._graph
self._node_frames = sub_g._node_frames self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames self._edge_frames = sub_g._edge_frames
...@@ -4294,7 +4294,7 @@ class DGLHeteroGraph(object): ...@@ -4294,7 +4294,7 @@ class DGLHeteroGraph(object):
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges') eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
if core.is_builtin(func): if core.is_builtin(func):
if not is_all(eid): if not is_all(eid):
g = g.edge_subgraph(eid, preserve_nodes=True) g = g.edge_subgraph(eid, relabel_nodes=False)
edata = core.invoke_gsddmm(g, func) edata = core.invoke_gsddmm(g, func)
else: else:
edata = core.invoke_edge_udf(g, eid, etype, func) edata = core.invoke_edge_udf(g, eid, etype, func)
......
...@@ -843,7 +843,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -843,7 +843,7 @@ class HeteroGraphIndex(ObjectBase):
raise DGLError('Invalid incidence matrix type: %s' % str(typestr)) raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
return inc, shuffle_idx return inc, shuffle_idx
def node_subgraph(self, induced_nodes): def node_subgraph(self, induced_nodes, relabel_nodes):
"""Return the induced node subgraph. """Return the induced node subgraph.
Parameters Parameters
...@@ -851,6 +851,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -851,6 +851,9 @@ class HeteroGraphIndex(ObjectBase):
induced_nodes : list of utils.Index induced_nodes : list of utils.Index
Induced nodes. The length should be equal to the number of Induced nodes. The length should be equal to the number of
node types in this heterograph. node types in this heterograph.
relabel_nodes : bool
If True, the extracted subgraph will only have the nodes in the specified node set
and it will relabel the nodes in order.
Returns Returns
------- -------
...@@ -858,7 +861,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -858,7 +861,7 @@ class HeteroGraphIndex(ObjectBase):
The subgraph index. The subgraph index.
""" """
vids = [F.to_dgl_nd(nodes) for nodes in induced_nodes] vids = [F.to_dgl_nd(nodes) for nodes in induced_nodes]
return _CAPI_DGLHeteroVertexSubgraph(self, vids) return _CAPI_DGLHeteroVertexSubgraph(self, vids, relabel_nodes)
def edge_subgraph(self, induced_edges, preserve_nodes): def edge_subgraph(self, induced_edges, preserve_nodes):
"""Return the induced edge subgraph. """Return the induced edge subgraph.
......
...@@ -342,7 +342,7 @@ class RelGraphConv(nn.Module): ...@@ -342,7 +342,7 @@ class RelGraphConv(nn.Module):
# list, where each element is the number of edges of the type. # list, where each element is the number of edges of the type.
# Sort the graph based on the etypes # Sort the graph based on the etypes
sorted_etypes, index = th.sort(etypes) sorted_etypes, index = th.sort(etypes)
g = edge_subgraph(g, index, preserve_nodes=True) g = edge_subgraph(g, index, relabel_nodes=False)
# Create a new etypes to be an integer list of number of edges. # Create a new etypes to be an integer list of number of edges.
pos = _searchsorted(sorted_etypes, th.arange(self.num_rels, device=g.device)) pos = _searchsorted(sorted_etypes, th.arange(self.num_rels, device=g.device))
num = th.tensor([len(etypes)], device=g.device) num = th.tensor([len(etypes)], device=g.device)
......
...@@ -199,7 +199,7 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False): ...@@ -199,7 +199,7 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
eid = F.astype(induced_edges[0], F.int64) + max_eid * F.astype(inner_edge == 0, F.int64) eid = F.astype(induced_edges[0], F.int64) + max_eid * F.astype(inner_edge == 0, F.int64)
_, index = F.sort_1d(eid) _, index = F.sort_1d(eid)
subg1 = edge_subgraph(subg1, index, preserve_nodes=True) subg1 = edge_subgraph(subg1, index, relabel_nodes=False)
subg1.ndata[NID] = induced_nodes[0] subg1.ndata[NID] = induced_nodes[0]
subg1.edata[EID] = F.gather_row(induced_edges[0], index) subg1.edata[EID] = F.gather_row(induced_edges[0], index)
else: else:
......
...@@ -6,7 +6,7 @@ For stochastic subgraph extraction, please see functions under :mod:`dgl.samplin ...@@ -6,7 +6,7 @@ For stochastic subgraph extraction, please see functions under :mod:`dgl.samplin
from collections.abc import Mapping from collections.abc import Mapping
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError, dgl_warning
from . import backend as F from . import backend as F
from . import graph_index from . import graph_index
from . import heterograph_index from . import heterograph_index
...@@ -17,17 +17,13 @@ from . import utils ...@@ -17,17 +17,13 @@ from . import utils
__all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph', __all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph',
'in_subgraph', 'out_subgraph'] 'in_subgraph', 'out_subgraph']
def node_subgraph(graph, nodes, store_ids=True): def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True):
"""Return a subgraph induced on the given nodes. """Return a subgraph induced on the given nodes.
A node-induced subgraph is a subset of the nodes of a graph together with A node-induced subgraph is a graph with edges whose endpoints are both in the
any edges whose endpoints are both in this subset. In addition to extracting specified node set. In addition to extracting the subgraph, DGL also copies
the subgraph, DGL conducts the following: the features of the extracted nodes and edges to the resulting graph. The copy
is *lazy* and incurs data movement only when needed.
* Relabel the extracted nodes to IDs starting from zero.
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations them as the resulting graph. Thus, the resulting graph has the same set of relations
...@@ -48,11 +44,15 @@ def node_subgraph(graph, nodes, store_ids=True): ...@@ -48,11 +44,15 @@ def node_subgraph(graph, nodes, store_ids=True):
If the graph is homogeneous, one can directly pass the above formats. If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being node types Otherwise, the argument must be a dictionary with keys being node types
and values being the nodes. and values being the node IDs in the above formats.
relabel_nodes : bool, optional
If True, the extracted subgraph will only have the nodes in the specified node set
and it will relabel the nodes in order.
store_ids : bool, optional store_ids : bool, optional
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata`` If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
respectively. also store the raw IDs of the specified nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns Returns
------- -------
...@@ -144,23 +144,20 @@ def node_subgraph(graph, nodes, store_ids=True): ...@@ -144,23 +144,20 @@ def node_subgraph(graph, nodes, store_ids=True):
for ntype in graph.ntypes: for ntype in graph.ntypes:
nids = nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), graph.device)) nids = nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), graph.device))
induced_nodes.append(_process_nodes(ntype, nids)) induced_nodes.append(_process_nodes(ntype, nids))
sgi = graph._graph.node_subgraph(induced_nodes) sgi = graph._graph.node_subgraph(induced_nodes, relabel_nodes)
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids) induced_nodes = sgi.induced_nodes if relabel_nodes else None
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph) DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph)
def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True): def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, **deprecated_kwargs):
"""Return a subgraph induced on the given edges. """Return a subgraph induced on the given edges.
An edge-induced subgraph is equivalent to creating a new graph An edge-induced subgraph is equivalent to creating a new graph using the given
with the same number of nodes using the given edges. In addition to extracting edges. In addition to extracting the subgraph, DGL also copies the features
the subgraph, DGL conducts the following: of the extracted nodes and edges to the resulting graph. The copy is *lazy*
and incurs data movement only when needed.
* Relabel the incident nodes to IDs starting from zero. Isolated nodes are removed.
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations them as the resulting graph. Thus, the resulting graph has the same set of relations
...@@ -170,7 +167,7 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True): ...@@ -170,7 +167,7 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph to extract the subgraph from. The graph to extract the subgraph from.
edges : dict[(str, str, str), edges] edges : edges or dict[(str, str, str), edges]
The edges to form the subgraph. The allowed edges formats are: The edges to form the subgraph. The allowed edges formats are:
* Int Tensor: Each element is an edge ID. The tensor must have the same device type * Int Tensor: Each element is an edge ID. The tensor must have the same device type
...@@ -181,14 +178,15 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True): ...@@ -181,14 +178,15 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
If the graph is homogeneous, one can directly pass the above formats. If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being edge types Otherwise, the argument must be a dictionary with keys being edge types
and values being the edge IDs. and values being the edge IDs in the above formats.
preserve_nodes : bool, optional relabel_nodes : bool, optional
If True, do not relabel the incident nodes and remove the isolated nodes If True, it will remove the isolated nodes and relabel the incident nodes in the
in the extracted subgraph. (Default: False) extracted subgraph.
store_ids : bool, optional store_ids : bool, optional
If True, it will store the IDs of the extracted nodes and edges in the ``ndata`` If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
respectively. also store the raw IDs of the incident nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns Returns
------- -------
...@@ -227,10 +225,10 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True): ...@@ -227,10 +225,10 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
Extract a subgraph without node relabeling. Extract a subgraph without node relabeling.
>>> sg = dgl.edge_subgraph(g, [0, 4], preserve_nodes=True) >>> sg = dgl.edge_subgraph(g, [0, 4], relabel_nodes=False)
>>> sg >>> sg
Graph(num_nodes=5, num_edges=2, Graph(num_nodes=5, num_edges=2,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)} ndata_schemes={}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}) edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges() >>> sg.edges()
(tensor([0, 4]), tensor([1, 0])) (tensor([0, 4]), tensor([1, 0]))
...@@ -276,7 +274,11 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True): ...@@ -276,7 +274,11 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
-------- --------
node_subgraph node_subgraph
""" """
if graph.is_block and not preserve_nodes: if len(deprecated_kwargs) != 0:
dgl_warning(
"Key word argument preserve_nodes is deprecated. Use relabel_nodes instead.")
relabel_nodes = not deprecated_kwargs.get('preserve_nodes')
if graph.is_block and relabel_nodes:
raise DGLError('Extracting subgraph from a block graph is not allowed.') raise DGLError('Extracting subgraph from a block graph is not allowed.')
if not isinstance(edges, Mapping): if not isinstance(edges, Mapping):
assert len(graph.canonical_etypes) == 1, \ assert len(graph.canonical_etypes) == 1, \
...@@ -294,25 +296,20 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True): ...@@ -294,25 +296,20 @@ def edge_subgraph(graph, edges, preserve_nodes=False, store_ids=True):
for cetype in graph.canonical_etypes: for cetype in graph.canonical_etypes:
eids = edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), graph.device)) eids = edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), graph.device))
induced_edges.append(_process_edges(cetype, eids)) induced_edges.append(_process_edges(cetype, eids))
sgi = graph._graph.edge_subgraph(induced_edges, preserve_nodes) sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes)
induced_nodes = sgi.induced_nodes induced_nodes = sgi.induced_nodes if relabel_nodes else None
return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids) return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph) DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph)
def in_subgraph(g, nodes): def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
"""Return the subgraph induced on the inbound edges of all the edge types of the """Return the subgraph induced on the inbound edges of all the edge types of the
given nodes. given nodes.
An edge-induced subgraph is equivalent to creating a new graph An in subgraph is equivalent to creating a new graph using the incoming edges of the
with the same number of nodes using the given edges. In addition to extracting given nodes. In addition to extracting the subgraph, DGL also copies the features of
the subgraph, DGL conducts the following: the extracted nodes and edges to the resulting graph. The copy is *lazy* and incurs
data movement only when needed.
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
* Store the IDs of the extracted edges in the ``edata``
of the resulting graph under name ``dgl.EID``.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations them as the resulting graph. Thus, the resulting graph has the same set of relations
...@@ -320,18 +317,26 @@ def in_subgraph(g, nodes): ...@@ -320,18 +317,26 @@ def in_subgraph(g, nodes):
Parameters Parameters
---------- ----------
g : DGLGraph graph : DGLGraph
The input graph. The input graph.
nodes : nodes or dict[str, nodes] nodes : nodes or dict[str, nodes]
The nodes to form the subgraph. The allowed nodes formats are: The nodes to form the subgraph. The allowed nodes formats are:
* Int Tensor: Each element is an ID. The tensor must have the same device type * Int Tensor: Each element is a node ID. The tensor must have the same device type
and ID data type as the graph's. and ID data type as the graph's.
* iterable[int]: Each element is an ID. * iterable[int]: Each element is a node ID.
If the graph is homogeneous, one can directly pass the above formats. If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being node types Otherwise, the argument must be a dictionary with keys being node types
and values being the nodes. and values being the node IDs in the above formats.
relabel_nodes : bool, optional
If True, it will remove the isolated nodes and relabel the rest nodes in the
extracted subgraph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns Returns
------- -------
...@@ -371,6 +376,21 @@ def in_subgraph(g, nodes): ...@@ -371,6 +376,21 @@ def in_subgraph(g, nodes):
tensor([[2, 3], tensor([[2, 3],
[8, 9]]) [8, 9]])
Extract a subgraph with node labeling.
>>> sg = dgl.in_subgraph(g, [2, 0], relabel_nodes=True)
>>> sg
Graph(num_nodes=4, num_edges=2,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64}
edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([1, 3]), tensor([2, 0]))
>>> sg.edata[dgl.EID] # original edge IDs
tensor([1, 4])
>>> sg.ndata[dgl.NID] # original node IDs
tensor([0, 1, 2, 4])
Extract a subgraph from a heterogeneous graph. Extract a subgraph from a heterogeneous graph.
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
...@@ -386,39 +406,35 @@ def in_subgraph(g, nodes): ...@@ -386,39 +406,35 @@ def in_subgraph(g, nodes):
-------- --------
out_subgraph out_subgraph
""" """
if g.is_block: if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.') raise DGLError('Extracting subgraph of a block graph is not allowed.')
if not isinstance(nodes, dict): if not isinstance(nodes, dict):
if len(g.ntypes) > 1: if len(graph.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.") raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes} nodes = {graph.ntypes[0] : nodes}
nodes = utils.prepare_tensor_dict(g, nodes, 'nodes') nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes')
nodes_all_types = [] nodes_all_types = []
for ntype in g.ntypes: for ntype in graph.ntypes:
if ntype in nodes: if ntype in nodes:
nodes_all_types.append(F.to_dgl_nd(nodes[ntype])) nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
else: else:
nodes_all_types.append(nd.NULL[g._idtype_str]) nodes_all_types.append(nd.NULL[graph._idtype_str])
sgi = _CAPI_DGLInSubgraph(g._graph, nodes_all_types) sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
return _create_hetero_subgraph(g, sgi, None, induced_edges) return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
DGLHeteroGraph.in_subgraph = utils.alias_func(in_subgraph) DGLHeteroGraph.in_subgraph = utils.alias_func(in_subgraph)
def out_subgraph(g, nodes): def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True):
"""Return the subgraph induced on the out-bound edges of all the edge types of the """Return the subgraph induced on the outbound edges of all the edge types of the
given nodes. given nodes.
An edge-induced subgraph is equivalent to creating a new graph An out subgraph is equivalent to creating a new graph using the outcoming edges of
with the same number of nodes using the given edges. In addition to extracting the given nodes. In addition to extracting the subgraph, DGL also copies the features
the subgraph, DGL conducts the following: of the extracted nodes and edges to the resulting graph. The copy is *lazy* and incurs
data movement only when needed.
* Copy the features of the extracted nodes and edges to the resulting graph.
The copy is *lazy* and incurs data movement only when needed.
* Store the IDs of the extracted edges in the ``edata``
of the resulting graph under name ``dgl.EID``.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations them as the resulting graph. Thus, the resulting graph has the same set of relations
...@@ -426,7 +442,7 @@ def out_subgraph(g, nodes): ...@@ -426,7 +442,7 @@ def out_subgraph(g, nodes):
Parameters Parameters
---------- ----------
g : DGLGraph graph : DGLGraph
The input graph. The input graph.
nodes : nodes or dict[str, nodes] nodes : nodes or dict[str, nodes]
The nodes to form the subgraph. The allowed nodes formats are: The nodes to form the subgraph. The allowed nodes formats are:
...@@ -437,7 +453,15 @@ def out_subgraph(g, nodes): ...@@ -437,7 +453,15 @@ def out_subgraph(g, nodes):
If the graph is homogeneous, one can directly pass the above formats. If the graph is homogeneous, one can directly pass the above formats.
Otherwise, the argument must be a dictionary with keys being node types Otherwise, the argument must be a dictionary with keys being node types
and values being the nodes. and values being the node IDs in the above formats.
relabel_nodes : bool, optional
If True, it will remove the isolated nodes and relabel the rest nodes in the
extracted subgraph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns Returns
------- -------
...@@ -477,6 +501,21 @@ def out_subgraph(g, nodes): ...@@ -477,6 +501,21 @@ def out_subgraph(g, nodes):
tensor([[4, 5], tensor([[4, 5],
[0, 1]]) [0, 1]])
Extract a subgraph with node labeling.
>>> sg = dgl.out_subgraph(g, [2, 0], relabel_nodes=True)
>>> sg
Graph(num_nodes=4, num_edges=2,
ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([2, 0]), tensor([3, 1]))
>>> sg.edata[dgl.EID] # original edge IDs
tensor([2, 0])
>>> sg.ndata[dgl.NID] # original node IDs
tensor([0, 1, 2, 3])
Extract a subgraph from a heterogeneous graph. Extract a subgraph from a heterogeneous graph.
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
...@@ -492,23 +531,24 @@ def out_subgraph(g, nodes): ...@@ -492,23 +531,24 @@ def out_subgraph(g, nodes):
-------- --------
in_subgraph in_subgraph
""" """
if g.is_block: if graph.is_block:
raise DGLError('Extracting subgraph of a block graph is not allowed.') raise DGLError('Extracting subgraph of a block graph is not allowed.')
if not isinstance(nodes, dict): if not isinstance(nodes, dict):
if len(g.ntypes) > 1: if len(graph.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.") raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes} nodes = {graph.ntypes[0] : nodes}
nodes = utils.prepare_tensor_dict(g, nodes, 'nodes') nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes')
nodes_all_types = [] nodes_all_types = []
for ntype in g.ntypes: for ntype in graph.ntypes:
if ntype in nodes: if ntype in nodes:
nodes_all_types.append(F.to_dgl_nd(nodes[ntype])) nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
else: else:
nodes_all_types.append(nd.NULL[g._idtype_str]) nodes_all_types.append(nd.NULL[graph._idtype_str])
sgi = _CAPI_DGLOutSubgraph(g._graph, nodes_all_types) sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes)
induced_nodes = sgi.induced_nodes if relabel_nodes else None
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
return _create_hetero_subgraph(g, sgi, None, induced_edges) return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids)
DGLHeteroGraph.out_subgraph = utils.alias_func(out_subgraph) DGLHeteroGraph.out_subgraph = utils.alias_func(out_subgraph)
...@@ -701,9 +741,10 @@ def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges, store_ids ...@@ -701,9 +741,10 @@ def _create_hetero_subgraph(parent, sgi, induced_nodes, induced_edges, store_ids
Induced edge IDs. Will store it as the dgl.EID ndata unless it Induced edge IDs. Will store it as the dgl.EID ndata unless it
is None, which means the induced edge IDs are the same as the parent edge IDs. is None, which means the induced edge IDs are the same as the parent edge IDs.
store_ids : bool store_ids : bool
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata`` If True and induced_nodes is not None, it will store the raw IDs of the extracted
and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``.
respectively. If True and induced_edges is not None, it will store the raw IDs of the extracted
edges in the ``edata`` of the resulting graph under name ``dgl.EID``.
Returns Returns
------- -------
......
...@@ -379,6 +379,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph") ...@@ -379,6 +379,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> vids = args[1]; List<Value> vids = args[1];
bool relabel_nodes = args[2];
CHECK(relabel_nodes) << "Node subgraph only supports relabel_nodes=True.";
std::vector<IdArray> vid_vec; std::vector<IdArray> vid_vec;
vid_vec.reserve(vids.size()); vid_vec.reserve(vids.size());
for (Value val : vids) { for (Value val : vids) {
...@@ -649,8 +651,9 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph") ...@@ -649,8 +651,9 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]); const auto& nodes = ListValueToVector<IdArray>(args[1]);
bool relabel_nodes = args[2];
std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph); std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);
*ret = InEdgeGraph(hg.sptr(), nodes); *ret = InEdgeGraph(hg.sptr(), nodes, relabel_nodes);
*rv = HeteroGraphRef(ret); *rv = HeteroGraphRef(ret);
}); });
...@@ -658,8 +661,9 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph") ...@@ -658,8 +661,9 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]); const auto& nodes = ListValueToVector<IdArray>(args[1]);
bool relabel_nodes = args[2];
std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph); std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);
*ret = OutEdgeGraph(hg.sptr(), nodes); *ret = OutEdgeGraph(hg.sptr(), nodes, relabel_nodes);
*rv = HeteroGraphRef(ret); *rv = HeteroGraphRef(ret);
}); });
......
...@@ -8,7 +8,27 @@ using namespace dgl::runtime; ...@@ -8,7 +8,27 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { HeteroSubgraph InEdgeGraphRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<IdArray> eids(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_type_t dst_vtype = pair.second;
if (aten::IsNullArray(vids[dst_vtype])) {
eids[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else {
const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
eids[etype] = earr.id;
}
}
return graph->EdgeSubgraph(eids, false);
}
HeteroSubgraph InEdgeGraphNoRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
// TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
...@@ -43,7 +63,36 @@ HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray ...@@ -43,7 +63,36 @@ HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray
return ret; return ret;
} }
HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) {
if (relabel_nodes) {
return InEdgeGraphRelabelNodes(graph, vids);
} else {
return InEdgeGraphNoRelabelNodes(graph, vids);
}
}
HeteroSubgraph OutEdgeGraphRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<IdArray> eids(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
if (aten::IsNullArray(vids[src_vtype])) {
eids[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else {
const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
eids[etype] = earr.id;
}
}
return graph->EdgeSubgraph(eids, false);
}
HeteroSubgraph OutEdgeGraphNoRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
// TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
...@@ -78,4 +127,13 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra ...@@ -78,4 +127,13 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra
return ret; return ret;
} }
HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) {
if (relabel_nodes) {
return OutEdgeGraphRelabelNodes(graph, vids);
} else {
return OutEdgeGraphNoRelabelNodes(graph, vids);
}
}
} // namespace dgl } // namespace dgl
...@@ -273,16 +273,16 @@ def test_set_batch_info(idtype): ...@@ -273,16 +273,16 @@ def test_set_batch_info(idtype):
assert subg_n2.num_edges() == subg2.num_edges() assert subg_n2.num_edges() == subg2.num_edges()
# test homogeneous edge subgraph # test homogeneous edge subgraph
sg_e = dgl.edge_subgraph(bg, list(range(40, 70)) + list(range(150, 200)), preserve_nodes=True) sg_e = dgl.edge_subgraph(bg, list(range(40, 70)) + list(range(150, 200)), relabel_nodes=False)
induced_nodes = sg_e.ndata['_ID'] induced_nodes = F.arange(0, bg.num_nodes(), idtype)
induced_edges = sg_e.edata['_ID'] induced_edges = sg_e.edata['_ID']
new_batch_num_nodes = _get_subgraph_batch_info(bg.ntypes, [induced_nodes], batch_num_nodes) new_batch_num_nodes = _get_subgraph_batch_info(bg.ntypes, [induced_nodes], batch_num_nodes)
new_batch_num_edges = _get_subgraph_batch_info(bg.canonical_etypes, [induced_edges], batch_num_edges) new_batch_num_edges = _get_subgraph_batch_info(bg.canonical_etypes, [induced_edges], batch_num_edges)
sg_e.set_batch_num_nodes(new_batch_num_nodes) sg_e.set_batch_num_nodes(new_batch_num_nodes)
sg_e.set_batch_num_edges(new_batch_num_edges) sg_e.set_batch_num_edges(new_batch_num_edges)
subg_e1, subg_e2 = dgl.unbatch(sg_e) subg_e1, subg_e2 = dgl.unbatch(sg_e)
subg1 = dgl.edge_subgraph(g1, list(range(40, 70)), preserve_nodes=True) subg1 = dgl.edge_subgraph(g1, list(range(40, 70)), relabel_nodes=False)
subg2 = dgl.edge_subgraph(g2, list(range(50, 100)), preserve_nodes=True) subg2 = dgl.edge_subgraph(g2, list(range(50, 100)), relabel_nodes=False)
assert subg_e1.num_nodes() == subg1.num_nodes() assert subg_e1.num_nodes() == subg1.num_nodes()
assert subg_e2.num_nodes() == subg2.num_nodes() assert subg_e2.num_nodes() == subg2.num_nodes()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment