Commit 14d88497 authored by Minjie Wang's avatar Minjie Wang
Browse files

impl

parent 2c626b90
......@@ -13,23 +13,50 @@ typedef tvm::runtime::NDArray IdArray;
typedef tvm::runtime::NDArray DegreeArray;
typedef tvm::runtime::NDArray BoolArray;
class DGLGraph;
class DGLSubGraph;
class Graph;
/*!
* \brief Base dgl graph class.
*
* DGLGraph is a directed graph. Vertices are integers enumerated from zero. Edges
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not
* supported.
*
* Removal of vertices/edges is not allowed. Instead, the graph can only be "cleared"
* by removing all the vertices and edges.
*
* When calling functions supporing multiple edges (e.g. AddEdges, HasEdges),
* the input edges are represented by two id arrays for source and destination
* vertex ids. In the general case, the two arrays should have the same length.
* If the length of src id array is one, it represents one-many connections.
* If the length of dst id array is one, it represents many-one connections.
*/
class DGLGraph {
class Graph {
public:
/*! \brief default constructor */
DGLGraph() {}
/*! \brief default constructor */
Graph() {}
/*! \brief default copy constructor */
Graph(const Graph& other) = default;
#ifndef _MSC_VER
/*! \brief default move constructor */
Graph(Graph&& other) = default;
#else
Graph(Graph&& other) {
adjlist_ = other.adjlist_;
read_only_ = other.read_only_;
num_edges_ = other.num_edges_;
other.clear();
}
#endif // _MSC_VER
/*! \brief default assign constructor */
Graph& operator=(const Graph& other) = default;
/*! \brief default destructor */
~Graph() = default;
/*!
* \brief Add vertices to the graph.
* \note Since vertices are integers enumerated from zero, only the number of
......@@ -37,46 +64,68 @@ class DGLGraph {
* \param num_vertices The number of vertices to be added.
*/
void AddVertices(uint64_t num_vertices);
/*!
* \brief Add one edge to the graph.
* \param src The source vertex.
* \param dst The destination vertex.
*/
void AddEdge(dgl_id_t src, dgl_id_t dst);
/*!
* \brief Add edges to the graph.
* \param src_ids The source vertex id array.
* \param dst_ids The destination vertex id array.
*/
void AddEdges(IdArray src_ids, IdArray dst_ids);
/*!
* \brief Clear the graph. Remove all vertices/edges.
*/
void Clear();
void Clear() {
adjlist_ = vector_view<EdgeList>();
read_only_ = false;
num_edges_ = 0;
}
/*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const;
uint64_t NumVertices() const {
return adjlist_.size();
}
/*! \return the number of edges in the graph.*/
uint64_t NumEdges() const;
uint64_t NumEdges() const {
return num_edges_;
}
/*! \return true if the given vertex is in the graph.*/
bool HasVertex(dgl_id_t vid) const;
bool HasVertex(dgl_id_t vid) const {
return vid < NumVertices();
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
BoolArray HasVertices(IdArray vids) const;
/*! \return true if the given edge is in the graph.*/
bool HasEdge(dgl_id_t src, dgl_id_t dst) const;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/
BoolArray HasEdges(IdArray src_ids, IdArray dst_ids) const;
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \return the predecessor id array.
*/
IdArray Predecessors(dgl_id_t vid) const;
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \return the successor id array.
*/
IdArray Successors(dgl_id_t vid) const;
/*!
* \brief Get the id of the given edge.
* \note Edges are associated with an integer id start from zero.
......@@ -86,6 +135,7 @@ class DGLGraph {
* \return the edge id.
*/
dgl_id_t EdgeId(dgl_id_t src, dgl_id_t dst) const;
/*!
* \brief Get the id of the given edges.
* \note Edges are associated with an integer id start from zero.
......@@ -93,59 +143,77 @@ class DGLGraph {
* \return the edge id array.
*/
IdArray EdgeIds(IdArray src, IdArray dst) const;
/*!
* \brief Get the in edges of the vertex.
* \note The returned dst id array is filled with vid.
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> InEdges(dgl_id_t vid) const;
/*!
* \brief Get the in edges of the vertices.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> InEdges(IdArray vids) const;
/*!
* \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> OutEdges(dgl_id_t vid) const;
/*!
* \brief Get the out edges of the vertices.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> OutEdges(IdArray vids) const;
/*!
* \brief Get all the edges in the graph.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> Edges() const;
/*!
* \brief Get the in degree of the given vertex.
* \param vid The vertex id.
* \return the in degree
*/
uint64_t InDegree(dgl_id_t vid) const;
uint64_t InDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return adjlist_[vid].pred.size();
}
/*!
* \brief Get the in degrees of the given vertices.
* \param vid The vertex id array.
* \return the in degree array
*/
DegreeArray InDegrees(IdArray vids) const;
/*!
* \brief Get the out degree of the given vertex.
* \param vid The vertex id.
* \return the out degree
*/
uint64_t OutDegree(dgl_id_t vid) const;
uint64_t OutDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return adjlist_[vid].succ.size();
}
/*!
* \brief Get the out degrees of the given vertices.
* \param vid The vertex id array.
* \return the out degree array
*/
DegreeArray OutDegrees(IdArray vids) const;
/*!
* \brief Construct the induced subgraph of the given vertices.
*
......@@ -162,7 +230,8 @@ class DGLGraph {
* \param vids The vertices in the subgraph.
* \return the induced subgraph
*/
DGLGraph Subgraph(IdArray vids) const;
Graph Subgraph(IdArray vids) const;
/*!
* \brief Construct the induced edge subgraph of the given edges.
*
......@@ -179,7 +248,8 @@ class DGLGraph {
* \param vids The edges in the subgraph.
* \return the induced edge subgraph
*/
DGLGraph EdgeSubgraph(IdArray src, IdArray dst) const;
Graph EdgeSubgraph(IdArray src, IdArray dst) const;
/*!
* \brief Return a new graph with all the edges reversed.
*
......@@ -187,7 +257,7 @@ class DGLGraph {
*
* \return the reversed graph
*/
DGLGraph Reverse() const;
Graph Reverse() const;
private:
/*! \brief Internal edge list type */
......@@ -196,7 +266,7 @@ class DGLGraph {
vector_view<dgl_id_t> succ;
/*! \brief predecessor vertex list */
vector_view<dgl_id_t> pred;
/*! \brief (local) edge id property */
/*! \brief (local) succ edge id property */
std::vector<dgl_id_t> edge_id;
};
/*! \brief Adjacency list using vector storage */
......@@ -210,6 +280,8 @@ class DGLGraph {
vector_view<EdgeList> adjlist_;
/*! \brief read only flag */
bool read_only_{false};
/*! \brief number of edges */
uint64_t num_edges_ = 0;
};
} // namespace dgl
......
......@@ -15,10 +15,42 @@ namespace dgl {
*/
template<typename ValueType>
class vector_view {
struct vector_view_iterator;
public:
typedef vector_view_iterator iterator;
/*! \brief iterator class */
class iterator : public std::iterator<std::forward_iterator_tag, ValueType> {
public:
/*! \brief iterator constructor */
iterator(const vector_view<ValueType>* vec, size_t pos): vec_(vec), pos_(pos) {}
/*! \brief move to next */
iterator& operator++() {
++pos_;
return *this;
}
/*! \brief move to next */
iterator operator++(int) {
iterator retval = *this;
++(*this);
return retval;
}
/*! \brief equal operator */
bool operator==(iterator other) const {
return vec_ == other.vec_ and pos_ == other.pos_;
}
/*! \brief not equal operator */
bool operator!=(iterator other) const {
return !(*this == other);
}
/*! \brief dereference operator */
const ValueType& operator*() const {
return (*vec_)[pos_];
}
private:
/*! \brief vector_view pointer */
const vector_view<ValueType>* vec_;
/*! \brief current position */
size_t pos_;
};
/*! \brief Default constructor. Create an empty vector. */
vector_view()
: data_(std::make_shared<std::vector<ValueType> >()) {}
......@@ -28,7 +60,7 @@ class vector_view {
: data_(vec.data_), index_(index), is_view_(true) {}
/*! \brief constructor from a vector pointer */
vector_view(const std::shared_ptr<std::vector<ValueType> >& other)
explicit vector_view(const std::shared_ptr<std::vector<ValueType> >& other)
: data_(other) {}
/*! \brief default copy constructor */
......@@ -53,7 +85,7 @@ class vector_view {
~vector_view() = default;
/*! \brief default assign constructor */
vector_view& operator=(const vector_view<ValueType>& other) = default;
vector_view<ValueType>& operator=(const vector_view<ValueType>& other) = default;
/*! \return the size of the vector */
size_t size() const {
......@@ -87,11 +119,15 @@ class vector_view {
}
}
// TODO(minjie)
iterator begin() const;
/*! \return an iterator pointing at the first element */
iterator begin() const {
return iterator(this, 0);
}
// TODO(minjie)
iterator end() const;
/*! \return an iterator pointing at the last element */
iterator end() const {
return iterator(this, size());
}
// Modifiers
// NOTE: The modifiers are not allowed for view.
......@@ -112,10 +148,17 @@ class vector_view {
data_ = std::make_shared<std::vector<ValueType> >();
}
private:
struct vector_view_iterator {
// TODO
};
/*! \brief Resize the vector */
void resize(size_t size) {
CHECK(!is_view_);
data_->resize(size);
}
/*! \brief Resize the vector with init value */
void resize(size_t size, const ValueType& val) {
CHECK(!is_view_);
data_->resize(size, val);
}
private:
/*! \brief pointer to the underlying vector data */
......
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
// Graph class implementation
#include <algorithm>
#include <dgl/graph.h>
using namespace tvm;
using namespace tvm::runtime;
namespace dgl {
namespace {
inline bool IsValidIdArray(const IdArray& arr) {
return arr->ctx.device_type == kDLCPU && arr->ndim == 1
&& arr->dtype.code == kDLInt && arr->dtype.bits == 64;
}
} // namespace
void Graph::AddVertices(uint64_t num_vertices) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
adjlist_.resize(adjlist_.size() + num_vertices);
}
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst))
<< "In valid vertices: " << src << " " << dst;
dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid);
adjlist_[dst].pred.push_back(src);
}
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
AddEdge(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[i]);
}
}
}
BoolArray Graph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const uint64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
}
return rst;
}
// O(E)
bool Graph::HasEdge(dgl_id_t src, dgl_id_t dst) const {
if (!HasVertex(src) || !HasVertex(dst)) return false;
const auto& succ = adjlist_[src].succ;
return std::find(succ.begin(), succ.end(), dst) != succ.end();
}
// O(E*K) pretty slow
BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdge(src_data[0], dst_data[i])? 1 : 0;
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[0])? 1 : 0;
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[i])? 1 : 0;
}
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Predecessors(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const auto& pred = adjlist_[vid].pred;
const int64_t len = pred.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = pred[i];
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Successors(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const auto& succ = adjlist_[vid].succ;
const int64_t len = succ.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = succ[i];
}
return rst;
}
// O(E)
dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src)) << "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ;
for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) {
return adjlist_[src].edge_id[i];
}
}
LOG(FATAL) << "invalid edge: " << src << " -> " << dst;
return 0;
}
// O(E*k) pretty slow
IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
IdArray rst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = EdgeId(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[i]);
}
}
return rst;
}
// O(E)
std::pair<IdArray, IdArray> Graph::InEdges(dgl_id_t vid) const {
const auto& src = Predecessors(vid);
const auto srclen = src->shape[0];
IdArray dst = IdArray::Empty({srclen}, src->dtype, src->ctx);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
std::fill(dst_data, dst_data + srclen, vid);
return std::make_pair(src, dst);
}
// O(E)
std::pair<IdArray, IdArray> Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += adjlist_[vid_data[i]].pred.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) {
const auto& pred = adjlist_[vid_data[i]].pred;
for (size_t j = 0; j < pred.size(); ++j) {
*(src_ptr++) = pred[j];
*(dst_ptr++) = vid_data[i];
}
}
return std::make_pair(src, dst);
}
// O(E)
std::pair<IdArray, IdArray> Graph::OutEdges(dgl_id_t vid) const {
const auto& dst = Successors(vid);
const auto dstlen = dst->shape[0];
IdArray src = IdArray::Empty({dstlen}, dst->dtype, dst->ctx);
int64_t* src_data = static_cast<int64_t*>(src->data);
std::fill(src_data, src_data + dstlen, vid);
return std::make_pair(src, dst);
}
// O(E)
std::pair<IdArray, IdArray> Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += adjlist_[vid_data[i]].succ.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) {
const auto& succ = adjlist_[vid_data[i]].succ;
for (size_t j = 0; j < succ.size(); ++j) {
*(src_ptr++) = vid_data[i];
*(dst_ptr++) = succ[j];
}
}
return std::make_pair(src, dst);
}
// O(E*log(E)) due to sorting
std::pair<IdArray, IdArray> Graph::Edges() const {
const int64_t len = num_edges_;
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples;
tuples.reserve(len);
for (dgl_id_t u = 0; u < NumVertices(); ++u) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) {
tuples.push_back(std::make_tuple(u, adjlist_[u].succ[i], adjlist_[u].edge_id[i]));
}
}
// sort according to edge ids
std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) {
return std::get<2>(t1) < std::get<2>(t2);
});
// make return arrays
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) {
src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]);
}
return std::make_pair(src, dst);
}
// O(V)
DegreeArray Graph::InDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = adjlist_[vid].pred.size();
}
return rst;
}
// O(V)
DegreeArray Graph::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = adjlist_[vid].succ.size();
}
return rst;
}
Graph Graph::Subgraph(IdArray vids) const {
LOG(FATAL) << "not implemented";
return *this;
}
void MyAdd(TVMArgs args, TVMRetValue* rv) {
int a = args[0];
int b = args[1];
*rv = a + b;
Graph Graph::EdgeSubgraph(IdArray src, IdArray dst) const {
LOG(FATAL) << "not implemented";
return *this;
}
void CallPacked() {
PackedFunc myadd = PackedFunc(MyAdd);
int c = myadd(1, 2);
Graph Graph::Reverse() const {
LOG(FATAL) << "not implemented";
return *this;
}
TVM_REGISTER_GLOBAL("myadd")
.set_body(MyAdd);
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment