Commit deb8a5a5 authored by Minjie Wang's avatar Minjie Wang
Browse files

WIP: graph interface

parent 9445e1ae
// DGL Graph interface
#ifndef DGL_DGLGRAPH_H_
#define DGL_DGLGRAPH_H_
#include <vector>
#include <stdint.h>
#include "runtime/ndarray.h"
namespace dgl {
typedef uint64_t dgl_id_t;
typedef tvm::runtime::NDArray DGLIdArray;
typedef tvm::runtime::NDArray DegreeArray;
/*!
* \brief Base dgl graph class.
* DGLGraph is a directed graph.
*/
class DGLGraph {
public:
void AddVertices(uint64_t num_vertices);
void AddEdge(dgl_id_t src, dgl_id_t dst);
void AddEdges(DGLIdArray src_ids, DGLIdArray dst_ids);
void Clear();
uint64_t NumVertices() const;
uint64_t NumEdges() const;
bool HasVertex(dgl_id_t vid) const;
tvm::runtime::NDArray HasVertices(DGLIdArray vids) const;
bool HasEdge(dgl_id_t src, dgl_id_t dst) const;
tvm::runtime::NDArray HasEdges(DGLIdArray src_ids, DGLIdArray dst_ids) const;
DGLIdArray Predecessors(dgl_id_t vid) const;
DGLIdArray Successors(dgl_id_t vid) const;
dgl_id_t GetEdgeId(dgl_id_t src, dgl_id_t dst) const;
DGLIdArray GetEdgeIds(DGLIdArray src, DGLIdArray dst) const;
std::pair<DGLIdArray, DGLIdArray> GetInEdge(dgl_id_t vid) const;
std::pair<DGLIdArray, DGLIdArray> GetInEdges(DGLIdArray vids) const;
std::pair<DGLIdArray, DGLIdArray> GetOutEdge(dgl_id_t vid) const;
std::pair<DGLIdArray, DGLIdArray> GetOutEdges(DGLIdArray vids) const;
uint64_t InDegree(dgl_id_t vid) const;
DegreeArray InDegrees(DGLIdArray vids) const;
uint64_t OutDegree(dgl_id_t vid) const;
DegreeArray OutDegrees(DGLIdArray vids) const;
DGLGraph Subgraph(DGLIdArray vids) const;
DGLGraph EdgeSubgraph(DGLIdArray src, DGLIdArray dst) const;
DGLGraph Reverse() const;
};
} // namespace dgl
#endif // DGL_DGLGRAPH_H_
...@@ -18,13 +18,6 @@ ...@@ -18,13 +18,6 @@
#include "module.h" #include "module.h"
#include "ndarray.h" #include "ndarray.h"
namespace HalideIR {
// Forward declare type for extensions
// The header works fine without depending on this.
struct Type;
struct Expr;
}
// Whether use TVM runtime in header only mode. // Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY #ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0 #define TVM_RUNTIME_HEADER_ONLY 0
...@@ -538,8 +531,6 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -538,8 +531,6 @@ class TVMArgValue : public TVMPODValue_ {
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type> std::is_class<TNodeRef>::value>::type>
inline bool IsNodeType() const; inline bool IsNodeType() const;
inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const;
// get internal node ptr, if it is node // get internal node ptr, if it is node
inline std::shared_ptr<Node>& node_sptr(); inline std::shared_ptr<Node>& node_sptr();
}; };
...@@ -733,9 +724,6 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -733,9 +724,6 @@ class TVMRetValue : public TVMPODValue_ {
inline TNodeRef AsNodeRef() const; inline TNodeRef AsNodeRef() const;
inline TVMRetValue& operator=(const NodeRef& other); inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other); inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
// type related
inline operator HalideIR::Type() const;
inline TVMRetValue& operator=(const HalideIR::Type& other);
private: private:
template<typename T> template<typename T>
...@@ -1045,7 +1033,6 @@ class TVMArgsSetter { ...@@ -1045,7 +1033,6 @@ class TVMArgsSetter {
inline void operator()(size_t i, const T& value) const; inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h // NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
inline void operator()(size_t i, const HalideIR::Type& t) const;
private: private:
/*! \brief The values fields */ /*! \brief The values fields */
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file c_runtime_api.cc * \file c_runtime_api.cc
* \brief Device specific implementations * \brief Runtime API implementation
*/ */
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment