Commit 7b5d4c58 authored by Minjie Wang's avatar Minjie Wang
Browse files

pass specialization test

parent b2e4bdc0
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <stdint.h> #include <stdint.h>
#include "runtime/ndarray.h" #include "runtime/ndarray.h"
#include "./vector_view.h"
namespace dgl { namespace dgl {
...@@ -89,7 +88,8 @@ class Graph { ...@@ -89,7 +88,8 @@ class Graph {
* \brief Clear the graph. Remove all vertices/edges. * \brief Clear the graph. Remove all vertices/edges.
*/ */
void Clear() { void Clear() {
adjlist_ = vector_view<EdgeList>(); adjlist_.clear();
reverse_adjlist_.clear();
read_only_ = false; read_only_ = false;
num_edges_ = 0; num_edges_ = 0;
} }
...@@ -184,8 +184,9 @@ class Graph { ...@@ -184,8 +184,9 @@ class Graph {
/*! /*!
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \note If sorted is true, the id array is not returned. * \note If sorted is true, the returned edges list is sorted by their src and
* \param sorted Whether the returned edge list is sorted by their edge ids. * dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
EdgeArray Edges(bool sorted = false) const; EdgeArray Edges(bool sorted = false) const;
...@@ -197,7 +198,7 @@ class Graph { ...@@ -197,7 +198,7 @@ class Graph {
*/ */
uint64_t InDegree(dgl_id_t vid) const { uint64_t InDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return adjlist_[vid].pred.size(); return reverse_adjlist_[vid].succ.size();
} }
/*! /*!
...@@ -277,23 +278,22 @@ class Graph { ...@@ -277,23 +278,22 @@ class Graph {
/*! \brief Internal edge list type */ /*! \brief Internal edge list type */
struct EdgeList { struct EdgeList {
/*! \brief successor vertex list */ /*! \brief successor vertex list */
vector_view<dgl_id_t> succ; std::vector<dgl_id_t> succ;
/*! \brief predecessor vertex list */ /*! \brief predecessor vertex list */
vector_view<dgl_id_t> pred; std::vector<dgl_id_t> edge_id;
/*! \brief (local) succ edge id property */
std::vector<dgl_id_t> succ_edge_id;
/*! \brief (local) pred edge id property */
std::vector<dgl_id_t> pred_edge_id;
}; };
/*! \brief Adjacency list using vector storage */ typedef std::vector<EdgeList> AdjacencyList;
// TODO(minjie): adjacent list is good for graph mutation and finding pred/succ.
// It is not good for getting all the edges of the graph. If the graph is known /*! \brief adjacency list using vector storage */
// to be static, how to design a data structure to speed this up? This idea can AdjacencyList adjlist_;
// be further extended. For example, CSC/CSR graph storage is known to be more /*! \brief reverse adjacency list using vector storage */
// compact than adjlist, but is difficult to be mutated. Shall we switch to a CSR/CSC AdjacencyList reverse_adjlist_;
// graph structure at some point? When shall such conversion happen? Which one
// will more likely to be a bottleneck? memory or computation? /*! \brief all edges' src endpoints in their edge id order */
vector_view<EdgeList> adjlist_; std::vector<dgl_id_t> all_edges_src_;
/*! \brief all edges' dst endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_dst_;
/*! \brief read only flag */ /*! \brief read only flag */
bool read_only_{false}; bool read_only_{false};
/*! \brief number of edges */ /*! \brief number of edges */
......
...@@ -121,6 +121,10 @@ class DGLGraph(object): ...@@ -121,6 +121,10 @@ class DGLGraph(object):
""" """
return self._graph.number_of_nodes() return self._graph.number_of_nodes()
def __len__(self):
"""Return the number of nodes."""
return self.number_of_nodes()
def number_of_edges(self): def number_of_edges(self):
"""Return the number of edges. """Return the number of edges.
...@@ -145,6 +149,10 @@ class DGLGraph(object): ...@@ -145,6 +149,10 @@ class DGLGraph(object):
True if the node exists True if the node exists
""" """
return self.has_node(vid) return self.has_node(vid)
def __contains__(self, vid):
"""Same as has_node."""
return self.has_node(vid)
def has_nodes(self, vids): def has_nodes(self, vids):
"""Return true if the nodes exist. """Return true if the nodes exist.
...@@ -319,7 +327,7 @@ class DGLGraph(object): ...@@ -319,7 +327,7 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
sorted : bool sorted : bool
True if the returned edges are sorted by their ids. True if the returned edges are sorted by their src and dst ids.
Returns Returns
------- -------
...@@ -543,29 +551,12 @@ class DGLGraph(object): ...@@ -543,29 +551,12 @@ class DGLGraph(object):
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
num_edges = self.number_of_edges() self.set_e_repr_by_id(h_uv, eid=ALL)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
num_edges = max(len(u), len(v))
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
assert F.shape(h_uv)[0] == num_edges
# set
if u_is_all:
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
eid = self._graph.edge_ids(u, v) eid = self._graph.edge_ids(u, v)
if utils.is_dict_like(h_uv): self.set_e_repr_by_id(h_uv, eid=eid)
self._edge_frame[eid] = h_uv
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, h_uv, eid=ALL):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
...@@ -622,18 +613,12 @@ class DGLGraph(object): ...@@ -622,18 +613,12 @@ class DGLGraph(object):
if len(self.edge_attr_schemes()) == 0: if len(self.edge_attr_schemes()) == 0:
return dict() return dict()
if u_is_all: if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: return self.get_e_repr_by_id(eid=ALL)
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
eid = self._graph.edge_ids(u, v) eid = self._graph.edge_ids(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: return self.get_e_repr_by_id(eid=eid)
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid)
def pop_e_repr(self, key=__REPR__): def pop_e_repr(self, key=__REPR__):
"""Get and remove the specified edge repr. """Get and remove the specified edge repr.
...@@ -855,7 +840,7 @@ class DGLGraph(object): ...@@ -855,7 +840,7 @@ class DGLGraph(object):
def _batch_send(self, u, v, message_func): def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v, _ = self._graph.edges(sorted=True) u, v, _ = self._graph.edges()
self._msg_graph.add_edges(u, v) self._msg_graph.add_edges(u, v)
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
...@@ -920,7 +905,7 @@ class DGLGraph(object): ...@@ -920,7 +905,7 @@ class DGLGraph(object):
def _batch_update_edge(self, u, v, edge_func): def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self._graph.edges(sorted=True) u, v = self._graph.edges()
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
......
...@@ -21,6 +21,7 @@ class GraphIndex(object): ...@@ -21,6 +21,7 @@ class GraphIndex(object):
self.from_networkx(graph_data) self.from_networkx(graph_data)
elif graph_data is not None: elif graph_data is not None:
self.from_networkx(nx.DiGraph(graph_data)) self.from_networkx(nx.DiGraph(graph_data))
self._cache = {}
def __del__(self): def __del__(self):
"""Free this graph index object.""" """Free this graph index object."""
...@@ -35,6 +36,7 @@ class GraphIndex(object): ...@@ -35,6 +36,7 @@ class GraphIndex(object):
Number of nodes to be added. Number of nodes to be added.
""" """
_CAPI_DGLGraphAddVertices(self._handle, num); _CAPI_DGLGraphAddVertices(self._handle, num);
self._cache.clear()
def add_edge(self, u, v): def add_edge(self, u, v):
"""Add one edge. """Add one edge.
...@@ -47,6 +49,7 @@ class GraphIndex(object): ...@@ -47,6 +49,7 @@ class GraphIndex(object):
The dst node. The dst node.
""" """
_CAPI_DGLGraphAddEdge(self._handle, u, v); _CAPI_DGLGraphAddEdge(self._handle, u, v);
self._cache.clear()
def add_edges(self, u, v): def add_edges(self, u, v):
"""Add many edges. """Add many edges.
...@@ -61,10 +64,12 @@ class GraphIndex(object): ...@@ -61,10 +64,12 @@ class GraphIndex(object):
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
_CAPI_DGLGraphAddEdges(self._handle, u_array, v_array) _CAPI_DGLGraphAddEdges(self._handle, u_array, v_array)
self._cache.clear()
def clear(self): def clear(self):
"""Clear the graph.""" """Clear the graph."""
_CAPI_DGLGraphClear(self._handle) _CAPI_DGLGraphClear(self._handle)
self._cache.clear()
def number_of_nodes(self): def number_of_nodes(self):
"""Return the number of nodes. """Return the number of nodes.
...@@ -283,7 +288,7 @@ class GraphIndex(object): ...@@ -283,7 +288,7 @@ class GraphIndex(object):
Parameters Parameters
---------- ----------
sorted : bool sorted : bool
True if the returned edges are sorted by their ids. True if the returned edges are sorted by their src and dst ids.
Returns Returns
------- -------
...@@ -362,6 +367,25 @@ class GraphIndex(object): ...@@ -362,6 +367,25 @@ class GraphIndex(object):
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array)) return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array))
def adjacency_matrix(self):
"""Return the adjacency matrix representation of this graph.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
if not 'adj' in self._cache:
src, dst, _ = self.edges(sorted=False)
src = F.unsqueeze(src.tousertensor(), 0)
dst = F.unsqueeze(dst.tousertensor(), 0)
idx = F.pack([dst, src])
n = self.number_of_nodes()
dat = F.ones((self.number_of_edges(),))
mat = F.sparse_tensor(idx, dat, [n, n])
self._cache['adj'] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return self._cache['adj']
def to_networkx(self): def to_networkx(self):
"""Convert to networkx graph. """Convert to networkx graph.
......
...@@ -134,7 +134,7 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -134,7 +134,7 @@ class UpdateAllExecutor(BasicExecutor):
@property @property
def graph_idx(self): def graph_idx(self):
if self._graph_idx is None: if self._graph_idx is None:
self._graph_idx = self.g.cached_graph.adjmat() self._graph_idx = self.g._graph.adjacency_matrix()
return self._graph_idx return self._graph_idx
@property @property
...@@ -221,8 +221,8 @@ class SendRecvExecutor(BasicExecutor): ...@@ -221,8 +221,8 @@ class SendRecvExecutor(BasicExecutor):
def _build_adjmat(self): def _build_adjmat(self):
# handle graph index # handle graph index
new2old, old2new = utils.build_relabel_map(self.v) new2old, old2new = utils.build_relabel_map(self.v)
u = self.u.totensor() u = self.u.tousertensor()
v = self.v.totensor() v = self.v.tousertensor()
# TODO(minjie): should not directly use [] # TODO(minjie): should not directly use []
new_v = old2new[v] new_v = old2new[v]
n = self.g.number_of_nodes() n = self.g.number_of_nodes()
......
...@@ -13,6 +13,7 @@ inline bool IsValidIdArray(const IdArray& arr) { ...@@ -13,6 +13,7 @@ inline bool IsValidIdArray(const IdArray& arr) {
void Graph::AddVertices(uint64_t num_vertices) { void Graph::AddVertices(uint64_t num_vertices) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed."; CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
adjlist_.resize(adjlist_.size() + num_vertices); adjlist_.resize(adjlist_.size() + num_vertices);
reverse_adjlist_.resize(reverse_adjlist_.size() + num_vertices);
} }
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) { void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
...@@ -21,9 +22,11 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) { ...@@ -21,9 +22,11 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
<< "In valid vertices: " << src << " " << dst; << "In valid vertices: " << src << " " << dst;
dgl_id_t eid = num_edges_++; dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst); adjlist_[src].succ.push_back(dst);
adjlist_[src].succ_edge_id.push_back(eid); adjlist_[src].edge_id.push_back(eid);
adjlist_[dst].pred.push_back(src); reverse_adjlist_[dst].succ.push_back(src);
adjlist_[dst].pred_edge_id.push_back(eid); reverse_adjlist_[dst].edge_id.push_back(eid);
all_edges_src_.push_back(src);
all_edges_dst_.push_back(dst);
} }
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) { void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
...@@ -108,7 +111,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const { ...@@ -108,7 +111,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const { IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius; CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& pred = adjlist_[vid].pred; const auto& pred = reverse_adjlist_[vid].succ;
const int64_t len = pred.size(); const int64_t len = pred.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
...@@ -138,7 +141,7 @@ dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const { ...@@ -138,7 +141,7 @@ dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
const auto& succ = adjlist_[src].succ; const auto& succ = adjlist_[src].succ;
for (size_t i = 0; i < succ.size(); ++i) { for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) { if (succ[i] == dst) {
return adjlist_[src].succ_edge_id[i]; return adjlist_[src].edge_id[i];
} }
} }
LOG(FATAL) << "invalid edge: " << src << " -> " << dst; LOG(FATAL) << "invalid edge: " << src << " -> " << dst;
...@@ -179,7 +182,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -179,7 +182,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
// O(E) // O(E)
Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const { Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].pred.size(); const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -187,8 +190,8 @@ Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const { ...@@ -187,8 +190,8 @@ Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
int64_t* dst_data = static_cast<int64_t*>(dst->data); int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
src_data[i] = adjlist_[vid].pred[i]; src_data[i] = reverse_adjlist_[vid].succ[i];
eid_data[i] = adjlist_[vid].pred_edge_id[i]; eid_data[i] = reverse_adjlist_[vid].edge_id[i];
} }
std::fill(dst_data, dst_data + len, vid); std::fill(dst_data, dst_data + len, vid);
return EdgeArray{src, dst, eid}; return EdgeArray{src, dst, eid};
...@@ -202,7 +205,7 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const { ...@@ -202,7 +205,7 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
int64_t rstlen = 0; int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i]; CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += adjlist_[vid_data[i]].pred.size(); rstlen += reverse_adjlist_[vid_data[i]].succ.size();
} }
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx); IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx); IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
...@@ -211,8 +214,8 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const { ...@@ -211,8 +214,8 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
int64_t* dst_ptr = static_cast<int64_t*>(dst->data); int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data); int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
const auto& pred = adjlist_[vid_data[i]].pred; const auto& pred = reverse_adjlist_[vid_data[i]].succ;
const auto& eids = adjlist_[vid_data[i]].pred_edge_id; const auto& eids = reverse_adjlist_[vid_data[i]].edge_id;
for (size_t j = 0; j < pred.size(); ++j) { for (size_t j = 0; j < pred.size(); ++j) {
*(src_ptr++) = pred[j]; *(src_ptr++) = pred[j];
*(dst_ptr++) = vid_data[i]; *(dst_ptr++) = vid_data[i];
...@@ -234,7 +237,7 @@ Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const { ...@@ -234,7 +237,7 @@ Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
int64_t* eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
dst_data[i] = adjlist_[vid].succ[i]; dst_data[i] = adjlist_[vid].succ[i];
eid_data[i] = adjlist_[vid].succ_edge_id[i]; eid_data[i] = adjlist_[vid].edge_id[i];
} }
std::fill(src_data, src_data + len, vid); std::fill(src_data, src_data + len, vid);
return EdgeArray{src, dst, eid}; return EdgeArray{src, dst, eid};
...@@ -258,7 +261,7 @@ Graph::EdgeArray Graph::OutEdges(IdArray vids) const { ...@@ -258,7 +261,7 @@ Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
int64_t* eid_ptr = static_cast<int64_t*>(eid->data); int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
const auto& succ = adjlist_[vid_data[i]].succ; const auto& succ = adjlist_[vid_data[i]].succ;
const auto& eids = adjlist_[vid_data[i]].succ_edge_id; const auto& eids = adjlist_[vid_data[i]].edge_id;
for (size_t j = 0; j < succ.size(); ++j) { for (size_t j = 0; j < succ.size(); ++j) {
*(src_ptr++) = vid_data[i]; *(src_ptr++) = vid_data[i];
*(dst_ptr++) = succ[j]; *(dst_ptr++) = succ[j];
...@@ -279,22 +282,21 @@ Graph::EdgeArray Graph::Edges(bool sorted) const { ...@@ -279,22 +282,21 @@ Graph::EdgeArray Graph::Edges(bool sorted) const {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple; typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples; std::vector<Tuple> tuples;
tuples.reserve(len); tuples.reserve(len);
for (dgl_id_t u = 0; u < NumVertices(); ++u) { for (uint64_t eid = 0; eid < num_edges_; ++eid) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) { tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);
tuples.emplace_back(u, adjlist_[u].succ[i], adjlist_[u].succ_edge_id[i]);
}
} }
// sort according to edge ids // sort according to src and dst ids
std::sort(tuples.begin(), tuples.end(), std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) { [] (const Tuple& t1, const Tuple& t2) {
return std::get<2>(t1) < std::get<2>(t2); return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
}); });
// make return arrays // make return arrays
int64_t* src_ptr = static_cast<int64_t*>(src->data); int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data); int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data); int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) { for (size_t i = 0; i < tuples.size(); ++i) {
src_ptr[i] = std::get<0>(tuples[i]); src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]); dst_ptr[i] = std::get<1>(tuples[i]);
eid_ptr[i] = std::get<2>(tuples[i]); eid_ptr[i] = std::get<2>(tuples[i]);
...@@ -303,12 +305,10 @@ Graph::EdgeArray Graph::Edges(bool sorted) const { ...@@ -303,12 +305,10 @@ Graph::EdgeArray Graph::Edges(bool sorted) const {
int64_t* src_ptr = static_cast<int64_t*>(src->data); int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data); int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data); int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (dgl_id_t u = 0; u < NumVertices(); ++u) { std::copy(all_edges_src_.begin(), all_edges_src_.end(), src_ptr);
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) { std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), dst_ptr);
*(src_ptr++) = u; for (uint64_t eid = 0; eid < num_edges_; ++eid) {
*(dst_ptr++) = adjlist_[u].succ[i]; eid_ptr[eid] = eid;
*(eid_ptr++) = adjlist_[u].succ_edge_id[i];
}
} }
} }
...@@ -325,7 +325,7 @@ DegreeArray Graph::InDegrees(IdArray vids) const { ...@@ -325,7 +325,7 @@ DegreeArray Graph::InDegrees(IdArray vids) const {
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i]; const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid; CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = adjlist_[vid].pred.size(); rst_data[i] = reverse_adjlist_[vid].succ.size();
} }
return rst; return rst;
} }
......
import torch as th
import numpy as np
import networkx as nx
from dgl import DGLGraph
from dgl.cached_graph import *
from dgl.utils import Index
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def test_basics():
g = DGLGraph()
g.add_edge(0, 1)
g.add_edge(1, 2)
g.add_edge(1, 3)
g.add_edge(2, 4)
g.add_edge(2, 5)
g.add_edge(0, 2)
cg = create_cached_graph(g)
u = Index(th.tensor([0, 0, 1, 1, 2, 2]))
v = Index(th.tensor([1, 2, 2, 3, 4, 5]))
check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4]))
query = Index(th.tensor([0, 1, 2, 5]))
s, d, orphan = cg.in_edges(query)
check_eq(s.totensor(), th.tensor([0, 0, 1, 2]))
check_eq(d.totensor(), th.tensor([1, 2, 2, 5]))
assert orphan.tolist() == [0]
s, d, orphan = cg.out_edges(query)
check_eq(s.totensor(), th.tensor([0, 0, 1, 1, 2, 2]))
check_eq(d.totensor(), th.tensor([1, 2, 2, 3, 4, 5]))
assert orphan.tolist() == [5]
if __name__ == '__main__':
test_basics()
...@@ -7,8 +7,7 @@ D = 5 ...@@ -7,8 +7,7 @@ D = 5
def generate_graph(): def generate_graph():
g = dgl.DGLGraph() g = dgl.DGLGraph()
for i in range(10): g.add_nodes(10)
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment