Unverified Commit 724aa0ca authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] FindEdge/FindEdges for Immutable Graph (#404)

* fix rgcn tutorial

* small fix

* upd

* findedge/s

* upd

* upd

* upd

* upd

* add test

* remove redundancy

* upd

* upd

* upd

* upd

* add edge_subgraph

* explicit cast

* add test immutable subg

* reformat

* reformat

* fix bug

* upd
parent c07ae34a
...@@ -32,6 +32,30 @@ class ImmutableGraph: public GraphInterface { ...@@ -32,6 +32,30 @@ class ImmutableGraph: public GraphInterface {
dgl_id_t edge_id; dgl_id_t edge_id;
}; };
// Edge list indexed by edge id;
struct EdgeList {
typedef std::shared_ptr<EdgeList> Ptr;
std::vector<dgl_id_t> src_points;
std::vector<dgl_id_t> dst_points;
EdgeList(int64_t len, dgl_id_t val) {
src_points.resize(len, val);
dst_points.resize(len, val);
}
void register_edge(dgl_id_t eid, dgl_id_t src, dgl_id_t dst) {
CHECK_LT(eid, src_points.size()) << "Invalid edge id " << eid;
src_points[eid] = src;
dst_points[eid] = dst;
}
static EdgeList::Ptr FromCSR(
const std::vector<int64_t>& indptr,
const std::vector<dgl_id_t>& indices,
const std::vector<dgl_id_t>& edge_ids,
bool in_csr);
};
struct CSR { struct CSR {
typedef std::shared_ptr<CSR> Ptr; typedef std::shared_ptr<CSR> Ptr;
std::vector<int64_t> indptr; std::vector<int64_t> indptr;
...@@ -79,6 +103,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -79,6 +103,7 @@ class ImmutableGraph: public GraphInterface {
void ReadAllEdges(std::vector<Edge> *edges) const; void ReadAllEdges(std::vector<Edge> *edges) const;
CSR::Ptr Transpose() const; CSR::Ptr Transpose() const;
std::pair<CSR::Ptr, IdArray> VertexSubgraph(IdArray vids) const; std::pair<CSR::Ptr, IdArray> VertexSubgraph(IdArray vids) const;
std::pair<CSR::Ptr, IdArray> EdgeSubgraph(IdArray eids, EdgeList::Ptr edge_list) const;
/* /*
* Construct a CSR from a list of edges. * Construct a CSR from a list of edges.
* *
...@@ -261,20 +286,14 @@ class ImmutableGraph: public GraphInterface { ...@@ -261,20 +286,14 @@ class ImmutableGraph: public GraphInterface {
* \param eid The edge ID * \param eid The edge ID
* \return a pair whose first element is the source and the second the destination. * \return a pair whose first element is the source and the second the destination.
*/ */
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const { std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const;
LOG(FATAL) << "FindEdge isn't supported in ImmutableGraph";
return std::pair<dgl_id_t, dgl_id_t>();
}
/*! /*!
* \brief Find the edge IDs and return their source and target node IDs. * \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array. * \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved. * \return EdgeArray containing all edges with id in eid. The order is preserved.
*/ */
EdgeArray FindEdges(IdArray eids) const { EdgeArray FindEdges(IdArray eids) const;
LOG(FATAL) << "FindEdges isn't supported in ImmutableGraph";
return EdgeArray();
}
/*! /*!
* \brief Get the in edges of the vertex. * \brief Get the in edges of the vertex.
...@@ -496,6 +515,25 @@ class ImmutableGraph: public GraphInterface { ...@@ -496,6 +515,25 @@ class ImmutableGraph: public GraphInterface {
} }
} }
/*
* The edge list is required for FindEdge/FindEdges/EdgeSubgraph, if no such function is called, we would not create edge list.
* if such function is called the first time, we create a edge list from one of the graph's csr representations,
* if we have called such function before, we get the one cached in the structure.
*/
EdgeList::Ptr GetEdgeList() const {
if (edge_list_)
return edge_list_;
if (in_csr_) {
const_cast<ImmutableGraph *>(this)->edge_list_ =\
EdgeList::FromCSR(in_csr_->indptr, in_csr_->indices, in_csr_->edge_ids, true);
} else {
CHECK(out_csr_ != nullptr) << "one of the CSRs must exist";
const_cast<ImmutableGraph *>(this)->edge_list_ =\
EdgeList::FromCSR(out_csr_->indptr, out_csr_->indices, out_csr_->edge_ids, false);
}
return edge_list_;
}
protected: protected:
DGLIdIters GetInEdgeIdRef(dgl_id_t src, dgl_id_t dst) const; DGLIdIters GetInEdgeIdRef(dgl_id_t src, dgl_id_t dst) const;
DGLIdIters GetOutEdgeIdRef(dgl_id_t src, dgl_id_t dst) const; DGLIdIters GetOutEdgeIdRef(dgl_id_t src, dgl_id_t dst) const;
...@@ -525,6 +563,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -525,6 +563,8 @@ class ImmutableGraph: public GraphInterface {
CSR::Ptr in_csr_; CSR::Ptr in_csr_;
// Store the out-edges. // Store the out-edges.
CSR::Ptr out_csr_; CSR::Ptr out_csr_;
// Store the edge list indexed by edge id
EdgeList::Ptr edge_list_;
/*! /*!
* \brief Whether if this is a multigraph. * \brief Whether if this is a multigraph.
* *
......
...@@ -3070,10 +3070,8 @@ class DGLGraph(DGLBaseGraph): ...@@ -3070,10 +3070,8 @@ class DGLGraph(DGLBaseGraph):
>>> G.number_of_nodes() >>> G.number_of_nodes()
8 8
""" """
if readonly_state == self._graph.is_readonly(): if readonly_state != self.is_readonly:
return self
self._graph.readonly(readonly_state) self._graph.readonly(readonly_state)
return self
def __repr__(self): def __repr__(self):
ret = ('DGLGraph(num_nodes={node}, num_edges={edge},\n' ret = ('DGLGraph(num_nodes={node}, num_edges={edge},\n'
......
...@@ -252,6 +252,7 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -252,6 +252,7 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
} }
Graph::EdgeArray Graph::FindEdges(IdArray eids) const { Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
int64_t len = eids->shape[0]; int64_t len = eids->shape[0];
IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx); IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
...@@ -472,7 +473,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const { ...@@ -472,7 +473,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
} }
Subgraph Graph::EdgeSubgraph(IdArray eids) const { Subgraph Graph::EdgeSubgraph(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid vertex id array."; CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
const auto len = eids->shape[0]; const auto len = eids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv; std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
......
...@@ -158,6 +158,25 @@ class HashTableChecker { ...@@ -158,6 +158,25 @@ class HashTableChecker {
} }
}; };
ImmutableGraph::EdgeList::Ptr ImmutableGraph::EdgeList::FromCSR(
const std::vector<int64_t>& indptr,
const std::vector<dgl_id_t>& indices,
const std::vector<dgl_id_t>& edge_ids,
bool in_csr) {
const auto n = indptr.size() - 1;
const auto len = edge_ids.size();
auto t = std::make_shared<EdgeList>(len, n);
for (size_t i = 0; i < indptr.size() - 1; i++) {
for (int64_t j = indptr[i]; j < indptr[i + 1]; j++) {
dgl_id_t row = i, col = indices[j];
if (in_csr)
std::swap(row, col);
t->register_edge(edge_ids[j], row, col);
}
}
return t;
}
std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::VertexSubgraph( std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::VertexSubgraph(
IdArray vids) const { IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
...@@ -165,7 +184,7 @@ std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::VertexSubgraph ...@@ -165,7 +184,7 @@ std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::VertexSubgraph
const int64_t len = vids->shape[0]; const int64_t len = vids->shape[0];
HashTableChecker def_check(vid_data, len); HashTableChecker def_check(vid_data, len);
// check if varr is sorted. // check if vid_data is sorted.
CHECK(std::is_sorted(vid_data, vid_data + len)) << "The input vertex list has to be sorted"; CHECK(std::is_sorted(vid_data, vid_data + len)) << "The input vertex list has to be sorted";
// Collect the non-zero entries in from the original graph. // Collect the non-zero entries in from the original graph.
...@@ -197,6 +216,42 @@ std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::VertexSubgraph ...@@ -197,6 +216,42 @@ std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::VertexSubgraph
return std::pair<ImmutableGraph::CSR::Ptr, IdArray>(sub_csr, rst_eids); return std::pair<ImmutableGraph::CSR::Ptr, IdArray>(sub_csr, rst_eids);
} }
std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::EdgeSubgraph(
IdArray eids, EdgeList::Ptr edge_list) const {
// Return sub_csr and vids array.
CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(eids->data);
const int64_t len = eids->shape[0];
std::vector<dgl_id_t> nodes;
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<Edge> edges;
for (int64_t i = 0; i < len; i++) {
dgl_id_t src_id = edge_list->src_points[eid_data[i]];
dgl_id_t dst_id = edge_list->dst_points[eid_data[i]];
// pair<iterator, bool>, the second indicates whether the insertion is successful or not.
auto src_pair = oldv2newv.insert(std::make_pair(src_id, oldv2newv.size()));
auto dst_pair = oldv2newv.insert(std::make_pair(dst_id, oldv2newv.size()));
if (src_pair.second)
nodes.push_back(src_id);
if (dst_pair.second)
nodes.push_back(dst_id);
edges.push_back(Edge{src_pair.first->second, dst_pair.first->second, static_cast<dgl_id_t>(i)});
}
const size_t n = oldv2newv.size();
auto sub_csr = CSR::FromEdges(&edges, 0, n);
IdArray rst_vids = IdArray::Empty({static_cast<int64_t>(nodes.size())},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* vid_data = static_cast<dgl_id_t*>(rst_vids->data);
std::copy(nodes.begin(), nodes.end(), vid_data);
return std::make_pair(sub_csr, rst_vids);
}
ImmutableGraph::CSR::Ptr ImmutableGraph::CSR::FromEdges(std::vector<Edge> *edges, ImmutableGraph::CSR::Ptr ImmutableGraph::CSR::FromEdges(std::vector<Edge> *edges,
int sort_on, uint64_t num_nodes) { int sort_on, uint64_t num_nodes) {
CHECK(sort_on == 0 || sort_on == 1) << "we must sort on the first or the second vector"; CHECK(sort_on == 0 || sort_on == 1) << "we must sort on the first or the second vector";
...@@ -449,6 +504,34 @@ ImmutableGraph::EdgeArray ImmutableGraph::EdgeIds(IdArray src_ids, IdArray dst_i ...@@ -449,6 +504,34 @@ ImmutableGraph::EdgeArray ImmutableGraph::EdgeIds(IdArray src_ids, IdArray dst_i
return ImmutableGraph::EdgeArray{rst_src, rst_dst, rst_eid}; return ImmutableGraph::EdgeArray{rst_src, rst_dst, rst_eid};
} }
std::pair<dgl_id_t, dgl_id_t> ImmutableGraph::FindEdge(dgl_id_t eid) const {
dgl_id_t row = 0, col = 0;
auto edge_list = GetEdgeList();
CHECK(eid < NumEdges()) << "Invalid edge id " << eid;
row = edge_list->src_points[eid];
col = edge_list->dst_points[eid];
CHECK(row < NumVertices() && col < NumVertices()) << "Invalid edge id " << eid;
return std::pair<dgl_id_t, dgl_id_t>(row, col);
}
ImmutableGraph::EdgeArray ImmutableGraph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
dgl_id_t* eid_data = static_cast<dgl_id_t*>(eids->data);
int64_t len = eids->shape[0];
IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
IdArray rst_dst = IdArray::Empty({len}, eids->dtype, eids->ctx);
dgl_id_t* rst_src_data = static_cast<dgl_id_t*>(rst_src->data);
dgl_id_t* rst_dst_data = static_cast<dgl_id_t*>(rst_dst->data);
for (int64_t i = 0; i < len; i++) {
auto edge = ImmutableGraph::FindEdge(eid_data[i]);
rst_src_data[i] = edge.first;
rst_dst_data[i] = edge.second;
}
return ImmutableGraph::EdgeArray{rst_src, rst_dst, eids};
}
ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const { ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const {
int64_t rstlen = NumEdges(); int64_t rstlen = NumEdges();
IdArray rst_src = IdArray::Empty({rstlen}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst_src = IdArray::Empty({rstlen}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -506,8 +589,19 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const { ...@@ -506,8 +589,19 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
} }
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids) const { Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids) const {
LOG(FATAL) << "EdgeSubgraph isn't implemented in immutable graph"; Subgraph subg;
return Subgraph(); std::pair<CSR::Ptr, IdArray> ret;
auto edge_list = GetEdgeList();
if (out_csr_) {
ret = out_csr_->EdgeSubgraph(eids, edge_list);
subg.graph = GraphPtr(new ImmutableGraph(nullptr, ret.first, IsMultigraph()));
} else {
ret = in_csr_->EdgeSubgraph(eids, edge_list);
subg.graph = GraphPtr(new ImmutableGraph(ret.first, nullptr, IsMultigraph()));
}
subg.induced_edges = eids;
subg.induced_vertices = ret.second;
return subg;
} }
ImmutableGraph::CSRArray ImmutableGraph::GetInCSRArray() const { ImmutableGraph::CSRArray ImmutableGraph::GetInCSRArray() const {
......
...@@ -184,6 +184,35 @@ def test_readonly(): ...@@ -184,6 +184,35 @@ def test_readonly():
assert g.number_of_edges() == 14 assert g.number_of_edges() == 14
assert F.shape(g.edata['x']) == (14, 4) assert F.shape(g.edata['x']) == (14, 4)
def test_find_edges():
g = dgl.DGLGraph()
g.add_nodes(10)
g.add_edges(range(9), range(1, 10))
e = g.find_edges([1, 3, 2, 4])
assert e[0][0] == 1 and e[0][1] == 3 and e[0][2] == 2 and e[0][3] == 4
assert e[1][0] == 2 and e[1][1] == 4 and e[1][2] == 3 and e[1][3] == 5
try:
g.find_edges([10])
fail = False
except DGLError:
fail = True
finally:
assert fail
g.readonly()
e = g.find_edges([1, 3, 2, 4])
assert e[0][0] == 1 and e[0][1] == 3 and e[0][2] == 2 and e[0][3] == 4
assert e[1][0] == 2 and e[1][1] == 4 and e[1][2] == 3 and e[1][3] == 5
try:
g.find_edges([10])
fail = False
except DGLError:
fail = True
finally:
assert fail
if __name__ == '__main__': if __name__ == '__main__':
test_graph_creation() test_graph_creation()
test_create_from_elist() test_create_from_elist()
...@@ -191,3 +220,4 @@ if __name__ == '__main__': ...@@ -191,3 +220,4 @@ if __name__ == '__main__':
test_incmat() test_incmat()
test_incmat_cache() test_incmat_cache()
test_readonly() test_readonly()
test_find_edges()
...@@ -32,7 +32,24 @@ def test_edge_subgraph(): ...@@ -32,7 +32,24 @@ def test_edge_subgraph():
assert sgi.induced_edges[e] in gi.edge_id( assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d]) sgi.induced_nodes[s], sgi.induced_nodes[d])
def test_immutable_edge_subgraph():
gi = create_graph_index()
gi.add_nodes(4)
gi.add_edge(0, 1)
gi.add_edge(0, 1)
gi.add_edge(0, 2)
gi.add_edge(2, 3)
gi.readonly() # Make the graph readonly
sub2par_edgemap = [3, 2]
sgi = gi.edge_subgraph(toindex(sub2par_edgemap))
for s, d, e in zip(*sgi.edges()):
assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d])
if __name__ == '__main__': if __name__ == '__main__':
test_node_subgraph() test_node_subgraph()
test_edge_subgraph() test_edge_subgraph()
test_immutable_edge_subgraph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment