"git@developer.sourcefind.cn:OpenDAS/pytorch-encoding.git" did not exist on "abcee3c9316e634dae93b1923dfeda403ade7888"
Unverified Commit baa16231 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Remove nodes/edges. (#599)

* upd

* upd

* reformat

* upd

* upd

* add test

* fix arange

* fix slight bug

* upd

* trigger

* upd docs

* upd

* upd

* upd

* change subgraph to be raw data wrapper

* upd

* fix test
parent e7389d7c
...@@ -45,6 +45,15 @@ Querying graph structure ...@@ -45,6 +45,15 @@ Querying graph structure
DGLGraph.out_degree DGLGraph.out_degree
DGLGraph.out_degrees DGLGraph.out_degrees
Removing nodes and edges
------------------------
.. autosummary::
:toctree: ../../generated/
DGLGraph.remove_nodes
DGLGraph.remove_edges
Transforming graph Transforming graph
------------------ ------------------
......
...@@ -308,7 +308,7 @@ class Graph: public GraphInterface { ...@@ -308,7 +308,7 @@ class Graph: public GraphInterface {
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
Subgraph EdgeSubgraph(IdArray eids) const override; Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
/*! /*!
* \brief Return a new graph with all the edges reversed. * \brief Return a new graph with all the edges reversed.
......
...@@ -324,7 +324,7 @@ class GraphInterface { ...@@ -324,7 +324,7 @@ class GraphInterface {
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
virtual Subgraph EdgeSubgraph(IdArray eids) const = 0; virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0;
/*! /*!
* \brief Return a new graph with all the edges reversed. * \brief Return a new graph with all the edges reversed.
......
...@@ -159,7 +159,7 @@ class CSR : public GraphInterface { ...@@ -159,7 +159,7 @@ class CSR : public GraphInterface {
Subgraph VertexSubgraph(IdArray vids) const override; Subgraph VertexSubgraph(IdArray vids) const override;
Subgraph EdgeSubgraph(IdArray eids) const override { Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override {
LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph." LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph."
<< " Please use COO graph instead."; << " Please use COO graph instead.";
return {}; return {};
...@@ -409,7 +409,7 @@ class COO : public GraphInterface { ...@@ -409,7 +409,7 @@ class COO : public GraphInterface {
return {}; return {};
} }
Subgraph EdgeSubgraph(IdArray eids) const override; Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
GraphPtr Reverse() const override { GraphPtr Reverse() const override {
return Transpose(); return Transpose();
...@@ -810,7 +810,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -810,7 +810,7 @@ class ImmutableGraph: public GraphInterface {
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
Subgraph EdgeSubgraph(IdArray eids) const override; Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
/*! /*!
* \brief Return a new graph with all the edges reversed. * \brief Return a new graph with all the edges reversed.
......
...@@ -1101,6 +1101,57 @@ class DGLGraph(DGLBaseGraph): ...@@ -1101,6 +1101,57 @@ class DGLGraph(DGLBaseGraph):
self._msg_index = self._msg_index.append_zeros(num) self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num) self._msg_frame.add_rows(num)
def remove_nodes(self, vids):
"""Remove multiple nodes.
Parameters
----------
vids: list, tensor
The id of nodes to remove.
"""
if self.is_readonly:
raise DGLError("remove_nodes is not supported by read-only graph.")
induced_nodes = utils.set_diff(utils.toindex(self.nodes()), utils.toindex(vids))
sgi = self._graph.node_subgraph(induced_nodes)
if isinstance(self._node_frame, FrameRef):
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes]))
else:
self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes)
if isinstance(self._edge_frame, FrameRef):
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges]))
else:
self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges)
self._graph = sgi.graph
def remove_edges(self, eids):
"""Remove multiple edges.
Parameters
----------
eids: list, tensor
The id of edges to remove.
"""
if self.is_readonly:
raise DGLError("remove_edges is not supported by read-only graph.")
induced_edges = utils.set_diff(
utils.toindex(range(self.number_of_edges())), utils.toindex(eids))
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=True)
if isinstance(self._node_frame, FrameRef):
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes]))
else:
self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes)
if isinstance(self._edge_frame, FrameRef):
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges]))
else:
self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges)
self._graph = sgi.graph
def clear(self): def clear(self):
"""Remove all nodes and edges, as well as their features, from the """Remove all nodes and edges, as well as their features, from the
graph. graph.
...@@ -2813,7 +2864,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2813,7 +2864,7 @@ class DGLGraph(DGLBaseGraph):
from . import subgraph from . import subgraph
induced_nodes = utils.toindex(nodes) induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi) return subgraph.DGLSubGraph(self, sgi)
def subgraphs(self, nodes): def subgraphs(self, nodes):
"""Return a list of subgraphs, each induced in the corresponding given """Return a list of subgraphs, each induced in the corresponding given
...@@ -2841,10 +2892,9 @@ class DGLGraph(DGLBaseGraph): ...@@ -2841,10 +2892,9 @@ class DGLGraph(DGLBaseGraph):
from . import subgraph from . import subgraph
induced_nodes = [utils.toindex(n) for n in nodes] induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes) sgis = self._graph.node_subgraphs(induced_nodes)
return [subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi) return [subgraph.DGLSubGraph(self, sgi) for sgi in sgis]
for sgi in sgis]
def edge_subgraph(self, edges): def edge_subgraph(self, edges, preserve_nodes=False):
"""Return the subgraph induced on given edges. """Return the subgraph induced on given edges.
Parameters Parameters
...@@ -2852,6 +2902,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -2852,6 +2902,10 @@ class DGLGraph(DGLBaseGraph):
edges : list, or iterable edges : list, or iterable
An edge ID array to construct subgraph. An edge ID array to construct subgraph.
All edges must exist in the subgraph. All edges must exist in the subgraph.
preserve_nodes : bool
Indicates whether to preserve all nodes or not.
If true, keep the nodes which have no edge connected in the subgraph;
If false, all nodes without edge connected to it would be removed.
Returns Returns
------- -------
...@@ -2880,6 +2934,15 @@ class DGLGraph(DGLBaseGraph): ...@@ -2880,6 +2934,15 @@ class DGLGraph(DGLBaseGraph):
tensor([0, 1, 4]) tensor([0, 1, 4])
>>> SG.parent_eid >>> SG.parent_eid
tensor([0, 4]) tensor([0, 4])
>>> SG = G.edge_subgraph([0, 4], preserve_nodes=True)
>>> SG.nodes()
tensor([0, 1, 2, 3, 4])
>>> SG.edges()
(tensor([0, 4]), tensor([1, 0]))
>>> SG.parent_nid
tensor([0, 1, 2, 3, 4])
>>> SG.parent_eid
tensor([0, 4])
See Also See Also
-------- --------
...@@ -2888,8 +2951,8 @@ class DGLGraph(DGLBaseGraph): ...@@ -2888,8 +2951,8 @@ class DGLGraph(DGLBaseGraph):
""" """
from . import subgraph from . import subgraph
induced_edges = utils.toindex(edges) induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges) sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes)
return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi) return subgraph.DGLSubGraph(self, sgi)
def adjacency_matrix_scipy(self, transpose=False, fmt='csr'): def adjacency_matrix_scipy(self, transpose=False, fmt='csr'):
"""Return the scipy adjacency matrix representation of this graph. """Return the scipy adjacency matrix representation of this graph.
......
...@@ -516,7 +516,8 @@ class GraphIndex(object): ...@@ -516,7 +516,8 @@ class GraphIndex(object):
v_array = v.todgltensor() v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array) rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array)
induced_edges = utils.toindex(rst(2)) induced_edges = utils.toindex(rst(2))
return SubgraphIndex(rst(0), self, v, induced_edges) gidx = GraphIndex(rst(0))
return SubgraphIndex(gidx, self, v, induced_edges)
def node_subgraphs(self, vs_arr): def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs. """Return the induced node subgraphs.
...@@ -536,13 +537,17 @@ class GraphIndex(object): ...@@ -536,13 +537,17 @@ class GraphIndex(object):
gis.append(self.node_subgraph(v)) gis.append(self.node_subgraph(v))
return gis return gis
def edge_subgraph(self, e): def edge_subgraph(self, e, preserve_nodes=False):
"""Return the induced edge subgraph. """Return the induced edge subgraph.
Parameters Parameters
---------- ----------
e : utils.Index e : utils.Index
The edges. The edges.
preserve_nodes : bool
Indicates whether to preserve all nodes or not.
If true, keep the nodes which have no edge connected in the subgraph;
If false, all nodes without edge connected to it would be removed.
Returns Returns
------- -------
...@@ -550,9 +555,10 @@ class GraphIndex(object): ...@@ -550,9 +555,10 @@ class GraphIndex(object):
The subgraph index. The subgraph index.
""" """
e_array = e.todgltensor() e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array) rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array, preserve_nodes)
induced_nodes = utils.toindex(rst(1)) induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e) gidx = GraphIndex(rst(0))
return SubgraphIndex(gidx, self, induced_nodes, e)
@utils.cached_member(cache='_cache', prefix='scipy_adj') @utils.cached_member(cache='_cache', prefix='scipy_adj')
def adjacency_matrix_scipy(self, transpose, fmt): def adjacency_matrix_scipy(self, transpose, fmt):
...@@ -870,59 +876,25 @@ class GraphIndex(object): ...@@ -870,59 +876,25 @@ class GraphIndex(object):
handle = _CAPI_DGLImmutableGraphAsNumBits(self._handle, int(bits)) handle = _CAPI_DGLImmutableGraphAsNumBits(self._handle, int(bits))
return GraphIndex(handle) return GraphIndex(handle)
class SubgraphIndex(GraphIndex): class SubgraphIndex(object):
"""Graph index for subgraph. """Internal subgraph data structure.
Parameters Parameters
---------- ----------
handle : GraphIndexHandle graph : GraphIndex
The capi handle. The graph structure of this subgraph.
paranet : GraphIndex parent : GraphIndex
The parent graph index. The parent graph index.
induced_nodes : utils.Index induced_nodes : utils.Index
The parent node ids in this subgraph. The parent node ids in this subgraph.
induced_edges : utils.Index induced_edges : utils.Index
The parent edge ids in this subgraph. The parent edge ids in this subgraph.
""" """
def __init__(self, handle, parent, induced_nodes, induced_edges): def __init__(self, graph, parent, induced_nodes, induced_edges):
super(SubgraphIndex, self).__init__(handle) self.graph = graph
self._parent = parent self.parent = parent
self._induced_nodes = induced_nodes self.induced_nodes = induced_nodes
self._induced_edges = induced_edges self.induced_edges = induced_edges
def add_nodes(self, num):
"""Add nodes. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def induced_nodes(self):
"""Return parent node ids.
Returns
-------
utils.Index
The parent node ids.
"""
return self._induced_nodes
@property
def induced_edges(self):
"""Return parent edge ids.
Returns
-------
utils.Index
The parent edge ids.
"""
return self._induced_edges
def __getstate__(self): def __getstate__(self):
raise NotImplementedError( raise NotImplementedError(
......
...@@ -4,6 +4,7 @@ from __future__ import absolute_import ...@@ -4,6 +4,7 @@ from __future__ import absolute_import
from .frame import Frame, FrameRef from .frame import Frame, FrameRef
from .graph import DGLGraph from .graph import DGLGraph
from . import utils from . import utils
from .base import DGLError
from .graph_index import map_to_subgraph_nid from .graph_index import map_to_subgraph_nid
class DGLSubGraph(DGLGraph): class DGLSubGraph(DGLGraph):
...@@ -32,35 +33,31 @@ class DGLSubGraph(DGLGraph): ...@@ -32,35 +33,31 @@ class DGLSubGraph(DGLGraph):
---------- ----------
parent : DGLGraph parent : DGLGraph
The parent graph The parent graph
parent_nid : utils.Index sgi : SubgraphIndex
The induced parent node ids in this subgraph. Internal subgraph data structure.
parent_eid : utils.Index
The induced parent edge ids in this subgraph.
graph_idx : GraphIndex
The graph index.
shared : bool, optional shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph. Whether the subgraph shares node/edge features with the parent graph.
""" """
def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False): def __init__(self, parent, sgi, shared=False):
super(DGLSubGraph, self).__init__(graph_data=graph_idx, super(DGLSubGraph, self).__init__(graph_data=sgi.graph,
readonly=graph_idx.is_readonly()) readonly=True)
if shared: if shared:
raise DGLError('Shared mode is not yet supported.') raise DGLError('Shared mode is not yet supported.')
self._parent = parent self._parent = parent
self._parent_nid = parent_nid self._parent_nid = sgi.induced_nodes
self._parent_eid = parent_eid self._parent_eid = sgi.induced_edges
# override APIs # override APIs
def add_nodes(self, num, data=None): def add_nodes(self, num, data=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only.""" """Add nodes. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, data=None): def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only.""" """Add one edge. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, data=None): def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only.""" """Add many edges. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.') raise DGLError('Readonly graph. Mutation is not allowed.')
@property @property
......
...@@ -252,6 +252,29 @@ def zero_index(size): ...@@ -252,6 +252,29 @@ def zero_index(size):
""" """
return Index(F.zeros((size,), dtype=F.int64, ctx=F.cpu())) return Index(F.zeros((size,), dtype=F.int64, ctx=F.cpu()))
def set_diff(ar1, ar2):
"""Find the set difference of two index arrays.
Return the unique values in ar1 that are not in ar2.
Parameters
----------
ar1: utils.Index
Input index array.
ar2: utils.Index
Input comparison index array.
Returns
-------
setdiff:
Array of values in ar1 that are not in ar2.
"""
ar1_np = ar1.tonumpy()
ar2_np = ar2.tonumpy()
setdiff = np.setdiff1d(ar1_np, ar2_np)
setdiff = toindex(setdiff)
return setdiff
class LazyDict(Mapping): class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage.""" """A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys): def __init__(self, fn, keys):
......
...@@ -467,37 +467,56 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const { ...@@ -467,37 +467,56 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
return rst; return rst;
} }
Subgraph Graph::EdgeSubgraph(IdArray eids) const { Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array."; CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
const auto len = eids->shape[0]; const auto len = eids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<dgl_id_t> nodes; std::vector<dgl_id_t> nodes;
const int64_t* eid_data = static_cast<int64_t*>(eids->data); const int64_t* eid_data = static_cast<int64_t*>(eids->data);
for (int64_t i = 0; i < len; ++i) {
dgl_id_t src_id = all_edges_src_[eid_data[i]];
dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
if (oldv2newv.insert(std::make_pair(src_id, oldv2newv.size())).second)
nodes.push_back(src_id);
if (oldv2newv.insert(std::make_pair(dst_id, oldv2newv.size())).second)
nodes.push_back(dst_id);
}
Subgraph rst; Subgraph rst;
rst.graph = std::make_shared<Graph>(IsMultigraph()); if (!preserve_nodes) {
rst.induced_edges = eids; std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
rst.graph->AddVertices(nodes.size());
for (int64_t i = 0; i < len; ++i) {
const dgl_id_t src_id = all_edges_src_[eid_data[i]];
const dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
if (oldv2newv.insert(std::make_pair(src_id, oldv2newv.size())).second)
nodes.push_back(src_id);
if (oldv2newv.insert(std::make_pair(dst_id, oldv2newv.size())).second)
nodes.push_back(dst_id);
}
for (int64_t i = 0; i < len; ++i) { rst.graph = std::make_shared<Graph>(IsMultigraph());
dgl_id_t src_id = all_edges_src_[eid_data[i]]; rst.induced_edges = eids;
dgl_id_t dst_id = all_edges_dst_[eid_data[i]]; rst.graph->AddVertices(nodes.size());
rst.graph->AddEdge(oldv2newv[src_id], oldv2newv[dst_id]);
} for (int64_t i = 0; i < len; ++i) {
const dgl_id_t src_id = all_edges_src_[eid_data[i]];
const dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
rst.graph->AddEdge(oldv2newv[src_id], oldv2newv[dst_id]);
}
rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
} else {
rst.graph = std::make_shared<Graph>(IsMultigraph());
rst.induced_edges = eids;
rst.graph->AddVertices(NumVertices());
for (int64_t i = 0; i < len; ++i) {
dgl_id_t src_id = all_edges_src_[eid_data[i]];
dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
rst.graph->AddEdge(src_id, dst_id);
}
rst.induced_vertices = IdArray::Empty( for (int64_t i = 0; i < NumVertices(); ++i)
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx); nodes.push_back(i);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
}
return rst; return rst;
} }
......
...@@ -444,7 +444,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph") ...@@ -444,7 +444,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const GraphInterface *gptr = static_cast<GraphInterface*>(ghandle); const GraphInterface *gptr = static_cast<GraphInterface*>(ghandle);
const IdArray eids = args[1]; const IdArray eids = args[1];
*rv = ConvertSubgraphToPackedFunc(gptr->EdgeSubgraph(eids)); bool preserve_nodes = args[2];
*rv = ConvertSubgraphToPackedFunc(gptr->EdgeSubgraph(eids, preserve_nodes));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
......
...@@ -549,8 +549,8 @@ COO::EdgeArray COO::Edges(const std::string &order) const { ...@@ -549,8 +549,8 @@ COO::EdgeArray COO::Edges(const std::string &order) const {
return EdgeArray{src_, dst_, rst_eid}; return EdgeArray{src_, dst_, rst_eid};
} }
Subgraph COO::EdgeSubgraph(IdArray eids) const { Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)); CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_->data); const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_->data); const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_->data);
const dgl_id_t* eids_data = static_cast<dgl_id_t*>(eids->data); const dgl_id_t* eids_data = static_cast<dgl_id_t*>(eids->data);
...@@ -558,32 +558,50 @@ Subgraph COO::EdgeSubgraph(IdArray eids) const { ...@@ -558,32 +558,50 @@ Subgraph COO::EdgeSubgraph(IdArray eids) const {
IdArray new_dst = NewIdArray(eids->shape[0]); IdArray new_dst = NewIdArray(eids->shape[0]);
dgl_id_t* new_src_data = static_cast<dgl_id_t*>(new_src->data); dgl_id_t* new_src_data = static_cast<dgl_id_t*>(new_src->data);
dgl_id_t* new_dst_data = static_cast<dgl_id_t*>(new_dst->data); dgl_id_t* new_dst_data = static_cast<dgl_id_t*>(new_dst->data);
dgl_id_t newid = 0; if (!preserve_nodes) {
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv; dgl_id_t newid = 0;
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
for (int64_t i = 0; i < eids->shape[0]; ++i) {
const dgl_id_t eid = eids_data[i]; for (int64_t i = 0; i < eids->shape[0]; ++i) {
const dgl_id_t src = src_data[eid]; const dgl_id_t eid = eids_data[i];
const dgl_id_t dst = dst_data[eid]; const dgl_id_t src = src_data[eid];
if (!oldv2newv.count(src)) { const dgl_id_t dst = dst_data[eid];
oldv2newv[src] = newid++; if (!oldv2newv.count(src)) {
oldv2newv[src] = newid++;
}
if (!oldv2newv.count(dst)) {
oldv2newv[dst] = newid++;
}
*(new_src_data++) = oldv2newv[src];
*(new_dst_data++) = oldv2newv[dst];
} }
if (!oldv2newv.count(dst)) {
oldv2newv[dst] = newid++; // induced nodes
IdArray induced_nodes = NewIdArray(newid);
dgl_id_t* induced_nodes_data = static_cast<dgl_id_t*>(induced_nodes->data);
for (const auto& kv : oldv2newv) {
induced_nodes_data[kv.second] = kv.first;
} }
*(new_src_data++) = oldv2newv[src];
*(new_dst_data++) = oldv2newv[dst];
}
// induced nodes COOPtr subcoo(new COO(newid, new_src, new_dst));
IdArray induced_nodes = NewIdArray(newid); return Subgraph{subcoo, induced_nodes, eids};
dgl_id_t* induced_nodes_data = static_cast<dgl_id_t*>(induced_nodes->data); } else {
for (const auto& kv : oldv2newv) { for (int64_t i = 0; i < eids->shape[0]; ++i) {
induced_nodes_data[kv.second] = kv.first; const dgl_id_t eid = eids_data[i];
} const dgl_id_t src = src_data[eid];
const dgl_id_t dst = dst_data[eid];
*(new_src_data++) = src;
*(new_dst_data++) = dst;
}
COOPtr subcoo(new COO(newid, new_src, new_dst)); IdArray induced_nodes = NewIdArray(NumVertices());
return Subgraph{subcoo, induced_nodes, eids}; dgl_id_t* induced_nodes_data = static_cast<dgl_id_t*>(induced_nodes->data);
for (int64_t i = 0; i < NumVertices(); ++i)
*(induced_nodes_data++) = i;
COOPtr subcoo(new COO(NumVertices(), new_src, new_dst));
return Subgraph{subcoo, induced_nodes, eids};
}
} }
// complexity: time O(E + V), space O(1) // complexity: time O(E + V), space O(1)
...@@ -696,9 +714,9 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const { ...@@ -696,9 +714,9 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
sg.induced_vertices, sg.induced_edges}; sg.induced_vertices, sg.induced_edges};
} }
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids) const { Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
// We prefer to generate a subgraph from out-csr. // We prefer to generate a subgraph from out-csr.
auto sg = GetCOO()->EdgeSubgraph(eids); auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph); COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
return Subgraph{GraphPtr(new ImmutableGraph(subcoo)), return Subgraph{GraphPtr(new ImmutableGraph(subcoo)),
sg.induced_vertices, sg.induced_edges}; sg.induced_vertices, sg.induced_edges};
......
...@@ -122,8 +122,8 @@ def test_node_subgraph(): ...@@ -122,8 +122,8 @@ def test_node_subgraph():
randv = np.unique(randv1) randv = np.unique(randv1)
subg = g.node_subgraph(utils.toindex(randv)) subg = g.node_subgraph(utils.toindex(randv))
subig = ig.node_subgraph(utils.toindex(randv)) subig = ig.node_subgraph(utils.toindex(randv))
check_basics(subg, subig) check_basics(subg.graph, subig.graph)
check_graph_equal(subg, subig) check_graph_equal(subg.graph, subig.graph)
assert F.sum(map_to_subgraph_nid(subg, utils.toindex(randv1[0:10])).tousertensor() assert F.sum(map_to_subgraph_nid(subg, utils.toindex(randv1[0:10])).tousertensor()
== map_to_subgraph_nid(subig, utils.toindex(randv1[0:10])).tousertensor(), 0) == 10 == map_to_subgraph_nid(subig, utils.toindex(randv1[0:10])).tousertensor(), 0) == 10
...@@ -136,8 +136,8 @@ def test_node_subgraph(): ...@@ -136,8 +136,8 @@ def test_node_subgraph():
subgs.append(g.node_subgraph(utils.toindex(randv))) subgs.append(g.node_subgraph(utils.toindex(randv)))
subigs= ig.node_subgraphs(randvs) subigs= ig.node_subgraphs(randvs)
for i in range(4): for i in range(4):
check_basics(subg, subig) check_basics(subg.graph, subig.graph)
check_graph_equal(subgs[i], subigs[i]) check_graph_equal(subgs[i].graph, subigs[i].graph)
def test_create_graph(): def test_create_graph():
elist = [(1, 2), (0, 1), (0, 2)] elist = [(1, 2), (0, 1), (0, 2)]
......
import os
import backend as F
import networkx as nx
import numpy as np
import torch as th
import dgl
def test_node_removal():
g = dgl.DGLGraph()
g.add_nodes(10)
g.add_edge(0, 0)
assert g.number_of_nodes() == 10
g.ndata['id'] = F.arange(0, 10)
# remove nodes
g.remove_nodes(range(4, 7))
assert g.number_of_nodes() == 7
assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9]))
# add nodes
g.add_nodes(3)
assert g.number_of_nodes() == 10
assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0]))
# remove nodes
g.remove_nodes(range(1, 4))
assert g.number_of_nodes() == 7
assert F.array_equal(g.ndata['id'], F.tensor([0, 7, 8, 9, 0, 0, 0]))
def test_multigraph_node_removal():
g = dgl.DGLGraph(multigraph=True)
g.add_nodes(5)
for i in range(5):
g.add_edge(i, i)
g.add_edge(i, i)
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 10
# remove nodes
g.remove_nodes([2, 3])
assert g.number_of_nodes() == 3
assert g.number_of_edges() == 6
# add nodes
g.add_nodes(1)
g.add_edge(1, 1)
g.add_edge(1, 1)
assert g.number_of_nodes() == 4
assert g.number_of_edges() == 8
# remove nodes
g.remove_nodes([0])
assert g.number_of_nodes() == 3
assert g.number_of_edges() == 6
def test_multigraph_edge_removal():
g = dgl.DGLGraph(multigraph=True)
g.add_nodes(5)
for i in range(5):
g.add_edge(i, i)
g.add_edge(i, i)
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 10
# remove edges
g.remove_edges([2, 3])
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 8
# add edges
g.add_edge(1, 1)
g.add_edge(1, 1)
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 10
# remove edges
g.remove_edges([0, 1])
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 8
def test_edge_removal():
g = dgl.DGLGraph()
g.add_nodes(5)
for i in range(5):
for j in range(5):
g.add_edge(i, j)
g.edata['id'] = F.arange(0, 25)
# remove edges
g.remove_edges(range(13, 20))
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 18
assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25))))
# add edges
g.add_edge(3, 3)
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 19
assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25)) + [0]))
# remove edges
g.remove_edges(range(2, 10))
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 11
assert F.array_equal(g.edata['id'], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0]))
def test_node_and_edge_removal():
g = dgl.DGLGraph()
g.add_nodes(10)
for i in range(10):
for j in range(10):
g.add_edge(i, j)
g.edata['id'] = F.arange(0, 100)
assert g.number_of_nodes() == 10
assert g.number_of_edges() == 100
# remove nodes
g.remove_nodes([2, 4])
assert g.number_of_nodes() == 8
assert g.number_of_edges() == 64
# remove edges
g.remove_edges(range(10, 20))
assert g.number_of_nodes() == 8
assert g.number_of_edges() == 54
# add nodes
g.add_nodes(2)
assert g.number_of_nodes() == 10
assert g.number_of_edges() == 54
# add edges
for i in range(8, 10):
for j in range(8, 10):
g.add_edge(i, j)
assert g.number_of_nodes() == 10
assert g.number_of_edges() == 58
# remove edges
g.remove_edges(range(10, 20))
assert g.number_of_nodes() == 10
assert g.number_of_edges() == 48
def test_node_frame():
g = dgl.DGLGraph()
g.add_nodes(10)
data = np.random.rand(10, 3)
new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
g.ndata['h'] = F.zerocopy_from_numpy(data)
# remove nodes
g.remove_nodes(range(3, 7))
assert F.allclose(g.ndata['h'], F.zerocopy_from_numpy(new_data))
def test_edge_frame():
g = dgl.DGLGraph()
g.add_nodes(10)
g.add_edges(list(range(10)), list(range(1, 10)) + [0])
data = np.random.rand(10, 3)
new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
g.edata['h'] = F.zerocopy_from_numpy(data)
# remove edges
g.remove_edges(range(3, 7))
assert F.allclose(g.edata['h'], F.zerocopy_from_numpy(new_data))
if __name__ == '__main__':
test_node_removal()
test_edge_removal()
test_multigraph_node_removal()
test_multigraph_edge_removal()
test_node_and_edge_removal()
test_node_frame()
test_edge_frame()
...@@ -13,7 +13,7 @@ def test_node_subgraph(): ...@@ -13,7 +13,7 @@ def test_node_subgraph():
sub2par_nodemap = [2, 0, 3] sub2par_nodemap = [2, 0, 3]
sgi = gi.node_subgraph(toindex(sub2par_nodemap)) sgi = gi.node_subgraph(toindex(sub2par_nodemap))
for s, d, e in zip(*sgi.edges()): for s, d, e in zip(*sgi.graph.edges()):
assert sgi.induced_edges[e] in gi.edge_id( assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d]) sgi.induced_nodes[s], sgi.induced_nodes[d])
...@@ -28,10 +28,28 @@ def test_edge_subgraph(): ...@@ -28,10 +28,28 @@ def test_edge_subgraph():
sub2par_edgemap = [3, 2] sub2par_edgemap = [3, 2]
sgi = gi.edge_subgraph(toindex(sub2par_edgemap)) sgi = gi.edge_subgraph(toindex(sub2par_edgemap))
for s, d, e in zip(*sgi.edges()): for s, d, e in zip(*sgi.graph.edges()):
assert sgi.induced_edges[e] in gi.edge_id( assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d]) sgi.induced_nodes[s], sgi.induced_nodes[d])
def test_edge_subgraph_preserve_nodes():
gi = create_graph_index(None, True, False)
gi.add_nodes(4)
gi.add_edge(0, 1)
gi.add_edge(0, 1)
gi.add_edge(0, 2)
gi.add_edge(2, 3)
sub2par_edgemap = [3, 2]
sgi = gi.edge_subgraph(toindex(sub2par_edgemap), preserve_nodes=True)
assert len(sgi.induced_nodes.tonumpy()) == 4
for s, d, e in zip(*sgi.graph.edges()):
assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d])
def test_immutable_edge_subgraph(): def test_immutable_edge_subgraph():
gi = create_graph_index(None, True, False) gi = create_graph_index(None, True, False)
gi.add_nodes(4) gi.add_nodes(4)
...@@ -44,7 +62,25 @@ def test_immutable_edge_subgraph(): ...@@ -44,7 +62,25 @@ def test_immutable_edge_subgraph():
sub2par_edgemap = [3, 2] sub2par_edgemap = [3, 2]
sgi = gi.edge_subgraph(toindex(sub2par_edgemap)) sgi = gi.edge_subgraph(toindex(sub2par_edgemap))
for s, d, e in zip(*sgi.edges()): for s, d, e in zip(*sgi.graph.edges()):
assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d])
def test_immutable_edge_subgraph_preserve_nodes():
gi = create_graph_index(None, True, False)
gi.add_nodes(4)
gi.add_edge(0, 1)
gi.add_edge(0, 1)
gi.add_edge(0, 2)
gi.add_edge(2, 3)
gi.readonly()
sub2par_edgemap = [3, 2]
sgi = gi.edge_subgraph(toindex(sub2par_edgemap), preserve_nodes=True)
assert len(sgi.induced_nodes.tonumpy()) == 4
for s, d, e in zip(*sgi.graph.edges()):
assert sgi.induced_edges[e] in gi.edge_id( assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d]) sgi.induced_nodes[s], sgi.induced_nodes[d])
...@@ -52,4 +88,6 @@ def test_immutable_edge_subgraph(): ...@@ -52,4 +88,6 @@ def test_immutable_edge_subgraph():
if __name__ == '__main__': if __name__ == '__main__':
test_node_subgraph() test_node_subgraph()
test_edge_subgraph() test_edge_subgraph()
test_edge_subgraph_preserve_nodes()
test_immutable_edge_subgraph() test_immutable_edge_subgraph()
test_immutable_edge_subgraph_preserve_nodes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment