Commit 14d88497 authored by Minjie Wang's avatar Minjie Wang
Browse files

impl

parent 2c626b90
...@@ -13,23 +13,50 @@ typedef tvm::runtime::NDArray IdArray; ...@@ -13,23 +13,50 @@ typedef tvm::runtime::NDArray IdArray;
typedef tvm::runtime::NDArray DegreeArray; typedef tvm::runtime::NDArray DegreeArray;
typedef tvm::runtime::NDArray BoolArray; typedef tvm::runtime::NDArray BoolArray;
class DGLGraph; class Graph;
class DGLSubGraph;
/*! /*!
* \brief Base dgl graph class. * \brief Base dgl graph class.
* *
* DGLGraph is a directed graph. Vertices are integers enumerated from zero. Edges * DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not * are uniquely identified by the two endpoints. Multi-edge is currently not
* supported. * supported.
* *
* Removal of vertices/edges is not allowed. Instead, the graph can only be "cleared" * Removal of vertices/edges is not allowed. Instead, the graph can only be "cleared"
* by removing all the vertices and edges. * by removing all the vertices and edges.
*
* When calling functions supporing multiple edges (e.g. AddEdges, HasEdges),
* the input edges are represented by two id arrays for source and destination
* vertex ids. In the general case, the two arrays should have the same length.
* If the length of src id array is one, it represents one-many connections.
* If the length of dst id array is one, it represents many-one connections.
*/ */
class DGLGraph { class Graph {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
DGLGraph() {} Graph() {}
/*! \brief default copy constructor */
Graph(const Graph& other) = default;
#ifndef _MSC_VER
/*! \brief default move constructor */
Graph(Graph&& other) = default;
#else
Graph(Graph&& other) {
adjlist_ = other.adjlist_;
read_only_ = other.read_only_;
num_edges_ = other.num_edges_;
other.clear();
}
#endif // _MSC_VER
/*! \brief default assign constructor */
Graph& operator=(const Graph& other) = default;
/*! \brief default destructor */
~Graph() = default;
/*! /*!
* \brief Add vertices to the graph. * \brief Add vertices to the graph.
* \note Since vertices are integers enumerated from zero, only the number of * \note Since vertices are integers enumerated from zero, only the number of
...@@ -37,46 +64,68 @@ class DGLGraph { ...@@ -37,46 +64,68 @@ class DGLGraph {
* \param num_vertices The number of vertices to be added. * \param num_vertices The number of vertices to be added.
*/ */
void AddVertices(uint64_t num_vertices); void AddVertices(uint64_t num_vertices);
/*! /*!
* \brief Add one edge to the graph. * \brief Add one edge to the graph.
* \param src The source vertex. * \param src The source vertex.
* \param dst The destination vertex. * \param dst The destination vertex.
*/ */
void AddEdge(dgl_id_t src, dgl_id_t dst); void AddEdge(dgl_id_t src, dgl_id_t dst);
/*! /*!
* \brief Add edges to the graph. * \brief Add edges to the graph.
* \param src_ids The source vertex id array. * \param src_ids The source vertex id array.
* \param dst_ids The destination vertex id array. * \param dst_ids The destination vertex id array.
*/ */
void AddEdges(IdArray src_ids, IdArray dst_ids); void AddEdges(IdArray src_ids, IdArray dst_ids);
/*! /*!
* \brief Clear the graph. Remove all vertices/edges. * \brief Clear the graph. Remove all vertices/edges.
*/ */
void Clear(); void Clear() {
adjlist_ = vector_view<EdgeList>();
read_only_ = false;
num_edges_ = 0;
}
/*! \return the number of vertices in the graph.*/ /*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const; uint64_t NumVertices() const {
return adjlist_.size();
}
/*! \return the number of edges in the graph.*/ /*! \return the number of edges in the graph.*/
uint64_t NumEdges() const; uint64_t NumEdges() const {
return num_edges_;
}
/*! \return true if the given vertex is in the graph.*/ /*! \return true if the given vertex is in the graph.*/
bool HasVertex(dgl_id_t vid) const; bool HasVertex(dgl_id_t vid) const {
return vid < NumVertices();
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/ /*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
BoolArray HasVertices(IdArray vids) const; BoolArray HasVertices(IdArray vids) const;
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
bool HasEdge(dgl_id_t src, dgl_id_t dst) const; bool HasEdge(dgl_id_t src, dgl_id_t dst) const;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/ /*! \return a 0-1 array indicating whether the given edges are in the graph.*/
BoolArray HasEdges(IdArray src_ids, IdArray dst_ids) const; BoolArray HasEdges(IdArray src_ids, IdArray dst_ids) const;
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \return the predecessor id array. * \return the predecessor id array.
*/ */
IdArray Predecessors(dgl_id_t vid) const; IdArray Predecessors(dgl_id_t vid) const;
/*! /*!
* \brief Find the successors of a vertex. * \brief Find the successors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \return the successor id array. * \return the successor id array.
*/ */
IdArray Successors(dgl_id_t vid) const; IdArray Successors(dgl_id_t vid) const;
/*! /*!
* \brief Get the id of the given edge. * \brief Get the id of the given edge.
* \note Edges are associated with an integer id start from zero. * \note Edges are associated with an integer id start from zero.
...@@ -86,6 +135,7 @@ class DGLGraph { ...@@ -86,6 +135,7 @@ class DGLGraph {
* \return the edge id. * \return the edge id.
*/ */
dgl_id_t EdgeId(dgl_id_t src, dgl_id_t dst) const; dgl_id_t EdgeId(dgl_id_t src, dgl_id_t dst) const;
/*! /*!
* \brief Get the id of the given edges. * \brief Get the id of the given edges.
* \note Edges are associated with an integer id start from zero. * \note Edges are associated with an integer id start from zero.
...@@ -93,59 +143,77 @@ class DGLGraph { ...@@ -93,59 +143,77 @@ class DGLGraph {
* \return the edge id array. * \return the edge id array.
*/ */
IdArray EdgeIds(IdArray src, IdArray dst) const; IdArray EdgeIds(IdArray src, IdArray dst) const;
/*! /*!
* \brief Get the in edges of the vertex. * \brief Get the in edges of the vertex.
* \note The returned dst id array is filled with vid.
* \param vid The vertex id. * \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
std::pair<IdArray, IdArray> InEdges(dgl_id_t vid) const; std::pair<IdArray, IdArray> InEdges(dgl_id_t vid) const;
/*! /*!
* \brief Get the in edges of the vertices. * \brief Get the in edges of the vertices.
* \param vids The vertex id array. * \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
std::pair<IdArray, IdArray> InEdges(IdArray vids) const; std::pair<IdArray, IdArray> InEdges(IdArray vids) const;
/*! /*!
* \brief Get the out edges of the vertex. * \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
* \param vid The vertex id. * \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
std::pair<IdArray, IdArray> OutEdges(dgl_id_t vid) const; std::pair<IdArray, IdArray> OutEdges(dgl_id_t vid) const;
/*! /*!
* \brief Get the out edges of the vertices. * \brief Get the out edges of the vertices.
* \param vids The vertex id array. * \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
std::pair<IdArray, IdArray> OutEdges(IdArray vids) const; std::pair<IdArray, IdArray> OutEdges(IdArray vids) const;
/*! /*!
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
std::pair<IdArray, IdArray> Edges() const; std::pair<IdArray, IdArray> Edges() const;
/*! /*!
* \brief Get the in degree of the given vertex. * \brief Get the in degree of the given vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \return the in degree * \return the in degree
*/ */
uint64_t InDegree(dgl_id_t vid) const; uint64_t InDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return adjlist_[vid].pred.size();
}
/*! /*!
* \brief Get the in degrees of the given vertices. * \brief Get the in degrees of the given vertices.
* \param vid The vertex id array. * \param vid The vertex id array.
* \return the in degree array * \return the in degree array
*/ */
DegreeArray InDegrees(IdArray vids) const; DegreeArray InDegrees(IdArray vids) const;
/*! /*!
* \brief Get the out degree of the given vertex. * \brief Get the out degree of the given vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \return the out degree * \return the out degree
*/ */
uint64_t OutDegree(dgl_id_t vid) const; uint64_t OutDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return adjlist_[vid].succ.size();
}
/*! /*!
* \brief Get the out degrees of the given vertices. * \brief Get the out degrees of the given vertices.
* \param vid The vertex id array. * \param vid The vertex id array.
* \return the out degree array * \return the out degree array
*/ */
DegreeArray OutDegrees(IdArray vids) const; DegreeArray OutDegrees(IdArray vids) const;
/*! /*!
* \brief Construct the induced subgraph of the given vertices. * \brief Construct the induced subgraph of the given vertices.
* *
...@@ -162,7 +230,8 @@ class DGLGraph { ...@@ -162,7 +230,8 @@ class DGLGraph {
* \param vids The vertices in the subgraph. * \param vids The vertices in the subgraph.
* \return the induced subgraph * \return the induced subgraph
*/ */
DGLGraph Subgraph(IdArray vids) const; Graph Subgraph(IdArray vids) const;
/*! /*!
* \brief Construct the induced edge subgraph of the given edges. * \brief Construct the induced edge subgraph of the given edges.
* *
...@@ -179,7 +248,8 @@ class DGLGraph { ...@@ -179,7 +248,8 @@ class DGLGraph {
* \param vids The edges in the subgraph. * \param vids The edges in the subgraph.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
DGLGraph EdgeSubgraph(IdArray src, IdArray dst) const; Graph EdgeSubgraph(IdArray src, IdArray dst) const;
/*! /*!
* \brief Return a new graph with all the edges reversed. * \brief Return a new graph with all the edges reversed.
* *
...@@ -187,7 +257,7 @@ class DGLGraph { ...@@ -187,7 +257,7 @@ class DGLGraph {
* *
* \return the reversed graph * \return the reversed graph
*/ */
DGLGraph Reverse() const; Graph Reverse() const;
private: private:
/*! \brief Internal edge list type */ /*! \brief Internal edge list type */
...@@ -196,7 +266,7 @@ class DGLGraph { ...@@ -196,7 +266,7 @@ class DGLGraph {
vector_view<dgl_id_t> succ; vector_view<dgl_id_t> succ;
/*! \brief predecessor vertex list */ /*! \brief predecessor vertex list */
vector_view<dgl_id_t> pred; vector_view<dgl_id_t> pred;
/*! \brief (local) edge id property */ /*! \brief (local) succ edge id property */
std::vector<dgl_id_t> edge_id; std::vector<dgl_id_t> edge_id;
}; };
/*! \brief Adjacency list using vector storage */ /*! \brief Adjacency list using vector storage */
...@@ -210,6 +280,8 @@ class DGLGraph { ...@@ -210,6 +280,8 @@ class DGLGraph {
vector_view<EdgeList> adjlist_; vector_view<EdgeList> adjlist_;
/*! \brief read only flag */ /*! \brief read only flag */
bool read_only_{false}; bool read_only_{false};
/*! \brief number of edges */
uint64_t num_edges_ = 0;
}; };
} // namespace dgl } // namespace dgl
......
...@@ -15,10 +15,42 @@ namespace dgl { ...@@ -15,10 +15,42 @@ namespace dgl {
*/ */
template<typename ValueType> template<typename ValueType>
class vector_view { class vector_view {
struct vector_view_iterator;
public: public:
typedef vector_view_iterator iterator; /*! \brief iterator class */
class iterator : public std::iterator<std::forward_iterator_tag, ValueType> {
public:
/*! \brief iterator constructor */
iterator(const vector_view<ValueType>* vec, size_t pos): vec_(vec), pos_(pos) {}
/*! \brief move to next */
iterator& operator++() {
++pos_;
return *this;
}
/*! \brief move to next */
iterator operator++(int) {
iterator retval = *this;
++(*this);
return retval;
}
/*! \brief equal operator */
bool operator==(iterator other) const {
return vec_ == other.vec_ and pos_ == other.pos_;
}
/*! \brief not equal operator */
bool operator!=(iterator other) const {
return !(*this == other);
}
/*! \brief dereference operator */
const ValueType& operator*() const {
return (*vec_)[pos_];
}
private:
/*! \brief vector_view pointer */
const vector_view<ValueType>* vec_;
/*! \brief current position */
size_t pos_;
};
/*! \brief Default constructor. Create an empty vector. */ /*! \brief Default constructor. Create an empty vector. */
vector_view() vector_view()
: data_(std::make_shared<std::vector<ValueType> >()) {} : data_(std::make_shared<std::vector<ValueType> >()) {}
...@@ -28,7 +60,7 @@ class vector_view { ...@@ -28,7 +60,7 @@ class vector_view {
: data_(vec.data_), index_(index), is_view_(true) {} : data_(vec.data_), index_(index), is_view_(true) {}
/*! \brief constructor from a vector pointer */ /*! \brief constructor from a vector pointer */
vector_view(const std::shared_ptr<std::vector<ValueType> >& other) explicit vector_view(const std::shared_ptr<std::vector<ValueType> >& other)
: data_(other) {} : data_(other) {}
/*! \brief default copy constructor */ /*! \brief default copy constructor */
...@@ -53,7 +85,7 @@ class vector_view { ...@@ -53,7 +85,7 @@ class vector_view {
~vector_view() = default; ~vector_view() = default;
/*! \brief default assign constructor */ /*! \brief default assign constructor */
vector_view& operator=(const vector_view<ValueType>& other) = default; vector_view<ValueType>& operator=(const vector_view<ValueType>& other) = default;
/*! \return the size of the vector */ /*! \return the size of the vector */
size_t size() const { size_t size() const {
...@@ -87,11 +119,15 @@ class vector_view { ...@@ -87,11 +119,15 @@ class vector_view {
} }
} }
// TODO(minjie) /*! \return an iterator pointing at the first element */
iterator begin() const; iterator begin() const {
return iterator(this, 0);
}
// TODO(minjie) /*! \return an iterator pointing at the last element */
iterator end() const; iterator end() const {
return iterator(this, size());
}
// Modifiers // Modifiers
// NOTE: The modifiers are not allowed for view. // NOTE: The modifiers are not allowed for view.
...@@ -112,10 +148,17 @@ class vector_view { ...@@ -112,10 +148,17 @@ class vector_view {
data_ = std::make_shared<std::vector<ValueType> >(); data_ = std::make_shared<std::vector<ValueType> >();
} }
private: /*! \brief Resize the vector */
struct vector_view_iterator { void resize(size_t size) {
// TODO CHECK(!is_view_);
}; data_->resize(size);
}
/*! \brief Resize the vector with init value */
void resize(size_t size, const ValueType& val) {
CHECK(!is_view_);
data_->resize(size, val);
}
private: private:
/*! \brief pointer to the underlying vector data */ /*! \brief pointer to the underlying vector data */
......
#include <dgl/runtime/packed_func.h> // Graph class implementation
#include <dgl/runtime/registry.h> #include <algorithm>
#include <dgl/graph.h>
using namespace tvm; namespace dgl {
using namespace tvm::runtime; namespace {
inline bool IsValidIdArray(const IdArray& arr) {
return arr->ctx.device_type == kDLCPU && arr->ndim == 1
&& arr->dtype.code == kDLInt && arr->dtype.bits == 64;
}
} // namespace
void Graph::AddVertices(uint64_t num_vertices) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
adjlist_.resize(adjlist_.size() + num_vertices);
}
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst))
<< "In valid vertices: " << src << " " << dst;
dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid);
adjlist_[dst].pred.push_back(src);
}
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
AddEdge(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[i]);
}
}
}
BoolArray Graph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const uint64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
}
return rst;
}
// O(E)
bool Graph::HasEdge(dgl_id_t src, dgl_id_t dst) const {
if (!HasVertex(src) || !HasVertex(dst)) return false;
const auto& succ = adjlist_[src].succ;
return std::find(succ.begin(), succ.end(), dst) != succ.end();
}
// O(E*K) pretty slow
BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdge(src_data[0], dst_data[i])? 1 : 0;
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[0])? 1 : 0;
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[i])? 1 : 0;
}
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Predecessors(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const auto& pred = adjlist_[vid].pred;
const int64_t len = pred.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = pred[i];
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Successors(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const auto& succ = adjlist_[vid].succ;
const int64_t len = succ.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = succ[i];
}
return rst;
}
// O(E)
dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src)) << "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ;
for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) {
return adjlist_[src].edge_id[i];
}
}
LOG(FATAL) << "invalid edge: " << src << " -> " << dst;
return 0;
}
// O(E*k) pretty slow
IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = src_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
IdArray rst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = EdgeId(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[i]);
}
}
return rst;
}
// O(E)
std::pair<IdArray, IdArray> Graph::InEdges(dgl_id_t vid) const {
const auto& src = Predecessors(vid);
const auto srclen = src->shape[0];
IdArray dst = IdArray::Empty({srclen}, src->dtype, src->ctx);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
std::fill(dst_data, dst_data + srclen, vid);
return std::make_pair(src, dst);
}
// O(E)
std::pair<IdArray, IdArray> Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += adjlist_[vid_data[i]].pred.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) {
const auto& pred = adjlist_[vid_data[i]].pred;
for (size_t j = 0; j < pred.size(); ++j) {
*(src_ptr++) = pred[j];
*(dst_ptr++) = vid_data[i];
}
}
return std::make_pair(src, dst);
}
// O(E)
std::pair<IdArray, IdArray> Graph::OutEdges(dgl_id_t vid) const {
const auto& dst = Successors(vid);
const auto dstlen = dst->shape[0];
IdArray src = IdArray::Empty({dstlen}, dst->dtype, dst->ctx);
int64_t* src_data = static_cast<int64_t*>(src->data);
std::fill(src_data, src_data + dstlen, vid);
return std::make_pair(src, dst);
}
// O(E)
std::pair<IdArray, IdArray> Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += adjlist_[vid_data[i]].succ.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) {
const auto& succ = adjlist_[vid_data[i]].succ;
for (size_t j = 0; j < succ.size(); ++j) {
*(src_ptr++) = vid_data[i];
*(dst_ptr++) = succ[j];
}
}
return std::make_pair(src, dst);
}
// O(E*log(E)) due to sorting
std::pair<IdArray, IdArray> Graph::Edges() const {
const int64_t len = num_edges_;
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples;
tuples.reserve(len);
for (dgl_id_t u = 0; u < NumVertices(); ++u) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) {
tuples.push_back(std::make_tuple(u, adjlist_[u].succ[i], adjlist_[u].edge_id[i]));
}
}
// sort according to edge ids
std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) {
return std::get<2>(t1) < std::get<2>(t2);
});
// make return arrays
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) {
src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]);
}
return std::make_pair(src, dst);
}
// O(V)
DegreeArray Graph::InDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = adjlist_[vid].pred.size();
}
return rst;
}
// O(V)
DegreeArray Graph::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = adjlist_[vid].succ.size();
}
return rst;
}
Graph Graph::Subgraph(IdArray vids) const {
LOG(FATAL) << "not implemented";
return *this;
}
void MyAdd(TVMArgs args, TVMRetValue* rv) { Graph Graph::EdgeSubgraph(IdArray src, IdArray dst) const {
int a = args[0]; LOG(FATAL) << "not implemented";
int b = args[1]; return *this;
*rv = a + b;
} }
void CallPacked() { Graph Graph::Reverse() const {
PackedFunc myadd = PackedFunc(MyAdd); LOG(FATAL) << "not implemented";
int c = myadd(1, 2); return *this;
} }
TVM_REGISTER_GLOBAL("myadd") } // namespace dgl
.set_body(MyAdd);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment